Compare commits

..

No commits in common. "2fe953cdad89352168387d5df764dcb32a7c7629" and "98eb7bea3d697e7324c63cb33db255ab45acc1e8" have entirely different histories.

21 changed files with 136 additions and 5152 deletions

View File

@ -26,8 +26,6 @@ anyhow = "1.0"
mime_guess = "2.0" mime_guess = "2.0"
base64 = "0.22" base64 = "0.22"
tempfile = "3" tempfile = "3"
cron = "0.15"
chrono-tz = "0.10"
meval = "0.2" meval = "0.2"
ratatui = "0.27" ratatui = "0.27"
crossterm = { version = "0.28", features = ["event-stream"] } crossterm = { version = "0.28", features = ["event-stream"] }

File diff suppressed because it is too large Load Diff

View File

@ -226,7 +226,6 @@ pub struct AgentLoop {
max_iterations: usize, max_iterations: usize,
workspace_dir: PathBuf, workspace_dir: PathBuf,
model_name: String, model_name: String,
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -248,7 +247,6 @@ impl AgentLoop {
provider: Arc::from(provider), provider: Arc::from(provider),
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -267,7 +265,6 @@ impl AgentLoop {
provider: Arc::from(provider), provider: Arc::from(provider),
tools, tools,
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -280,7 +277,6 @@ impl AgentLoop {
provider, provider,
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -299,7 +295,6 @@ impl AgentLoop {
provider, provider,
tools, tools,
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -318,11 +313,6 @@ impl AgentLoop {
self self
} }
pub fn with_notify(mut self, tx: tokio::sync::mpsc::UnboundedSender<String>) -> Self {
self.notify_tx = Some(tx);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> { pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools &self.tools
} }
@ -341,7 +331,7 @@ impl AgentLoop {
// Build and inject system prompt if not present // Build and inject system prompt if not present
let has_system = messages.first().map_or(false, |m| m.role == "system"); let has_system = messages.first().map_or(false, |m| m.role == "system");
if !has_system { if !has_system {
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None); let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt); tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt)); messages.insert(0, ChatMessage::system(system_prompt));
@ -400,16 +390,12 @@ impl AgentLoop {
}); });
} }
// Execute tool calls — log and notify immediately // Execute tool calls — log tool names and args before execution
{ {
let tools_info: Vec<String> = response.tool_calls.iter() let tools_info: Vec<String> = response.tool_calls.iter()
.map(|tc| { .map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
let s = format!("{}:{}", tc.name, args); format!("{}:{}", tc.name, args)
if let Some(ref tx) = self.notify_tx {
let _ = tx.send(format!("调用工具 {}", s));
}
s
}) })
.collect(); .collect();
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools"); tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");

View File

@ -18,7 +18,6 @@ pub struct PromptContext<'a> {
pub workspace_dir: &'a Path, pub workspace_dir: &'a Path,
pub model_name: &'a str, pub model_name: &'a str,
pub tools: &'a ToolRegistry, pub tools: &'a ToolRegistry,
pub session_id: Option<&'a str>,
} }
/// Trait for system prompt sections. /// Trait for system prompt sections.
@ -39,6 +38,7 @@ impl SystemPromptBuilder {
Self { Self {
sections: vec![ sections: vec![
Box::new(ToolHonestySection), Box::new(ToolHonestySection),
Box::new(NoToolNarrationSection),
Box::new(YourTaskSection), Box::new(YourTaskSection),
Box::new(SafetySection), Box::new(SafetySection),
Box::new(WorkspaceSection), Box::new(WorkspaceSection),
@ -82,11 +82,30 @@ impl PromptSection for ToolHonestySection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 关键规则:工具诚实性 "## CRITICAL: Tool Honesty
- \"没有找到结果\" - NEVER fabricate, invent, or guess tool results. If a tool returns empty results, say \"No results found.\"
- - If a tool call fails, report the error - never make up data to fill the gap.
- " - When unsure whether a tool call succeeded, ask the user rather than guessing."
.to_string()
}
}
/// Critical rule: never narrate tool usage.
pub struct NoToolNarrationSection;
impl PromptSection for NoToolNarrationSection {
fn name(&self) -> &str {
"no_narration"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## CRITICAL: No Tool Narration
NEVER narrate, announce, describe, or explain your tool usage to the user.
Do NOT say things like \"Let me check...\", \"I will use bash to...\", \"I'll fetch that for you\", \"Searching now...\", or similar.
The user must ONLY see the final answer. Tool calls are invisible infrastructure - never reference them.
If you catch yourself starting a sentence about what tool you are about to use or just used, DELETE it and give the answer directly."
.to_string() .to_string()
} }
} }
@ -104,7 +123,7 @@ impl PromptSection for ToolsSection {
return String::new(); return String::new();
} }
let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n"); let mut output = String::from("## Tools\n\nYou have access to the following tools:\n\n");
for (name, tool) in ctx.tools.iter() { for (name, tool) in ctx.tools.iter() {
let _ = writeln!(output, "- **{}**: {}", name, tool.description()); let _ = writeln!(output, "- **{}**: {}", name, tool.description());
} }
@ -121,11 +140,11 @@ impl PromptSection for YourTaskSection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 你的任务 "## Your Task
使 When the user sends a message, ACT on it. Use the tools to fulfill their request.
Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions.
使" Instead: use tools directly when needed, and give the final answer when done."
.to_string() .to_string()
} }
} }
@ -139,13 +158,13 @@ impl PromptSection for SafetySection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 安全规则 "## Safety
- - Do not exfiltrate private data.
- - Do not run destructive commands without asking.
- - Do not bypass oversight or approval mechanisms.
- - Prefer safe operations over risky ones.
- " - When in doubt, ask before acting externally."
.to_string() .to_string()
} }
} }
@ -165,7 +184,7 @@ impl PromptSection for WorkspaceSection {
.canonicalize() .canonicalize()
.unwrap_or_else(|_| ctx.workspace_dir.to_path_buf()); .unwrap_or_else(|_| ctx.workspace_dir.to_path_buf());
format!( format!(
"## 工作目录\n\n工作目录:`{}`\n\n### 文件存储规范\n\n- **生成的文件**:将所有生成的文件(代码、文档、制品)存放在工作目录或其子目录中。\n- **下载的文件**:将下载的文件保存到工作目录,按任务整理。\n- **一个任务一个文件夹**:为每个任务或项目创建专用的子文件夹(如 `task_2024_01_01/`)。\n- **临时文件**:如果文件仅在处理期间需要且不保留,使用 `/tmp/` 或创建临时文件夹(如 `/tmp/picobot_task_xxx/`),以免弄乱工作目录。\n\n### 目录结构\n\n工作目录是你在本会话中的操作大本营。通过为不同任务创建子目录来保持整洁。", "## Workspace\n\nWorking directory: `{}`\n\n### File Storage Guidelines\n\n- **Generated files**: Store all generated files (code, documents, artifacts) in the workspace directory or its subdirectories.\n- **Downloaded files**: Save downloaded files to the workspace directory, organized by task.\n- **One task, one folder**: Create a dedicated subfolder for each task or project (e.g., `task_2024_01_01/`).\n- **Temporary files**: If files are only needed during processing and won't be kept, use `/tmp/` or create a temp folder (e.g., `/tmp/picobot_task_xxx/`) instead of cluttering the workspace.\n\n### Working Directory Structure\n\nThe workspace is your home base for this session. Keep it organized by creating subdirectories for different tasks.",
abs_path.display() abs_path.display()
) )
} }
@ -180,7 +199,7 @@ impl PromptSection for UserProfileSection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
let mut output = String::from("## 用户配置\n\n"); let mut output = String::from("## User Profile\n\n");
// Load USER.md from ~/.picobot/USER.md // Load USER.md from ~/.picobot/USER.md
if let Some(user_config_dir) = get_user_config_dir() { if let Some(user_config_dir) = get_user_config_dir() {
@ -208,7 +227,7 @@ impl PromptSection for DateTimeSection {
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
let now = chrono::Local::now(); let now = chrono::Local::now();
format!( format!(
"## 当前日期与时间\n\n{} ({})", "## Current Date & Time\n\n{} ({})",
now.format("%Y-%m-%d %H:%M:%S"), now.format("%Y-%m-%d %H:%M:%S"),
now.format("%Z") now.format("%Z")
) )
@ -223,43 +242,37 @@ impl PromptSection for CrossChannelSection {
"cross_channel" "cross_channel"
} }
fn build(&self, ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
let session_line = if let Some(id) = ctx.session_id { r#"## 关于跨渠道消息和系统通知
format!("当前会话的 ID 是 `{}`。\n", id)
} else {
String::new()
};
format!( `source`
r#"## 关于会话和跨渠道消息
### ID ### source.kind = "system_notification"
session ID<channel>:<chat_id>:<dialog_id>
- channel: "cli_chat""feishu" - `system_name`:
- chat_id: / - `task_id`: ID
- dialog_id: chat dialog
{}### ### source.kind = "cross_channel"
`[message from X to Y]` assistant
send_message - `from_channel`: "feishu"
- X: ID "unknown" - `from_user_id`: ID
- Y: session ID (<channel>:<chat_id>:<dialog_id>)
### send_message ### send_message
- target_chat_id: <channel>:<chat_id> <channel>:<chat_id>:<dialog_id>
- content:
### chat_manager 使 `send_message`
- `target_chat_id`: ID
- action = "list_sessions" 1. `<channel>:<chat_id>`
- action = "list_channels" 2. `<channel>:<chat_id>:<dialog_id>`
- action = "list_messages" session session_id count"#, - `content`:
session_line - `origin`: 使 session_id
)
`[message from X to Y]`
LLM /
###
-
- "#
.to_string()
} }
} }
@ -276,7 +289,7 @@ impl PromptSection for RuntimeSection {
.map(|h| h.to_string_lossy().to_string()) .map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()); .unwrap_or_else(|_| "unknown".to_string());
format!( format!(
"## 运行环境\n\n主机: {} | 操作系统: {} | 模型: {}", "## Runtime\n\nHost: {} | OS: {} | Model: {}",
host, host,
std::env::consts::OS, std::env::consts::OS,
ctx.model_name ctx.model_name
@ -308,7 +321,7 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
.unwrap_or(trimmed) .unwrap_or(trimmed)
.to_string() .to_string()
+ &format!( + &format!(
"\n\n[... 已截断至 {} 字符 - 使用 file_read 获取完整文件]", "\n\n[... truncated at {} characters - use file_read for full file]",
max_chars max_chars
) )
} else { } else {
@ -321,12 +334,11 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
} }
/// Build a complete system prompt with default configuration. /// Build a complete system prompt with default configuration.
pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry, session_id: Option<&str>) -> String { pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry) -> String {
let ctx = PromptContext { let ctx = PromptContext {
workspace_dir, workspace_dir,
model_name, model_name,
tools, tools,
session_id,
}; };
SystemPromptBuilder::with_defaults().build(&ctx) SystemPromptBuilder::with_defaults().build(&ctx)
} }
@ -345,16 +357,16 @@ mod tests {
workspace_dir: &temp_dir, workspace_dir: &temp_dir,
model_name: "test-model", model_name: "test-model",
tools: &tools, tools: &tools,
session_id: None,
}; };
let prompt = SystemPromptBuilder::with_defaults().build(&ctx); let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
assert!(prompt.contains("## 关键规则:工具诚实性")); assert!(prompt.contains("## CRITICAL: Tool Honesty"));
assert!(prompt.contains("## 安全规则")); assert!(prompt.contains("## CRITICAL: No Tool Narration"));
assert!(prompt.contains("## 工作目录")); assert!(prompt.contains("## Safety"));
assert!(prompt.contains("## 当前日期与时间")); assert!(prompt.contains("## Workspace"));
assert!(prompt.contains("## 运行环境")); assert!(prompt.contains("## Current Date & Time"));
assert!(prompt.contains("## Runtime"));
} }
#[test] #[test]
@ -375,7 +387,7 @@ mod tests {
let temp_dir = std::env::temp_dir(); let temp_dir = std::env::temp_dir();
let tools = ToolRegistry::new(); let tools = ToolRegistry::new();
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None); let prompt = build_system_prompt(&temp_dir, "test-model", &tools);
assert!(!prompt.is_empty()); assert!(!prompt.is_empty());
assert!(prompt.contains("test-model")); assert!(prompt.contains("test-model"));

