Compare commits
16 Commits
98eb7bea3d
...
2fe953cdad
| Author | SHA1 | Date | |
|---|---|---|---|
| 2fe953cdad | |||
| 61d2fe9ef0 | |||
| db609342f7 | |||
| 62f4326131 | |||
| 0056bfbd23 | |||
| 5746668e36 | |||
| f7b0a33e66 | |||
| 205b814933 | |||
| 3a94b9718f | |||
| 0757638c6f | |||
| 8415e85026 | |||
| eccae20a0a | |||
| 4e5f412c2d | |||
| 75b8f7b8a5 | |||
| 46527edb7b | |||
| 0e146a8f2a |
@ -26,6 +26,8 @@ anyhow = "1.0"
|
|||||||
mime_guess = "2.0"
|
mime_guess = "2.0"
|
||||||
base64 = "0.22"
|
base64 = "0.22"
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
|
cron = "0.15"
|
||||||
|
chrono-tz = "0.10"
|
||||||
meval = "0.2"
|
meval = "0.2"
|
||||||
ratatui = "0.27"
|
ratatui = "0.27"
|
||||||
crossterm = { version = "0.28", features = ["event-stream"] }
|
crossterm = { version = "0.28", features = ["event-stream"] }
|
||||||
|
|||||||
2356
docs/superpowers/plans/2026-05-04-scheduled-tasks.md
Normal file
2356
docs/superpowers/plans/2026-05-04-scheduled-tasks.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -226,6 +226,7 @@ pub struct AgentLoop {
|
|||||||
max_iterations: usize,
|
max_iterations: usize,
|
||||||
workspace_dir: PathBuf,
|
workspace_dir: PathBuf,
|
||||||
model_name: String,
|
model_name: String,
|
||||||
|
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -247,6 +248,7 @@ impl AgentLoop {
|
|||||||
provider: Arc::from(provider),
|
provider: Arc::from(provider),
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
observer: None,
|
observer: None,
|
||||||
|
notify_tx: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -265,6 +267,7 @@ impl AgentLoop {
|
|||||||
provider: Arc::from(provider),
|
provider: Arc::from(provider),
|
||||||
tools,
|
tools,
|
||||||
observer: None,
|
observer: None,
|
||||||
|
notify_tx: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -277,6 +280,7 @@ impl AgentLoop {
|
|||||||
provider,
|
provider,
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
observer: None,
|
observer: None,
|
||||||
|
notify_tx: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -295,6 +299,7 @@ impl AgentLoop {
|
|||||||
provider,
|
provider,
|
||||||
tools,
|
tools,
|
||||||
observer: None,
|
observer: None,
|
||||||
|
notify_tx: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -313,6 +318,11 @@ impl AgentLoop {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_notify(mut self, tx: tokio::sync::mpsc::UnboundedSender<String>) -> Self {
|
||||||
|
self.notify_tx = Some(tx);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||||||
&self.tools
|
&self.tools
|
||||||
}
|
}
|
||||||
@ -331,7 +341,7 @@ impl AgentLoop {
|
|||||||
// Build and inject system prompt if not present
|
// Build and inject system prompt if not present
|
||||||
let has_system = messages.first().map_or(false, |m| m.role == "system");
|
let has_system = messages.first().map_or(false, |m| m.role == "system");
|
||||||
if !has_system {
|
if !has_system {
|
||||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools);
|
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None);
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||||
messages.insert(0, ChatMessage::system(system_prompt));
|
messages.insert(0, ChatMessage::system(system_prompt));
|
||||||
@ -390,12 +400,16 @@ impl AgentLoop {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tool calls — log tool names and args before execution
|
// Execute tool calls — log and notify immediately
|
||||||
{
|
{
|
||||||
let tools_info: Vec<String> = response.tool_calls.iter()
|
let tools_info: Vec<String> = response.tool_calls.iter()
|
||||||
.map(|tc| {
|
.map(|tc| {
|
||||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||||
format!("{}:{}", tc.name, args)
|
let s = format!("{}:{}", tc.name, args);
|
||||||
|
if let Some(ref tx) = self.notify_tx {
|
||||||
|
let _ = tx.send(format!("调用工具 {}", s));
|
||||||
|
}
|
||||||
|
s
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");
|
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");
|
||||||
|
|||||||
@ -18,6 +18,7 @@ pub struct PromptContext<'a> {
|
|||||||
pub workspace_dir: &'a Path,
|
pub workspace_dir: &'a Path,
|
||||||
pub model_name: &'a str,
|
pub model_name: &'a str,
|
||||||
pub tools: &'a ToolRegistry,
|
pub tools: &'a ToolRegistry,
|
||||||
|
pub session_id: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for system prompt sections.
|
/// Trait for system prompt sections.
|
||||||
@ -38,7 +39,6 @@ impl SystemPromptBuilder {
|
|||||||
Self {
|
Self {
|
||||||
sections: vec![
|
sections: vec![
|
||||||
Box::new(ToolHonestySection),
|
Box::new(ToolHonestySection),
|
||||||
Box::new(NoToolNarrationSection),
|
|
||||||
Box::new(YourTaskSection),
|
Box::new(YourTaskSection),
|
||||||
Box::new(SafetySection),
|
Box::new(SafetySection),
|
||||||
Box::new(WorkspaceSection),
|
Box::new(WorkspaceSection),
|
||||||
@ -82,30 +82,11 @@ impl PromptSection for ToolHonestySection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||||
"## CRITICAL: Tool Honesty
|
"## 关键规则:工具诚实性
|
||||||
|
|
||||||
- NEVER fabricate, invent, or guess tool results. If a tool returns empty results, say \"No results found.\"
|
- 绝对不要编造、虚构或猜测工具结果。如果工具返回空结果,说\"没有找到结果\"。
|
||||||
- If a tool call fails, report the error - never make up data to fill the gap.
|
- 如果工具调用失败,报告错误——绝不要编造数据来填补空白。
|
||||||
- When unsure whether a tool call succeeded, ask the user rather than guessing."
|
- 当不确定工具调用是否成功时,询问用户而不是猜测。"
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Critical rule: never narrate tool usage.
|
|
||||||
pub struct NoToolNarrationSection;
|
|
||||||
|
|
||||||
impl PromptSection for NoToolNarrationSection {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"no_narration"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
|
||||||
"## CRITICAL: No Tool Narration
|
|
||||||
|
|
||||||
NEVER narrate, announce, describe, or explain your tool usage to the user.
|
|
||||||
Do NOT say things like \"Let me check...\", \"I will use bash to...\", \"I'll fetch that for you\", \"Searching now...\", or similar.
|
|
||||||
The user must ONLY see the final answer. Tool calls are invisible infrastructure - never reference them.
|
|
||||||
If you catch yourself starting a sentence about what tool you are about to use or just used, DELETE it and give the answer directly."
|
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -123,7 +104,7 @@ impl PromptSection for ToolsSection {
|
|||||||
return String::new();
|
return String::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut output = String::from("## Tools\n\nYou have access to the following tools:\n\n");
|
let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n");
|
||||||
for (name, tool) in ctx.tools.iter() {
|
for (name, tool) in ctx.tools.iter() {
|
||||||
let _ = writeln!(output, "- **{}**: {}", name, tool.description());
|
let _ = writeln!(output, "- **{}**: {}", name, tool.description());
|
||||||
}
|
}
|
||||||
@ -140,11 +121,11 @@ impl PromptSection for YourTaskSection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||||
"## Your Task
|
"## 你的任务
|
||||||
|
|
||||||
When the user sends a message, ACT on it. Use the tools to fulfill their request.
|
当用户发送消息时,立即行动。使用工具来完成他们的请求。
|
||||||
Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions.
|
不要:总结此配置、描述你的能力、用元评论回复、或输出逐步指令。
|
||||||
Instead: use tools directly when needed, and give the final answer when done."
|
而是:在需要时直接使用工具,完成后给出最终答案。"
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,13 +139,13 @@ impl PromptSection for SafetySection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||||
"## Safety
|
"## 安全规则
|
||||||
|
|
||||||
- Do not exfiltrate private data.
|
- 不要泄露隐私数据。
|
||||||
- Do not run destructive commands without asking.
|
- 未经询问不要执行破坏性命令。
|
||||||
- Do not bypass oversight or approval mechanisms.
|
- 不要绕过监督或审批机制。
|
||||||
- Prefer safe operations over risky ones.
|
- 优先选择安全操作而非风险操作。
|
||||||
- When in doubt, ask before acting externally."
|
- 不确定时,在外部操作前先询问。"
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -184,7 +165,7 @@ impl PromptSection for WorkspaceSection {
|
|||||||
.canonicalize()
|
.canonicalize()
|
||||||
.unwrap_or_else(|_| ctx.workspace_dir.to_path_buf());
|
.unwrap_or_else(|_| ctx.workspace_dir.to_path_buf());
|
||||||
format!(
|
format!(
|
||||||
"## Workspace\n\nWorking directory: `{}`\n\n### File Storage Guidelines\n\n- **Generated files**: Store all generated files (code, documents, artifacts) in the workspace directory or its subdirectories.\n- **Downloaded files**: Save downloaded files to the workspace directory, organized by task.\n- **One task, one folder**: Create a dedicated subfolder for each task or project (e.g., `task_2024_01_01/`).\n- **Temporary files**: If files are only needed during processing and won't be kept, use `/tmp/` or create a temp folder (e.g., `/tmp/picobot_task_xxx/`) instead of cluttering the workspace.\n\n### Working Directory Structure\n\nThe workspace is your home base for this session. Keep it organized by creating subdirectories for different tasks.",
|
"## 工作目录\n\n工作目录:`{}`\n\n### 文件存储规范\n\n- **生成的文件**:将所有生成的文件(代码、文档、制品)存放在工作目录或其子目录中。\n- **下载的文件**:将下载的文件保存到工作目录,按任务整理。\n- **一个任务一个文件夹**:为每个任务或项目创建专用的子文件夹(如 `task_2024_01_01/`)。\n- **临时文件**:如果文件仅在处理期间需要且不保留,使用 `/tmp/` 或创建临时文件夹(如 `/tmp/picobot_task_xxx/`),以免弄乱工作目录。\n\n### 目录结构\n\n工作目录是你在本会话中的操作大本营。通过为不同任务创建子目录来保持整洁。",
|
||||||
abs_path.display()
|
abs_path.display()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -199,7 +180,7 @@ impl PromptSection for UserProfileSection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||||
let mut output = String::from("## User Profile\n\n");
|
let mut output = String::from("## 用户配置\n\n");
|
||||||
|
|
||||||
// Load USER.md from ~/.picobot/USER.md
|
// Load USER.md from ~/.picobot/USER.md
|
||||||
if let Some(user_config_dir) = get_user_config_dir() {
|
if let Some(user_config_dir) = get_user_config_dir() {
|
||||||
@ -227,7 +208,7 @@ impl PromptSection for DateTimeSection {
|
|||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||||
let now = chrono::Local::now();
|
let now = chrono::Local::now();
|
||||||
format!(
|
format!(
|
||||||
"## Current Date & Time\n\n{} ({})",
|
"## 当前日期与时间\n\n{} ({})",
|
||||||
now.format("%Y-%m-%d %H:%M:%S"),
|
now.format("%Y-%m-%d %H:%M:%S"),
|
||||||
now.format("%Z")
|
now.format("%Z")
|
||||||
)
|
)
|
||||||
@ -242,37 +223,43 @@ impl PromptSection for CrossChannelSection {
|
|||||||
"cross_channel"
|
"cross_channel"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||||
r#"## 关于跨渠道消息和系统通知
|
let session_line = if let Some(id) = ctx.session_id {
|
||||||
|
format!("当前会话的 ID 是 `{}`。\n", id)
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
当前对话中可能出现带有 `source` 标记的消息,这些消息不是用户直接输入:
|
format!(
|
||||||
|
r#"## 关于会话和跨渠道消息
|
||||||
|
|
||||||
### 系统通知(source.kind = "system_notification")
|
### 会话 ID 格式
|
||||||
来自机器人内部系统(如定时任务、后台任务)的通知。
|
每个会话都有唯一的 session ID,由三部分组成:<channel>:<chat_id>:<dialog_id>
|
||||||
- `system_name`: 发出通知的系统名称
|
- channel: 消息渠道(如 "cli_chat"、"feishu")
|
||||||
- `task_id`: 关联的任务 ID
|
- chat_id: 聊天/群组标识
|
||||||
|
- dialog_id: 对话标识,同一 chat 下可以有多个 dialog
|
||||||
|
|
||||||
### 跨渠道消息(source.kind = "cross_channel")
|
{}### 跨会话消息
|
||||||
来自其他渠道的消息被写入当前对话。
|
对话历史中可能出现带有 `[message from X to Y]` 前缀的 assistant 消息,
|
||||||
- `from_channel`: 来源渠道(如 "feishu")
|
表示此消息由 send_message 工具从别处发送过来。
|
||||||
- `from_user_id`: 来源用户 ID
|
- X: 来源标识,可能是会话 ID、工具名或其他标识字符串;未指定时为 "unknown"
|
||||||
|
- Y: 目标会话的完整 session ID (<channel>:<chat_id>:<dialog_id>)
|
||||||
|
|
||||||
|
收到此类消息时一般不需要主动处理,只需知晓。如果用户问及相关信息,
|
||||||
|
可以尝试从来源处获取更多详情。
|
||||||
|
|
||||||
### send_message 工具
|
### send_message 工具
|
||||||
|
向指定会话发送消息。参数:
|
||||||
|
- target_chat_id: 格式 <channel>:<chat_id> 或 <channel>:<chat_id>:<dialog_id>
|
||||||
|
- content: 消息内容
|
||||||
|
|
||||||
使用 `send_message` 向其他渠道发送消息。参数:
|
### chat_manager 工具
|
||||||
- `target_chat_id`: 目标会话ID,支持两种格式:
|
管理会话和查看消息。参数:
|
||||||
1. `<channel>:<chat_id>` — 发送到该聊天下最新活跃的会话,若没有活跃会话则自动创建
|
- action = "list_sessions" — 列出最近活跃的会话
|
||||||
2. `<channel>:<chat_id>:<dialog_id>` — 发送到指定会话,若会话已过期则自动激活
|
- action = "list_channels" — 列出所有可用渠道
|
||||||
- `content`: 要发送的消息内容
|
- action = "list_messages" — 查看指定 session 的最近消息,需提供 session_id 和 count"#,
|
||||||
- `origin`(可选): 消息来源标识,不填则自动使用当前会话的完整 session_id
|
session_line
|
||||||
|
)
|
||||||
跨渠道消息到达目标会话时,内容前会带有 `[message from X to Y]` 标记,
|
|
||||||
表示该消息的来源和目标。目标会话的 LLM 应将此理解为来自其他渠道/会话的消息。
|
|
||||||
|
|
||||||
### 处理建议
|
|
||||||
- 系统通知:可以提及但不建议以此为由改变对话主题
|
|
||||||
- 跨渠道消息:当用户提及相关事务时可关联这些消息"#
|
|
||||||
.to_string()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,7 +276,7 @@ impl PromptSection for RuntimeSection {
|
|||||||
.map(|h| h.to_string_lossy().to_string())
|
.map(|h| h.to_string_lossy().to_string())
|
||||||
.unwrap_or_else(|_| "unknown".to_string());
|
.unwrap_or_else(|_| "unknown".to_string());
|
||||||
format!(
|
format!(
|
||||||
"## Runtime\n\nHost: {} | OS: {} | Model: {}",
|
"## 运行环境\n\n主机: {} | 操作系统: {} | 模型: {}",
|
||||||
host,
|
host,
|
||||||
std::env::consts::OS,
|
std::env::consts::OS,
|
||||||
ctx.model_name
|
ctx.model_name
|
||||||
@ -321,7 +308,7 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
|
|||||||
.unwrap_or(trimmed)
|
.unwrap_or(trimmed)
|
||||||
.to_string()
|
.to_string()
|
||||||
+ &format!(
|
+ &format!(
|
||||||
"\n\n[... truncated at {} characters - use file_read for full file]",
|
"\n\n[... 已截断至 {} 字符 - 使用 file_read 获取完整文件]",
|
||||||
max_chars
|
max_chars
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@ -334,11 +321,12 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build a complete system prompt with default configuration.
|
/// Build a complete system prompt with default configuration.
|
||||||
pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry) -> String {
|
pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry, session_id: Option<&str>) -> String {
|
||||||
let ctx = PromptContext {
|
let ctx = PromptContext {
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
tools,
|
tools,
|
||||||
|
session_id,
|
||||||
};
|
};
|
||||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||||
}
|
}
|
||||||
@ -357,16 +345,16 @@ mod tests {
|
|||||||
workspace_dir: &temp_dir,
|
workspace_dir: &temp_dir,
|
||||||
model_name: "test-model",
|
model_name: "test-model",
|
||||||
tools: &tools,
|
tools: &tools,
|
||||||
|
session_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||||
|
|
||||||
assert!(prompt.contains("## CRITICAL: Tool Honesty"));
|
assert!(prompt.contains("## 关键规则:工具诚实性"));
|
||||||
assert!(prompt.contains("## CRITICAL: No Tool Narration"));
|
assert!(prompt.contains("## 安全规则"));
|
||||||
assert!(prompt.contains("## Safety"));
|
assert!(prompt.contains("## 工作目录"));
|
||||||
assert!(prompt.contains("## Workspace"));
|
assert!(prompt.contains("## 当前日期与时间"));
|
||||||
assert!(prompt.contains("## Current Date & Time"));
|
assert!(prompt.contains("## 运行环境"));
|
||||||
assert!(prompt.contains("## Runtime"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -387,7 +375,7 @@ mod tests {
|
|||||||
let temp_dir = std::env::temp_dir();
|
let temp_dir = std::env::temp_dir();
|
||||||
let tools = ToolRegistry::new();
|
let tools = ToolRegistry::new();
|
||||||
|
|
||||||
let prompt = build_system_prompt(&temp_dir, "test-model", &tools);
|
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None);
|
||||||
|
|
||||||
assert!(!prompt.is_empty());
|
assert!(!prompt.is_empty());
|
||||||
assert!(prompt.contains("test-model"));
|
assert!(prompt.contains("test-model"));
|
||||||
|
|||||||
@ -195,6 +195,20 @@ impl ChatMessage {
|
|||||||
source: None,
|
source: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn user_with_source(content: impl Into<String>, source: MessageSource) -> Self {
|
||||||
|
Self {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
|
timestamp: current_timestamp(),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
source: Some(source),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|||||||
@ -479,10 +479,16 @@ impl Channel for CliChatChannel {
|
|||||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
let clients = self.clients.lock().await.clone();
|
let clients = self.clients.lock().await.clone();
|
||||||
for client in clients {
|
for client in clients {
|
||||||
let outbound = WsOutbound::AssistantResponse {
|
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
|
||||||
id: short_id(),
|
WsOutbound::SystemNotification {
|
||||||
content: msg.content.clone(),
|
content: msg.content.clone(),
|
||||||
role: "assistant".to_string(),
|
}
|
||||||
|
} else {
|
||||||
|
WsOutbound::AssistantResponse {
|
||||||
|
id: short_id(),
|
||||||
|
content: msg.content.clone(),
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let _ = client.sender.send(outbound).await;
|
let _ = client.sender.send(outbound).await;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -140,6 +140,43 @@ pub struct GatewayConfig {
|
|||||||
pub cleanup_interval_minutes: Option<u64>,
|
pub cleanup_interval_minutes: Option<u64>,
|
||||||
#[serde(default, rename = "session_db_path")]
|
#[serde(default, rename = "session_db_path")]
|
||||||
pub session_db_path: Option<String>,
|
pub session_db_path: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub scheduler: Option<SchedulerConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SchedulerConfig {
|
||||||
|
/// Whether the scheduler is enabled
|
||||||
|
#[serde(default = "default_scheduler_enabled")]
|
||||||
|
pub enabled: bool,
|
||||||
|
/// Poll interval in seconds (how often to check for due jobs)
|
||||||
|
#[serde(default = "default_poll_interval_secs")]
|
||||||
|
pub poll_interval_secs: u64,
|
||||||
|
/// Maximum concurrent job executions (currently sequential, reserved for future)
|
||||||
|
#[serde(default = "default_max_concurrent")]
|
||||||
|
pub max_concurrent: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_scheduler_enabled() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_poll_interval_secs() -> u64 {
|
||||||
|
60
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_concurrent() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SchedulerConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: true,
|
||||||
|
poll_interval_secs: 60,
|
||||||
|
max_concurrent: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -168,6 +205,7 @@ impl Default for GatewayConfig {
|
|||||||
session_ttl_hours: None,
|
session_ttl_hours: None,
|
||||||
cleanup_interval_minutes: None,
|
cleanup_interval_minutes: None,
|
||||||
session_db_path: None,
|
session_db_path: None,
|
||||||
|
scheduler: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,12 +11,14 @@ use crate::channels::base::{Channel, ChannelError};
|
|||||||
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::session::SessionManager;
|
use crate::session::SessionManager;
|
||||||
|
use crate::scheduler::Scheduler;
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
pub workspace_dir: std::path::PathBuf,
|
pub workspace_dir: std::path::PathBuf,
|
||||||
pub session_manager: Arc<SessionManager>,
|
pub session_manager: Arc<SessionManager>,
|
||||||
pub channel_manager: ChannelManager,
|
pub channel_manager: ChannelManager,
|
||||||
|
pub storage: Arc<crate::storage::Storage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GatewayState {
|
impl GatewayState {
|
||||||
@ -73,13 +75,45 @@ impl GatewayState {
|
|||||||
|
|
||||||
// Register send_message tool with available channel names
|
// Register send_message tool with available channel names
|
||||||
let available_channels = channel_manager.list_channel_names().await;
|
let available_channels = channel_manager.list_channel_names().await;
|
||||||
|
let valid_channels = available_channels.clone();
|
||||||
session_manager.register_outbound_tool(available_channels);
|
session_manager.register_outbound_tool(available_channels);
|
||||||
|
|
||||||
|
// Register chat_manager tool
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Initialize scheduler if enabled in config
|
||||||
|
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
||||||
|
if scheduler_config.enabled {
|
||||||
|
// Register cron tools
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
|
||||||
|
);
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::cron::CronListTool::new(storage.clone()),
|
||||||
|
);
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::cron::CronRemoveTool::new(storage.clone()),
|
||||||
|
);
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::cron::CronEnableTool::new(storage.clone()),
|
||||||
|
);
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::cron::CronDisableTool::new(storage.clone()),
|
||||||
|
);
|
||||||
|
session_manager.tools().register(
|
||||||
|
crate::tools::cron::CronUpdateTool::new(storage.clone()),
|
||||||
|
);
|
||||||
|
tracing::info!("Cron tools registered");
|
||||||
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config,
|
config,
|
||||||
workspace_dir: workspace_path,
|
workspace_dir: workspace_path,
|
||||||
session_manager: session_manager.clone(),
|
session_manager: session_manager.clone(),
|
||||||
channel_manager,
|
channel_manager,
|
||||||
|
storage,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,6 +202,21 @@ impl GatewayState {
|
|||||||
tracing::info!("Outbound dispatcher started");
|
tracing::info!("Outbound dispatcher started");
|
||||||
dispatcher.run().await;
|
dispatcher.run().await;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Spawn scheduler background task if enabled
|
||||||
|
let scheduler_config = self.config.gateway.scheduler.clone().unwrap_or_default();
|
||||||
|
if scheduler_config.enabled {
|
||||||
|
let sched = Arc::new(Scheduler::new(
|
||||||
|
self.storage.clone(),
|
||||||
|
self.session_manager.clone(),
|
||||||
|
self.channel_manager.bus(),
|
||||||
|
scheduler_config,
|
||||||
|
));
|
||||||
|
tokio::spawn(async move {
|
||||||
|
sched.run().await;
|
||||||
|
});
|
||||||
|
tracing::info!("Scheduler background task spawned");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle control messages (session management operations)
|
/// Handle control messages (session management operations)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ pub mod protocol;
|
|||||||
pub mod channels;
|
pub mod channels;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
|
pub mod scheduler;
|
||||||
pub mod skills;
|
pub mod skills;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|||||||
@ -6,13 +6,8 @@ use std::collections::HashMap;
|
|||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
use std::sync::Arc;
|
||||||
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
|
use crate::storage::Storage;
|
||||||
where
|
|
||||||
S: serde::Serializer,
|
|
||||||
{
|
|
||||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||||
blocks.iter().map(|b| match b {
|
blocks.iter().map(|b| match b {
|
||||||
@ -62,6 +57,7 @@ pub struct AnthropicProvider {
|
|||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
storage: Option<Arc<Storage>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
@ -85,8 +81,13 @@ impl AnthropicProvider {
|
|||||||
temperature,
|
temperature,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
model_extra,
|
model_extra,
|
||||||
|
storage: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_storage(&mut self, storage: Arc<Storage>) {
|
||||||
|
self.storage = Some(storage);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@ -104,7 +105,6 @@ struct AnthropicRequest {
|
|||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct AnthropicMessage {
|
struct AnthropicMessage {
|
||||||
role: String,
|
role: String,
|
||||||
#[serde(serialize_with = "serialize_content_blocks")]
|
|
||||||
content: Vec<serde_json::Value>,
|
content: Vec<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,14 +128,23 @@ struct AnthropicResponse {
|
|||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
enum AnthropicContent {
|
enum AnthropicContent {
|
||||||
Text { text: String },
|
Text {
|
||||||
Thinking { thinking: String },
|
#[serde(alias = "content")]
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
Thinking {
|
||||||
|
#[serde(alias = "content")]
|
||||||
|
thinking: String,
|
||||||
|
},
|
||||||
#[serde(rename = "tool_use")]
|
#[serde(rename = "tool_use")]
|
||||||
ToolUse {
|
ToolUse {
|
||||||
id: String,
|
id: String,
|
||||||
name: String,
|
name: String,
|
||||||
|
#[serde(alias = "arguments")]
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
},
|
},
|
||||||
|
#[serde(other)]
|
||||||
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -152,6 +161,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
let url = format!("{}/v1/messages", self.base_url);
|
let url = format!("{}/v1/messages", self.base_url);
|
||||||
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
||||||
|
|
||||||
@ -190,7 +200,19 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
"content": output,
|
"content": output,
|
||||||
})]
|
})]
|
||||||
} else {
|
} else {
|
||||||
convert_content_blocks(&m.content)
|
let mut blocks = convert_content_blocks(&m.content);
|
||||||
|
// Append tool_use blocks from assistant messages with tool calls
|
||||||
|
if let Some(ref tool_calls) = m.tool_calls {
|
||||||
|
for tc in tool_calls {
|
||||||
|
blocks.push(serde_json::json!({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.id,
|
||||||
|
"name": tc.name,
|
||||||
|
"input": tc.arguments,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
blocks
|
||||||
};
|
};
|
||||||
AnthropicMessage { role, content }
|
AnthropicMessage { role, content }
|
||||||
})
|
})
|
||||||
@ -212,10 +234,14 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||||
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await?;
|
let resp = req_builder.json(&body).send().await?;
|
||||||
|
|
||||||
let status = resp.status();
|
let status = resp.status();
|
||||||
let body_text = resp.text().await?;
|
let body_text = resp.text().await?;
|
||||||
|
tracing::debug!(status = %status, resp_body = %body_text, "LLM response");
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
|
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
|
||||||
@ -227,11 +253,33 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|| body_text.clone());
|
.unwrap_or_else(|| body_text.clone());
|
||||||
|
if let Some(ref storage) = self.storage {
|
||||||
|
let _ = storage.append_llm_call(
|
||||||
|
&self.name, &self.model_id, &req_body_str,
|
||||||
|
Some(&body_text), Some(&error_msg),
|
||||||
|
start.elapsed().as_millis() as u64,
|
||||||
|
).await;
|
||||||
|
}
|
||||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
||||||
.map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?;
|
.map_err(|e| {
|
||||||
|
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||||
|
if let Some(ref storage) = self.storage {
|
||||||
|
let name = self.name.clone();
|
||||||
|
let model = self.model_id.clone();
|
||||||
|
let req = req_body_str.clone();
|
||||||
|
let resp_body = body_text.clone();
|
||||||
|
let dur = start.elapsed().as_millis() as u64;
|
||||||
|
let err = err_msg.clone();
|
||||||
|
let s = storage.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
err_msg
|
||||||
|
})?;
|
||||||
|
|
||||||
let mut content = String::new();
|
let mut content = String::new();
|
||||||
let mut tool_calls = Vec::new();
|
let mut tool_calls = Vec::new();
|
||||||
@ -247,6 +295,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnthropicContent::Thinking { .. } => {}
|
AnthropicContent::Thinking { .. } => {}
|
||||||
|
AnthropicContent::Unknown => {}
|
||||||
AnthropicContent::ToolUse { id, name, input } => {
|
AnthropicContent::ToolUse { id, name, input } => {
|
||||||
tool_calls.push(ToolCall {
|
tool_calls.push(ToolCall {
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
@ -257,7 +306,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ChatCompletionResponse {
|
let response = ChatCompletionResponse {
|
||||||
id: anthropic_resp.id.unwrap_or_default(),
|
id: anthropic_resp.id.unwrap_or_default(),
|
||||||
model: anthropic_resp.model.unwrap_or_default(),
|
model: anthropic_resp.model.unwrap_or_default(),
|
||||||
content,
|
content,
|
||||||
@ -267,7 +316,20 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||||
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
||||||
},
|
},
|
||||||
})
|
};
|
||||||
|
|
||||||
|
if let Some(ref storage) = self.storage {
|
||||||
|
let _ = storage.append_llm_call(
|
||||||
|
&self.name,
|
||||||
|
&self.model_id,
|
||||||
|
&req_body_str,
|
||||||
|
Some(&body_text),
|
||||||
|
None,
|
||||||
|
start.elapsed().as_millis() as u64,
|
||||||
|
).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptype(&self) -> &str {
|
fn ptype(&self) -> &str {
|
||||||
|
|||||||
@ -7,6 +7,8 @@ use std::collections::HashMap;
|
|||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::storage::Storage;
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||||
if blocks.len() == 1 {
|
if blocks.len() == 1 {
|
||||||
@ -32,6 +34,7 @@ pub struct OpenAIProvider {
|
|||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
storage: Option<Arc<Storage>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIProvider {
|
impl OpenAIProvider {
|
||||||
@ -55,9 +58,14 @@ impl OpenAIProvider {
|
|||||||
temperature,
|
temperature,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
model_extra,
|
model_extra,
|
||||||
|
storage: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_storage(&mut self, storage: Arc<Storage>) {
|
||||||
|
self.storage = Some(storage);
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
@ -162,6 +170,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
let body = self.build_request_body(&request);
|
let body = self.build_request_body(&request);
|
||||||
@ -200,24 +209,44 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||||
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await?;
|
let resp = req_builder.json(&body).send().await?;
|
||||||
|
|
||||||
let status = resp.status();
|
let status = resp.status();
|
||||||
let text = resp.text().await?;
|
let text = resp.text().await?;
|
||||||
|
tracing::debug!(status = %status, resp_body = %text, "LLM response");
|
||||||
// Debug: Log LLM response (only in debug builds)
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
let resp_preview: String = text.chars().take(100).collect();
|
|
||||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
|
|
||||||
}
|
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
return Err(format!("API error {}: {}", status, text).into());
|
let error = format!("API error {}: {}", status, text);
|
||||||
|
if let Some(ref storage) = self.storage {
|
||||||
|
let _ = storage.append_llm_call(
|
||||||
|
&self.name, &self.model_id, &req_body_str,
|
||||||
|
Some(&text), Some(&error),
|
||||||
|
start.elapsed().as_millis() as u64,
|
||||||
|
).await;
|
||||||
|
}
|
||||||
|
return Err(error.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
||||||
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
.map_err(|e| {
|
||||||
|
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||||
|
if let Some(ref storage) = self.storage {
|
||||||
|
let name = self.name.clone();
|
||||||
|
let model = self.model_id.clone();
|
||||||
|
let req = req_body_str.clone();
|
||||||
|
let resp = text.clone();
|
||||||
|
let dur = start.elapsed().as_millis() as u64;
|
||||||
|
let err = err_msg.clone();
|
||||||
|
let s = storage.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
err_msg
|
||||||
|
})?;
|
||||||
|
|
||||||
let content = openai_resp.choices[0]
|
let content = openai_resp.choices[0]
|
||||||
.message
|
.message
|
||||||
@ -237,7 +266,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(ChatCompletionResponse {
|
let response = ChatCompletionResponse {
|
||||||
id: openai_resp.id,
|
id: openai_resp.id,
|
||||||
model: openai_resp.model,
|
model: openai_resp.model,
|
||||||
content,
|
content,
|
||||||
@ -247,7 +276,17 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
completion_tokens: openai_resp.usage.completion_tokens,
|
completion_tokens: openai_resp.usage.completion_tokens,
|
||||||
total_tokens: openai_resp.usage.total_tokens,
|
total_tokens: openai_resp.usage.total_tokens,
|
||||||
},
|
},
|
||||||
})
|
};
|
||||||
|
|
||||||
|
if let Some(ref storage) = self.storage {
|
||||||
|
let _ = storage.append_llm_call(
|
||||||
|
&self.name, &self.model_id, &req_body_str,
|
||||||
|
Some(&text), None,
|
||||||
|
start.elapsed().as_millis() as u64,
|
||||||
|
).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptype(&self) -> &str {
|
fn ptype(&self) -> &str {
|
||||||
|
|||||||
@ -123,4 +123,6 @@ pub trait LLMProvider: Send + Sync {
|
|||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
fn model_id(&self) -> &str;
|
fn model_id(&self) -> &str;
|
||||||
|
|
||||||
|
fn set_storage(&mut self, _storage: std::sync::Arc<crate::storage::Storage>) {}
|
||||||
}
|
}
|
||||||
|
|||||||
310
src/scheduler/mod.rs
Normal file
310
src/scheduler/mod.rs
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
pub mod types;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::time;
|
||||||
|
|
||||||
|
use crate::bus::MessageBus;
|
||||||
|
use crate::config::SchedulerConfig;
|
||||||
|
use crate::session::session::HandleResult;
|
||||||
|
use crate::session::SessionManager;
|
||||||
|
use crate::storage::{JobRun, ScheduledJob, Storage};
|
||||||
|
|
||||||
|
pub use types::Schedule;
|
||||||
|
|
||||||
|
/// Compute the next execution time (Unix ms) for a schedule, given `from` (Unix ms).
|
||||||
|
/// Returns `None` if no next time can be determined (e.g., invalid cron expression).
|
||||||
|
pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
|
||||||
|
use chrono::{TimeZone, Utc};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
match schedule {
|
||||||
|
Schedule::At { at } => Some(*at),
|
||||||
|
Schedule::Every { every_ms } => Some(from + *every_ms as i64),
|
||||||
|
Schedule::Cron { expr, tz } => {
|
||||||
|
let cron_schedule = cron::Schedule::from_str(expr.as_str()).ok()?;
|
||||||
|
let from_secs = from / 1000;
|
||||||
|
let from_nanos = ((from % 1000) * 1_000_000) as u32;
|
||||||
|
let from_dt = Utc.timestamp_opt(from_secs, from_nanos).single()?;
|
||||||
|
|
||||||
|
let next_utc = if let Some(tz_str) = tz {
|
||||||
|
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
|
||||||
|
let _from_local = from_dt.with_timezone(&tz);
|
||||||
|
let next_local = cron_schedule.upcoming(tz).next()?;
|
||||||
|
next_local.with_timezone(&Utc)
|
||||||
|
} else {
|
||||||
|
cron_schedule.upcoming(Utc).next()?
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(next_utc.timestamp_millis())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now_ms() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The scheduler runs as a background tokio task, periodically checking for due jobs
|
||||||
|
/// and executing them via `SessionManager::handle_cron_message`.
|
||||||
|
pub struct Scheduler {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
session_manager: Arc<SessionManager>,
|
||||||
|
bus: Arc<MessageBus>,
|
||||||
|
config: SchedulerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scheduler {
|
||||||
|
pub fn new(
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
session_manager: Arc<SessionManager>,
|
||||||
|
bus: Arc<MessageBus>,
|
||||||
|
config: SchedulerConfig,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
storage,
|
||||||
|
session_manager,
|
||||||
|
bus,
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the scheduler loop. This is a long-running async function meant to be
|
||||||
|
/// spawned as a tokio background task.
|
||||||
|
pub async fn run(self: Arc<Self>) {
|
||||||
|
let poll_duration = time::Duration::from_secs(self.config.poll_interval_secs);
|
||||||
|
let mut interval = time::interval(poll_duration);
|
||||||
|
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Scheduler started (poll interval: {}s, max concurrent: {})",
|
||||||
|
self.config.poll_interval_secs,
|
||||||
|
self.config.max_concurrent,
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
let now = now_ms();
|
||||||
|
|
||||||
|
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
|
||||||
|
Ok(jobs) => jobs,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if due.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("scheduler: found {} due job(s)", due.len());
|
||||||
|
|
||||||
|
for job in &due {
|
||||||
|
let start = Instant::now();
|
||||||
|
let started_at = now_ms();
|
||||||
|
|
||||||
|
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
job_id = %job.id,
|
||||||
|
job_name = %job.name,
|
||||||
|
"scheduler: executing cron job"
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.session_manager
|
||||||
|
.handle_cron_message(
|
||||||
|
&job.channel,
|
||||||
|
&job.chat_id,
|
||||||
|
&job.prompt,
|
||||||
|
&job.id,
|
||||||
|
&job.name,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let finished_at = now_ms();
|
||||||
|
let duration_ms = start.elapsed().as_millis() as i64;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(HandleResult::AgentResponse(output)) => {
|
||||||
|
let outbound = crate::bus::OutboundMessage {
|
||||||
|
channel: job.channel.clone(),
|
||||||
|
chat_id: job.chat_id.clone(),
|
||||||
|
content: output.clone(),
|
||||||
|
reply_to: None,
|
||||||
|
media: vec![],
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
};
|
||||||
|
let _ = self.bus.publish_outbound(outbound).await;
|
||||||
|
|
||||||
|
let output_truncated = if output.len() > 8000 {
|
||||||
|
format!("{}...[truncated]", &output[..8000])
|
||||||
|
} else {
|
||||||
|
output.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let run = JobRun {
|
||||||
|
id: 0,
|
||||||
|
job_id: job.id.clone(),
|
||||||
|
started_at,
|
||||||
|
finished_at,
|
||||||
|
status: "ok".to_string(),
|
||||||
|
output: Some(output_truncated),
|
||||||
|
error: None,
|
||||||
|
duration_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = self.storage.record_scheduled_job_run(&run).await {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
job_id = %job.id,
|
||||||
|
duration_ms = %duration_ms,
|
||||||
|
"scheduler: job completed successfully"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(HandleResult::CommandOutput(output)) => {
|
||||||
|
let outbound = crate::bus::OutboundMessage {
|
||||||
|
channel: job.channel.clone(),
|
||||||
|
chat_id: job.chat_id.clone(),
|
||||||
|
content: output.clone(),
|
||||||
|
reply_to: None,
|
||||||
|
media: vec![],
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
};
|
||||||
|
let _ = self.bus.publish_outbound(outbound).await;
|
||||||
|
|
||||||
|
let run = JobRun {
|
||||||
|
id: 0,
|
||||||
|
job_id: job.id.clone(),
|
||||||
|
started_at,
|
||||||
|
finished_at,
|
||||||
|
status: "ok".to_string(),
|
||||||
|
output: Some(output),
|
||||||
|
error: None,
|
||||||
|
duration_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = self.storage.record_scheduled_job_run(&run).await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let error_str = e.to_string();
|
||||||
|
let run = JobRun {
|
||||||
|
id: 0,
|
||||||
|
job_id: job.id.clone(),
|
||||||
|
started_at,
|
||||||
|
finished_at,
|
||||||
|
status: "error".to_string(),
|
||||||
|
output: None,
|
||||||
|
error: Some(error_str.clone()),
|
||||||
|
duration_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e2) = self.storage.record_scheduled_job_run(&run).await {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e2) = self.storage.set_scheduled_job_last_status(
|
||||||
|
&job.id, "error", Some(&error_str),
|
||||||
|
).await {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::error!(
|
||||||
|
job_id = %job.id,
|
||||||
|
duration_ms = %duration_ms,
|
||||||
|
error = %error_str,
|
||||||
|
"scheduler: job failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = self.reschedule_after_run(job).await {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: failed to reschedule: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// After a job runs, compute its next execution time or disable/delete it.
|
||||||
|
async fn reschedule_after_run(&self, job: &ScheduledJob) -> anyhow::Result<()> {
|
||||||
|
let now = now_ms();
|
||||||
|
|
||||||
|
match &job.schedule {
|
||||||
|
Schedule::At { .. } => {
|
||||||
|
if job.delete_after_run {
|
||||||
|
self.storage.remove_scheduled_job(&job.id).await?;
|
||||||
|
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
||||||
|
} else {
|
||||||
|
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||||
|
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
||||||
|
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
||||||
|
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
|
||||||
|
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
||||||
|
} else {
|
||||||
|
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
||||||
|
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_at_schedule() {
|
||||||
|
let now = 1000000;
|
||||||
|
let next = next_run_for_schedule(&Schedule::At { at: 2000000 }, now);
|
||||||
|
assert_eq!(next, Some(2000000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_every_schedule() {
|
||||||
|
let now = 1000000;
|
||||||
|
let next = next_run_for_schedule(&Schedule::Every { every_ms: 5000 }, now);
|
||||||
|
assert_eq!(next, Some(1005000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_cron_every_minute() {
|
||||||
|
let expr = "0 * * * * *".to_string();
|
||||||
|
let schedule = Schedule::Cron { expr, tz: None };
|
||||||
|
let now = 1000000;
|
||||||
|
let next = next_run_for_schedule(&schedule, now);
|
||||||
|
assert!(next.is_some());
|
||||||
|
assert!(next.unwrap() > now);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_cron_every_day_at_9am() {
|
||||||
|
let expr = "0 0 9 * * *".to_string();
|
||||||
|
let schedule = Schedule::Cron { expr, tz: None };
|
||||||
|
let now = 1000000;
|
||||||
|
let next = next_run_for_schedule(&schedule, now);
|
||||||
|
assert!(next.is_some());
|
||||||
|
let next_ms = next.unwrap();
|
||||||
|
assert!(next_ms > now);
|
||||||
|
}
|
||||||
|
}
|
||||||
16
src/scheduler/types.rs
Normal file
16
src/scheduler/types.rs
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// How a job is scheduled. Serialized as JSON in the database `schedule` column.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum Schedule {
|
||||||
|
/// One-shot: fires once at a specific Unix millisecond timestamp, then disables.
|
||||||
|
#[serde(rename = "at")]
|
||||||
|
At { at: i64 },
|
||||||
|
/// Recurring: fires every `every_ms` milliseconds.
|
||||||
|
#[serde(rename = "every")]
|
||||||
|
Every { every_ms: u64 },
|
||||||
|
/// Recurring: fires on a cron schedule with optional timezone.
|
||||||
|
#[serde(rename = "cron")]
|
||||||
|
Cron { expr: String, tz: Option<String> },
|
||||||
|
}
|
||||||
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::Mutex;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
||||||
@ -21,7 +21,6 @@ use crate::config::LLMProviderConfig;
|
|||||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||||
use crate::agent::system_prompt::build_system_prompt;
|
use crate::agent::system_prompt::build_system_prompt;
|
||||||
use crate::agent::context_compressor::ContextCompressionConfig;
|
use crate::agent::context_compressor::ContextCompressionConfig;
|
||||||
use crate::protocol::WsOutbound;
|
|
||||||
use crate::providers::{create_provider, LLMProvider};
|
use crate::providers::{create_provider, LLMProvider};
|
||||||
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
|
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
|
||||||
use crate::session::events::DialogInfo;
|
use crate::session::events::DialogInfo;
|
||||||
@ -49,7 +48,6 @@ pub struct Session {
|
|||||||
messages: Vec<ChatMessage>,
|
messages: Vec<ChatMessage>,
|
||||||
seq_counter: i64,
|
seq_counter: i64,
|
||||||
|
|
||||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
provider: Arc<dyn LLMProvider>,
|
provider: Arc<dyn LLMProvider>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
@ -63,14 +61,16 @@ impl Session {
|
|||||||
pub async fn new(
|
pub async fn new(
|
||||||
id: UnifiedSessionId,
|
id: UnifiedSessionId,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
user_tx: mpsc::Sender<WsOutbound>,
|
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
storage: Option<StdArc<Storage>>,
|
storage: Option<StdArc<Storage>>,
|
||||||
routing_info: String,
|
routing_info: String,
|
||||||
title: String,
|
title: String,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let provider_box = create_provider(provider_config.clone())
|
let mut provider_box = create_provider(provider_config.clone())
|
||||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||||
|
if let Some(ref s) = storage {
|
||||||
|
provider_box.set_storage(s.clone());
|
||||||
|
}
|
||||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||||
|
|
||||||
let compressor_config = ContextCompressionConfig {
|
let compressor_config = ContextCompressionConfig {
|
||||||
@ -89,7 +89,6 @@ impl Session {
|
|||||||
total_message_count: 0,
|
total_message_count: 0,
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
seq_counter: 1,
|
seq_counter: 1,
|
||||||
user_tx,
|
|
||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
provider: provider.clone(),
|
provider: provider.clone(),
|
||||||
tools,
|
tools,
|
||||||
@ -103,7 +102,6 @@ impl Session {
|
|||||||
pub async fn from_storage(
|
pub async fn from_storage(
|
||||||
id: UnifiedSessionId,
|
id: UnifiedSessionId,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
user_tx: mpsc::Sender<WsOutbound>,
|
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
storage: StdArc<Storage>,
|
storage: StdArc<Storage>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
@ -113,8 +111,9 @@ impl Session {
|
|||||||
let messages = storage.load_messages(&id.to_string(), 0).await
|
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||||
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||||
|
|
||||||
let provider_box = create_provider(provider_config.clone())
|
let mut provider_box = create_provider(provider_config.clone())
|
||||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||||
|
provider_box.set_storage(storage.clone());
|
||||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||||
|
|
||||||
let compressor_config = ContextCompressionConfig {
|
let compressor_config = ContextCompressionConfig {
|
||||||
@ -123,6 +122,7 @@ impl Session {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Convert MessageMeta to ChatMessage
|
// Convert MessageMeta to ChatMessage
|
||||||
|
// Clear tool_call_id/tool_name — they're not valid across API sessions
|
||||||
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
||||||
ChatMessage {
|
ChatMessage {
|
||||||
id: m.id,
|
id: m.id,
|
||||||
@ -130,8 +130,8 @@ impl Session {
|
|||||||
content: m.content,
|
content: m.content,
|
||||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||||
timestamp: m.created_at,
|
timestamp: m.created_at,
|
||||||
tool_call_id: m.tool_call_id,
|
tool_call_id: None,
|
||||||
tool_name: m.tool_name,
|
tool_name: None,
|
||||||
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
|
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
|
||||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||||
}
|
}
|
||||||
@ -149,7 +149,6 @@ impl Session {
|
|||||||
total_message_count,
|
total_message_count,
|
||||||
messages: chat_messages,
|
messages: chat_messages,
|
||||||
seq_counter,
|
seq_counter,
|
||||||
user_tx,
|
|
||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
provider: provider.clone(),
|
provider: provider.clone(),
|
||||||
tools,
|
tools,
|
||||||
@ -252,16 +251,17 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send(&self, msg: WsOutbound) {
|
pub fn create_user_message_with_source(
|
||||||
let _ = self.user_tx.send(msg).await;
|
&self,
|
||||||
}
|
content: &str,
|
||||||
|
media_refs: Vec<String>,
|
||||||
/// 发送系统通知(不记录进 session 历史)
|
source: MessageSource,
|
||||||
pub async fn send_system_notification(&self, content: &str) {
|
) -> ChatMessage {
|
||||||
let msg = WsOutbound::SystemNotification {
|
if media_refs.is_empty() {
|
||||||
content: content.to_string(),
|
ChatMessage::user_with_source(content, source)
|
||||||
};
|
} else {
|
||||||
let _ = self.user_tx.send(msg).await;
|
ChatMessage::user_with_source(content, source)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将 session 元数据写回 Storage
|
/// 将 session 元数据写回 Storage
|
||||||
@ -364,12 +364,21 @@ impl Session {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 创建一个附通知通道的 AgentLoop 实例
|
||||||
|
pub fn create_agent_with_notify(
|
||||||
|
&self,
|
||||||
|
notify_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||||
|
) -> Result<AgentLoop, AgentError> {
|
||||||
|
Ok(self.create_agent()?.with_notify(notify_tx))
|
||||||
|
}
|
||||||
|
|
||||||
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills)
|
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills)
|
||||||
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
|
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
|
||||||
let base_prompt = build_system_prompt(
|
let base_prompt = build_system_prompt(
|
||||||
&self.provider_config.workspace_dir,
|
&self.provider_config.workspace_dir,
|
||||||
&self.provider_config.model_id,
|
&self.provider_config.model_id,
|
||||||
&self.tools,
|
&self.tools,
|
||||||
|
Some(&self.id.to_string()),
|
||||||
);
|
);
|
||||||
|
|
||||||
if skills_prompt.trim().is_empty() {
|
if skills_prompt.trim().is_empty() {
|
||||||
@ -874,11 +883,9 @@ impl SessionManager {
|
|||||||
self.storage.upsert_session(&meta).await
|
self.storage.upsert_session(&meta).await
|
||||||
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
|
||||||
|
|
||||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
|
||||||
let session = Session::new(
|
let session = Session::new(
|
||||||
unified_id.clone(),
|
unified_id.clone(),
|
||||||
self.provider_config.clone(),
|
self.provider_config.clone(),
|
||||||
user_tx,
|
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
Some(self.storage.clone()),
|
Some(self.storage.clone()),
|
||||||
routing_info,
|
routing_info,
|
||||||
@ -909,11 +916,9 @@ impl SessionManager {
|
|||||||
match self.storage.get_session(&session_id_str).await {
|
match self.storage.get_session(&session_id_str).await {
|
||||||
Ok(meta) => {
|
Ok(meta) => {
|
||||||
tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage");
|
tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage");
|
||||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
|
||||||
let session = Session::from_storage(
|
let session = Session::from_storage(
|
||||||
unified_id.clone(),
|
unified_id.clone(),
|
||||||
self.provider_config.clone(),
|
self.provider_config.clone(),
|
||||||
user_tx,
|
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
self.storage.clone(),
|
self.storage.clone(),
|
||||||
).await?;
|
).await?;
|
||||||
@ -932,11 +937,9 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create new session
|
// Create new session
|
||||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
|
||||||
let session = Session::new(
|
let session = Session::new(
|
||||||
unified_id.clone(),
|
unified_id.clone(),
|
||||||
self.provider_config.clone(),
|
self.provider_config.clone(),
|
||||||
user_tx,
|
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
Some(self.storage.clone()),
|
Some(self.storage.clone()),
|
||||||
String::new(),
|
String::new(),
|
||||||
@ -1175,6 +1178,30 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Normal message handling through LLM
|
// Normal message handling through LLM
|
||||||
|
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Spawn notification publisher — sends immediately when tools are detected
|
||||||
|
{
|
||||||
|
let bus = self.bus.clone();
|
||||||
|
let ch = channel.to_string();
|
||||||
|
let cid = chat_id.to_string();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(notif) = notify_rx.recv().await {
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
metadata.insert("_type".to_string(), "notification".to_string());
|
||||||
|
let outbound = OutboundMessage {
|
||||||
|
channel: ch.clone(),
|
||||||
|
chat_id: cid.clone(),
|
||||||
|
content: notif,
|
||||||
|
reply_to: None,
|
||||||
|
media: vec![],
|
||||||
|
metadata,
|
||||||
|
};
|
||||||
|
let _ = bus.publish_outbound(outbound).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let response: String = {
|
let response: String = {
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
@ -1202,7 +1229,7 @@ impl SessionManager {
|
|||||||
.compress_if_needed(history)
|
.compress_if_needed(history)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let agent = session_guard.create_agent()?;
|
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
||||||
let result = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
for msg in result.emitted_messages {
|
for msg in result.emitted_messages {
|
||||||
@ -1233,6 +1260,142 @@ impl SessionManager {
|
|||||||
Ok(HandleResult::AgentResponse(response))
|
Ok(HandleResult::AgentResponse(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle a message triggered by a scheduled cron job.
|
||||||
|
///
|
||||||
|
/// This is similar to `handle_message`, but the user message is created with
|
||||||
|
/// `SourceKind::ExternalTrigger` source metadata so that the cron job identity
|
||||||
|
/// is preserved in the conversation history and database.
|
||||||
|
pub async fn handle_cron_message(
|
||||||
|
&self,
|
||||||
|
channel: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
prompt: &str,
|
||||||
|
job_id: &str,
|
||||||
|
job_name: &str,
|
||||||
|
) -> Result<HandleResult, AgentError> {
|
||||||
|
use crate::bus::{MessageSource, SourceKind};
|
||||||
|
|
||||||
|
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
|
||||||
|
*self.current_source_session.lock().await = Some(unified_id.to_string());
|
||||||
|
tracing::debug!(unified_id = %unified_id, job_id = %job_id, "handle_cron_message resolved");
|
||||||
|
|
||||||
|
let session = self.get_or_create_session(&unified_id).await?;
|
||||||
|
|
||||||
|
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
{
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use crate::bus::OutboundMessage;
|
||||||
|
let bus = self.bus.clone();
|
||||||
|
let ch = channel.to_string();
|
||||||
|
let cid = chat_id.to_string();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(notif) = notify_rx.recv().await {
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
metadata.insert("_type".to_string(), "notification".to_string());
|
||||||
|
let outbound = OutboundMessage {
|
||||||
|
channel: ch.clone(),
|
||||||
|
chat_id: cid.clone(),
|
||||||
|
content: notif,
|
||||||
|
reply_to: None,
|
||||||
|
media: vec![],
|
||||||
|
metadata,
|
||||||
|
};
|
||||||
|
let _ = bus.publish_outbound(outbound).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let response: String = {
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
|
let source = MessageSource {
|
||||||
|
kind: SourceKind::ExternalTrigger,
|
||||||
|
from_channel: Some(channel.to_string()),
|
||||||
|
from_session: None,
|
||||||
|
from_user_id: None,
|
||||||
|
system_name: Some(job_name.to_string()),
|
||||||
|
task_id: Some(job_id.to_string()),
|
||||||
|
};
|
||||||
|
let user_message = session_guard.create_user_message_with_source(prompt, vec![], source);
|
||||||
|
session_guard.add_message(user_message, true).await
|
||||||
|
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||||
|
|
||||||
|
let mut history = session_guard.get_history().to_vec();
|
||||||
|
|
||||||
|
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||||
|
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
|
||||||
|
let cron_context = format!(
|
||||||
|
"\n\n## 定时任务执行\n\n\
|
||||||
|
你正在执行定时任务「{}」({})。\n\
|
||||||
|
目标渠道: {}:{}\n\n\
|
||||||
|
定时任务执行规则:\n\
|
||||||
|
- 这不是聊天对话,没有人会回复你,不要等待用户输入\n\
|
||||||
|
- 你的职责是根据任务指令直接生成要发送的消息内容\n\
|
||||||
|
- 只输出最终消息,不要输出中间思考过程或分析\n\
|
||||||
|
- 系统会自动将你的回复推送到目标渠道,不要使用 send_message 工具\n\
|
||||||
|
- 你的最终回复就是推送给用户的消息原文",
|
||||||
|
job_name, job_id, channel, chat_id
|
||||||
|
);
|
||||||
|
let full_system_prompt = format!("{}{}", system_prompt, cron_context);
|
||||||
|
history.insert(0, ChatMessage::system(full_system_prompt));
|
||||||
|
|
||||||
|
let history = session_guard.compressor
|
||||||
|
.compress_if_needed(history)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
||||||
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
|
for msg in result.emitted_messages {
|
||||||
|
session_guard.add_message(msg, true).await
|
||||||
|
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if session_guard.should_generate_title() {
|
||||||
|
if let Err(e) = session_guard.generate_title().await {
|
||||||
|
tracing::warn!("failed to generate title: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw_response = result.final_response.content;
|
||||||
|
|
||||||
|
let target_id = unified_id.to_string();
|
||||||
|
let prefix = format!(
|
||||||
|
"[message from cron:{}({}) to {}]\n",
|
||||||
|
job_name, job_id, target_id
|
||||||
|
);
|
||||||
|
let prefixed_response = format!("{}{}", prefix, raw_response);
|
||||||
|
|
||||||
|
let source = MessageSource {
|
||||||
|
kind: SourceKind::CrossChannel,
|
||||||
|
from_channel: Some("cron".to_string()),
|
||||||
|
from_session: Some(format!("{}:{}", job_name, job_id)),
|
||||||
|
from_user_id: None,
|
||||||
|
system_name: Some(job_name.to_string()),
|
||||||
|
task_id: Some(job_id.to_string()),
|
||||||
|
};
|
||||||
|
let msg = ChatMessage::assistant_with_source(prefixed_response.clone(), source);
|
||||||
|
session_guard.add_message(msg, true).await
|
||||||
|
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||||
|
|
||||||
|
prefixed_response
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(
|
||||||
|
channel = %channel,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
job_id = %job_id,
|
||||||
|
response_len = %response.len(),
|
||||||
|
"Cron agent response received"
|
||||||
|
);
|
||||||
|
|
||||||
|
*self.current_source_session.lock().await = None;
|
||||||
|
|
||||||
|
Ok(HandleResult::AgentResponse(response))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||||||
let session = self.get_or_create_session(unified_id).await?;
|
let session = self.get_or_create_session(unified_id).await?;
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
@ -1322,7 +1485,6 @@ impl OutboundMessenger for SessionManager {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
fn test_provider_config() -> LLMProviderConfig {
|
fn test_provider_config() -> LLMProviderConfig {
|
||||||
LLMProviderConfig {
|
LLMProviderConfig {
|
||||||
|
|||||||
@ -1,15 +1,17 @@
|
|||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod session;
|
|
||||||
pub mod message;
|
pub mod message;
|
||||||
|
pub mod scheduler;
|
||||||
|
pub mod session;
|
||||||
|
|
||||||
pub use error::StorageError;
|
pub use error::StorageError;
|
||||||
|
pub use scheduler::{JobRun, ScheduledJob};
|
||||||
|
|
||||||
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
pub struct Storage {
|
pub struct Storage {
|
||||||
pool: Pool<Sqlite>,
|
pub(crate) pool: Pool<Sqlite>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
@ -92,6 +94,130 @@ impl Storage {
|
|||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS llm_calls (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
request_body TEXT NOT NULL,
|
||||||
|
response_body TEXT,
|
||||||
|
error TEXT,
|
||||||
|
duration_ms INTEGER
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
|
||||||
|
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize the scheduler tables (idempotent).
|
||||||
|
pub(crate) async fn init_scheduler_schema(pool: &Pool<Sqlite>) -> Result<(), StorageError> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS scheduled_jobs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
schedule TEXT NOT NULL,
|
||||||
|
prompt TEXT NOT NULL,
|
||||||
|
channel TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
model TEXT,
|
||||||
|
enabled INTEGER NOT NULL DEFAULT 1,
|
||||||
|
delete_after_run INTEGER NOT NULL DEFAULT 0,
|
||||||
|
next_run_at INTEGER NOT NULL,
|
||||||
|
last_run_at INTEGER,
|
||||||
|
last_status TEXT,
|
||||||
|
last_error TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
updated_at INTEGER NOT NULL
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS job_runs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
job_id TEXT NOT NULL REFERENCES scheduled_jobs(id) ON DELETE CASCADE,
|
||||||
|
started_at INTEGER NOT NULL,
|
||||||
|
finished_at INTEGER NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
output TEXT,
|
||||||
|
error TEXT,
|
||||||
|
duration_ms INTEGER NOT NULL
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_jobs_next_run ON scheduled_jobs(enabled, next_run_at)",
|
||||||
|
)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_runs_job_id ON job_runs(job_id)")
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn append_llm_call(
|
||||||
|
&self,
|
||||||
|
provider: &str,
|
||||||
|
model: &str,
|
||||||
|
request_body: &str,
|
||||||
|
response_body: Option<&str>,
|
||||||
|
error: Option<&str>,
|
||||||
|
duration_ms: u64,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO llm_calls (created_at, provider, model, request_body, response_body, error, duration_ms)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(now)
|
||||||
|
.bind(provider)
|
||||||
|
.bind(model)
|
||||||
|
.bind(request_body)
|
||||||
|
.bind(response_body)
|
||||||
|
.bind(error)
|
||||||
|
.bind(duration_ms as i64)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Prune to keep last 1000 records
|
||||||
|
self.prune_llm_calls(1000).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn prune_llm_calls(&self, max_records: i64) -> Result<(), StorageError> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
DELETE FROM llm_calls WHERE id <= (
|
||||||
|
SELECT COALESCE(MAX(id), 0) - ? FROM llm_calls
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(max_records)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,6 +465,79 @@ impl Storage {
|
|||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn list_all_active_sessions(
|
||||||
|
&self,
|
||||||
|
limit: i64,
|
||||||
|
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||||
|
FROM sessions
|
||||||
|
WHERE deleted_at IS NULL
|
||||||
|
ORDER BY last_active_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(limit)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| crate::storage::session::SessionMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
channel: row.get("channel"),
|
||||||
|
chat_id: row.get("chat_id"),
|
||||||
|
dialog_id: row.get("dialog_id"),
|
||||||
|
title: row.get("title"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
last_active_at: row.get("last_active_at"),
|
||||||
|
message_count: row.get("message_count"),
|
||||||
|
routing_info: row.get("routing_info"),
|
||||||
|
deleted_at: row.get("deleted_at"),
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_recent_messages(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
count: i64,
|
||||||
|
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ?
|
||||||
|
ORDER BY seq DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(count)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut messages: Vec<_> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| crate::storage::message::MessageMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
session_id: row.get("session_id"),
|
||||||
|
seq: row.get("seq"),
|
||||||
|
role: row.get("role"),
|
||||||
|
content: row.get("content"),
|
||||||
|
media_refs: row.get("media_refs"),
|
||||||
|
tool_call_id: row.get("tool_call_id"),
|
||||||
|
tool_name: row.get("tool_name"),
|
||||||
|
tool_calls: row.get("tool_calls"),
|
||||||
|
source: row.get("source"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
messages.reverse();
|
||||||
|
Ok(messages)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
||||||
.bind(session_id)
|
.bind(session_id)
|
||||||
|
|||||||
551
src/storage/scheduler.rs
Normal file
551
src/storage/scheduler.rs
Normal file
@ -0,0 +1,551 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::Row;
|
||||||
|
|
||||||
|
use crate::scheduler::Schedule;
|
||||||
|
|
||||||
|
/// A scheduled job stored in the database.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ScheduledJob {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
/// JSON-serialized `Schedule` stored as TEXT in SQLite.
|
||||||
|
pub schedule: Schedule,
|
||||||
|
pub prompt: String,
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub enabled: bool,
|
||||||
|
pub delete_after_run: bool,
|
||||||
|
pub next_run_at: i64,
|
||||||
|
pub last_run_at: Option<i64>,
|
||||||
|
pub last_status: Option<String>,
|
||||||
|
pub last_error: Option<String>,
|
||||||
|
pub created_at: i64,
|
||||||
|
pub updated_at: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single execution record for a job.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct JobRun {
|
||||||
|
pub id: i64,
|
||||||
|
pub job_id: String,
|
||||||
|
pub started_at: i64,
|
||||||
|
pub finished_at: i64,
|
||||||
|
pub status: String,
|
||||||
|
pub output: Option<String>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
pub duration_ms: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::storage::Storage {
|
||||||
|
/// Insert a new scheduled job.
|
||||||
|
pub async fn add_scheduled_job(&self, job: &ScheduledJob) -> anyhow::Result<()> {
|
||||||
|
let schedule_json = serde_json::to_string(&job.schedule)?;
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO scheduled_jobs
|
||||||
|
(id, name, schedule, prompt, channel, chat_id, model,
|
||||||
|
enabled, delete_after_run, next_run_at, last_run_at,
|
||||||
|
last_status, last_error, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&job.id)
|
||||||
|
.bind(&job.name)
|
||||||
|
.bind(&schedule_json)
|
||||||
|
.bind(&job.prompt)
|
||||||
|
.bind(&job.channel)
|
||||||
|
.bind(&job.chat_id)
|
||||||
|
.bind(&job.model)
|
||||||
|
.bind(job.enabled as i32)
|
||||||
|
.bind(job.delete_after_run as i32)
|
||||||
|
.bind(job.next_run_at)
|
||||||
|
.bind(job.last_run_at)
|
||||||
|
.bind(&job.last_status)
|
||||||
|
.bind(&job.last_error)
|
||||||
|
.bind(job.created_at)
|
||||||
|
.bind(job.updated_at)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch a single scheduled job by ID.
|
||||||
|
pub async fn get_scheduled_job(&self, id: &str) -> anyhow::Result<ScheduledJob> {
|
||||||
|
let row = sqlx::query("SELECT * FROM scheduled_jobs WHERE id = ?")
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(self.pool())
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("job not found: {id}"))?;
|
||||||
|
row_to_job(&row)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all scheduled jobs, ordered by next_run_at ascending.
|
||||||
|
pub async fn list_scheduled_jobs(&self) -> anyhow::Result<Vec<ScheduledJob>> {
|
||||||
|
let rows = sqlx::query("SELECT * FROM scheduled_jobs ORDER BY next_run_at ASC")
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
rows.iter().map(row_to_job).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a scheduled job (cascades to job_runs).
|
||||||
|
pub async fn remove_scheduled_job(&self, id: &str) -> anyhow::Result<()> {
|
||||||
|
sqlx::query("DELETE FROM scheduled_jobs WHERE id = ?")
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enable or disable a scheduled job.
|
||||||
|
pub async fn set_scheduled_job_enabled(&self, id: &str, enabled: bool) -> anyhow::Result<()> {
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET enabled = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(enabled as i32)
|
||||||
|
.bind(now_ms())
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update selective fields on a scheduled job.
|
||||||
|
pub async fn update_scheduled_job(
|
||||||
|
&self,
|
||||||
|
id: &str,
|
||||||
|
prompt: Option<String>,
|
||||||
|
schedule: Option<Schedule>,
|
||||||
|
channel: Option<String>,
|
||||||
|
chat_id: Option<String>,
|
||||||
|
model: Option<String>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let now = now_ms();
|
||||||
|
|
||||||
|
if let Some(p) = prompt {
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET prompt = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(&p)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
if let Some(s) = schedule {
|
||||||
|
let json = serde_json::to_string(&s)?;
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET schedule = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(&json)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
if let Some(c) = channel {
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET channel = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(&c)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
if let Some(c) = chat_id {
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET chat_id = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(&c)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
if let Some(m) = model {
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET model = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(&m)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update next_run_at and last_run_at for a job.
|
||||||
|
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
|
||||||
|
let now = now_ms();
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
||||||
|
)
|
||||||
|
.bind(next_run_at)
|
||||||
|
.bind(now)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set last_run_at for a job (used when starting execution).
|
||||||
|
pub async fn touch_scheduled_job_last_run(&self, id: &str, at: i64) -> anyhow::Result<()> {
|
||||||
|
sqlx::query("UPDATE scheduled_jobs SET last_run_at = ?, updated_at = ? WHERE id = ?")
|
||||||
|
.bind(at)
|
||||||
|
.bind(at)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set last_status and last_error after job completion.
|
||||||
|
pub async fn set_scheduled_job_last_status(
|
||||||
|
&self,
|
||||||
|
id: &str,
|
||||||
|
status: &str,
|
||||||
|
error: Option<&str>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let now = now_ms();
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE scheduled_jobs SET last_status = ?, last_error = ?, updated_at = ? WHERE id = ?",
|
||||||
|
)
|
||||||
|
.bind(status)
|
||||||
|
.bind(error)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch enabled jobs whose next_run_at <= now, up to `limit`.
|
||||||
|
pub async fn due_scheduled_jobs(
|
||||||
|
&self,
|
||||||
|
now: i64,
|
||||||
|
limit: usize,
|
||||||
|
) -> anyhow::Result<Vec<ScheduledJob>> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
"SELECT * FROM scheduled_jobs WHERE enabled = 1 AND next_run_at <= ? ORDER BY next_run_at ASC LIMIT ?",
|
||||||
|
)
|
||||||
|
.bind(now)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
rows.iter().map(row_to_job).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a job execution run.
|
||||||
|
pub async fn record_scheduled_job_run(&self, run: &JobRun) -> anyhow::Result<()> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO job_runs (job_id, started_at, finished_at, status, output, error, duration_ms)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&run.job_id)
|
||||||
|
.bind(run.started_at)
|
||||||
|
.bind(run.finished_at)
|
||||||
|
.bind(&run.status)
|
||||||
|
.bind(&run.output)
|
||||||
|
.bind(&run.error)
|
||||||
|
.bind(run.duration_ms)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List recent runs for a job, newest first.
|
||||||
|
pub async fn list_scheduled_job_runs(
|
||||||
|
&self,
|
||||||
|
job_id: &str,
|
||||||
|
limit: usize,
|
||||||
|
) -> anyhow::Result<Vec<JobRun>> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
"SELECT * FROM job_runs WHERE job_id = ? ORDER BY finished_at DESC LIMIT ?",
|
||||||
|
)
|
||||||
|
.bind(job_id)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
rows.iter()
|
||||||
|
.map(|r| {
|
||||||
|
Ok(JobRun {
|
||||||
|
id: r.try_get("id")?,
|
||||||
|
job_id: r.try_get("job_id")?,
|
||||||
|
started_at: r.try_get("started_at")?,
|
||||||
|
finished_at: r.try_get("finished_at")?,
|
||||||
|
status: r.try_get("status")?,
|
||||||
|
output: r.try_get("output")?,
|
||||||
|
error: r.try_get("error")?,
|
||||||
|
duration_ms: r.try_get("duration_ms")?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete disabled jobs whose updated_at is before `before`.
|
||||||
|
pub async fn cleanup_disabled_scheduled_jobs(&self, before: i64) -> anyhow::Result<()> {
|
||||||
|
sqlx::query("DELETE FROM scheduled_jobs WHERE enabled = 0 AND updated_at < ?")
|
||||||
|
.bind(before)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now_ms() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
fn row_to_job(row: &sqlx::sqlite::SqliteRow) -> anyhow::Result<ScheduledJob> {
|
||||||
|
let schedule_json: String = row.try_get("schedule")?;
|
||||||
|
let schedule: Schedule = serde_json::from_str(&schedule_json)?;
|
||||||
|
Ok(ScheduledJob {
|
||||||
|
id: row.try_get("id")?,
|
||||||
|
name: row.try_get("name")?,
|
||||||
|
schedule,
|
||||||
|
prompt: row.try_get("prompt")?,
|
||||||
|
channel: row.try_get("channel")?,
|
||||||
|
chat_id: row.try_get("chat_id")?,
|
||||||
|
model: row.try_get("model")?,
|
||||||
|
enabled: row.try_get::<i32, _>("enabled")? != 0,
|
||||||
|
delete_after_run: row.try_get::<i32, _>("delete_after_run")? != 0,
|
||||||
|
next_run_at: row.try_get("next_run_at")?,
|
||||||
|
last_run_at: row.try_get("last_run_at")?,
|
||||||
|
last_status: row.try_get("last_status")?,
|
||||||
|
last_error: row.try_get("last_error")?,
|
||||||
|
created_at: row.try_get("created_at")?,
|
||||||
|
updated_at: row.try_get("updated_at")?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::ScheduledJob;
|
||||||
|
use crate::scheduler::Schedule;
|
||||||
|
use crate::storage::Storage;
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
|
||||||
|
fn now() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn setup_storage() -> Storage {
|
||||||
|
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||||
|
let storage = Storage { pool };
|
||||||
|
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
|
||||||
|
storage
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_init_creates_tables() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM scheduled_jobs")
|
||||||
|
.fetch_one(storage.pool())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(row.0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_add_and_get_job() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-1".into(),
|
||||||
|
name: "test job".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 3600000 },
|
||||||
|
prompt: "say hello".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "conn-1".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t + 3600000,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
let got = storage.get_scheduled_job("job-1").await.unwrap();
|
||||||
|
assert_eq!(got.id, "job-1");
|
||||||
|
assert_eq!(got.name, "test job");
|
||||||
|
assert_eq!(got.prompt, "say hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_jobs() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
for i in 0..3 {
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: format!("job-{}", i),
|
||||||
|
name: format!("job {}", i),
|
||||||
|
schedule: Schedule::Every { every_ms: 3600000 },
|
||||||
|
prompt: "ping".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "conn-1".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t + 1000,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
}
|
||||||
|
let jobs = storage.list_scheduled_jobs().await.unwrap();
|
||||||
|
assert_eq!(jobs.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_remove_job() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-rm".into(),
|
||||||
|
name: "remove me".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "hi".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "c".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
storage.remove_scheduled_job("job-rm").await.unwrap();
|
||||||
|
let result = storage.get_scheduled_job("job-rm").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_set_enabled() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-toggle".into(),
|
||||||
|
name: "toggle".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "hi".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "c".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
|
||||||
|
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
||||||
|
assert!(!got.enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_due_jobs_only_returns_enabled_and_overdue() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let jobs = vec![
|
||||||
|
ScheduledJob {
|
||||||
|
id: "due".into(), name: "due".into(),
|
||||||
|
schedule: Schedule::At { at: t }, prompt: "1".into(),
|
||||||
|
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
|
model: None, enabled: true, delete_after_run: false,
|
||||||
|
next_run_at: t - 1000, last_run_at: None,
|
||||||
|
last_status: None, last_error: None,
|
||||||
|
created_at: t, updated_at: t,
|
||||||
|
},
|
||||||
|
ScheduledJob {
|
||||||
|
id: "future".into(), name: "future".into(),
|
||||||
|
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
|
||||||
|
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
|
model: None, enabled: true, delete_after_run: false,
|
||||||
|
next_run_at: t + 99999999, last_run_at: None,
|
||||||
|
last_status: None, last_error: None,
|
||||||
|
created_at: t, updated_at: t,
|
||||||
|
},
|
||||||
|
ScheduledJob {
|
||||||
|
id: "disabled-due".into(), name: "disabled due".into(),
|
||||||
|
schedule: Schedule::At { at: t }, prompt: "3".into(),
|
||||||
|
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
|
model: None, enabled: false, delete_after_run: false,
|
||||||
|
next_run_at: t - 1000, last_run_at: None,
|
||||||
|
last_status: None, last_error: None,
|
||||||
|
created_at: t, updated_at: t,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
for j in &jobs {
|
||||||
|
storage.add_scheduled_job(j).await.unwrap();
|
||||||
|
}
|
||||||
|
let due = storage.due_scheduled_jobs(t, 10).await.unwrap();
|
||||||
|
assert_eq!(due.len(), 1);
|
||||||
|
assert_eq!(due[0].id, "due");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_record_run_and_list_runs() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-run".into(), name: "run test".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
|
model: None, enabled: true, delete_after_run: false,
|
||||||
|
next_run_at: t, last_run_at: None,
|
||||||
|
last_status: None, last_error: None,
|
||||||
|
created_at: t, updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
|
let run = super::JobRun {
|
||||||
|
id: 0, job_id: "job-run".into(),
|
||||||
|
started_at: t, finished_at: t + 500,
|
||||||
|
status: "ok".into(), output: Some("hello".into()),
|
||||||
|
error: None, duration_ms: 500,
|
||||||
|
};
|
||||||
|
storage.record_scheduled_job_run(&run).await.unwrap();
|
||||||
|
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
|
||||||
|
assert_eq!(runs.len(), 1);
|
||||||
|
assert_eq!(runs[0].status, "ok");
|
||||||
|
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_job() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-update".into(), name: "old name".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "old prompt".into(), channel: "feishu".into(),
|
||||||
|
chat_id: "oc_1".into(), model: None,
|
||||||
|
enabled: true, delete_after_run: false,
|
||||||
|
next_run_at: t, last_run_at: None,
|
||||||
|
last_status: None, last_error: None,
|
||||||
|
created_at: t, updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
storage.update_scheduled_job(
|
||||||
|
"job-update",
|
||||||
|
Some("new prompt".into()),
|
||||||
|
Some(Schedule::Every { every_ms: 60000 }),
|
||||||
|
None, None, None,
|
||||||
|
).await.unwrap();
|
||||||
|
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
||||||
|
assert_eq!(got.prompt, "new prompt");
|
||||||
|
}
|
||||||
|
}
|
||||||
343
src/tools/chat_manager.rs
Normal file
343
src/tools/chat_manager.rs
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::storage::Storage;
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
pub struct ChatManagerTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
available_channels: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatManagerTool {
|
||||||
|
pub fn new(storage: Arc<Storage>, available_channels: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
storage,
|
||||||
|
available_channels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for ChatManagerTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"chat_manager"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"聊天管理工具。可以列出当前活跃的 session、可用的 channel、以及查看指定 session 的最近消息内容。\
|
||||||
|
action 可选值: list_sessions (列出最近活跃会话), list_channels (列出可用渠道), list_messages (查看最近消息)"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["list_sessions", "list_channels", "list_messages"],
|
||||||
|
"description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的最近消息"
|
||||||
|
},
|
||||||
|
"session_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "会话 ID,格式 channel:chat_id:dialog_id,仅在 action 为 list_messages 时必填"
|
||||||
|
},
|
||||||
|
"count": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "获取最近消息的数量,仅在 action 为 list_messages 时有效,默认 20"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["action"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn concurrency_safe(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let action = args["action"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
|
||||||
|
|
||||||
|
match action {
|
||||||
|
"list_channels" => self.list_channels().await,
|
||||||
|
"list_sessions" => self.list_sessions().await,
|
||||||
|
"list_messages" => self.list_messages(&args).await,
|
||||||
|
_ => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Unknown action: {}. Supported: list_sessions, list_channels, list_messages",
|
||||||
|
action
|
||||||
|
)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatManagerTool {
|
||||||
|
async fn list_channels(&self) -> anyhow::Result<ToolResult> {
|
||||||
|
let channels = self.available_channels.join(", ");
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("可用渠道 ({}): {}", self.available_channels.len(), channels),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_sessions(&self) -> anyhow::Result<ToolResult> {
|
||||||
|
let sessions = self
|
||||||
|
.storage
|
||||||
|
.list_all_active_sessions(20)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to list sessions: {}", e))?;
|
||||||
|
|
||||||
|
if sessions.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "当前没有活跃的会话".to_string(),
|
||||||
|
error: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let now_ms = chrono::Utc::now().timestamp_millis();
|
||||||
|
let mut output = format!("活跃会话 (共 {} 个):\n", sessions.len());
|
||||||
|
|
||||||
|
for s in &sessions {
|
||||||
|
let ago = format_duration_ago(now_ms - s.last_active_at);
|
||||||
|
output.push_str(&format!(
|
||||||
|
"- {}\n title={} | channel={} chat_id={} | {}条消息 | 最后活动: {}前\n",
|
||||||
|
s.id, s.title, s.channel, s.chat_id, s.message_count, ago
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_messages(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let session_id = args["session_id"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("missing required parameter: session_id"))?;
|
||||||
|
|
||||||
|
let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100);
|
||||||
|
|
||||||
|
let session = self
|
||||||
|
.storage
|
||||||
|
.get_session(session_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Session not found: {}", e))?;
|
||||||
|
|
||||||
|
let messages = self
|
||||||
|
.storage
|
||||||
|
.list_recent_messages(session_id, count)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to load messages: {}", e))?;
|
||||||
|
|
||||||
|
let mut output = format!(
|
||||||
|
"会话: {} ({})\n--- 最近 {} 条消息 (共 {} 条) ---\n",
|
||||||
|
session_id, session.title, messages.len(), session.message_count
|
||||||
|
);
|
||||||
|
|
||||||
|
if messages.is_empty() {
|
||||||
|
output.push_str("(暂无消息)\n");
|
||||||
|
} else {
|
||||||
|
for m in &messages {
|
||||||
|
let time = format_timestamp(m.created_at);
|
||||||
|
let role_tag = match m.role.as_str() {
|
||||||
|
"user" => "user ",
|
||||||
|
"assistant" => "assistant",
|
||||||
|
"tool" => "tool ",
|
||||||
|
"system" => "system ",
|
||||||
|
other => other,
|
||||||
|
};
|
||||||
|
let preview = truncate_content(&m.content, 200);
|
||||||
|
output.push_str(&format!(
|
||||||
|
"[{}] {} | {} | {}\n",
|
||||||
|
m.seq, time, role_tag, preview
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_duration_ago(millis: i64) -> String {
|
||||||
|
let secs = millis / 1000;
|
||||||
|
if secs < 60 {
|
||||||
|
format!("{}秒", secs)
|
||||||
|
} else if secs < 3600 {
|
||||||
|
format!("{}分钟", secs / 60)
|
||||||
|
} else if secs < 86400 {
|
||||||
|
format!("{}小时", secs / 3600)
|
||||||
|
} else {
|
||||||
|
format!("{}天", secs / 86400)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_timestamp(ms: i64) -> String {
|
||||||
|
if let Some(dt) = chrono::DateTime::from_timestamp_millis(ms) {
|
||||||
|
dt.format("%m-%d %H:%M").to_string()
|
||||||
|
} else {
|
||||||
|
ms.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_content(content: &str, max_len: usize) -> String {
|
||||||
|
let content = content.replace('\n', " ");
|
||||||
|
if content.chars().count() > max_len {
|
||||||
|
format!("{}...", content.chars().take(max_len).collect::<String>())
|
||||||
|
} else {
|
||||||
|
content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
async fn create_test_storage() -> (Arc<Storage>, TempDir) {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let db_path = dir.path().join("test.db");
|
||||||
|
let storage = Storage::new(&db_path).await.unwrap();
|
||||||
|
(Arc::new(storage), dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_channels() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
let tool = ChatManagerTool::new(storage, vec!["cli_chat".into(), "feishu".into()]);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "list_channels" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("cli_chat"));
|
||||||
|
assert!(result.output.contains("feishu"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_sessions_empty() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "list_sessions" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("没有"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_sessions_with_data() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
for i in 0..3 {
|
||||||
|
let meta = crate::storage::session::SessionMeta {
|
||||||
|
id: format!("cli_chat:sid{}:dialog{}", i, i),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: format!("sid{}", i),
|
||||||
|
dialog_id: format!("dialog{}", i),
|
||||||
|
title: format!("会话{}", i),
|
||||||
|
created_at: now - i * 3600_000,
|
||||||
|
last_active_at: now - i * 3600_000,
|
||||||
|
message_count: i * 5,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "list_sessions" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("会话0"));
|
||||||
|
assert!(result.output.contains("会话1"));
|
||||||
|
assert!(result.output.contains("会话2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_messages() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
let session_id = "cli_chat:sid0:dialog0";
|
||||||
|
let meta = crate::storage::session::SessionMeta {
|
||||||
|
id: session_id.to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid0".to_string(),
|
||||||
|
dialog_id: "dialog0".to_string(),
|
||||||
|
title: "测试会话".to_string(),
|
||||||
|
created_at: now,
|
||||||
|
last_active_at: now,
|
||||||
|
message_count: 3,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
for i in 0..3 {
|
||||||
|
let msg = crate::storage::message::MessageMeta {
|
||||||
|
id: format!("msg{}", i),
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
seq: i as i64 + 1,
|
||||||
|
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||||
|
content: format!("消息内容 {}", i),
|
||||||
|
media_refs: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
source: None,
|
||||||
|
created_at: now + i * 1000,
|
||||||
|
};
|
||||||
|
storage.append_message(session_id, &msg).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "list_messages", "session_id": session_id }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("消息内容 0"));
|
||||||
|
assert!(result.output.contains("消息内容 2"));
|
||||||
|
assert!(result.output.contains("测试会话"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unknown_action() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "unknown" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("Unknown action"));
|
||||||
|
}
|
||||||
|
}
|
||||||
800
src/tools/cron.rs
Normal file
800
src/tools/cron.rs
Normal file
@ -0,0 +1,800 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::scheduler::{next_run_for_schedule, Schedule};
|
||||||
|
use crate::storage::{ScheduledJob, Storage};
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
fn now_ms() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CronAddTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
valid_channels: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CronAddTool {
|
||||||
|
pub fn new(storage: Arc<Storage>, valid_channels: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
storage,
|
||||||
|
valid_channels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CronAddTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cron_add"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Create a new scheduled task (cron job). The task will execute an AI prompt on a schedule \
|
||||||
|
and deliver the result to the specified channel/chat. \
|
||||||
|
Important: the execution environment is a fresh session with no access to your current \
|
||||||
|
conversation history. The prompt parameter MUST include all necessary context: \
|
||||||
|
what to do, the target audience, required output format, and any background information. \
|
||||||
|
Schedule formats: \
|
||||||
|
- 'every': {\"type\":\"every\",\"every_ms\":3600000} for every hour, \
|
||||||
|
- 'at': {\"type\":\"at\",\"at\":<unix_timestamp_ms>} for one-shot, \
|
||||||
|
- 'cron': {\"type\":\"cron\",\"expr\":\"0 0 9 * * *\"} for cron expressions (6-field: sec min hour dom month dow)."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"schedule": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Schedule definition. One of: {\"type\":\"every\",\"every_ms\":<ms>}, {\"type\":\"at\",\"at\":<unix_ms>}, or {\"type\":\"cron\",\"expr\":\"<cron_expr>\",\"tz\":\"<tz>\"}",
|
||||||
|
"required": ["type"]
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The AI prompt to execute on each trigger"
|
||||||
|
},
|
||||||
|
"channel": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Target channel for delivering results (e.g., 'feishu', 'cli_chat')"
|
||||||
|
},
|
||||||
|
"chat_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Target chat ID within the channel"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Human-readable name for the job (optional, defaults to truncated prompt)"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional model override for this job"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["schedule", "prompt", "channel", "chat_id"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let schedule_json = args
|
||||||
|
.get("schedule")
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("missing 'schedule'"))?;
|
||||||
|
let schedule: Schedule = serde_json::from_value(schedule_json.clone())
|
||||||
|
.map_err(|e| anyhow::anyhow!("invalid schedule: {}", e))?;
|
||||||
|
|
||||||
|
let prompt = args
|
||||||
|
.get("prompt")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if prompt.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("prompt is required".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let channel = args
|
||||||
|
.get("channel")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if !self.valid_channels.contains(&channel) {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: format!(
|
||||||
|
"Unknown channel '{}'. Available: {}",
|
||||||
|
channel,
|
||||||
|
self.valid_channels.join(", ")
|
||||||
|
),
|
||||||
|
error: Some(format!("Unknown channel: {}", channel)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_id = args
|
||||||
|
.get("chat_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if chat_id.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("chat_id is required".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let name = args
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
// char-boundary-safe truncation to 50 bytes
|
||||||
|
let limit = 50;
|
||||||
|
if prompt.len() <= limit {
|
||||||
|
prompt.as_str()
|
||||||
|
} else {
|
||||||
|
let mut end = limit;
|
||||||
|
while !prompt.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
&prompt[..end]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
let model = args
|
||||||
|
.get("model")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
let now = now_ms();
|
||||||
|
let next_run_at = next_run_for_schedule(&schedule, now)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("could not compute next run time from schedule"))?;
|
||||||
|
|
||||||
|
let id = Uuid::new_v4().to_string()[..8].to_string();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: id.clone(),
|
||||||
|
name: name.clone(),
|
||||||
|
schedule,
|
||||||
|
prompt,
|
||||||
|
channel,
|
||||||
|
chat_id,
|
||||||
|
model,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.storage.add_scheduled_job(&job).await?;
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!(
|
||||||
|
"Scheduled job created: id={}, name=\"{}\", next_run_at={}",
|
||||||
|
id, name, next_run_at
|
||||||
|
),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CronListTool ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct CronListTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CronListTool {
|
||||||
|
pub fn new(storage: Arc<Storage>) -> Self {
|
||||||
|
Self { storage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CronListTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cron_list"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"List all scheduled tasks (cron jobs) with their status and next run time."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["all", "enabled", "disabled"],
|
||||||
|
"description": "Filter by job status (default: all)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let filter = args
|
||||||
|
.get("status")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("all");
|
||||||
|
let jobs = self.storage.list_scheduled_jobs().await?;
|
||||||
|
|
||||||
|
let filtered: Vec<&ScheduledJob> = match filter {
|
||||||
|
"enabled" => jobs.iter().filter(|j| j.enabled).collect(),
|
||||||
|
"disabled" => jobs.iter().filter(|j| !j.enabled).collect(),
|
||||||
|
_ => jobs.iter().collect(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if filtered.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "No scheduled jobs found.".into(),
|
||||||
|
error: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
for j in &filtered {
|
||||||
|
let status = if j.enabled { "enabled" } else { "disabled" };
|
||||||
|
let last = match (&j.last_status, &j.last_error) {
|
||||||
|
(Some(s), _) if s == "ok" => " last:ok".to_string(),
|
||||||
|
(Some(_), Some(e)) => format!(" last:err({})", &e[..e.len().min(40)]),
|
||||||
|
_ => String::new(),
|
||||||
|
};
|
||||||
|
let model = j.model.as_deref().unwrap_or("default");
|
||||||
|
lines.push(format!(
|
||||||
|
"[{}] id={} name=\"{}\" channel={} chat={} model={} next={}{}",
|
||||||
|
status, j.id, j.name, j.channel, j.chat_id, model, j.next_run_at, last
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: lines.join("\n"),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CronRemoveTool ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct CronRemoveTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CronRemoveTool {
|
||||||
|
pub fn new(storage: Arc<Storage>) -> Self {
|
||||||
|
Self { storage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CronRemoveTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cron_remove"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Delete a scheduled task permanently by its job ID. Use cron_list first to find the ID."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"job_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the job to delete"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["job_id"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let job_id = args
|
||||||
|
.get("job_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if job_id.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("job_id is required".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.storage.get_scheduled_job(&job_id).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(_) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: format!("Job {} not found.", job_id),
|
||||||
|
error: Some("not found".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.storage.remove_scheduled_job(&job_id).await?;
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Job {} deleted.", job_id),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CronEnableTool ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct CronEnableTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CronEnableTool {
|
||||||
|
pub fn new(storage: Arc<Storage>) -> Self {
|
||||||
|
Self { storage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CronEnableTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cron_enable"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Enable a disabled scheduled task by its job ID."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"job_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the job to enable"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["job_id"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let job_id = args
|
||||||
|
.get("job_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if job_id.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("job_id is required".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let job = self
|
||||||
|
.storage
|
||||||
|
.get_scheduled_job(&job_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
|
|
||||||
|
let next = next_run_for_schedule(&job.schedule, now_ms());
|
||||||
|
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
|
||||||
|
if let Some(n) = next {
|
||||||
|
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Job {} enabled.", job_id),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CronDisableTool ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct CronDisableTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CronDisableTool {
|
||||||
|
pub fn new(storage: Arc<Storage>) -> Self {
|
||||||
|
Self { storage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CronDisableTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cron_disable"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Disable a scheduled task by its job ID without deleting it."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"job_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the job to disable"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["job_id"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let job_id = args
|
||||||
|
.get("job_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if job_id.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("job_id is required".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = self
|
||||||
|
.storage
|
||||||
|
.get_scheduled_job(&job_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
|
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Job {} disabled.", job_id),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CronUpdateTool ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct CronUpdateTool {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CronUpdateTool {
|
||||||
|
pub fn new(storage: Arc<Storage>) -> Self {
|
||||||
|
Self { storage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CronUpdateTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cron_update"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Update fields of an existing scheduled task. Only specified fields are changed."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"job_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the job to update"
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "New AI prompt"
|
||||||
|
},
|
||||||
|
"schedule": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "New schedule definition"
|
||||||
|
},
|
||||||
|
"channel": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "New target channel"
|
||||||
|
},
|
||||||
|
"chat_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "New target chat ID"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "New model override"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["job_id"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let job_id = args
|
||||||
|
.get("job_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if job_id.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("job_id is required".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = self
|
||||||
|
.storage
|
||||||
|
.get_scheduled_job(&job_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
|
|
||||||
|
let prompt = args
|
||||||
|
.get("prompt")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let schedule: Option<Schedule> = match args.get("schedule") {
|
||||||
|
Some(s) => Some(
|
||||||
|
serde_json::from_value(s.clone())
|
||||||
|
.map_err(|e| anyhow::anyhow!("invalid schedule: {}", e))?,
|
||||||
|
),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let channel = args
|
||||||
|
.get("channel")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let chat_id = args
|
||||||
|
.get("chat_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let model = args
|
||||||
|
.get("model")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
self.storage
|
||||||
|
.update_scheduled_job(&job_id, prompt, schedule, channel, chat_id, model)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if args.get("schedule").is_some() {
|
||||||
|
let job = self.storage.get_scheduled_job(&job_id).await?;
|
||||||
|
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
||||||
|
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Job {} updated.", job_id),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::scheduler::Schedule;
|
||||||
|
use crate::storage::{ScheduledJob, Storage};
|
||||||
|
use serde_json::json;
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
|
||||||
|
async fn setup_storage() -> Arc<Storage> {
|
||||||
|
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||||
|
Storage::init_scheduler_schema(&pool).await.unwrap();
|
||||||
|
Arc::new(Storage { pool })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cron_add_tool() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let tool = CronAddTool::new(storage.clone(), vec!["cli_chat".to_string()]);
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"schedule": {"type": "every", "every_ms": 3600000},
|
||||||
|
"prompt": "report status",
|
||||||
|
"channel": "cli_chat",
|
||||||
|
"chat_id": "test-chat-1",
|
||||||
|
"name": "hourly report"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("hourly report"));
|
||||||
|
|
||||||
|
let jobs = storage.list_scheduled_jobs().await.unwrap();
|
||||||
|
assert_eq!(jobs.len(), 1);
|
||||||
|
assert_eq!(jobs[0].name, "hourly report");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cron_add_invalid_channel() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let tool = CronAddTool::new(storage.clone(), vec!["cli_chat".to_string()]);
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"schedule": {"type": "every", "every_ms": 3600000},
|
||||||
|
"prompt": "test",
|
||||||
|
"channel": "nonexistent",
|
||||||
|
"chat_id": "x",
|
||||||
|
"name": "test"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.as_ref().unwrap().contains("Unknown channel"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cron_list_tool() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
name: "list-test".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "hi".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "c".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t + 1000,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
|
let tool = CronListTool::new(storage.clone());
|
||||||
|
let result = tool.execute(json!({})).await.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("list-test"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cron_remove_tool() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-rm-tool".into(),
|
||||||
|
name: "rm me".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "hi".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "c".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
|
let tool = CronRemoveTool::new(storage.clone());
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({"job_id": "job-rm-tool"}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(storage.get_scheduled_job("job-rm-tool").await.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cron_enable_disable_tools() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-toggle-tool".into(),
|
||||||
|
name: "toggle".into(),
|
||||||
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
|
prompt: "hi".into(),
|
||||||
|
channel: "cli_chat".into(),
|
||||||
|
chat_id: "c".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
|
let disable_tool = CronDisableTool::new(storage.clone());
|
||||||
|
let result = disable_tool
|
||||||
|
.execute(json!({"job_id": "job-toggle-tool"}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
let got = storage.get_scheduled_job("job-toggle-tool").await.unwrap();
|
||||||
|
assert!(!got.enabled);
|
||||||
|
|
||||||
|
let enable_tool = CronEnableTool::new(storage.clone());
|
||||||
|
let result = enable_tool
|
||||||
|
.execute(json!({"job_id": "job-toggle-tool"}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
let got = storage.get_scheduled_job("job-toggle-tool").await.unwrap();
|
||||||
|
assert!(got.enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cron_update_tool() {
|
||||||
|
let storage = setup_storage().await;
|
||||||
|
let t = now();
|
||||||
|
let job = ScheduledJob {
|
||||||
|
id: "job-update-tool".into(),
|
||||||
|
name: "old".into(),
|
||||||
|
schedule: Schedule::Every {
|
||||||
|
every_ms: 3600000,
|
||||||
|
},
|
||||||
|
prompt: "old prompt".into(),
|
||||||
|
channel: "feishu".into(),
|
||||||
|
chat_id: "oc_1".into(),
|
||||||
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t + 1000,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
|
};
|
||||||
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
|
let tool = CronUpdateTool::new(storage.clone());
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"job_id": "job-update-tool",
|
||||||
|
"prompt": "new prompt",
|
||||||
|
"schedule": {"type": "every", "every_ms": 60000}
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
let got = storage.get_scheduled_job("job-update-tool").await.unwrap();
|
||||||
|
assert_eq!(got.prompt, "new prompt");
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,7 @@
|
|||||||
pub mod bash;
|
pub mod bash;
|
||||||
pub mod calculator;
|
pub mod calculator;
|
||||||
|
pub mod chat_manager;
|
||||||
|
pub mod cron;
|
||||||
pub mod file_edit;
|
pub mod file_edit;
|
||||||
pub mod file_read;
|
pub mod file_read;
|
||||||
pub mod file_write;
|
pub mod file_write;
|
||||||
@ -13,6 +15,7 @@ pub mod web_fetch;
|
|||||||
|
|
||||||
pub use bash::BashTool;
|
pub use bash::BashTool;
|
||||||
pub use calculator::CalculatorTool;
|
pub use calculator::CalculatorTool;
|
||||||
|
pub use chat_manager::ChatManagerTool;
|
||||||
pub use file_edit::FileEditTool;
|
pub use file_edit::FileEditTool;
|
||||||
pub use file_read::FileReadTool;
|
pub use file_read::FileReadTool;
|
||||||
pub use file_write::FileWriteTool;
|
pub use file_write::FileWriteTool;
|
||||||
|
|||||||
61
tests/test_scheduler.rs
Normal file
61
tests/test_scheduler.rs
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
/// Integration tests for the scheduled tasks (cron) system.
|
||||||
|
/// Run with: cargo test --test test_scheduler
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
/// Verify that Schedule types (de)serialize correctly.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_scheduler_types_roundtrip() {
|
||||||
|
use picobot::scheduler::Schedule;
|
||||||
|
|
||||||
|
let s1 = Schedule::Every { every_ms: 3600000 };
|
||||||
|
let json = serde_json::to_string(&s1).unwrap();
|
||||||
|
let s2: Schedule = serde_json::from_str(&json).unwrap();
|
||||||
|
match s2 {
|
||||||
|
Schedule::Every { every_ms } => assert_eq!(every_ms, 3600000),
|
||||||
|
_ => panic!("expected Every"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let s1 = Schedule::At { at: 1000000 };
|
||||||
|
let json = serde_json::to_string(&s1).unwrap();
|
||||||
|
let s2: Schedule = serde_json::from_str(&json).unwrap();
|
||||||
|
match s2 {
|
||||||
|
Schedule::At { at } => assert_eq!(at, 1000000),
|
||||||
|
_ => panic!("expected At"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let s1 = Schedule::Cron {
|
||||||
|
expr: "0 0 9 * * *".into(),
|
||||||
|
tz: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&s1).unwrap();
|
||||||
|
let s2: Schedule = serde_json::from_str(&json).unwrap();
|
||||||
|
match s2 {
|
||||||
|
Schedule::Cron { expr, tz } => {
|
||||||
|
assert_eq!(expr, "0 0 9 * * *");
|
||||||
|
assert!(tz.is_none());
|
||||||
|
}
|
||||||
|
_ => panic!("expected Cron"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verify that next_run_for_schedule produces valid future timestamps.
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_always_future() {
|
||||||
|
use picobot::scheduler::{next_run_for_schedule, Schedule};
|
||||||
|
|
||||||
|
let now = 1700000000000_i64;
|
||||||
|
|
||||||
|
let schedules = vec![
|
||||||
|
Schedule::Every { every_ms: 60000 },
|
||||||
|
Schedule::Cron {
|
||||||
|
expr: "0 0 9 * * *".into(),
|
||||||
|
tz: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for s in &schedules {
|
||||||
|
let next = next_run_for_schedule(s, now);
|
||||||
|
assert!(next.is_some(), "expected next run for {:?}", s);
|
||||||
|
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user