From 102a4a63c59f769ac8dd447094b628ac2ea0b75f Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Thu, 14 May 2026 10:07:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=20/save=20=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=EF=BC=8C=E6=94=AF=E6=8C=81=E4=BF=9D=E5=AD=98=E5=85=A8?= =?UTF-8?q?=E9=83=A8=E6=B6=88=E6=81=AF=E5=88=B0=E6=8C=87=E5=AE=9A=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=EF=BC=9B=E4=BF=AE=E6=94=B9=E5=91=BD=E4=BB=A4=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=E4=BB=A5=E5=8C=85=E5=90=AB=E6=96=B0?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/command/adapters/cli.rs | 60 +++++- src/command/handler.rs | 17 +- src/command/handlers/save_session.rs | 67 ++++--- src/command/mod.rs | 5 +- src/gateway/processor.rs | 54 ++++- src/gateway/ws.rs | 2 +- src/tools/scheduler_manage.rs | 289 ++++++++++++++++++++++++++- 7 files changed, 439 insertions(+), 55 deletions(-) diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 9c256d8..1f2c723 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -43,14 +43,27 @@ impl InputAdapter for CliInputAdapter { // 解析 /save 命令 if trimmed == "/save" { - return Ok(Some(Command::SaveSession { filepath: None })); + return Ok(Some(Command::SaveSession { + filepath: None, + include_all: false, + })); } - if let Some(path) = trimmed.strip_prefix("/save ") { - let path = path.trim(); - return Ok(Some(Command::SaveSession { - filepath: Some(path.to_string()), - })); + if let Some(args) = trimmed.strip_prefix("/save ") { + let args = args.trim(); + // 解析参数:可能是 "all"、路径、或 "all 路径" + let (include_all, filepath) = if args == "all" { + // /save all - 保存全部消息 + (true, None) + } else if args.starts_with("all ") { + // /save all - 保存全部消息到指定路径 + let path = args[4..].trim(); + (true, Some(path.to_string())) + } else { + // /save - 保存活跃消息到指定路径 + (false, Some(args.to_string())) + }; + return Ok(Some(Command::SaveSession { filepath, include_all })); } // 不是命令,返回 None @@ -192,7 +205,7 @@ mod tests { assert!(result.is_some()); let cmd = result.unwrap(); - assert!(matches!(cmd, Command::SaveSession { filepath: None })); + assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false })); } #[test] @@ -207,7 +220,38 @@ mod tests { assert!(matches!( cmd, Command::SaveSession { - filepath: Some(ref p) + filepath: Some(ref p), + include_all: false, + } if p == "./debug/session.md" + )); + } + + #[test] + fn test_cli_input_adapter_save_all() { + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("test"); + + let result = adapter.try_parse("/save all", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true })); + } + + #[test] + fn test_cli_input_adapter_save_all_with_path() { + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("test"); + + let result = adapter.try_parse("/save all ./debug/session.md", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveSession { + filepath: Some(ref p), + include_all: true, } if p == "./debug/session.md" )); } diff --git a/src/command/handler.rs b/src/command/handler.rs index d6b3572..f50cae2 100644 --- a/src/command/handler.rs +++ b/src/command/handler.rs @@ -48,14 +48,15 @@ pub trait InChatCommandHandler: Send + Sync { /// * `session_manager` - 会话管理器(用于获取 session) /// /// # Returns - /// * `Ok(())` - 命令执行成功 + /// * `Ok(Some(msg))` - 命令执行成功,返回要发送给用户的消息 + /// * `Ok(None)` - 命令执行成功,无需发送消息 /// * `Err(AgentError)` - 命令执行失败 async fn handle( &self, cmd: Command, inbound: &InboundMessage, session_manager: &SessionManager, - ) -> Result<(), AgentError>; + ) -> Result, AgentError>; } /// 命令路由器 @@ -163,25 +164,25 @@ impl InChatCommandRouter { /// * `session_manager` - 会话管理器 /// /// # Returns - /// * `Ok(true)` - 命令被处理 - /// * `Ok(false)` - 没有合适的处理器 + /// * `Ok(Some(msg))` - 命令被处理,返回成功消息 + /// * `Ok(None)` - 没有合适的处理器 /// * `Err(AgentError)` - 执行失败 pub async fn dispatch( &self, cmd: Command, inbound: &InboundMessage, session_manager: &SessionManager, - ) -> Result { + ) -> Result, AgentError> { // 查找能处理此命令的处理器 for handler in &self.handlers { if handler.can_handle(&cmd) { - handler.handle(cmd, inbound, session_manager).await?; - return Ok(true); + let result = handler.handle(cmd, inbound, session_manager).await?; + return Ok(result); } } // 没有找到合适的处理器 - Ok(false) + Ok(None) } } diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index 72cd7e1..d4fd394 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -6,10 +6,8 @@ use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider; -use crate::gateway::session::SessionManager; use crate::storage::{SessionRecord, SessionStore}; use crate::agent::AgentError; -use crate::bus::OutboundMessage; use async_trait::async_trait; use chrono::{Local, TimeZone}; use std::path::PathBuf; @@ -20,6 +18,7 @@ use std::sync::Arc; /// # Arguments /// * `session_id` - 会话ID /// * `filepath` - 可选的文件路径 +/// * `include_all` - 是否包含 cutoff 之前的所有消息 /// * `store` - 会话存储 /// * `provider_config` - LLM提供者配置 /// @@ -28,6 +27,7 @@ use std::sync::Arc; pub async fn save_session_to_file( session_id: &str, filepath: Option, + include_all: bool, store: &SessionStore, provider_config: &LLMProviderConfig, ) -> Result { @@ -37,10 +37,16 @@ pub async fn save_session_to_file( .map_err(|e| format!("Failed to get session: {}", e))? .ok_or_else(|| "Session not found".to_string())?; - // 获取所有消息(包括历史) - let messages = store - .load_all_messages(session_id) - .map_err(|e| format!("Failed to load messages: {}", e))?; + // 根据 include_all 决定加载消息范围 + let messages = if include_all { + store + .load_all_messages(session_id) + .map_err(|e| format!("Failed to load messages: {}", e))? + } else { + store + .load_messages(session_id) + .map_err(|e| format!("Failed to load messages: {}", e))? + }; // 计算用户消息数(用于系统提示词构建) let user_message_count = messages.iter().filter(|m| m.role == "user").count(); @@ -109,8 +115,8 @@ impl CommandHandler for SaveSessionCommandHandler { ctx: CommandContext, ) -> Result { match cmd { - Command::SaveSession { filepath } => { - handle_save_session(self, filepath, ctx).await + Command::SaveSession { filepath, include_all } => { + handle_save_session(self, filepath, include_all, ctx).await } _ => unreachable!(), } @@ -121,6 +127,7 @@ impl CommandHandler for SaveSessionCommandHandler { async fn handle_save_session( handler: &SaveSessionCommandHandler, filepath: Option, + include_all: bool, ctx: CommandContext, ) -> Result { let session_id = ctx @@ -132,18 +139,25 @@ async fn handle_save_session( let output_path = save_session_to_file( session_id, filepath, + include_all, &*handler.store, &handler.provider_config, ) .await .map_err(|e| CommandError::new("SAVE_ERROR", e))?; - // 获取消息数量用于返回 - let message_count = handler - .store - .load_all_messages(session_id) - .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))? - .len(); + // 根据 include_all 获取消息数量 + let message_count = if include_all { + handler + .store + .load_all_messages(session_id) + } else { + handler + .store + .load_messages(session_id) + } + .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))? + .len(); Ok(CommandResponse::success(ctx.request_id) .with_message( @@ -377,9 +391,9 @@ impl InChatCommandHandler for SaveSessionInChatHandler { cmd: Command, inbound: &InboundMessage, session_manager: &crate::gateway::session::SessionManager, - ) -> Result<(), AgentError> { - let Command::SaveSession { filepath } = cmd else { - return Ok(()); + ) -> Result, AgentError> { + let Command::SaveSession { filepath, include_all } = cmd else { + return Ok(None); }; // 通过 session_manager 获取 session @@ -387,7 +401,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler { Some(s) => s, None => { tracing::error!("Session not found for channel: {}", inbound.channel); - return Ok(()); + return Ok(Some("Session not found".to_string())); } }; @@ -398,21 +412,23 @@ impl InChatCommandHandler for SaveSessionInChatHandler { let result = save_session_to_file( &session_id, filepath, + include_all, &*self.store, &self.provider_config, ) .await; - // 结果通过返回 Ok(()) 表示成功 - // 实际输出由调用者通过消息总线发送 + // 返回成功或失败消息 match result { Ok(output_path) => { - tracing::info!("Session saved to: {}", output_path.display()); - Ok(()) + let msg = format!("Session saved to: {}", output_path.display()); + tracing::info!("{}", msg); + Ok(Some(msg)) } Err(error) => { - tracing::error!("Failed to save session: {}", error); - Ok(()) + let msg = format!("Failed to save session: {}", error); + tracing::error!("{}", msg); + Ok(Some(msg)) } } } @@ -527,7 +543,8 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let handler = SaveSessionCommandHandler::new(store, test_config()); - assert!(handler.can_handle(&Command::SaveSession { filepath: None })); + assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false })); + assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true })); assert!(!handler.can_handle(&Command::CreateSession { title: None })); } } diff --git a/src/command/mod.rs b/src/command/mod.rs index e4b02fb..81e344f 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -14,7 +14,10 @@ pub enum Command { /// 创建新会话 CreateSession { title: Option }, /// 保存会话内容到 Markdown 文件 - SaveSession { filepath: Option }, + SaveSession { + filepath: Option, + include_all: bool, + }, } impl Command { diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 4f362ed..30cd637 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -4,7 +4,7 @@ use tokio::sync::Semaphore; use crate::agent::AgentError; use crate::bus::{InboundMessage, MessageBus, OutboundMessage}; -use crate::command::handler::{InChatCommandHandler, InChatCommandRouter}; +use crate::command::handler::InChatCommandRouter; use crate::command::Command; use crate::config::LLMProviderConfig; @@ -94,9 +94,28 @@ impl InboundProcessor { // 尝试解析为命令 if let Some(cmd) = parse_in_chat_command(&inbound.content) { // 使用命令路由器处理 - let handled = self.command_router.dispatch(cmd, &inbound, &self.session_manager).await?; - if handled { - return Ok(()); + match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? { + Some(response_msg) => { + // 发送命令执行结果给用户 + if let Err(error) = self + .bus + .publish_outbound(OutboundMessage::assistant( + inbound.channel.clone(), + inbound.chat_id.clone(), + response_msg, + None, + inbound.forwarded_metadata.clone(), + )) + .await + { + tracing::error!(error = %error, "Failed to publish command response"); + } + return Ok(()); + } + None => { + // 命令已处理但没有返回消息 + return Ok(()); + } } } @@ -156,7 +175,10 @@ impl InboundProcessor { /// 解析聊天中的命令 /// /// 支持格式: -/// - `/save [filepath]` - 保存会话 +/// - `/save` - 保存活跃会话消息(到 cutoff) +/// - `/save all` - 保存全部会话消息(包括 cutoff 之前) +/// - `/save ` - 保存活跃消息到指定路径 +/// - `/save all ` - 保存全部消息到指定路径 /// /// 返回 Some(Command) 如果是命令 /// 返回 None 如果不是命令 @@ -164,13 +186,25 @@ fn parse_in_chat_command(content: &str) -> Option { let trimmed = content.trim(); if trimmed.starts_with("/save") { - let path = trimmed[5..].trim(); - let filepath = if path.is_empty() { - None + let args = trimmed[5..].trim(); + + // 解析参数 + let (include_all, filepath) = if args.is_empty() { + // /save 无参数 - 只保存活跃消息 + (false, None) + } else if args == "all" { + // /save all - 保存全部消息 + (true, None) + } else if args.starts_with("all ") { + // /save all - 保存全部消息到指定路径 + let path = args[4..].trim(); + (true, Some(path.to_string())) } else { - Some(path.to_string()) + // /save - 保存活跃消息到指定路径 + (false, Some(args.to_string())) }; - Some(Command::SaveSession { filepath }) + + Some(Command::SaveSession { filepath, include_all }) } else { None } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 9d9839d..08b07bd 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -366,7 +366,7 @@ async fn handle_inbound( }; // 构建命令 - let cmd = crate::command::Command::SaveSession { filepath }; + let cmd = crate::command::Command::SaveSession { filepath, include_all: true }; let cmd_ctx = CommandContext::new("websocket") .with_session_id(&target_session_id); diff --git a/src/tools/scheduler_manage.rs b/src/tools/scheduler_manage.rs index ec741d6..f1411ea 100644 --- a/src/tools/scheduler_manage.rs +++ b/src/tools/scheduler_manage.rs @@ -31,7 +31,7 @@ impl Tool for SchedulerManageTool { } fn description(&self) -> &str { - "Manage repository-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. \ + "Manage repository-backed scheduled jobs. Supports actions: list, get, put, update, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. \ \ When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing. For cron schedules, standard cron syntax is supported: use 1-5 for Monday-Friday, 0 or 7 for Sunday. \ \ @@ -63,7 +63,7 @@ impl Tool for SchedulerManageTool { "properties": { "action": { "type": "string", - "enum": ["list", "get", "put", "delete", "pause", "resume"] + "enum": ["list", "get", "put", "update", "delete", "pause", "resume"] }, "id": { "type": "string", @@ -179,6 +179,16 @@ impl Tool for SchedulerManageTool { let saved = self.jobs.upsert_scheduler_job(&input)?; record_to_json(&saved) } + "update" => { + let id = require_str(&args, "id")?; + let record = self + .jobs + .get_scheduler_job(id)? + .ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?; + let input = build_update_upsert(context, &args, &self.known_agents, &record)?; + let saved = self.jobs.upsert_scheduler_job(&input)?; + record_to_json(&saved) + } _ => return Ok(error_result("Unsupported action")), }; @@ -257,6 +267,72 @@ fn build_upsert( }) } +fn build_update_upsert( + context: &crate::tools::ToolContext, + args: &serde_json::Value, + known_agents: &HashSet, + existing: &crate::storage::SchedulerJobRecord, +) -> anyhow::Result { + let mut upsert = record_to_upsert(existing); + + if let Some(schedule_value) = args.get("schedule") { + let schedule: SchedulerSchedule = serde_json::from_value(schedule_value.clone())?; + schedule.validate(&upsert.id)?; + upsert.schedule = serde_json::to_value(&schedule)?; + let (interval_secs, startup_delay_secs) = match &schedule { + SchedulerSchedule::Interval { + seconds, + startup_delay_secs, + } => (*seconds as i64, *startup_delay_secs as i64), + _ => (0, 0), + }; + upsert.interval_secs = interval_secs; + upsert.startup_delay_secs = startup_delay_secs; + upsert.next_fire_at = None; + } + + if args.get("target").is_some() { + upsert.target = enrich_target_from_context( + args.get("target").cloned().unwrap_or_else(|| json!({})), + context, + ); + } + + if let Some(payload) = args.get("payload") { + upsert.payload = payload.clone(); + } + + if let Some(enabled) = args.get("enabled").and_then(|value| value.as_bool()) { + upsert.enabled = enabled; + upsert.state = if enabled { + SchedulerJobState::Scheduled + } else { + SchedulerJobState::Paused + }; + if !enabled { + upsert.paused_at = Some(current_timestamp()); + } else { + upsert.paused_at = None; + upsert.completed_at = None; + } + upsert.next_fire_at = None; + } + + if args.get("max_runs").is_some() { + upsert.max_runs = args.get("max_runs").and_then(|value| value.as_i64()); + } + + if upsert.kind == "agent_task" || upsert.kind == "silent_agent_task" { + validate_agent_task_payload(&upsert.payload, known_agents)?; + validate_target_fields(&upsert.target, &["channel", "chat_id"], &upsert.kind)?; + } else if upsert.kind == "outbound_message" { + validate_outbound_message_payload(&upsert.payload)?; + validate_target_fields(&upsert.target, &["channel", "chat_id"], "outbound_message")?; + } + + Ok(upsert) +} + fn enrich_target_from_context( target: serde_json::Value, context: &crate::tools::ToolContext, @@ -713,4 +789,213 @@ mod tests { assert!(payload_description.contains("每天9点")); assert!(payload_description.contains("每小时")); } + + #[tokio::test] + async fn test_scheduler_manage_update_partial_fields() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store.clone(), HashSet::new()); + + // First, create a job + let put_result = tool + .execute(json!({ + "action": "put", + "id": "test_update_job", + "kind": "outbound_message", + "schedule": { + "type": "interval", + "seconds": 60 + }, + "target": { + "channel": "test-channel", + "chat_id": "oc_demo" + }, + "payload": { + "content": "original message" + }, + "max_runs": 10 + })) + .await + .unwrap(); + assert!(put_result.success); + + // Update only payload + let update_result = tool + .execute(json!({ + "action": "update", + "id": "test_update_job", + "payload": { + "content": "updated message" + } + })) + .await + .unwrap(); + assert!(update_result.success); + assert!(update_result.output.contains("updated message")); + assert!(update_result.output.contains("test_update_job")); + + // Verify other fields preserved + let get_result = tool + .execute(json!({ + "action": "get", + "id": "test_update_job" + })) + .await + .unwrap(); + assert!(get_result.success); + assert!(get_result.output.contains("interval")); + assert!(get_result.output.contains("test-channel")); + assert!(get_result.output.contains("max_runs\": 10")); + } + + #[tokio::test] + async fn test_scheduler_manage_update_schedule() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store.clone(), HashSet::new()); + + // Create job + let _ = tool + .execute(json!({ + "action": "put", + "id": "test_update_schedule", + "kind": "outbound_message", + "schedule": { + "type": "interval", + "seconds": 60 + }, + "target": { + "channel": "test", + "chat_id": "oc_demo" + }, + "payload": { "content": "ping" } + })) + .await + .unwrap(); + + // Update schedule + let update_result = tool + .execute(json!({ + "action": "update", + "id": "test_update_schedule", + "schedule": { + "type": "cron", + "expression": "0 9 * * *" + } + })) + .await + .unwrap(); + assert!(update_result.success); + assert!(update_result.output.contains("cron")); + assert!(update_result.output.contains("0 9 * * *")); + } + + #[tokio::test] + async fn test_scheduler_manage_update_enabled() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store.clone(), HashSet::new()); + + // Create enabled job + let _ = tool + .execute(json!({ + "action": "put", + "id": "test_update_enabled", + "kind": "outbound_message", + "schedule": { "type": "interval", "seconds": 60 }, + "target": { "channel": "test", "chat_id": "oc_demo" }, + "payload": { "content": "ping" }, + "enabled": true + })) + .await + .unwrap(); + + // Disable it + let update_result = tool + .execute(json!({ + "action": "update", + "id": "test_update_enabled", + "enabled": false + })) + .await + .unwrap(); + assert!(update_result.success); + assert!(update_result.output.contains("\"enabled\": false")); + assert!(update_result.output.contains("paused")); + + // Re-enable it + let update_result = tool + .execute(json!({ + "action": "update", + "id": "test_update_enabled", + "enabled": true + })) + .await + .unwrap(); + assert!(update_result.success); + assert!(update_result.output.contains("\"enabled\": true")); + assert!(update_result.output.contains("scheduled")); + } + + #[tokio::test] + async fn test_scheduler_manage_update_job_not_found() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store, HashSet::new()); + + let result = tool + .execute(json!({ + "action": "update", + "id": "nonexistent_job", + "payload": { "content": "new" } + })) + .await; + + assert!(result.is_err()); + let error = result.err().unwrap().to_string(); + assert!(error.contains("scheduler job 'nonexistent_job' not found")); + } + + #[tokio::test] + async fn test_scheduler_manage_update_preserves_agent_task_agent() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store.clone(), HashSet::from(["planner".to_string()])); + + // Create agent_task + let _ = tool + .execute(json!({ + "action": "put", + "id": "test_update_agent", + "kind": "agent_task", + "schedule": { "type": "cron", "expression": "0 9 * * *" }, + "target": { "channel": "test", "chat_id": "oc_demo" }, + "payload": { + "prompt": "original task", + "agent": "planner" + } + })) + .await + .unwrap(); + + // Update only prompt + let update_result = tool + .execute(json!({ + "action": "update", + "id": "test_update_agent", + "payload": { + "prompt": "updated task", + "agent": "planner" + } + })) + .await + .unwrap(); + assert!(update_result.success); + + // Verify agent preserved (when explicitly provided) + let get_result = tool + .execute(json!({ + "action": "get", + "id": "test_update_agent" + })) + .await + .unwrap(); + assert!(get_result.output.contains("planner")); + assert!(get_result.output.contains("updated task")); + } }