View File

@ -195,20 +195,6 @@ impl ChatMessage {
source: None, source: None,
} }
} }
pub fn user_with_source(content: impl Into<String>, source: MessageSource) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: Some(source),
}
}
} }
// ============================================================================ // ============================================================================

View File

@ -479,16 +479,10 @@ impl Channel for CliChatChannel {
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let clients = self.clients.lock().await.clone(); let clients = self.clients.lock().await.clone();
for client in clients { for client in clients {
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") { let outbound = WsOutbound::AssistantResponse {
WsOutbound::SystemNotification {
content: msg.content.clone(),
}
} else {
WsOutbound::AssistantResponse {
id: short_id(), id: short_id(),
content: msg.content.clone(), content: msg.content.clone(),
role: "assistant".to_string(), role: "assistant".to_string(),
}
}; };
let _ = client.sender.send(outbound).await; let _ = client.sender.send(outbound).await;
} }

View File

@ -140,43 +140,6 @@ pub struct GatewayConfig {
pub cleanup_interval_minutes: Option<u64>, pub cleanup_interval_minutes: Option<u64>,
#[serde(default, rename = "session_db_path")] #[serde(default, rename = "session_db_path")]
pub session_db_path: Option<String>, pub session_db_path: Option<String>,
#[serde(default)]
pub scheduler: Option<SchedulerConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig {
/// Whether the scheduler is enabled
#[serde(default = "default_scheduler_enabled")]
pub enabled: bool,
/// Poll interval in seconds (how often to check for due jobs)
#[serde(default = "default_poll_interval_secs")]
pub poll_interval_secs: u64,
/// Maximum concurrent job executions (currently sequential, reserved for future)
#[serde(default = "default_max_concurrent")]
pub max_concurrent: usize,
}
fn default_scheduler_enabled() -> bool {
true
}
fn default_poll_interval_secs() -> u64 {
60
}
fn default_max_concurrent() -> usize {
1
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
enabled: true,
poll_interval_secs: 60,
max_concurrent: 1,
}
}
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
@ -205,7 +168,6 @@ impl Default for GatewayConfig {
session_ttl_hours: None, session_ttl_hours: None,
cleanup_interval_minutes: None, cleanup_interval_minutes: None,
session_db_path: None, session_db_path: None,
scheduler: None,
} }
} }
} }

View File

@ -11,14 +11,12 @@ use crate::channels::base::{Channel, ChannelError};
use crate::config::{Config, expand_path, ensure_workspace_dir}; use crate::config::{Config, expand_path, ensure_workspace_dir};
use crate::logging; use crate::logging;
use crate::session::SessionManager; use crate::session::SessionManager;
use crate::scheduler::Scheduler;
pub struct GatewayState { pub struct GatewayState {
pub config: Config, pub config: Config,
pub workspace_dir: std::path::PathBuf, pub workspace_dir: std::path::PathBuf,
pub session_manager: Arc<SessionManager>, pub session_manager: Arc<SessionManager>,
pub channel_manager: ChannelManager, pub channel_manager: ChannelManager,
pub storage: Arc<crate::storage::Storage>,
} }
impl GatewayState { impl GatewayState {
@ -75,45 +73,13 @@ impl GatewayState {
// Register send_message tool with available channel names // Register send_message tool with available channel names
let available_channels = channel_manager.list_channel_names().await; let available_channels = channel_manager.list_channel_names().await;
let valid_channels = available_channels.clone();
session_manager.register_outbound_tool(available_channels); session_manager.register_outbound_tool(available_channels);
// Register chat_manager tool
session_manager.tools().register(
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
);
// Initialize scheduler if enabled in config
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled {
// Register cron tools
session_manager.tools().register(
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
);
session_manager.tools().register(
crate::tools::cron::CronListTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronRemoveTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronEnableTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronDisableTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronUpdateTool::new(storage.clone()),
);
tracing::info!("Cron tools registered");
}
Ok(Self { Ok(Self {
config, config,
workspace_dir: workspace_path, workspace_dir: workspace_path,
session_manager: session_manager.clone(), session_manager: session_manager.clone(),
channel_manager, channel_manager,
storage,
}) })
} }
@ -202,21 +168,6 @@ impl GatewayState {
tracing::info!("Outbound dispatcher started"); tracing::info!("Outbound dispatcher started");
dispatcher.run().await; dispatcher.run().await;
}); });
// Spawn scheduler background task if enabled
let scheduler_config = self.config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled {
let sched = Arc::new(Scheduler::new(
self.storage.clone(),
self.session_manager.clone(),
self.channel_manager.bus(),
scheduler_config,
));
tokio::spawn(async move {
sched.run().await;
});
tracing::info!("Scheduler background task spawned");
}
} }
/// Handle control messages (session management operations) /// Handle control messages (session management operations)

View File

@ -9,7 +9,6 @@ pub mod protocol;
pub mod channels; pub mod channels;
pub mod logging; pub mod logging;
pub mod observability; pub mod observability;
pub mod scheduler;
pub mod skills; pub mod skills;
pub mod storage; pub mod storage;
pub mod tools; pub mod tools;

View File

