diff --git a/src/cli/input.rs b/src/cli/input.rs index 7a0820d..f6b0417 100644 --- a/src/cli/input.rs +++ b/src/cli/input.rs @@ -10,14 +10,8 @@ pub enum InputEvent { #[derive(Debug, Clone, PartialEq, Eq)] pub enum InputCommand { Exit, - Clear, New(Option), Save(Option), - Sessions, - Use(String), - Rename(String), - Archive, - Delete, } pub struct InputHandler { @@ -74,14 +68,8 @@ impl InputHandler { match command { "/quit" | "/exit" | "/q" => Some(InputCommand::Exit), - "/clear" => Some(InputCommand::Clear), "/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))), "/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))), - "/sessions" => Some(InputCommand::Sessions), - "/use" => arg.map(|value| InputCommand::Use(value.to_string())), - "/rename" => arg.map(|value| InputCommand::Rename(value.to_string())), - "/archive" => Some(InputCommand::Archive), - "/delete" => Some(InputCommand::Delete), _ => None, } } @@ -120,10 +108,6 @@ mod tests { handler.handle_special_commands("/quit"), Some(InputCommand::Exit) ); - assert_eq!( - handler.handle_special_commands("/clear"), - Some(InputCommand::Clear) - ); assert_eq!( handler.handle_special_commands("/new"), Some(InputCommand::New(None)) @@ -140,27 +124,6 @@ mod tests { handler.handle_special_commands("/save ./debug/session.md"), Some(InputCommand::Save(Some("./debug/session.md".to_string()))) ); - assert_eq!( - handler.handle_special_commands("/sessions"), - Some(InputCommand::Sessions) - ); - assert_eq!( - handler.handle_special_commands("/use abc123"), - Some(InputCommand::Use("abc123".to_string())) - ); - assert_eq!( - handler.handle_special_commands("/rename project alpha"), - Some(InputCommand::Rename("project alpha".to_string())) - ); - assert_eq!( - handler.handle_special_commands("/archive"), - Some(InputCommand::Archive) - ); - assert_eq!( - handler.handle_special_commands("/delete"), - Some(InputCommand::Delete) - ); assert_eq!(handler.handle_special_commands("/unknown"), None); - assert_eq!(handler.handle_special_commands("/use"), None); } } diff --git a/src/client/mod.rs b/src/client/mod.rs index a86b259..3cf9f62 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,36 +8,6 @@ use tokio_tungstenite::{connect_async, tungstenite::Message}; use crate::cli::{InputCommand, InputEvent, InputHandler}; -fn format_session_list( - sessions: &[crate::protocol::SessionSummary], - current_session_id: Option<&str>, -) -> String { - if sessions.is_empty() { - return "No sessions found.".to_string(); - } - - let mut lines = Vec::with_capacity(sessions.len() + 1); - lines.push("Sessions:".to_string()); - for session in sessions { - let marker = if current_session_id == Some(session.session_id.as_str()) { - "*" - } else { - "-" - }; - let archived = if session.archived_at.is_some() { - " [archived]" - } else { - "" - }; - lines.push(format!( - "{} {} | {} | {} messages{}", - marker, session.session_id, session.title, session.message_count, archived, - )); - } - - lines.join("\n") -} - fn parse_message(raw: &str) -> Result { serde_json::from_str(raw) } @@ -54,7 +24,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { let mut input = InputHandler::new(); let mut current_session_id: Option = None; - input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use , /rename , /archive, /delete, /clear, /quit\n").await?; + input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?; // Main loop: poll both stdin and WebSocket loop { @@ -91,29 +61,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { current_session_id = Some(session_id.clone()); input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?; } - WsOutbound::SessionList { sessions, current_session_id: listed_current } => { - let display = format_session_list(&sessions, listed_current.as_deref()); - input.write_output(&format!("{}\n", display)).await?; - } - WsOutbound::SessionLoaded { session_id, title, message_count } => { - current_session_id = Some(session_id.clone()); - input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?; - } - WsOutbound::SessionRenamed { session_id, title } => { - input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?; - } - WsOutbound::SessionArchived { session_id } => { - input.write_output(&format!("Archived session: {}\n", session_id)).await?; - } - WsOutbound::SessionDeleted { session_id } => { - if current_session_id.as_deref() == Some(session_id.as_str()) { - current_session_id = None; - } - input.write_output(&format!("Deleted session: {}\n", session_id)).await?; - } - WsOutbound::HistoryCleared { session_id } => { - input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?; - } WsOutbound::SessionSaved { session_id, filepath } => { input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?; } @@ -138,39 +85,25 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { input.write_output("Goodbye!").await?; break; } - InputEvent::Command(InputCommand::Clear) => { - let inbound = WsInbound::ClearHistory { - chat_id: None, - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } InputEvent::Command(InputCommand::New(title)) => { - // 使用新的命令层:通过 CliInputAdapter 构建 Command + // 使用 CliInputAdapter 构建 Command let adapter = CliInputAdapter::new(); let ctx = AdapterContext::new("cli") .with_session_id(current_session_id.as_deref().unwrap_or("")); // 构建输入字符串 - let input = match title { + let input_str = match title { Some(t) => format!("/new {}", t), None => "/new".to_string(), }; // 解析为 Command - match adapter.try_parse(&input, ctx) { + match adapter.try_parse(&input_str, ctx) { Ok(Some(command)) => { - // 序列化为 JSON 通过 WebSocket 发送 + // 序列化为 JSON let json = serde_json::to_string(&command).unwrap_or_default(); - let inbound = WsInbound::UserInput { - content: json, - channel: None, - chat_id: current_session_id.clone(), - sender_id: None, - }; + // 通过 Command 消息发送 + let inbound = WsInbound::Command { payload: json }; if let Ok(text) = serialize_inbound(&inbound) { let _ = sender.send(Message::Text(text.into())).await; } @@ -184,62 +117,40 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { } continue; } - InputEvent::Command(InputCommand::Sessions) => { - let inbound = WsInbound::ListSessions { - include_archived: true, - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Use(session_id)) => { - let inbound = WsInbound::LoadSession { session_id }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Rename(title)) => { - let inbound = WsInbound::RenameSession { - session_id: current_session_id.clone(), - title, - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Archive) => { - let inbound = WsInbound::ArchiveSession { - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Delete) => { - let inbound = WsInbound::DeleteSession { - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } InputEvent::Command(InputCommand::Save(filepath)) => { - let inbound = WsInbound::SaveSession { - filepath, - session_id: current_session_id.clone(), + // 使用 CliInputAdapter 构建 Command + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("cli") + .with_session_id(current_session_id.as_deref().unwrap_or("")); + + // 构建输入字符串 + let input_str = match filepath { + Some(p) => format!("/save {}", p), + None => "/save".to_string(), }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; + + // 解析为 Command + match adapter.try_parse(&input_str, ctx) { + Ok(Some(command)) => { + // 序列化为 JSON + let json = serde_json::to_string(&command).unwrap_or_default(); + // 通过 Command 消息发送 + let inbound = WsInbound::Command { payload: json }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + } + Ok(None) => { + tracing::warn!("Failed to parse /save command"); + } + Err(e) => { + tracing::error!(error = %e, "Error parsing /save command"); + } } continue; } InputEvent::Message(msg) => { - let inbound = WsInbound::UserInput { + let inbound = WsInbound::Message { content: msg.content, channel: None, chat_id: current_session_id.clone(), diff --git a/src/command/adapters/channel.rs b/src/command/adapters/channel.rs new file mode 100644 index 0000000..e9af6e7 --- /dev/null +++ b/src/command/adapters/channel.rs @@ -0,0 +1,187 @@ +use crate::command::adapter::{AdapterError, InputAdapter}; +use crate::command::context::AdapterContext; +use crate::command::Command; + +/// Channel 输入适配器 +/// +/// 将 Channel 消息中的文本命令(如 "/new", "/save")转换为 Command +pub struct ChannelInputAdapter; + +impl ChannelInputAdapter { + /// 创建新的 Channel 输入适配器 + pub fn new() -> Self { + Self + } +} + +impl Default for ChannelInputAdapter { + fn default() -> Self { + Self::new() + } +} + +impl InputAdapter for ChannelInputAdapter { + fn try_parse( + &self, + input: &str, + _ctx: AdapterContext, + ) -> Result<Option<Command>, AdapterError> { + let trimmed = input.trim(); + + // 解析 /new 命令 + if trimmed == "/new" { + return Ok(Some(Command::CreateSession { title: None })); + } + + if let Some(title) = trimmed.strip_prefix("/new ") { + let title = title.trim(); + return Ok(Some(Command::CreateSession { + title: Some(title.to_string()), + })); + } + + // 解析 /save 命令 + if trimmed == "/save" { + return Ok(Some(Command::SaveSession { + filepath: None, + include_all: false, + })); + } + + if let Some(args) = trimmed.strip_prefix("/save ") { + let args = args.trim(); + // 解析参数:可能是 "all"、路径、或 "all 路径" + let (include_all, filepath) = if args == "all" { + // /save all - 保存全部消息 + (true, None) + } else if args.starts_with("all ") { + // /save all <filepath> - 保存全部消息到指定路径 + let path = args[4..].trim(); + (true, Some(path.to_string())) + } else { + // /save <filepath> - 保存活跃消息到指定路径 + (false, Some(args.to_string())) + }; + return Ok(Some(Command::SaveSession { filepath, include_all })); + } + + // 不是命令,返回 None + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_channel_adapter_new_without_title() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("/new", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!(cmd, Command::CreateSession { title: None })); + } + + #[test] + fn test_channel_adapter_new_with_title() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("/new planning session", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::CreateSession { + title: Some(ref t) + } if t == "planning session" + )); + } + + #[test] + fn test_channel_adapter_save_without_path() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("/save", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveSession { + filepath: None, + include_all: false, + } + )); + } + + #[test] + fn test_channel_adapter_save_with_path() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("/save ./session.md", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveSession { + filepath: Some(ref p), + include_all: false, + } if p == "./session.md" + )); + } + + #[test] + fn test_channel_adapter_save_all() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("/save all", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveSession { + filepath: None, + include_all: true, + } + )); + } + + #[test] + fn test_channel_adapter_save_all_with_path() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("/save all ./session.md", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveSession { + filepath: Some(ref p), + include_all: true, + } if p == "./session.md" + )); + } + + #[test] + fn test_channel_adapter_not_command() { + let adapter = ChannelInputAdapter::new(); + let ctx = AdapterContext::new("channel"); + + let result = adapter.try_parse("hello world", ctx).unwrap(); + + assert!(result.is_none()); + } +} diff --git a/src/command/adapters/mod.rs b/src/command/adapters/mod.rs index 3ea0590..d487429 100644 --- a/src/command/adapters/mod.rs +++ b/src/command/adapters/mod.rs @@ -1,2 +1,3 @@ +pub mod channel; pub mod cli; pub mod websocket; diff --git a/src/config/mod.rs b/src/config/mod.rs index 2a7fb86..0bff93c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1472,7 +1472,7 @@ mod tests { assert!(config.scheduler.jobs.is_empty()); let effective_jobs = config.scheduler.effective_jobs(&config.time); - assert_eq!(effective_jobs.len(), 1); + assert_eq!(effective_jobs.len(), 2); assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID); assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent); assert_eq!( @@ -1481,6 +1481,8 @@ mod tests { expression: "0 */4 * * *".to_string(), } ); + // 第二个内置作业是会话清理 + assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID); } #[test] @@ -1516,7 +1518,8 @@ mod tests { let effective_jobs = scheduler.effective_jobs(&TimeConfig { timezone: "Asia/Shanghai".to_string(), }); - assert_eq!(effective_jobs.len(), 2); + assert_eq!(effective_jobs.len(), 3); // 2个内置 + 1个自定义 + // 第一个作业:内存维护(被覆盖为禁用) assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID); assert!(!effective_jobs[0].enabled); assert_eq!( @@ -1525,7 +1528,11 @@ mod tests { expression: "15 2 * * *".to_string(), } ); - assert_eq!(effective_jobs[1].id, "custom.reminder"); + // 第二个作业:会话清理(保持默认) + assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID); + assert!(effective_jobs[1].enabled); + // 第三个作业:自定义提醒 + assert_eq!(effective_jobs[2].id, "custom.reminder"); } #[test] diff --git a/src/gateway/command.rs b/src/gateway/command.rs index 03ad8f0..c0278b3 100644 --- a/src/gateway/command.rs +++ b/src/gateway/command.rs @@ -1,165 +1,2 @@ -use crate::agent::AgentError; - -use super::session::Session; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum InChatCommand { - FreshConversation, -} - -fn parse_in_chat_command(content: &str) -> Option<InChatCommand> { - match content.trim() { - "/new" | "/reset" => Some(InChatCommand::FreshConversation), - _ => None, - } -} - -pub(crate) fn handle_in_chat_command( - session: &mut Session, - chat_id: &str, - content: &str, -) -> Result<Option<String>, AgentError> { - match parse_in_chat_command(content) { - Some(InChatCommand::FreshConversation) => { - session.reset_chat_context(chat_id)?; - Ok(Some("Started a fresh conversation.".to_string())) - } - None => Ok(None), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::bus::ChatMessage; - use crate::config::LLMProviderConfig; - use crate::skills::SkillRuntime; - use crate::storage::SessionStore; - use crate::tools::ToolRegistry; - use std::collections::HashMap; - use std::sync::Arc; - use tokio::sync::mpsc; - - const TEST_CHANNEL: &str = "test-channel"; - - fn test_provider_config() -> LLMProviderConfig { - LLMProviderConfig { - provider_type: "openai".to_string(), - name: "test".to_string(), - base_url: "http://localhost".to_string(), - api_key: "test-key".to_string(), - extra_headers: HashMap::new(), - llm_timeout_secs: 120, - memory_maintenance_timeout_secs: 600, - model_id: "test-model".to_string(), - temperature: Some(0.0), - max_tokens: Some(32), - context_window_tokens: None, - model_extra: HashMap::new(), - max_tool_iterations: 1, - tool_result_max_chars: 20_000, - context_tool_result_trim_chars: 20_000, - } - } - - #[test] - fn test_parse_in_chat_command_aliases() { - assert_eq!( - parse_in_chat_command("/new"), - Some(InChatCommand::FreshConversation) - ); - assert_eq!( - parse_in_chat_command(" /reset \n"), - Some(InChatCommand::FreshConversation) - ); - assert_eq!(parse_in_chat_command("/new planning"), None); - assert_eq!(parse_in_chat_command("please /reset"), None); - } - - #[tokio::test] - async fn test_handle_in_chat_command_resets_active_history_only() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(ToolRegistry::new()); - let mut session = Session::new( - TEST_CHANNEL.to_string(), - test_provider_config(), - user_tx, - tools, - skills, - store.clone(), - 100, - Some(4), - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - let response = handle_in_chat_command(&mut session, "chat-1", "/reset") - .unwrap() - .unwrap(); - - assert_eq!(response, "Started a fresh conversation."); - assert!(session.get_history("chat-1").unwrap().is_empty()); - assert!( - store - .load_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .is_empty() - ); - assert_eq!( - store - .load_all_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .len(), - // 新设计:系统提示词不再持久化,只有 1 条用户消息 - 1, - ); - - session.ensure_chat_loaded("chat-1").unwrap(); - let history = session.get_history("chat-1").unwrap(); - // 新设计:系统提示词不再持久化到历史记录 - assert_eq!(history.len(), 0); - } - - #[tokio::test] - async fn test_reset_reinjects_agent_prompt_before_next_user_message() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(ToolRegistry::new()); - let mut session = Session::new( - TEST_CHANNEL.to_string(), - test_provider_config(), - user_tx, - tools, - skills, - store, - 100, - Some(4), - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap(); - session - .ensure_agent_prompt_before_user_message("chat-1") - .unwrap(); - - // 新设计:系统提示词不再持久化到历史记录 - let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 0); - } -} +// 此文件已废弃,InChatCommand 功能已合并到 Command 系统 +// 保留文件以避免破坏现有 import,但内容为空 diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index d1624ae..265527a 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -7,7 +7,6 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL use crate::config::LLMProviderConfig; use tokio::sync::Mutex; -use super::command::handle_in_chat_command; use super::compaction::schedule_background_history_compaction; use super::message_prepare::enrich_user_content_with_media_refs; use super::session::Session; @@ -138,18 +137,6 @@ impl AgentExecutionService { session_guard.ensure_persistent_session(request.chat_id)?; session_guard.ensure_chat_loaded(request.chat_id)?; - if let Some(command_response) = - handle_in_chat_command(&mut session_guard, request.chat_id, request.content)? - { - return Ok(vec![OutboundMessage::assistant( - request.channel_name.to_string(), - request.chat_id.to_string(), - command_response, - None, - HashMap::new(), - )]); - } - session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?; let media_refs: Vec<String> = request diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 5c21705..fd80335 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -4,8 +4,11 @@ use tokio::sync::Semaphore; use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::{InboundMessage, MessageBus, OutboundMessage}; -use crate::command::handler::InChatCommandRouter; -use crate::command::Command; +use crate::command::adapter::InputAdapter; +use crate::command::adapters::channel::ChannelInputAdapter; +use crate::command::handler::CommandRouter; +use crate::command::handlers::save_session::SaveSessionCommandHandler; +use crate::command::handlers::session::SessionCommandHandler; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::skills::SkillPromptProvider; @@ -18,7 +21,7 @@ pub struct InboundProcessor { session_manager: SessionManager, semaphore: Arc<Semaphore>, _provider_config: LLMProviderConfig, - command_router: Arc<InChatCommandRouter>, + command_router: Arc<CommandRouter>, } impl InboundProcessor { @@ -29,7 +32,11 @@ impl InboundProcessor { provider_config: LLMProviderConfig, ) -> Self { // 创建命令路由器并注册处理器 - let mut command_router = InChatCommandRouter::new(); + let mut command_router = CommandRouter::new(); + + // 注册 Session 处理器 + let cli_sessions = session_manager.cli_sessions(); + command_router.register(Box::new(SessionCommandHandler::new(cli_sessions))); // 注册 save_session 处理器 let store = session_manager.store(); @@ -43,7 +50,7 @@ impl InboundProcessor { )), Box::new(SkillPromptProvider::new(skills)), ])); - command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new( + command_router.register(Box::new(SaveSessionCommandHandler::new( store, system_prompt_provider, ))); @@ -103,18 +110,28 @@ impl InboundProcessor { } async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> { - // 尝试解析为命令 - if let Some(cmd) = parse_in_chat_command(&inbound.content) { + // 使用 ChannelInputAdapter 尝试解析命令 + let adapter = ChannelInputAdapter::new(); + let ctx = crate::command::context::AdapterContext::new(&inbound.channel) + .with_session_id(&inbound.chat_id); + + if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) { // 使用命令路由器处理 - match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? { - Some(response_msg) => { - // 发送命令执行结果给用户 + let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel) + .with_session_id(&inbound.chat_id); + + let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await; + + // 发送响应给用户 + if response.success { + // 提取响应消息 + for msg in &response.messages { if let Err(error) = self .bus .publish_outbound(OutboundMessage::assistant( inbound.channel.clone(), inbound.chat_id.clone(), - response_msg, + msg.content.clone(), None, inbound.forwarded_metadata.clone(), )) @@ -122,13 +139,23 @@ impl InboundProcessor { { tracing::error!(error = %error, "Failed to publish command response"); } - return Ok(()); } - None => { - // 命令已处理但没有返回消息 - return Ok(()); + } else if let Some(error) = response.error { + if let Err(e) = self + .bus + .publish_outbound(OutboundMessage::assistant( + inbound.channel.clone(), + inbound.chat_id.clone(), + format!("Error [{}]: {}", error.code, error.message), + None, + inbound.forwarded_metadata.clone(), + )) + .await + { + tracing::error!(error = %e, "Failed to publish error response"); } } + return Ok(()); } // 普通消息进入 AgentLoop @@ -183,41 +210,3 @@ impl InboundProcessor { Ok(()) } } - -/// 解析聊天中的命令 -/// -/// 支持格式: -/// - `/save` - 保存活跃会话消息(到 cutoff) -/// - `/save all` - 保存全部会话消息(包括 cutoff 之前) -/// - `/save <filepath>` - 保存活跃消息到指定路径 -/// - `/save all <filepath>` - 保存全部消息到指定路径 -/// -/// 返回 Some(Command) 如果是命令 -/// 返回 None 如果不是命令 -fn parse_in_chat_command(content: &str) -> Option<Command> { - let trimmed = content.trim(); - - if trimmed.starts_with("/save") { - let args = trimmed[5..].trim(); - - // 解析参数 - let (include_all, filepath) = if args.is_empty() { - // /save 无参数 - 只保存活跃消息 - (false, None) - } else if args == "all" { - // /save all - 保存全部消息 - (true, None) - } else if args.starts_with("all ") { - // /save all <filepath> - 保存全部消息到指定路径 - let path = args[4..].trim(); - (true, Some(path.to_string())) - } else { - // /save <filepath> - 保存活跃消息到指定路径 - (false, Some(args.to_string())) - }; - - Some(Command::SaveSession { filepath, include_all }) - } else { - None - } -} diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index ee7249b..c9c0a8b 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,14 +1,14 @@ use super::GatewayState; use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::InboundMessage; -use crate::command::adapter::OutputAdapter; +use crate::command::adapter::{InputAdapter, OutputAdapter}; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::gateway::agent_prompt_provider::AgentPromptProvider; -use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; +use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use crate::skills::SkillPromptProvider; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; @@ -125,17 +125,6 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } -fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { - SessionSummary { - session_id: record.id, - title: record.title, - channel_name: record.channel_name, - chat_id: record.chat_id, - message_count: record.message_count, - last_active_at: record.last_active_at, - archived_at: record.archived_at, - } -} async fn handle_inbound( state: &Arc<GatewayState>, @@ -143,9 +132,9 @@ async fn handle_inbound( runtime_session_id: &str, current_session_id: &mut String, inbound: WsInbound, -) -> Result<(), crate::agent::AgentError> { +) -> Result<(), AgentError> { match inbound { - WsInbound::UserInput { + WsInbound::Message { content, chat_id, sender_id, @@ -181,53 +170,73 @@ async fn handle_inbound( Ok(()) } - WsInbound::ClearHistory { - session_id, - chat_id, - } => { - let target = session_id - .or(chat_id) - .unwrap_or_else(|| current_session_id.clone()); - state - .session_manager - .cli_sessions() - .clear_messages(&target)?; - - if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { - session.lock().await.remove_history(&target); - } - - let _ = sender - .send(WsOutbound::HistoryCleared { session_id: target }) - .await; - Ok(()) - } - WsInbound::CreateSession { title } => { - // 使用新的命令层处理 - let _input_adapter = WebSocketInputAdapter::new(); + WsInbound::Command { payload } => { + // 使用 Command 系统处理命令 + let input_adapter = WebSocketInputAdapter::new(); let output_adapter = WebSocketOutputAdapter::new(); - let cli_sessions = state.session_manager.cli_sessions(); - let handler = SessionCommandHandler::new(cli_sessions); - let router = { - let mut r = CommandRouter::new(); - r.register(Box::new(handler)); - r + + // 解析命令 + let adapter_ctx = crate::command::context::AdapterContext::new("websocket") + .with_session_id(current_session_id.as_str()); + + let cmd = match input_adapter.try_parse(&payload, adapter_ctx) { + Ok(Some(cmd)) => cmd, + Ok(None) => { + // 不是命令,返回错误 + let _ = sender + .send(WsOutbound::Error { + code: "INVALID_COMMAND".to_string(), + message: "Invalid command payload".to_string(), + }) + .await; + return Ok(()); + } + Err(e) => { + let _ = sender + .send(WsOutbound::Error { + code: "PARSE_ERROR".to_string(), + message: e.to_string(), + }) + .await; + return Ok(()); + } }; - // 构建命令 - let cmd = crate::command::Command::CreateSession { title }; + // 创建命令路由器 + let cli_sessions = state.session_manager.cli_sessions(); + let store = state.session_manager.store(); + let skills = state.session_manager.skills(); + let provider_config = state.config.get_provider_config("default") + .map_err(|e| AgentError::Other(e.to_string()))?; + let prompt_repository = state.session_manager.store().clone(); + + let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![ + Box::new(AgentPromptProvider::new( + 0, + provider_config.clone(), + prompt_repository.clone(), + )), + Box::new(SkillPromptProvider::new(skills)), + ])); + + let mut router = CommandRouter::new(); + router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone()))); + router.register(Box::new(SaveSessionCommandHandler::new( + store, + system_prompt_provider, + ))); + + // 构建命令上下文 let cmd_ctx = CommandContext::new("websocket") .with_session_id(current_session_id.as_str()); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; - // 适配输出 - let outbounds = output_adapter.adapt(response); - // 处理响应 - for msg in outbounds { - if let WsOutbound::SessionCreated { session_id, title: _ } = &msg { + if response.success { + // 更新当前会话 ID(如果是创建会话) + if let Some(session_id) = response.metadata.get("session_id") { *current_session_id = session_id.clone(); state .channel_manager @@ -239,178 +248,14 @@ async fn handle_inbound( ) .await; } + } + + // 适配并发送响应 + let outbounds = output_adapter.adapt(response); + for msg in outbounds { let _ = sender.send(msg).await; } - Ok(()) - } - WsInbound::ListSessions { include_archived } => { - let records = state - .session_manager - .cli_sessions() - .list(include_archived)?; - let summaries = records.into_iter().map(to_session_summary).collect(); - let _ = sender - .send(WsOutbound::SessionList { - sessions: summaries, - current_session_id: Some(current_session_id.clone()), - }) - .await; - Ok(()) - } - WsInbound::LoadSession { session_id } => { - let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else { - let _ = sender - .send(WsOutbound::Error { - code: "SESSION_NOT_FOUND".to_string(), - message: format!("Session not found: {}", session_id), - }) - .await; - return Ok(()); - }; - - *current_session_id = record.id.clone(); - state - .channel_manager - .cli_channel() - .register_connection( - record.id.clone(), - runtime_session_id.to_string(), - sender.clone(), - ) - .await; - let _ = sender - .send(WsOutbound::SessionLoaded { - session_id: record.id, - title: record.title, - message_count: record.message_count, - }) - .await; - Ok(()) - } - WsInbound::RenameSession { session_id, title } => { - let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state - .session_manager - .cli_sessions() - .rename(&target, &title)?; - let _ = sender - .send(WsOutbound::SessionRenamed { - session_id: target, - title, - }) - .await; - Ok(()) - } - WsInbound::ArchiveSession { session_id } => { - let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.cli_sessions().archive(&target)?; - let _ = sender - .send(WsOutbound::SessionArchived { session_id: target }) - .await; - Ok(()) - } - WsInbound::DeleteSession { session_id } => { - let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.cli_sessions().delete(&target)?; - - let replacement = if target == *current_session_id { - Some(state.session_manager.cli_sessions().create(None)?) - } else { - None - }; - - if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { - session.lock().await.remove_history(&target); - } - - let _ = sender - .send(WsOutbound::SessionDeleted { - session_id: target.clone(), - }) - .await; - - if let Some(record) = replacement { - *current_session_id = record.id.clone(); - state - .channel_manager - .cli_channel() - .register_connection( - record.id.clone(), - runtime_session_id.to_string(), - sender.clone(), - ) - .await; - let _ = sender - .send(WsOutbound::SessionCreated { - session_id: record.id, - title: record.title, - }) - .await; - } - - Ok(()) - } - WsInbound::SaveSession { filepath, session_id } => { - let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone()); - - // 获取所需依赖 - let store = state.session_manager.store(); - let skills = state.session_manager.skills(); - let provider_config = state.config.get_provider_config("default") - .map_err(|e| AgentError::Other(e.to_string()))?; - let prompt_repository = state.session_manager.store().clone(); - - // 构建组合系统提示词提供者(与运行时一致) - let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![ - Box::new(AgentPromptProvider::new( - 0, // save_session 不需要 reinject 逻辑 - provider_config.clone(), - prompt_repository, - )), - Box::new(SkillPromptProvider::new(skills)), - ])); - - // 构建处理器 - let handler = SaveSessionCommandHandler::new(store, system_prompt_provider); - let router = { - let mut r = CommandRouter::new(); - r.register(Box::new(handler)); - r - }; - - // 构建命令 - let cmd = crate::command::Command::SaveSession { filepath, include_all: true }; - let cmd_ctx = CommandContext::new("websocket") - .with_session_id(&target_session_id); - - // 执行命令 - let response = router.dispatch_with_response(cmd, cmd_ctx).await; - - // 处理响应 - if response.success { - let filepath = response - .metadata - .get("filepath") - .cloned() - .unwrap_or_default(); - let _ = sender - .send(WsOutbound::SessionSaved { - session_id: target_session_id, - filepath, - }) - .await; - } else { - let error = response.error.unwrap_or_else(|| { - crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error") - }); - let _ = sender - .send(WsOutbound::Error { - code: error.code, - message: error.message, - }) - .await; - } Ok(()) } WsInbound::Ping => { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index baad9d2..547c052 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -17,8 +17,9 @@ pub struct SessionSummary { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum WsInbound { - #[serde(rename = "user_input")] - UserInput { + /// 普通用户消息 + #[serde(rename = "message")] + Message { content: String, #[serde(default, skip_serializing_if = "Option::is_none")] channel: Option<String>, @@ -27,48 +28,9 @@ pub enum WsInbound { #[serde(default, skip_serializing_if = "Option::is_none")] sender_id: Option<String>, }, - #[serde(rename = "clear_history")] - ClearHistory { - #[serde(default, skip_serializing_if = "Option::is_none")] - chat_id: Option<String>, - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, - #[serde(rename = "create_session")] - CreateSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - title: Option<String>, - }, - #[serde(rename = "list_sessions")] - ListSessions { - #[serde(default)] - include_archived: bool, - }, - #[serde(rename = "load_session")] - LoadSession { session_id: String }, - #[serde(rename = "rename_session")] - RenameSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - title: String, - }, - #[serde(rename = "archive_session")] - ArchiveSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, - #[serde(rename = "delete_session")] - DeleteSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, - #[serde(rename = "save_session")] - SaveSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - filepath: Option<String>, - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, + /// 命令(JSON 格式) + #[serde(rename = "command")] + Command { payload: String }, #[serde(rename = "ping")] Ping, } @@ -114,26 +76,6 @@ pub enum WsOutbound { SessionEstablished { session_id: String }, #[serde(rename = "session_created")] SessionCreated { session_id: String, title: String }, - #[serde(rename = "session_list")] - SessionList { - sessions: Vec<SessionSummary>, - #[serde(default, skip_serializing_if = "Option::is_none")] - current_session_id: Option<String>, - }, - #[serde(rename = "session_loaded")] - SessionLoaded { - session_id: String, - title: String, - message_count: i64, - }, - #[serde(rename = "session_renamed")] - SessionRenamed { session_id: String, title: String }, - #[serde(rename = "session_archived")] - SessionArchived { session_id: String }, - #[serde(rename = "session_deleted")] - SessionDeleted { session_id: String }, - #[serde(rename = "history_cleared")] - HistoryCleared { session_id: String }, #[serde(rename = "session_saved")] SessionSaved { session_id: String, filepath: String }, #[serde(rename = "pong")] diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index de665d5..fa0cdf1 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -1,4 +1,4 @@ -use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; +use picobot::protocol::{WsInbound, WsOutbound}; use picobot::providers::{ChatCompletionRequest, Message}; /// Test that message with special characters is properly escaped @@ -53,70 +53,65 @@ fn test_chat_request_serialization() { } #[test] -fn test_session_inbound_serialization() { - let msg = WsInbound::CreateSession { - title: Some("demo".to_string()), +fn test_command_inbound_serialization() { + // Command is now sent as payload in WsInbound::Command + let command_json = r#"{"type":"create_session","title":"demo"}"#; + let msg = WsInbound::Command { + payload: command_json.to_string(), }; let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains(r#""type":"create_session""#)); - assert!(json.contains(r#""title":"demo""#)); + assert!(json.contains(r#""type":"command""#)); + assert!(json.contains(r#""payload":""#)); + assert!(json.contains(r#"create_session"#)); +} + +#[test] +fn test_message_inbound_serialization() { + let msg = WsInbound::Message { + content: "Hello world".to_string(), + channel: None, + chat_id: Some("session-1".to_string()), + sender_id: Some("user-1".to_string()), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"message""#)); + assert!(json.contains(r#""content":"Hello world""#)); + assert!(json.contains(r#""chat_id":"session-1""#)); let decoded: WsInbound = serde_json::from_str(&json).unwrap(); match decoded { - WsInbound::CreateSession { title } => { - assert_eq!(title.as_deref(), Some("demo")); + WsInbound::Message { content, chat_id, .. } => { + assert_eq!(content, "Hello world"); + assert_eq!(chat_id.as_deref(), Some("session-1")); } other => panic!("unexpected decoded variant: {:?}", other), } } #[test] -fn test_session_list_outbound_serialization() { - let msg = WsOutbound::SessionList { - sessions: vec![SessionSummary { - session_id: "session-1".to_string(), - title: "demo".to_string(), - channel_name: "cli".to_string(), - chat_id: "session-1".to_string(), - message_count: 2, - last_active_at: 123, - archived_at: None, - }], - current_session_id: Some("session-1".to_string()), +fn test_session_created_outbound_serialization() { + let msg = WsOutbound::SessionCreated { + session_id: "session-1".to_string(), + title: "demo".to_string(), }; let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains(r#""type":"session_list""#)); + assert!(json.contains(r#""type":"session_created""#)); assert!(json.contains(r#""session_id":"session-1""#)); - assert!(json.contains(r#""message_count":2"#)); + assert!(json.contains(r#""title":"demo""#)); let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); match decoded { - WsOutbound::SessionList { - sessions, - current_session_id, - } => { - assert_eq!(sessions.len(), 1); - assert_eq!(sessions[0].title, "demo"); - assert_eq!(current_session_id.as_deref(), Some("session-1")); + WsOutbound::SessionCreated { session_id, title } => { + assert_eq!(session_id, "session-1"); + assert_eq!(title, "demo"); } other => panic!("unexpected decoded variant: {:?}", other), } } -#[test] -fn test_clear_history_with_session_id_serialization() { - let msg = WsInbound::ClearHistory { - chat_id: None, - session_id: Some("session-1".to_string()), - }; - - let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains(r#""type":"clear_history""#)); - assert!(json.contains(r#""session_id":"session-1""#)); -} - #[test] fn test_tool_call_outbound_serialization() { let msg = WsOutbound::ToolCall {