@ -6,8 +6,13 @@ use std::collections::HashMap;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage; fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
}
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> { fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b { blocks.iter().map(|b| match b {
@ -57,7 +62,6 @@ pub struct AnthropicProvider {
temperature: Option<f32>, temperature: Option<f32>,
max_tokens: Option<u32>, max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>, model_extra: HashMap<String, serde_json::Value>,
storage: Option<Arc<Storage>>,
} }
impl AnthropicProvider { impl AnthropicProvider {
@ -81,13 +85,8 @@ impl AnthropicProvider {
temperature, temperature,
max_tokens, max_tokens,
model_extra, model_extra,
storage: None,
} }
} }
pub fn set_storage(&mut self, storage: Arc<Storage>) {
self.storage = Some(storage);
}
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -105,6 +104,7 @@ struct AnthropicRequest {
#[derive(Serialize)] #[derive(Serialize)]
struct AnthropicMessage { struct AnthropicMessage {
role: String, role: String,
#[serde(serialize_with = "serialize_content_blocks")]
content: Vec<serde_json::Value>, content: Vec<serde_json::Value>,
} }
@ -128,23 +128,14 @@ struct AnthropicResponse {
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent { enum AnthropicContent {
Text { Text { text: String },
#[serde(alias = "content")] Thinking { thinking: String },
text: String,
},
Thinking {
#[serde(alias = "content")]
thinking: String,
},
#[serde(rename = "tool_use")] #[serde(rename = "tool_use")]
ToolUse { ToolUse {
id: String, id: String,
name: String, name: String,
#[serde(alias = "arguments")]
input: serde_json::Value, input: serde_json::Value,
}, },
#[serde(other)]
Unknown,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -161,7 +152,6 @@ impl LLMProvider for AnthropicProvider {
&self, &self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let start = std::time::Instant::now();
let url = format!("{}/v1/messages", self.base_url); let url = format!("{}/v1/messages", self.base_url);
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024); let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
@ -200,19 +190,7 @@ impl LLMProvider for AnthropicProvider {
"content": output, "content": output,
})] })]
} else { } else {
let mut blocks = convert_content_blocks(&m.content); convert_content_blocks(&m.content)
// Append tool_use blocks from assistant messages with tool calls
if let Some(ref tool_calls) = m.tool_calls {
for tc in tool_calls {
blocks.push(serde_json::json!({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments,
}));
}
}
blocks
}; };
AnthropicMessage { role, content } AnthropicMessage { role, content }
}) })
@ -234,14 +212,10 @@ impl LLMProvider for AnthropicProvider {
req_builder = req_builder.header(key.as_str(), value.as_str()); req_builder = req_builder.header(key.as_str(), value.as_str());
} }
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await?; let resp = req_builder.json(&body).send().await?;
let status = resp.status(); let status = resp.status();
let body_text = resp.text().await?; let body_text = resp.text().await?;
tracing::debug!(status = %status, resp_body = %body_text, "LLM response");
if !status.is_success() { if !status.is_success() {
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text) let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
@ -253,33 +227,11 @@ impl LLMProvider for AnthropicProvider {
.map(|s| s.to_string()) .map(|s| s.to_string())
}) })
.unwrap_or_else(|| body_text.clone()); .unwrap_or_else(|| body_text.clone());
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&body_text), Some(&error_msg),
start.elapsed().as_millis() as u64,
).await;
}
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
} }
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
.map_err(|e| { .map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?;
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
});
}
err_msg
})?;
let mut content = String::new(); let mut content = String::new();
let mut tool_calls = Vec::new(); let mut tool_calls = Vec::new();
@ -295,7 +247,6 @@ impl LLMProvider for AnthropicProvider {
} }
} }
AnthropicContent::Thinking { .. } => {} AnthropicContent::Thinking { .. } => {}
AnthropicContent::Unknown => {}
AnthropicContent::ToolUse { id, name, input } => { AnthropicContent::ToolUse { id, name, input } => {
tool_calls.push(ToolCall { tool_calls.push(ToolCall {
id: id.clone(), id: id.clone(),
@ -306,7 +257,7 @@ impl LLMProvider for AnthropicProvider {
} }
} }
let response = ChatCompletionResponse { Ok(ChatCompletionResponse {
id: anthropic_resp.id.unwrap_or_default(), id: anthropic_resp.id.unwrap_or_default(),
model: anthropic_resp.model.unwrap_or_default(), model: anthropic_resp.model.unwrap_or_default(),
content, content,
@ -316,20 +267,7 @@ impl LLMProvider for AnthropicProvider {
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
}, },
}; })
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)
} }
fn ptype(&self) -> &str { fn ptype(&self) -> &str {

View File

@ -7,8 +7,6 @@ use std::collections::HashMap;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 { if blocks.len() == 1 {
@ -34,7 +32,6 @@ pub struct OpenAIProvider {
temperature: Option<f32>, temperature: Option<f32>,
max_tokens: Option<u32>, max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>, model_extra: HashMap<String, serde_json::Value>,
storage: Option<Arc<Storage>>,
} }
impl OpenAIProvider { impl OpenAIProvider {
@ -58,14 +55,9 @@ impl OpenAIProvider {
temperature, temperature,
max_tokens, max_tokens,
model_extra, model_extra,
storage: None,
} }
} }
pub fn set_storage(&mut self, storage: Arc<Storage>) {
self.storage = Some(storage);
}
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let mut body = json!({ let mut body = json!({
"model": self.model_id, "model": self.model_id,
@ -170,7 +162,6 @@ impl LLMProvider for OpenAIProvider {
&self, &self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let start = std::time::Instant::now();
let url = format!("{}/chat/completions", self.base_url); let url = format!("{}/chat/completions", self.base_url);
let body = self.build_request_body(&request); let body = self.build_request_body(&request);
@ -209,44 +200,24 @@ impl LLMProvider for OpenAIProvider {
req_builder = req_builder.header(key.as_str(), value.as_str()); req_builder = req_builder.header(key.as_str(), value.as_str());
} }
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await?; let resp = req_builder.json(&body).send().await?;
let status = resp.status(); let status = resp.status();
let text = resp.text().await?; let text = resp.text().await?;
tracing::debug!(status = %status, resp_body = %text, "LLM response");
// Debug: Log LLM response (only in debug builds)
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
}
if !status.is_success() { if !status.is_success() {
let error = format!("API error {}: {}", status, text); return Err(format!("API error {}: {}", status, text).into());
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), Some(&error),
start.elapsed().as_millis() as u64,
).await;
}
return Err(error.into());
} }
let openai_resp: OpenAIResponse = serde_json::from_str(&text) let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| { .map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await;
});
}
err_msg
})?;
let content = openai_resp.choices[0] let content = openai_resp.choices[0]
.message .message
@ -266,7 +237,7 @@ impl LLMProvider for OpenAIProvider {
}) })
.collect(); .collect();
let response = ChatCompletionResponse { Ok(ChatCompletionResponse {
id: openai_resp.id, id: openai_resp.id,
model: openai_resp.model, model: openai_resp.model,
content, content,
@ -276,17 +247,7 @@ impl LLMProvider for OpenAIProvider {
completion_tokens: openai_resp.usage.completion_tokens, completion_tokens: openai_resp.usage.completion_tokens,
total_tokens: openai_resp.usage.total_tokens, total_tokens: openai_resp.usage.total_tokens,
}, },
}; })
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)
} }
fn ptype(&self) -> &str { fn ptype(&self) -> &str {

View File

@ -123,6 +123,4 @@ pub trait LLMProvider: Send + Sync {
fn name(&self) -> &str; fn name(&self) -> &str;
fn model_id(&self) -> &str; fn model_id(&self) -> &str;
fn set_storage(&mut self, _storage: std::sync::Arc<crate::storage::Storage>) {}
} }

View File

@ -1,310 +0,0 @@
pub mod types;
use std::sync::Arc;
use std::time::Instant;
use tokio::time;
use crate::bus::MessageBus;
use crate::config::SchedulerConfig;
use crate::session::session::HandleResult;
use crate::session::SessionManager;
use crate::storage::{JobRun, ScheduledJob, Storage};
pub use types::Schedule;
/// Compute the next execution time (Unix ms) for a schedule, given `from` (Unix ms).
/// Returns `None` if no next time can be determined (e.g., invalid cron expression).
pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
use chrono::{TimeZone, Utc};
use std::str::FromStr;
match schedule {
Schedule::At { at } => Some(*at),
Schedule::Every { every_ms } => Some(from + *every_ms as i64),
Schedule::Cron { expr, tz } => {
let cron_schedule = cron::Schedule::from_str(expr.as_str()).ok()?;
let from_secs = from / 1000;
let from_nanos = ((from % 1000) * 1_000_000) as u32;
let from_dt = Utc.timestamp_opt(from_secs, from_nanos).single()?;
let next_utc = if let Some(tz_str) = tz {
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
let _from_local = from_dt.with_timezone(&tz);
let next_local = cron_schedule.upcoming(tz).next()?;
next_local.with_timezone(&Utc)
} else {
cron_schedule.upcoming(Utc).next()?
};
Some(next_utc.timestamp_millis())
}
}
}
fn now_ms() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
/// The scheduler runs as a background tokio task, periodically checking for due jobs
/// and executing them via `SessionManager::handle_cron_message`.
pub struct Scheduler {
storage: Arc<Storage>,
session_manager: Arc<SessionManager>,
bus: Arc<MessageBus>,
config: SchedulerConfig,
}
impl Scheduler {
pub fn new(
storage: Arc<Storage>,
session_manager: Arc<SessionManager>,
bus: Arc<MessageBus>,
config: SchedulerConfig,
) -> Self {
Self {
storage,
session_manager,
bus,
config,
}
}
/// Run the scheduler loop. This is a long-running async function meant to be
/// spawned as a tokio background task.
pub async fn run(self: Arc<Self>) {
let poll_duration = time::Duration::from_secs(self.config.poll_interval_secs);
let mut interval = time::interval(poll_duration);
interval.tick().await;
tracing::info!(
"Scheduler started (poll interval: {}s, max concurrent: {})",
self.config.poll_interval_secs,
self.config.max_concurrent,
);
loop {
interval.tick().await;
let now = now_ms();
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
Ok(jobs) => jobs,
Err(e) => {
tracing::error!("scheduler: failed to query due jobs: {}", e);
continue;
}
};
if due.is_empty() {
continue;
}
tracing::info!("scheduler: found {} due job(s)", due.len());
for job in &due {
let start = Instant::now();
let started_at = now_ms();
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
continue;
}
tracing::info!(
job_id = %job.id,
job_name = %job.name,
"scheduler: executing cron job"
);
let result = self
.session_manager
.handle_cron_message(
&job.channel,
&job.chat_id,
&job.prompt,
&job.id,
&job.name,
)
.await;
let finished_at = now_ms();
let duration_ms = start.elapsed().as_millis() as i64;
match result {
Ok(HandleResult::AgentResponse(output)) => {
let outbound = crate::bus::OutboundMessage {
channel: job.channel.clone(),
chat_id: job.chat_id.clone(),
content: output.clone(),
reply_to: None,
media: vec![],
metadata: std::collections::HashMap::new(),
};
let _ = self.bus.publish_outbound(outbound).await;
let output_truncated = if output.len() > 8000 {
format!("{}...[truncated]", &output[..8000])
} else {
output.clone()
};
let run = JobRun {
id: 0,
job_id: job.id.clone(),
started_at,
finished_at,
status: "ok".to_string(),
output: Some(output_truncated),
error: None,
duration_ms,
};
if let Err(e) = self.storage.record_scheduled_job_run(&run).await {
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
}
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
}
tracing::info!(
job_id = %job.id,
duration_ms = %duration_ms,
"scheduler: job completed successfully"
);
}
Ok(HandleResult::CommandOutput(output)) => {
let outbound = crate::bus::OutboundMessage {
channel: job.channel.clone(),
chat_id: job.chat_id.clone(),
content: output.clone(),
reply_to: None,
media: vec![],
metadata: std::collections::HashMap::new(),
};
let _ = self.bus.publish_outbound(outbound).await;
let run = JobRun {
id: 0,
job_id: job.id.clone(),
started_at,
finished_at,
status: "ok".to_string(),
output: Some(output),
error: None,
duration_ms,
};
let _ = self.storage.record_scheduled_job_run(&run).await;
}
Err(e) => {
let error_str = e.to_string();
let run = JobRun {
id: 0,
job_id: job.id.clone(),
started_at,
finished_at,
status: "error".to_string(),
output: None,
error: Some(error_str.clone()),
duration_ms,
};
if let Err(e2) = self.storage.record_scheduled_job_run(&run).await {
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
}
if let Err(e2) = self.storage.set_scheduled_job_last_status(
&job.id, "error", Some(&error_str),
).await {
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
}
tracing::error!(
job_id = %job.id,
duration_ms = %duration_ms,
error = %error_str,
"scheduler: job failed"
);
}
}
if let Err(e) = self.reschedule_after_run(job).await {
tracing::error!(job_id = %job.id, "scheduler: failed to reschedule: {}", e);
}
}
}
}
/// After a job runs, compute its next execution time or disable/delete it.
async fn reschedule_after_run(&self, job: &ScheduledJob) -> anyhow::Result<()> {
let now = now_ms();
match &job.schedule {
Schedule::At { .. } => {
if job.delete_after_run {
self.storage.remove_scheduled_job(&job.id).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
} else {
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
}
}
Schedule::Every { .. } | Schedule::Cron { .. } => {
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
} else {
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_next_run_at_schedule() {
let now = 1000000;
let next = next_run_for_schedule(&Schedule::At { at: 2000000 }, now);
assert_eq!(next, Some(2000000));
}
#[test]
fn test_next_run_every_schedule() {
let now = 1000000;
let next = next_run_for_schedule(&Schedule::Every { every_ms: 5000 }, now);
assert_eq!(next, Some(1005000));
}
#[test]
fn test_next_run_cron_every_minute() {
let expr = "0 * * * * *".to_string();
let schedule = Schedule::Cron { expr, tz: None };
let now = 1000000;
let next = next_run_for_schedule(&schedule, now);
assert!(next.is_some());
assert!(next.unwrap() > now);
}
#[test]
fn test_next_run_cron_every_day_at_9am() {
let expr = "0 0 9 * * *".to_string();
let schedule = Schedule::Cron { expr, tz: None };
let now = 1000000;
let next = next_run_for_schedule(&schedule, now);
assert!(next.is_some());
let next_ms = next.unwrap();
assert!(next_ms > now);
}
}

View File

@ -1,16 +0,0 @@
use serde::{Deserialize, Serialize};
/// How a job is scheduled. Serialized as JSON in the database `schedule` column.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Schedule {
/// One-shot: fires once at a specific Unix millisecond timestamp, then disables.
#[serde(rename = "at")]
At { at: i64 },
/// Recurring: fires every `every_ms` milliseconds.
#[serde(rename = "every")]
Every { every_ms: u64 },
/// Recurring: fires on a cron schedule with optional timezone.
#[serde(rename = "cron")]
Cron { expr: String, tz: Option<String> },
}

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::Mutex; use tokio::sync::{Mutex, mpsc};
use uuid::Uuid; use uuid::Uuid;
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
@ -21,6 +21,7 @@ use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::agent::system_prompt::build_system_prompt; use crate::agent::system_prompt::build_system_prompt;
use crate::agent::context_compressor::ContextCompressionConfig; use crate::agent::context_compressor::ContextCompressionConfig;
use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider}; use crate::providers::{create_provider, LLMProvider};
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
use crate::session::events::DialogInfo; use crate::session::events::DialogInfo;
@ -48,6 +49,7 @@ pub struct Session {
messages: Vec<ChatMessage>, messages: Vec<ChatMessage>,
seq_counter: i64, seq_counter: i64,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>, provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
@ -61,16 +63,14 @@ impl Session {
pub async fn new( pub async fn new(
id: UnifiedSessionId, id: UnifiedSessionId,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
storage: Option<StdArc<Storage>>, storage: Option<StdArc<Storage>>,
routing_info: String, routing_info: String,
title: String, title: String,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let mut provider_box = create_provider(provider_config.clone()) let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
if let Some(ref s) = storage {
provider_box.set_storage(s.clone());
}
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box); let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig { let compressor_config = ContextCompressionConfig {
@ -89,6 +89,7 @@ impl Session {
total_message_count: 0, total_message_count: 0,
messages: Vec::new(), messages: Vec::new(),
seq_counter: 1, seq_counter: 1,
user_tx,
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
provider: provider.clone(), provider: provider.clone(),
tools, tools,
@ -102,6 +103,7 @@ impl Session {
pub async fn from_storage( pub async fn from_storage(
id: UnifiedSessionId, id: UnifiedSessionId,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
storage: StdArc<Storage>, storage: StdArc<Storage>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
@ -111,9 +113,8 @@ impl Session {
let messages = storage.load_messages(&id.to_string(), 0).await let messages = storage.load_messages(&id.to_string(), 0).await
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?; .map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
let mut provider_box = create_provider(provider_config.clone()) let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
provider_box.set_storage(storage.clone());
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box); let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig { let compressor_config = ContextCompressionConfig {
@ -122,7 +123,6 @@ impl Session {
}; };
// Convert MessageMeta to ChatMessage // Convert MessageMeta to ChatMessage
// Clear tool_call_id/tool_name — they're not valid across API sessions
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| { let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
ChatMessage { ChatMessage {
id: m.id, id: m.id,
@ -130,8 +130,8 @@ impl Session {
content: m.content, content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(), media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at, timestamp: m.created_at,
tool_call_id: None, tool_call_id: m.tool_call_id,
tool_name: None, tool_name: m.tool_name,
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()), tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()), source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
} }
@ -149,6 +149,7 @@ impl Session {
total_message_count, total_message_count,
messages: chat_messages, messages: chat_messages,
seq_counter, seq_counter,
user_tx,
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
provider: provider.clone(), provider: provider.clone(),
tools, tools,
@ -251,17 +252,16 @@ impl Session {
} }
} }
pub fn create_user_message_with_source( pub async fn send(&self, msg: WsOutbound) {
&self, let _ = self.user_tx.send(msg).await;
content: &str,
media_refs: Vec<String>,
source: MessageSource,
) -> ChatMessage {
if media_refs.is_empty() {
ChatMessage::user_with_source(content, source)
} else {
ChatMessage::user_with_source(content, source)
} }
/// 发送系统通知(不记录进 session 历史)
pub async fn send_system_notification(&self, content: &str) {
let msg = WsOutbound::SystemNotification {
content: content.to_string(),
};
let _ = self.user_tx.send(msg).await;
} }
/// 将 session 元数据写回 Storage /// 将 session 元数据写回 Storage
@ -364,21 +364,12 @@ impl Session {
)) ))
} }
/// 创建一个附通知通道的 AgentLoop 实例
pub fn create_agent_with_notify(
&self,
notify_tx: tokio::sync::mpsc::UnboundedSender<String>,
) -> Result<AgentLoop, AgentError> {
Ok(self.create_agent()?.with_notify(notify_tx))
}
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills /// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills
pub fn build_system_prompt(&self, skills_prompt: &str) -> String { pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
let base_prompt = build_system_prompt( let base_prompt = build_system_prompt(
&self.provider_config.workspace_dir, &self.provider_config.workspace_dir,
&self.provider_config.model_id, &self.provider_config.model_id,
&self.tools, &self.tools,
Some(&self.id.to_string()),
); );
if skills_prompt.trim().is_empty() { if skills_prompt.trim().is_empty() {
@ -883,9 +874,11 @@ impl SessionManager {
self.storage.upsert_session(&meta).await self.storage.upsert_session(&meta).await
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?; .map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new( let session = Session::new(
unified_id.clone(), unified_id.clone(),
self.provider_config.clone(), self.provider_config.clone(),
user_tx,
self.tools.clone(), self.tools.clone(),
Some(self.storage.clone()), Some(self.storage.clone()),
routing_info, routing_info,
@ -916,9 +909,11 @@ impl SessionManager {
match self.storage.get_session(&session_id_str).await { match self.storage.get_session(&session_id_str).await {
Ok(meta) => { Ok(meta) => {
tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage"); tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage");
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::from_storage( let session = Session::from_storage(
unified_id.clone(), unified_id.clone(),
self.provider_config.clone(), self.provider_config.clone(),
user_tx,
self.tools.clone(), self.tools.clone(),
self.storage.clone(), self.storage.clone(),
).await?; ).await?;
@ -937,9 +932,11 @@ impl SessionManager {
} }
// Create new session // Create new session
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new( let session = Session::new(
unified_id.clone(), unified_id.clone(),
self.provider_config.clone(), self.provider_config.clone(),
user_tx,
self.tools.clone(), self.tools.clone(),
Some(self.storage.clone()), Some(self.storage.clone()),
String::new(), String::new(),
@ -1178,30 +1175,6 @@ impl SessionManager {
} }
// Normal message handling through LLM // Normal message handling through LLM
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
// Spawn notification publisher — sends immediately when tools are detected
{
let bus = self.bus.clone();
let ch = channel.to_string();
let cid = chat_id.to_string();
tokio::spawn(async move {
while let Some(notif) = notify_rx.recv().await {
let mut metadata = HashMap::new();
metadata.insert("_type".to_string(), "notification".to_string());
let outbound = OutboundMessage {
channel: ch.clone(),
chat_id: cid.clone(),
content: notif,
reply_to: None,
media: vec![],
metadata,
};
let _ = bus.publish_outbound(outbound).await;
}
});
}
let response: String = { let response: String = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
@ -1229,7 +1202,7 @@ impl SessionManager {
.compress_if_needed(history) .compress_if_needed(history)
.await?; .await?;
let agent = session_guard.create_agent_with_notify(notify_tx)?; let agent = session_guard.create_agent()?;
let result = agent.process(history).await?; let result = agent.process(history).await?;
for msg in result.emitted_messages { for msg in result.emitted_messages {
@ -1260,142 +1233,6 @@ impl SessionManager {
Ok(HandleResult::AgentResponse(response)) Ok(HandleResult::AgentResponse(response))
} }
/// Handle a message triggered by a scheduled cron job.
///
/// This is similar to `handle_message`, but the user message is created with
/// `SourceKind::ExternalTrigger` source metadata so that the cron job identity
/// is preserved in the conversation history and database.
pub async fn handle_cron_message(
&self,
channel: &str,
chat_id: &str,
prompt: &str,
job_id: &str,
job_name: &str,
) -> Result<HandleResult, AgentError> {
use crate::bus::{MessageSource, SourceKind};
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
*self.current_source_session.lock().await = Some(unified_id.to_string());
tracing::debug!(unified_id = %unified_id, job_id = %job_id, "handle_cron_message resolved");
let session = self.get_or_create_session(&unified_id).await?;
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
{
use std::collections::HashMap;
use crate::bus::OutboundMessage;
let bus = self.bus.clone();
let ch = channel.to_string();
let cid = chat_id.to_string();
tokio::spawn(async move {
while let Some(notif) = notify_rx.recv().await {
let mut metadata = HashMap::new();
metadata.insert("_type".to_string(), "notification".to_string());
let outbound = OutboundMessage {
channel: ch.clone(),
chat_id: cid.clone(),
content: notif,
reply_to: None,
media: vec![],
metadata,
};
let _ = bus.publish_outbound(outbound).await;
}
});
}
let response: String = {
let mut session_guard = session.lock().await;
let source = MessageSource {
kind: SourceKind::ExternalTrigger,
from_channel: Some(channel.to_string()),
from_session: None,
from_user_id: None,
system_name: Some(job_name.to_string()),
task_id: Some(job_id.to_string()),
};
let user_message = session_guard.create_user_message_with_source(prompt, vec![], source);
session_guard.add_message(user_message, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
let mut history = session_guard.get_history().to_vec();
let skills_prompt = self.skills_loader.build_skills_prompt();
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
let cron_context = format!(
"\n\n## 定时任务执行\n\n\
{}({})\n\
: {}:{}\n\n\
\n\
- \n\
- \n\
- \n\
- 使 send_message \n\
- ",
job_name, job_id, channel, chat_id
);
let full_system_prompt = format!("{}{}", system_prompt, cron_context);
history.insert(0, ChatMessage::system(full_system_prompt));
let history = session_guard.compressor
.compress_if_needed(history)
.await?;
let agent = session_guard.create_agent_with_notify(notify_tx)?;
let result = agent.process(history).await?;
for msg in result.emitted_messages {
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
}
if session_guard.should_generate_title() {
if let Err(e) = session_guard.generate_title().await {
tracing::warn!("failed to generate title: {}", e);
}
}
let raw_response = result.final_response.content;
let target_id = unified_id.to_string();
let prefix = format!(
"[message from cron:{}({}) to {}]\n",
job_name, job_id, target_id
);
let prefixed_response = format!("{}{}", prefix, raw_response);
let source = MessageSource {
kind: SourceKind::CrossChannel,
from_channel: Some("cron".to_string()),
from_session: Some(format!("{}:{}", job_name, job_id)),
from_user_id: None,
system_name: Some(job_name.to_string()),
task_id: Some(job_id.to_string()),
};
let msg = ChatMessage::assistant_with_source(prefixed_response.clone(), source);
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
prefixed_response
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel,
chat_id = %chat_id,
job_id = %job_id,
response_len = %response.len(),
"Cron agent response received"
);
*self.current_source_session.lock().await = None;
Ok(HandleResult::AgentResponse(response))
}
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
let session = self.get_or_create_session(unified_id).await?; let session = self.get_or_create_session(unified_id).await?;
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
@ -1485,6 +1322,7 @@ impl OutboundMessenger for SessionManager {
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig { fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig { LLMProviderConfig {

View File

@ -1,17 +1,15 @@
pub mod error; pub mod error;
pub mod message;
pub mod scheduler;
pub mod session; pub mod session;
pub mod message;
pub use error::StorageError; pub use error::StorageError;
pub use scheduler::{JobRun, ScheduledJob};
use sqlx::{Pool, Row, Sqlite, SqlitePool}; use sqlx::{Pool, Row, Sqlite, SqlitePool};
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use std::path::Path; use std::path::Path;
pub struct Storage { pub struct Storage {
pub(crate) pool: Pool<Sqlite>, pool: Pool<Sqlite>,
} }
impl Storage { impl Storage {
@ -94,130 +92,6 @@ impl Storage {
.await .await
.ok(); .ok();
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS llm_calls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at INTEGER NOT NULL,
provider TEXT NOT NULL,
model TEXT NOT NULL,
request_body TEXT NOT NULL,
response_body TEXT,
error TEXT,
duration_ms INTEGER
)
"#,
)
.execute(&self.pool)
.await?;
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
}
Ok(())
}
/// Initialize the scheduler tables (idempotent).
pub(crate) async fn init_scheduler_schema(pool: &Pool<Sqlite>) -> Result<(), StorageError> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS scheduled_jobs (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
schedule TEXT NOT NULL,
prompt TEXT NOT NULL,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
model TEXT,
enabled INTEGER NOT NULL DEFAULT 1,
delete_after_run INTEGER NOT NULL DEFAULT 0,
next_run_at INTEGER NOT NULL,
last_run_at INTEGER,
last_status TEXT,
last_error TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS job_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id TEXT NOT NULL REFERENCES scheduled_jobs(id) ON DELETE CASCADE,
started_at INTEGER NOT NULL,
finished_at INTEGER NOT NULL,
status TEXT NOT NULL,
output TEXT,
error TEXT,
duration_ms INTEGER NOT NULL
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_jobs_next_run ON scheduled_jobs(enabled, next_run_at)",
)
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_runs_job_id ON job_runs(job_id)")
.execute(pool)
.await?;
Ok(())
}
pub async fn append_llm_call(
&self,
provider: &str,
model: &str,
request_body: &str,
response_body: Option<&str>,
error: Option<&str>,
duration_ms: u64,
) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(
r#"
INSERT INTO llm_calls (created_at, provider, model, request_body, response_body, error, duration_ms)
VALUES (?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(now)
.bind(provider)
.bind(model)
.bind(request_body)
.bind(response_body)
.bind(error)
.bind(duration_ms as i64)
.execute(self.pool())
.await?;
// Prune to keep last 1000 records
self.prune_llm_calls(1000).await?;
Ok(())
}
async fn prune_llm_calls(&self, max_records: i64) -> Result<(), StorageError> {
sqlx::query(
r#"
DELETE FROM llm_calls WHERE id <= (
SELECT COALESCE(MAX(id), 0) - ? FROM llm_calls
)
"#,
)
.bind(max_records)
.execute(self.pool())
.await?;
Ok(()) Ok(())
} }
@ -465,79 +339,6 @@ impl Storage {
.collect()) .collect())
} }
pub async fn list_all_active_sessions(
&self,
limit: i64,
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions
WHERE deleted_at IS NULL
ORDER BY last_active_at DESC
LIMIT ?
"#,
)
.bind(limit)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| crate::storage::session::SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})
.collect())
}
pub async fn list_recent_messages(
&self,
session_id: &str,
count: i64,
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
FROM messages
WHERE session_id = ?
ORDER BY seq DESC
LIMIT ?
"#,
)
.bind(session_id)
.bind(count)
.fetch_all(self.pool())
.await?;
let mut messages: Vec<_> = rows
.into_iter()
.map(|row| crate::storage::message::MessageMeta {
id: row.get("id"),
session_id: row.get("session_id"),
seq: row.get("seq"),
role: row.get("role"),
content: row.get("content"),
media_refs: row.get("media_refs"),
tool_call_id: row.get("tool_call_id"),
tool_name: row.get("tool_name"),
tool_calls: row.get("tool_calls"),
source: row.get("source"),
created_at: row.get("created_at"),
})
.collect();
messages.reverse();
Ok(messages)
}
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#) sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
.bind(session_id) .bind(session_id)

View File

@ -1,551 +0,0 @@
use serde::{Deserialize, Serialize};
use sqlx::Row;
use crate::scheduler::Schedule;
/// A scheduled job stored in the database.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduledJob {
pub id: String,
pub name: String,
/// JSON-serialized `Schedule` stored as TEXT in SQLite.
pub schedule: Schedule,
pub prompt: String,
pub channel: String,
pub chat_id: String,
pub model: Option<String>,
pub enabled: bool,
pub delete_after_run: bool,
pub next_run_at: i64,
pub last_run_at: Option<i64>,
pub last_status: Option<String>,
pub last_error: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
/// A single execution record for a job.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobRun {
pub id: i64,
pub job_id: String,
pub started_at: i64,
pub finished_at: i64,
pub status: String,
pub output: Option<String>,
pub error: Option<String>,
pub duration_ms: i64,
}
impl crate::storage::Storage {
/// Insert a new scheduled job.
pub async fn add_scheduled_job(&self, job: &ScheduledJob) -> anyhow::Result<()> {
let schedule_json = serde_json::to_string(&job.schedule)?;
sqlx::query(
r#"
INSERT INTO scheduled_jobs
(id, name, schedule, prompt, channel, chat_id, model,
enabled, delete_after_run, next_run_at, last_run_at,
last_status, last_error, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&job.id)
.bind(&job.name)
.bind(&schedule_json)
.bind(&job.prompt)
.bind(&job.channel)
.bind(&job.chat_id)
.bind(&job.model)
.bind(job.enabled as i32)
.bind(job.delete_after_run as i32)
.bind(job.next_run_at)
.bind(job.last_run_at)
.bind(&job.last_status)
.bind(&job.last_error)
.bind(job.created_at)
.bind(job.updated_at)
.execute(self.pool())
.await?;
Ok(())
}
/// Fetch a single scheduled job by ID.
pub async fn get_scheduled_job(&self, id: &str) -> anyhow::Result<ScheduledJob> {
let row = sqlx::query("SELECT * FROM scheduled_jobs WHERE id = ?")
.bind(id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| anyhow::anyhow!("job not found: {id}"))?;
row_to_job(&row)
}
/// List all scheduled jobs, ordered by next_run_at ascending.
pub async fn list_scheduled_jobs(&self) -> anyhow::Result<Vec<ScheduledJob>> {
let rows = sqlx::query("SELECT * FROM scheduled_jobs ORDER BY next_run_at ASC")
.fetch_all(self.pool())
.await?;
rows.iter().map(row_to_job).collect()
}
/// Delete a scheduled job (cascades to job_runs).
pub async fn remove_scheduled_job(&self, id: &str) -> anyhow::Result<()> {
sqlx::query("DELETE FROM scheduled_jobs WHERE id = ?")
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// Enable or disable a scheduled job.
pub async fn set_scheduled_job_enabled(&self, id: &str, enabled: bool) -> anyhow::Result<()> {
sqlx::query("UPDATE scheduled_jobs SET enabled = ?, updated_at = ? WHERE id = ?")
.bind(enabled as i32)
.bind(now_ms())
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// Update selective fields on a scheduled job.
pub async fn update_scheduled_job(
&self,
id: &str,
prompt: Option<String>,
schedule: Option<Schedule>,
channel: Option<String>,
chat_id: Option<String>,
model: Option<String>,
) -> anyhow::Result<()> {
let now = now_ms();
if let Some(p) = prompt {
sqlx::query("UPDATE scheduled_jobs SET prompt = ?, updated_at = ? WHERE id = ?")
.bind(&p)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
}
if let Some(s) = schedule {
let json = serde_json::to_string(&s)?;
sqlx::query("UPDATE scheduled_jobs SET schedule = ?, updated_at = ? WHERE id = ?")
.bind(&json)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
}
if let Some(c) = channel {
sqlx::query("UPDATE scheduled_jobs SET channel = ?, updated_at = ? WHERE id = ?")
.bind(&c)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
}
if let Some(c) = chat_id {
sqlx::query("UPDATE scheduled_jobs SET chat_id = ?, updated_at = ? WHERE id = ?")
.bind(&c)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
}
if let Some(m) = model {
sqlx::query("UPDATE scheduled_jobs SET model = ?, updated_at = ? WHERE id = ?")
.bind(&m)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
}
Ok(())
}
/// Update next_run_at and last_run_at for a job.
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
let now = now_ms();
sqlx::query(
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
)
.bind(next_run_at)
.bind(now)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// Set last_run_at for a job (used when starting execution).
pub async fn touch_scheduled_job_last_run(&self, id: &str, at: i64) -> anyhow::Result<()> {
sqlx::query("UPDATE scheduled_jobs SET last_run_at = ?, updated_at = ? WHERE id = ?")
.bind(at)
.bind(at)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// Set last_status and last_error after job completion.
pub async fn set_scheduled_job_last_status(
&self,
id: &str,
status: &str,
error: Option<&str>,
) -> anyhow::Result<()> {
let now = now_ms();
sqlx::query(
"UPDATE scheduled_jobs SET last_status = ?, last_error = ?, updated_at = ? WHERE id = ?",
)
.bind(status)
.bind(error)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// Fetch enabled jobs whose next_run_at <= now, up to `limit`.
pub async fn due_scheduled_jobs(
&self,
now: i64,
limit: usize,
) -> anyhow::Result<Vec<ScheduledJob>> {
let rows = sqlx::query(
"SELECT * FROM scheduled_jobs WHERE enabled = 1 AND next_run_at <= ? ORDER BY next_run_at ASC LIMIT ?",
)
.bind(now)
.bind(limit as i64)
.fetch_all(self.pool())
.await?;
rows.iter().map(row_to_job).collect()
}
/// Record a job execution run.
pub async fn record_scheduled_job_run(&self, run: &JobRun) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO job_runs (job_id, started_at, finished_at, status, output, error, duration_ms)
VALUES (?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&run.job_id)
.bind(run.started_at)
.bind(run.finished_at)
.bind(&run.status)
.bind(&run.output)
.bind(&run.error)
.bind(run.duration_ms)
.execute(self.pool())
.await?;
Ok(())
}
/// List recent runs for a job, newest first.
pub async fn list_scheduled_job_runs(
&self,
job_id: &str,
limit: usize,
) -> anyhow::Result<Vec<JobRun>> {
let rows = sqlx::query(
"SELECT * FROM job_runs WHERE job_id = ? ORDER BY finished_at DESC LIMIT ?",
)
.bind(job_id)
.bind(limit as i64)
.fetch_all(self.pool())
.await?;
rows.iter()
.map(|r| {
Ok(JobRun {
id: r.try_get("id")?,
job_id: r.try_get("job_id")?,
started_at: r.try_get("started_at")?,
finished_at: r.try_get("finished_at")?,
status: r.try_get("status")?,
output: r.try_get("output")?,
error: r.try_get("error")?,
duration_ms: r.try_get("duration_ms")?,
})
})
.collect()
}
/// Delete disabled jobs whose updated_at is before `before`.
pub async fn cleanup_disabled_scheduled_jobs(&self, before: i64) -> anyhow::Result<()> {
sqlx::query("DELETE FROM scheduled_jobs WHERE enabled = 0 AND updated_at < ?")
.bind(before)
.execute(self.pool())
.await?;
Ok(())
}
}
fn now_ms() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
fn row_to_job(row: &sqlx::sqlite::SqliteRow) -> anyhow::Result<ScheduledJob> {
let schedule_json: String = row.try_get("schedule")?;
let schedule: Schedule = serde_json::from_str(&schedule_json)?;
Ok(ScheduledJob {
id: row.try_get("id")?,
name: row.try_get("name")?,
schedule,
prompt: row.try_get("prompt")?,
channel: row.try_get("channel")?,
chat_id: row.try_get("chat_id")?,
model: row.try_get("model")?,
enabled: row.try_get::<i32, _>("enabled")? != 0,
delete_after_run: row.try_get::<i32, _>("delete_after_run")? != 0,
next_run_at: row.try_get("next_run_at")?,
last_run_at: row.try_get("last_run_at")?,
last_status: row.try_get("last_status")?,
last_error: row.try_get("last_error")?,
created_at: row.try_get("created_at")?,
updated_at: row.try_get("updated_at")?,
})
}
#[cfg(test)]
mod tests {
use super::ScheduledJob;
use crate::scheduler::Schedule;
use crate::storage::Storage;
use sqlx::SqlitePool;
fn now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
async fn setup_storage() -> Storage {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let storage = Storage { pool };
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
storage
}
#[tokio::test]
async fn test_init_creates_tables() {
let storage = setup_storage().await;
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM scheduled_jobs")
.fetch_one(storage.pool())
.await
.unwrap();
assert_eq!(row.0, 0);
}
#[tokio::test]
async fn test_add_and_get_job() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-1".into(),
name: "test job".into(),
schedule: Schedule::Every { every_ms: 3600000 },
prompt: "say hello".into(),
channel: "cli_chat".into(),
chat_id: "conn-1".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 3600000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let got = storage.get_scheduled_job("job-1").await.unwrap();
assert_eq!(got.id, "job-1");
assert_eq!(got.name, "test job");
assert_eq!(got.prompt, "say hello");
}
#[tokio::test]
async fn test_list_jobs() {
let storage = setup_storage().await;
let t = now();
for i in 0..3 {
let job = ScheduledJob {
id: format!("job-{}", i),
name: format!("job {}", i),
schedule: Schedule::Every { every_ms: 3600000 },
prompt: "ping".into(),
channel: "cli_chat".into(),
chat_id: "conn-1".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
}
let jobs = storage.list_scheduled_jobs().await.unwrap();
assert_eq!(jobs.len(), 3);
}
#[tokio::test]
async fn test_remove_job() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-rm".into(),
name: "remove me".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage.remove_scheduled_job("job-rm").await.unwrap();
let result = storage.get_scheduled_job("job-rm").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_set_enabled() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-toggle".into(),
name: "toggle".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
assert!(!got.enabled);
}
#[tokio::test]
async fn test_due_jobs_only_returns_enabled_and_overdue() {
let storage = setup_storage().await;
let t = now();
let jobs = vec![
ScheduledJob {
id: "due".into(), name: "due".into(),
schedule: Schedule::At { at: t }, prompt: "1".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t - 1000, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
},
ScheduledJob {
id: "future".into(), name: "future".into(),
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t + 99999999, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
},
ScheduledJob {
id: "disabled-due".into(), name: "disabled due".into(),
schedule: Schedule::At { at: t }, prompt: "3".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: false, delete_after_run: false,
next_run_at: t - 1000, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
},
];
for j in &jobs {
storage.add_scheduled_job(j).await.unwrap();
}
let due = storage.due_scheduled_jobs(t, 10).await.unwrap();
assert_eq!(due.len(), 1);
assert_eq!(due[0].id, "due");
}
#[tokio::test]
async fn test_record_run_and_list_runs() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-run".into(), name: "run test".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let run = super::JobRun {
id: 0, job_id: "job-run".into(),
started_at: t, finished_at: t + 500,
status: "ok".into(), output: Some("hello".into()),
error: None, duration_ms: 500,
};
storage.record_scheduled_job_run(&run).await.unwrap();
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].status, "ok");
assert_eq!(runs[0].output.as_deref(), Some("hello"));
}
#[tokio::test]
async fn test_update_job() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-update".into(), name: "old name".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "old prompt".into(), channel: "feishu".into(),
chat_id: "oc_1".into(), model: None,
enabled: true, delete_after_run: false,
next_run_at: t, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage.update_scheduled_job(
"job-update",
Some("new prompt".into()),
Some(Schedule::Every { every_ms: 60000 }),
None, None, None,
).await.unwrap();
let got = storage.get_scheduled_job("job-update").await.unwrap();
assert_eq!(got.prompt, "new prompt");
}
}

View File

@ -1,343 +0,0 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::storage::Storage;
use crate::tools::traits::{Tool, ToolResult};
pub struct ChatManagerTool {
storage: Arc<Storage>,
available_channels: Vec<String>,
}
impl ChatManagerTool {
pub fn new(storage: Arc<Storage>, available_channels: Vec<String>) -> Self {
Self {
storage,
available_channels,
}
}
}
#[async_trait]
impl Tool for ChatManagerTool {
fn name(&self) -> &str {
"chat_manager"
}
fn description(&self) -> &str {
"聊天管理工具。可以列出当前活跃的 session、可用的 channel、以及查看指定 session 的最近消息内容。\
action : list_sessions (), list_channels (), list_messages ()"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["list_sessions", "list_channels", "list_messages"],
"description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的最近消息"
},
"session_id": {
"type": "string",
"description": "会话 ID格式 channel:chat_id:dialog_id仅在 action 为 list_messages 时必填"
},
"count": {
"type": "integer",
"description": "获取最近消息的数量,仅在 action 为 list_messages 时有效,默认 20"
}
},
"required": ["action"]
})
}
fn read_only(&self) -> bool {
true
}
fn concurrency_safe(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args["action"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
match action {
"list_channels" => self.list_channels().await,
"list_sessions" => self.list_sessions().await,
"list_messages" => self.list_messages(&args).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: {}. Supported: list_sessions, list_channels, list_messages",
action
)),
}),
}
}
}
impl ChatManagerTool {
async fn list_channels(&self) -> anyhow::Result<ToolResult> {
let channels = self.available_channels.join(", ");
Ok(ToolResult {
success: true,
output: format!("可用渠道 ({}): {}", self.available_channels.len(), channels),
error: None,
})
}
async fn list_sessions(&self) -> anyhow::Result<ToolResult> {
let sessions = self
.storage
.list_all_active_sessions(20)
.await
.map_err(|e| anyhow::anyhow!("Failed to list sessions: {}", e))?;
if sessions.is_empty() {
return Ok(ToolResult {
success: true,
output: "当前没有活跃的会话".to_string(),
error: None,
});
}
let now_ms = chrono::Utc::now().timestamp_millis();
let mut output = format!("活跃会话 (共 {} 个):\n", sessions.len());
for s in &sessions {
let ago = format_duration_ago(now_ms - s.last_active_at);
output.push_str(&format!(
"- {}\n title={} | channel={} chat_id={} | {}条消息 | 最后活动: {}前\n",
s.id, s.title, s.channel, s.chat_id, s.message_count, ago
));
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
async fn list_messages(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let session_id = args["session_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: session_id"))?;
let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100);
let session = self
.storage
.get_session(session_id)
.await
.map_err(|e| anyhow::anyhow!("Session not found: {}", e))?;
let messages = self
.storage
.list_recent_messages(session_id, count)
.await
.map_err(|e| anyhow::anyhow!("Failed to load messages: {}", e))?;
let mut output = format!(
"会话: {} ({})\n--- 最近 {} 条消息 (共 {} 条) ---\n",
session_id, session.title, messages.len(), session.message_count
);
if messages.is_empty() {
output.push_str("(暂无消息)\n");
} else {
for m in &messages {
let time = format_timestamp(m.created_at);
let role_tag = match m.role.as_str() {
"user" => "user ",
"assistant" => "assistant",
"tool" => "tool ",
"system" => "system ",
other => other,
};
let preview = truncate_content(&m.content, 200);
output.push_str(&format!(
"[{}] {} | {} | {}\n",
m.seq, time, role_tag, preview
));
}
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
fn format_duration_ago(millis: i64) -> String {
let secs = millis / 1000;
if secs < 60 {
format!("{}", secs)
} else if secs < 3600 {
format!("{}分钟", secs / 60)
} else if secs < 86400 {
format!("{}小时", secs / 3600)
} else {
format!("{}", secs / 86400)
}
}
fn format_timestamp(ms: i64) -> String {
if let Some(dt) = chrono::DateTime::from_timestamp_millis(ms) {
dt.format("%m-%d %H:%M").to_string()
} else {
ms.to_string()
}
}
fn truncate_content(content: &str, max_len: usize) -> String {
let content = content.replace('\n', " ");
if content.chars().count() > max_len {
format!("{}...", content.chars().take(max_len).collect::<String>())
} else {
content
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_storage() -> (Arc<Storage>, TempDir) {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let storage = Storage::new(&db_path).await.unwrap();
(Arc::new(storage), dir)
}
#[tokio::test]
async fn test_list_channels() {
let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec!["cli_chat".into(), "feishu".into()]);
let result = tool
.execute(json!({ "action": "list_channels" }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("cli_chat"));
assert!(result.output.contains("feishu"));
}
#[tokio::test]
async fn test_list_sessions_empty() {
let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec![]);
let result = tool
.execute(json!({ "action": "list_sessions" }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("没有"));
}
#[tokio::test]
async fn test_list_sessions_with_data() {
let (storage, _dir) = create_test_storage().await;
let now = chrono::Utc::now().timestamp_millis();
for i in 0..3 {
let meta = crate::storage::session::SessionMeta {
id: format!("cli_chat:sid{}:dialog{}", i, i),
channel: "cli_chat".to_string(),
chat_id: format!("sid{}", i),
dialog_id: format!("dialog{}", i),
title: format!("会话{}", i),
created_at: now - i * 3600_000,
last_active_at: now - i * 3600_000,
message_count: i * 5,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
let tool = ChatManagerTool::new(storage, vec![]);
let result = tool
.execute(json!({ "action": "list_sessions" }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("会话0"));
assert!(result.output.contains("会话1"));
assert!(result.output.contains("会话2"));
}
#[tokio::test]
async fn test_list_messages() {
let (storage, _dir) = create_test_storage().await;
let now = chrono::Utc::now().timestamp_millis();
let session_id = "cli_chat:sid0:dialog0";
let meta = crate::storage::session::SessionMeta {
id: session_id.to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid0".to_string(),
dialog_id: "dialog0".to_string(),
title: "测试会话".to_string(),
created_at: now,
last_active_at: now,
message_count: 3,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
for i in 0..3 {
let msg = crate::storage::message::MessageMeta {
id: format!("msg{}", i),
session_id: session_id.to_string(),
seq: i as i64 + 1,
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
content: format!("消息内容 {}", i),
media_refs: None,
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: None,
created_at: now + i * 1000,
};
storage.append_message(session_id, &msg).await.unwrap();
}
let tool = ChatManagerTool::new(storage, vec![]);
let result = tool
.execute(json!({ "action": "list_messages", "session_id": session_id }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("消息内容 0"));
assert!(result.output.contains("消息内容 2"));
assert!(result.output.contains("测试会话"));
}
#[tokio::test]
async fn test_unknown_action() {
let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec![]);
let result = tool
.execute(json!({ "action": "unknown" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}
}

View File

@ -1,800 +0,0 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{json, Value};
use uuid::Uuid;
use crate::scheduler::{next_run_for_schedule, Schedule};
use crate::storage::{ScheduledJob, Storage};
use crate::tools::traits::{Tool, ToolResult};
fn now_ms() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
pub struct CronAddTool {
storage: Arc<Storage>,
valid_channels: Vec<String>,
}
impl CronAddTool {
pub fn new(storage: Arc<Storage>, valid_channels: Vec<String>) -> Self {
Self {
storage,
valid_channels,
}
}
}
#[async_trait]
impl Tool for CronAddTool {
fn name(&self) -> &str {
"cron_add"
}
fn description(&self) -> &str {
"Create a new scheduled task (cron job). The task will execute an AI prompt on a schedule \
and deliver the result to the specified channel/chat. \
Important: the execution environment is a fresh session with no access to your current \
conversation history. The prompt parameter MUST include all necessary context: \
what to do, the target audience, required output format, and any background information. \
Schedule formats: \
- 'every': {\"type\":\"every\",\"every_ms\":3600000} for every hour, \
- 'at': {\"type\":\"at\",\"at\":<unix_timestamp_ms>} for one-shot, \
- 'cron': {\"type\":\"cron\",\"expr\":\"0 0 9 * * *\"} for cron expressions (6-field: sec min hour dom month dow)."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"schedule": {
"type": "object",
"description": "Schedule definition. One of: {\"type\":\"every\",\"every_ms\":<ms>}, {\"type\":\"at\",\"at\":<unix_ms>}, or {\"type\":\"cron\",\"expr\":\"<cron_expr>\",\"tz\":\"<tz>\"}",
"required": ["type"]
},
"prompt": {
"type": "string",
"description": "The AI prompt to execute on each trigger"
},
"channel": {
"type": "string",
"description": "Target channel for delivering results (e.g., 'feishu', 'cli_chat')"
},
"chat_id": {
"type": "string",
"description": "Target chat ID within the channel"
},
"name": {
"type": "string",
"description": "Human-readable name for the job (optional, defaults to truncated prompt)"
},
"model": {
"type": "string",
"description": "Optional model override for this job"
}
},
"required": ["schedule", "prompt", "channel", "chat_id"]
})
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let schedule_json = args
.get("schedule")
.ok_or_else(|| anyhow::anyhow!("missing 'schedule'"))?;
let schedule: Schedule = serde_json::from_value(schedule_json.clone())
.map_err(|e| anyhow::anyhow!("invalid schedule: {}", e))?;
let prompt = args
.get("prompt")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if prompt.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("prompt is required".into()),
});
}
let channel = args
.get("channel")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if !self.valid_channels.contains(&channel) {
return Ok(ToolResult {
success: false,
output: format!(
"Unknown channel '{}'. Available: {}",
channel,
self.valid_channels.join(", ")
),
error: Some(format!("Unknown channel: {}", channel)),
});
}
let chat_id = args
.get("chat_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if chat_id.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("chat_id is required".into()),
});
}
let name = args
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_else(|| {
// char-boundary-safe truncation to 50 bytes
let limit = 50;
if prompt.len() <= limit {
prompt.as_str()
} else {
let mut end = limit;
while !prompt.is_char_boundary(end) {
end -= 1;
}
&prompt[..end]
}
})
.to_string();
let model = args
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let now = now_ms();
let next_run_at = next_run_for_schedule(&schedule, now)
.ok_or_else(|| anyhow::anyhow!("could not compute next run time from schedule"))?;
let id = Uuid::new_v4().to_string()[..8].to_string();
let job = ScheduledJob {
id: id.clone(),
name: name.clone(),
schedule,
prompt,
channel,
chat_id,
model,
enabled: true,
delete_after_run: false,
next_run_at,
last_run_at: None,
last_status: None,
last_error: None,
created_at: now,
updated_at: now,
};
self.storage.add_scheduled_job(&job).await?;
Ok(ToolResult {
success: true,
output: format!(
"Scheduled job created: id={}, name=\"{}\", next_run_at={}",
id, name, next_run_at
),
error: None,
})
}
}
// ── CronListTool ─────────────────────────────────────────────────────────────
pub struct CronListTool {
storage: Arc<Storage>,
}
impl CronListTool {
pub fn new(storage: Arc<Storage>) -> Self {
Self { storage }
}
}
#[async_trait]
impl Tool for CronListTool {
fn name(&self) -> &str {
"cron_list"
}
fn description(&self) -> &str {
"List all scheduled tasks (cron jobs) with their status and next run time."
}
fn read_only(&self) -> bool {
true
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["all", "enabled", "disabled"],
"description": "Filter by job status (default: all)"
}
}
})
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let filter = args
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("all");
let jobs = self.storage.list_scheduled_jobs().await?;
let filtered: Vec<&ScheduledJob> = match filter {
"enabled" => jobs.iter().filter(|j| j.enabled).collect(),
"disabled" => jobs.iter().filter(|j| !j.enabled).collect(),
_ => jobs.iter().collect(),
};
if filtered.is_empty() {
return Ok(ToolResult {
success: true,
output: "No scheduled jobs found.".into(),
error: None,
});
}
let mut lines = Vec::new();
for j in &filtered {
let status = if j.enabled { "enabled" } else { "disabled" };
let last = match (&j.last_status, &j.last_error) {
(Some(s), _) if s == "ok" => " last:ok".to_string(),
(Some(_), Some(e)) => format!(" last:err({})", &e[..e.len().min(40)]),
_ => String::new(),
};
let model = j.model.as_deref().unwrap_or("default");
lines.push(format!(
"[{}] id={} name=\"{}\" channel={} chat={} model={} next={}{}",
status, j.id, j.name, j.channel, j.chat_id, model, j.next_run_at, last
));
}
Ok(ToolResult {
success: true,
output: lines.join("\n"),
error: None,
})
}
}
// ── CronRemoveTool ───────────────────────────────────────────────────────────
pub struct CronRemoveTool {
storage: Arc<Storage>,
}
impl CronRemoveTool {
pub fn new(storage: Arc<Storage>) -> Self {
Self { storage }
}
}
#[async_trait]
impl Tool for CronRemoveTool {
fn name(&self) -> &str {
"cron_remove"
}
fn description(&self) -> &str {
"Delete a scheduled task permanently by its job ID. Use cron_list first to find the ID."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"job_id": {
"type": "string",
"description": "The ID of the job to delete"
}
},
"required": ["job_id"]
})
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let job_id = args
.get("job_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if job_id.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("job_id is required".into()),
});
}
match self.storage.get_scheduled_job(&job_id).await {
Ok(_) => {}
Err(_) => {
return Ok(ToolResult {
success: false,
output: format!("Job {} not found.", job_id),
error: Some("not found".into()),
});
}
}
self.storage.remove_scheduled_job(&job_id).await?;
Ok(ToolResult {
success: true,
output: format!("Job {} deleted.", job_id),
error: None,
})
}
}
// ── CronEnableTool ───────────────────────────────────────────────────────────
pub struct CronEnableTool {
storage: Arc<Storage>,
}
impl CronEnableTool {
pub fn new(storage: Arc<Storage>) -> Self {
Self { storage }
}
}
#[async_trait]
impl Tool for CronEnableTool {
fn name(&self) -> &str {
"cron_enable"
}
fn description(&self) -> &str {
"Enable a disabled scheduled task by its job ID."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"job_id": {
"type": "string",
"description": "The ID of the job to enable"
}
},
"required": ["job_id"]
})
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let job_id = args
.get("job_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if job_id.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("job_id is required".into()),
});
}
let job = self
.storage
.get_scheduled_job(&job_id)
.await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
let next = next_run_for_schedule(&job.schedule, now_ms());
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
if let Some(n) = next {
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
}
Ok(ToolResult {
success: true,
output: format!("Job {} enabled.", job_id),
error: None,
})
}
}
// ── CronDisableTool ──────────────────────────────────────────────────────────
pub struct CronDisableTool {
storage: Arc<Storage>,
}
impl CronDisableTool {
pub fn new(storage: Arc<Storage>) -> Self {
Self { storage }
}
}
#[async_trait]
impl Tool for CronDisableTool {
fn name(&self) -> &str {
"cron_disable"
}
fn description(&self) -> &str {
"Disable a scheduled task by its job ID without deleting it."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"job_id": {
"type": "string",
"description": "The ID of the job to disable"
}
},
"required": ["job_id"]
})
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let job_id = args
.get("job_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if job_id.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("job_id is required".into()),
});
}
let _ = self
.storage
.get_scheduled_job(&job_id)
.await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
Ok(ToolResult {
success: true,
output: format!("Job {} disabled.", job_id),
error: None,
})
}
}
// ── CronUpdateTool ───────────────────────────────────────────────────────────
pub struct CronUpdateTool {
storage: Arc<Storage>,
}
impl CronUpdateTool {
pub fn new(storage: Arc<Storage>) -> Self {
Self { storage }
}
}
#[async_trait]
impl Tool for CronUpdateTool {
fn name(&self) -> &str {
"cron_update"
}
fn description(&self) -> &str {
"Update fields of an existing scheduled task. Only specified fields are changed."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"job_id": {
"type": "string",
"description": "The ID of the job to update"
},
"prompt": {
"type": "string",
"description": "New AI prompt"
},
"schedule": {
"type": "object",
"description": "New schedule definition"
},
"channel": {
"type": "string",
"description": "New target channel"
},
"chat_id": {
"type": "string",
"description": "New target chat ID"
},
"model": {
"type": "string",
"description": "New model override"
}
},
"required": ["job_id"]
})
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let job_id = args
.get("job_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if job_id.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("job_id is required".into()),
});
}
let _ = self
.storage
.get_scheduled_job(&job_id)
.await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
let prompt = args
.get("prompt")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let schedule: Option<Schedule> = match args.get("schedule") {
Some(s) => Some(
serde_json::from_value(s.clone())
.map_err(|e| anyhow::anyhow!("invalid schedule: {}", e))?,
),
None => None,
};
let channel = args
.get("channel")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let chat_id = args
.get("chat_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let model = args
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
self.storage
.update_scheduled_job(&job_id, prompt, schedule, channel, chat_id, model)
.await?;
if args.get("schedule").is_some() {
let job = self.storage.get_scheduled_job(&job_id).await?;
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
}
}
Ok(ToolResult {
success: true,
output: format!("Job {} updated.", job_id),
error: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scheduler::Schedule;
use crate::storage::{ScheduledJob, Storage};
use serde_json::json;
use sqlx::SqlitePool;
async fn setup_storage() -> Arc<Storage> {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
Storage::init_scheduler_schema(&pool).await.unwrap();
Arc::new(Storage { pool })
}
fn now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
#[tokio::test]
async fn test_cron_add_tool() {
let storage = setup_storage().await;
let tool = CronAddTool::new(storage.clone(), vec!["cli_chat".to_string()]);
let result = tool
.execute(json!({
"schedule": {"type": "every", "every_ms": 3600000},
"prompt": "report status",
"channel": "cli_chat",
"chat_id": "test-chat-1",
"name": "hourly report"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("hourly report"));
let jobs = storage.list_scheduled_jobs().await.unwrap();
assert_eq!(jobs.len(), 1);
assert_eq!(jobs[0].name, "hourly report");
}
#[tokio::test]
async fn test_cron_add_invalid_channel() {
let storage = setup_storage().await;
let tool = CronAddTool::new(storage.clone(), vec!["cli_chat".to_string()]);
let result = tool
.execute(json!({
"schedule": {"type": "every", "every_ms": 3600000},
"prompt": "test",
"channel": "nonexistent",
"chat_id": "x",
"name": "test"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("Unknown channel"));
}
#[tokio::test]
async fn test_cron_list_tool() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: uuid::Uuid::new_v4().to_string(),
name: "list-test".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let tool = CronListTool::new(storage.clone());
let result = tool.execute(json!({})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("list-test"));
}
#[tokio::test]
async fn test_cron_remove_tool() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-rm-tool".into(),
name: "rm me".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let tool = CronRemoveTool::new(storage.clone());
let result = tool
.execute(json!({"job_id": "job-rm-tool"}))
.await
.unwrap();
assert!(result.success);
assert!(storage.get_scheduled_job("job-rm-tool").await.is_err());
}
#[tokio::test]
async fn test_cron_enable_disable_tools() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-toggle-tool".into(),
name: "toggle".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let disable_tool = CronDisableTool::new(storage.clone());
let result = disable_tool
.execute(json!({"job_id": "job-toggle-tool"}))
.await
.unwrap();
assert!(result.success);
let got = storage.get_scheduled_job("job-toggle-tool").await.unwrap();
assert!(!got.enabled);
let enable_tool = CronEnableTool::new(storage.clone());
let result = enable_tool
.execute(json!({"job_id": "job-toggle-tool"}))
.await
.unwrap();
assert!(result.success);
let got = storage.get_scheduled_job("job-toggle-tool").await.unwrap();
assert!(got.enabled);
}
#[tokio::test]
async fn test_cron_update_tool() {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-update-tool".into(),
name: "old".into(),
schedule: Schedule::Every {
every_ms: 3600000,
},
prompt: "old prompt".into(),
channel: "feishu".into(),
chat_id: "oc_1".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let tool = CronUpdateTool::new(storage.clone());
let result = tool
.execute(json!({
"job_id": "job-update-tool",
"prompt": "new prompt",
"schedule": {"type": "every", "every_ms": 60000}
}))
.await
.unwrap();
assert!(result.success);
let got = storage.get_scheduled_job("job-update-tool").await.unwrap();
assert_eq!(got.prompt, "new prompt");
}
}

View File

@ -1,7 +1,5 @@
pub mod bash; pub mod bash;
pub mod calculator; pub mod calculator;
pub mod chat_manager;
pub mod cron;
pub mod file_edit; pub mod file_edit;
pub mod file_read; pub mod file_read;
pub mod file_write; pub mod file_write;
@ -15,7 +13,6 @@ pub mod web_fetch;
pub use bash::BashTool; pub use bash::BashTool;
pub use calculator::CalculatorTool; pub use calculator::CalculatorTool;
pub use chat_manager::ChatManagerTool;
pub use file_edit::FileEditTool; pub use file_edit::FileEditTool;
pub use file_read::FileReadTool; pub use file_read::FileReadTool;
pub use file_write::FileWriteTool; pub use file_write::FileWriteTool;

View File

@ -1,61 +0,0 @@
/// Integration tests for the scheduled tasks (cron) system.
/// Run with: cargo test --test test_scheduler
use serde_json::json;
/// Verify that Schedule types (de)serialize correctly.
#[tokio::test]
async fn test_scheduler_types_roundtrip() {
use picobot::scheduler::Schedule;
let s1 = Schedule::Every { every_ms: 3600000 };
let json = serde_json::to_string(&s1).unwrap();
let s2: Schedule = serde_json::from_str(&json).unwrap();
match s2 {
Schedule::Every { every_ms } => assert_eq!(every_ms, 3600000),
_ => panic!("expected Every"),
}
let s1 = Schedule::At { at: 1000000 };
let json = serde_json::to_string(&s1).unwrap();
let s2: Schedule = serde_json::from_str(&json).unwrap();
match s2 {
Schedule::At { at } => assert_eq!(at, 1000000),
_ => panic!("expected At"),
}
let s1 = Schedule::Cron {
expr: "0 0 9 * * *".into(),
tz: None,
};
let json = serde_json::to_string(&s1).unwrap();
let s2: Schedule = serde_json::from_str(&json).unwrap();
match s2 {
Schedule::Cron { expr, tz } => {
assert_eq!(expr, "0 0 9 * * *");
assert!(tz.is_none());
}
_ => panic!("expected Cron"),
}
}
/// Verify that next_run_for_schedule produces valid future timestamps.
#[test]
fn test_next_run_always_future() {
use picobot::scheduler::{next_run_for_schedule, Schedule};
let now = 1700000000000_i64;
let schedules = vec![
Schedule::Every { every_ms: 60000 },
Schedule::Cron {
expr: "0 0 9 * * *".into(),
tz: None,
},
];
for s in &schedules {
let next = next_run_for_schedule(s, now);
assert!(next.is_some(), "expected next run for {:?}", s);
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
}
}