diff --git a/src/cli/input.rs b/src/cli/input.rs index 7a0820d..3e4bb18 100644 --- a/src/cli/input.rs +++ b/src/cli/input.rs @@ -10,14 +10,10 @@ pub enum InputEvent { #[derive(Debug, Clone, PartialEq, Eq)] pub enum InputCommand { Exit, - Clear, New(Option), Save(Option), Sessions, Use(String), - Rename(String), - Archive, - Delete, } pub struct InputHandler { @@ -74,14 +70,10 @@ impl InputHandler { match command { "/quit" | "/exit" | "/q" => Some(InputCommand::Exit), - "/clear" => Some(InputCommand::Clear), "/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))), "/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))), - "/sessions" => Some(InputCommand::Sessions), + "/sessions" | "/list" => Some(InputCommand::Sessions), "/use" => arg.map(|value| InputCommand::Use(value.to_string())), - "/rename" => arg.map(|value| InputCommand::Rename(value.to_string())), - "/archive" => Some(InputCommand::Archive), - "/delete" => Some(InputCommand::Delete), _ => None, } } @@ -120,10 +112,6 @@ mod tests { handler.handle_special_commands("/quit"), Some(InputCommand::Exit) ); - assert_eq!( - handler.handle_special_commands("/clear"), - Some(InputCommand::Clear) - ); assert_eq!( handler.handle_special_commands("/new"), Some(InputCommand::New(None)) @@ -140,6 +128,10 @@ mod tests { handler.handle_special_commands("/save ./debug/session.md"), Some(InputCommand::Save(Some("./debug/session.md".to_string()))) ); + assert_eq!( + handler.handle_special_commands("/list"), + Some(InputCommand::Sessions) + ); assert_eq!( handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions) @@ -148,18 +140,6 @@ mod tests { handler.handle_special_commands("/use abc123"), Some(InputCommand::Use("abc123".to_string())) ); - assert_eq!( - handler.handle_special_commands("/rename project alpha"), - Some(InputCommand::Rename("project alpha".to_string())) - ); - assert_eq!( - handler.handle_special_commands("/archive"), - Some(InputCommand::Archive) - ); - assert_eq!( - handler.handle_special_commands("/delete"), - Some(InputCommand::Delete) - ); assert_eq!(handler.handle_special_commands("/unknown"), None); assert_eq!(handler.handle_special_commands("/use"), None); } diff --git a/src/client/mod.rs b/src/client/mod.rs index a86b259..8ae48af 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,36 +8,6 @@ use tokio_tungstenite::{connect_async, tungstenite::Message}; use crate::cli::{InputCommand, InputEvent, InputHandler}; -fn format_session_list( - sessions: &[crate::protocol::SessionSummary], - current_session_id: Option<&str>, -) -> String { - if sessions.is_empty() { - return "No sessions found.".to_string(); - } - - let mut lines = Vec::with_capacity(sessions.len() + 1); - lines.push("Sessions:".to_string()); - for session in sessions { - let marker = if current_session_id == Some(session.session_id.as_str()) { - "*" - } else { - "-" - }; - let archived = if session.archived_at.is_some() { - " [archived]" - } else { - "" - }; - lines.push(format!( - "{} {} | {} | {} messages{}", - marker, session.session_id, session.title, session.message_count, archived, - )); - } - - lines.join("\n") -} - fn parse_message(raw: &str) -> Result { serde_json::from_str(raw) } @@ -54,7 +24,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { let mut input = InputHandler::new(); let mut current_session_id: Option = None; - input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use , /rename , /archive, /delete, /clear, /quit\n").await?; + input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?; // Main loop: poll both stdin and WebSocket loop { @@ -91,29 +61,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { current_session_id = Some(session_id.clone()); input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?; } - WsOutbound::SessionList { sessions, current_session_id: listed_current } => { - let display = format_session_list(&sessions, listed_current.as_deref()); - input.write_output(&format!("{}\n", display)).await?; - } - WsOutbound::SessionLoaded { session_id, title, message_count } => { - current_session_id = Some(session_id.clone()); - input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?; - } - WsOutbound::SessionRenamed { session_id, title } => { - input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?; - } - WsOutbound::SessionArchived { session_id } => { - input.write_output(&format!("Archived session: {}\n", session_id)).await?; - } - WsOutbound::SessionDeleted { session_id } => { - if current_session_id.as_deref() == Some(session_id.as_str()) { - current_session_id = None; - } - input.write_output(&format!("Deleted session: {}\n", session_id)).await?; - } - WsOutbound::HistoryCleared { session_id } => { - input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?; - } WsOutbound::SessionSaved { session_id, filepath } => { input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?; } @@ -138,39 +85,25 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { input.write_output("Goodbye!").await?; break; } - InputEvent::Command(InputCommand::Clear) => { - let inbound = WsInbound::ClearHistory { - chat_id: None, - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } InputEvent::Command(InputCommand::New(title)) => { - // 使用新的命令层:通过 CliInputAdapter 构建 Command + // 使用 CliInputAdapter 构建 Command let adapter = CliInputAdapter::new(); let ctx = AdapterContext::new("cli") .with_session_id(current_session_id.as_deref().unwrap_or("")); // 构建输入字符串 - let input = match title { + let input_str = match title { Some(t) => format!("/new {}", t), None => "/new".to_string(), }; // 解析为 Command - match adapter.try_parse(&input, ctx) { + match adapter.try_parse(&input_str, ctx) { Ok(Some(command)) => { - // 序列化为 JSON 通过 WebSocket 发送 + // 序列化为 JSON let json = serde_json::to_string(&command).unwrap_or_default(); - let inbound = WsInbound::UserInput { - content: json, - channel: None, - chat_id: current_session_id.clone(), - sender_id: None, - }; + // 通过 Command 消息发送 + let inbound = WsInbound::Command { payload: json }; if let Ok(text) = serialize_inbound(&inbound) { let _ = sender.send(Message::Text(text.into())).await; } @@ -184,62 +117,97 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { } continue; } - InputEvent::Command(InputCommand::Sessions) => { - let inbound = WsInbound::ListSessions { - include_archived: true, + InputEvent::Command(InputCommand::Save(filepath)) => { + // 使用 CliInputAdapter 构建 Command + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("cli") + .with_session_id(current_session_id.as_deref().unwrap_or("")); + + // 构建输入字符串 + let input_str = match filepath { + Some(p) => format!("/save {}", p), + None => "/save".to_string(), }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; + + // 解析为 Command + match adapter.try_parse(&input_str, ctx) { + Ok(Some(command)) => { + // 序列化为 JSON + let json = serde_json::to_string(&command).unwrap_or_default(); + // 通过 Command 消息发送 + let inbound = WsInbound::Command { payload: json }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + } + Ok(None) => { + tracing::warn!("Failed to parse /save command"); + } + Err(e) => { + tracing::error!(error = %e, "Error parsing /save command"); + } + } + continue; + } + InputEvent::Command(InputCommand::Sessions) => { + // 使用 CliInputAdapter 构建 Command + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("cli") + .with_session_id(current_session_id.as_deref().unwrap_or("")); + + // 解析为 Command + match adapter.try_parse("/list", ctx) { + Ok(Some(command)) => { + // 序列化为 JSON + let json = serde_json::to_string(&command).unwrap_or_default(); + // 通过 Command 消息发送 + let inbound = WsInbound::Command { payload: json }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + } + Ok(None) => { + tracing::warn!("Failed to parse /list command"); + } + Err(e) => { + tracing::error!(error = %e, "Error parsing /list command"); + } } continue; } InputEvent::Command(InputCommand::Use(session_id)) => { - let inbound = WsInbound::LoadSession { session_id }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Rename(title)) => { - let inbound = WsInbound::RenameSession { - session_id: current_session_id.clone(), - title, - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Archive) => { - let inbound = WsInbound::ArchiveSession { - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Delete) => { - let inbound = WsInbound::DeleteSession { - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; - } - continue; - } - InputEvent::Command(InputCommand::Save(filepath)) => { - let inbound = WsInbound::SaveSession { - filepath, - session_id: current_session_id.clone(), - }; - if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(Message::Text(text.into())).await; + // 使用 CliInputAdapter 构建 Command + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("cli") + .with_session_id(current_session_id.as_deref().unwrap_or("")); + + // 构建输入字符串 + let input_str = format!("/use {}", session_id); + + // 解析为 Command + match adapter.try_parse(&input_str, ctx) { + Ok(Some(command)) => { + // 序列化为 JSON + let json = serde_json::to_string(&command).unwrap_or_default(); + // 通过 Command 消息发送 + let inbound = WsInbound::Command { payload: json }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + // 更新当前会话 ID + current_session_id = Some(session_id.clone()); + } + Ok(None) => { + tracing::warn!("Failed to parse /use command"); + } + Err(e) => { + tracing::error!(error = %e, "Error parsing /use command"); + } } continue; } InputEvent::Message(msg) => { - let inbound = WsInbound::UserInput { + let inbound = WsInbound::Message { content: msg.content, channel: None, chat_id: current_session_id.clone(), diff --git a/src/command/adapters/channel.rs b/src/command/adapters/channel.rs new file mode 100644 index 0000000..cd9e55b --- /dev/null +++ b/src/command/adapters/channel.rs @@ -0,0 +1,92 @@ +use crate::command::adapter::{AdapterError, InputAdapter}; +use crate::command::context::AdapterContext; +use crate::command::Command; + +/// Channel 输入适配器 +/// +/// 将 Channel 消息中的文本命令(如 "/new", "/save")转换为 Command +pub struct ChannelInputAdapter; + +impl ChannelInputAdapter { + /// 创建新的 Channel 输入适配器 + pub fn new() -> Self { + Self + } +} + +impl Default for ChannelInputAdapter { + fn default() -> Self { + Self::new() + } +} + +impl InputAdapter for ChannelInputAdapter { + fn try_parse( + &self, + input: &str, + _ctx: AdapterContext, + ) -> Result<Option<Command>, AdapterError> { + let trimmed = input.trim(); + + // 解析 /new 命令 + if trimmed == "/new" { + return Ok(Some(Command::CreateSession { title: None })); + } + + if let Some(title) = trimmed.strip_prefix("/new ") { + let title = title.trim(); + return Ok(Some(Command::CreateSession { + title: Some(title.to_string()), + })); + } + + // 解析 /save 命令 + if trimmed == "/save" { + return Ok(Some(Command::SaveSession { + filepath: None, + include_all: false, + })); + } + + if let Some(args) = trimmed.strip_prefix("/save ") { + let args = args.trim(); + // 解析参数:可能是 "all"、路径、或 "all 路径" + let (include_all, filepath) = if args == "all" { + // /save all - 保存全部消息 + (true, None) + } else if args.starts_with("all ") { + // /save all <filepath> - 保存全部消息到指定路径 + let path = args[4..].trim(); + (true, Some(path.to_string())) + } else { + // /save <filepath> - 保存活跃消息到指定路径 + (false, Some(args.to_string())) + }; + return Ok(Some(Command::SaveSession { filepath, include_all })); + } + + // 解析 /list 命令 + if trimmed == "/list" { + return Ok(Some(Command::ListSessions { + include_archived: false, + })); + } + + if trimmed == "/list all" { + return Ok(Some(Command::ListSessions { + include_archived: true, + })); + } + + // 解析 /use 命令 + if let Some(session_id) = trimmed.strip_prefix("/use ") { + let session_id = session_id.trim(); + return Ok(Some(Command::LoadSession { + session_id: session_id.to_string(), + })); + } + + // 不是命令,返回 None + Ok(None) + } +} diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 1f2c723..809286b 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -66,6 +66,27 @@ impl InputAdapter for CliInputAdapter { return Ok(Some(Command::SaveSession { filepath, include_all })); } + // 解析 /list 命令 + if trimmed == "/list" { + return Ok(Some(Command::ListSessions { + include_archived: false, + })); + } + + if trimmed == "/list all" { + return Ok(Some(Command::ListSessions { + include_archived: true, + })); + } + + // 解析 /use 命令 + if let Some(session_id) = trimmed.strip_prefix("/use ") { + let session_id = session_id.trim(); + return Ok(Some(Command::LoadSession { + session_id: session_id.to_string(), + })); + } + // 不是命令,返回 None Ok(None) } diff --git a/src/command/adapters/mod.rs b/src/command/adapters/mod.rs index 3ea0590..d487429 100644 --- a/src/command/adapters/mod.rs +++ b/src/command/adapters/mod.rs @@ -1,2 +1,3 @@ +pub mod channel; pub mod cli; pub mod websocket; diff --git a/src/command/context.rs b/src/command/context.rs index dbcc551..9602bfa 100644 --- a/src/command/context.rs +++ b/src/command/context.rs @@ -12,18 +12,21 @@ pub struct CommandContext { pub chat_id: Option<String>, /// 发送者ID pub sender_id: String, + /// 通道名称(如 "cli", "feishu", "wechat") + pub channel_name: String, /// 额外元数据 pub metadata: HashMap<String, String>, } impl CommandContext { /// 创建新的命令上下文 - pub fn new(sender_id: impl Into<String>) -> Self { + pub fn new(sender_id: impl Into<String>, channel_name: impl Into<String>) -> Self { Self { request_id: Uuid::new_v4(), session_id: None, chat_id: None, sender_id: sender_id.into(), + channel_name: channel_name.into(), metadata: HashMap::new(), } } diff --git a/src/command/handler.rs b/src/command/handler.rs index f50cae2..2962f3a 100644 --- a/src/command/handler.rs +++ b/src/command/handler.rs @@ -238,7 +238,7 @@ mod tests { router.register(Box::new(TestHandler)); router.register(Box::new(NoOpHandler)); - let ctx = CommandContext::new("test"); + let ctx = CommandContext::new("test", "test"); let cmd = Command::CreateSession { title: None }; let result = router.dispatch(cmd, ctx).await; @@ -252,7 +252,7 @@ mod tests { async fn test_router_no_handler() { let router = CommandRouter::new(); - let ctx = CommandContext::new("test"); + let ctx = CommandContext::new("test", "test"); let cmd = Command::CreateSession { title: None }; let result = router.dispatch(cmd, ctx).await; diff --git a/src/command/handlers/mod.rs b/src/command/handlers/mod.rs index 0df37c6..3f4f14f 100644 --- a/src/command/handlers/mod.rs +++ b/src/command/handlers/mod.rs @@ -1,2 +1,3 @@ pub mod save_session; pub mod session; +pub mod session_query; diff --git a/src/command/handlers/session.rs b/src/command/handlers/session.rs index 0e44ce0..a4840e9 100644 --- a/src/command/handlers/session.rs +++ b/src/command/handlers/session.rs @@ -36,6 +36,7 @@ impl CommandHandler for SessionCommandHandler { match cmd { Command::CreateSession { title } => handle_create_session(self, title, ctx).await, Command::SaveSession { .. } => unreachable!("SaveSession should be handled by SaveSessionCommandHandler"), + _ => unreachable!("Other commands should be handled by other handlers"), } } } @@ -48,7 +49,7 @@ async fn handle_create_session( ) -> Result<CommandResponse, CommandError> { let record = handler .cli_sessions - .create(title.as_deref()) + .create_with_channel(&ctx.channel_name, title.as_deref()) .map_err(|e| CommandError::new("CREATE_SESSION_ERROR", e.to_string()))?; Ok(CommandResponse::success(ctx.request_id) @@ -74,7 +75,7 @@ mod tests { async fn test_create_session_with_title() { let service = create_test_service(); let handler = SessionCommandHandler::new(service); - let ctx = CommandContext::new("test"); + let ctx = CommandContext::new("test", "test"); let cmd = Command::CreateSession { title: Some("my session".to_string()), }; @@ -93,7 +94,7 @@ mod tests { async fn test_create_session_without_title() { let service = create_test_service(); let handler = SessionCommandHandler::new(service); - let ctx = CommandContext::new("test"); + let ctx = CommandContext::new("test", "test"); let cmd = Command::CreateSession { title: None }; let result = handler.handle(cmd, ctx).await; diff --git a/src/command/handlers/session_query.rs b/src/command/handlers/session_query.rs new file mode 100644 index 0000000..d63c88d --- /dev/null +++ b/src/command/handlers/session_query.rs @@ -0,0 +1,202 @@ +use crate::command::context::CommandContext; +use crate::command::handler::CommandHandler; +use crate::command::response::{CommandError, CommandResponse, MessageKind}; +use crate::command::Command; +use crate::gateway::cli_session::CliSessionService; +use crate::protocol::SessionSummary; +use async_trait::async_trait; + +/// 会话查询命令处理器 +/// +/// 处理 ListSessions 和 LoadSession 命令 +pub struct SessionQueryCommandHandler { + cli_sessions: CliSessionService, +} + +impl SessionQueryCommandHandler { + /// 创建新的会话查询命令处理器 + pub fn new(cli_sessions: CliSessionService) -> Self { + Self { cli_sessions } + } +} + +#[async_trait] +impl CommandHandler for SessionQueryCommandHandler { + fn can_handle(&self, cmd: &Command) -> bool { + matches!(cmd, Command::ListSessions { .. } | Command::LoadSession { .. }) + } + + async fn handle( + &self, + cmd: Command, + ctx: CommandContext, + ) -> Result<CommandResponse, CommandError> { + match cmd { + Command::ListSessions { include_archived } => { + handle_list_sessions(self, include_archived, ctx).await + } + Command::LoadSession { session_id } => { + handle_load_session(self, session_id, ctx).await + } + _ => unreachable!(), + } + } +} + +/// 处理列出会话命令 +async fn handle_list_sessions( + handler: &SessionQueryCommandHandler, + include_archived: bool, + ctx: CommandContext, +) -> Result<CommandResponse, CommandError> { + let records = handler + .cli_sessions + .list(include_archived) + .map_err(|e| CommandError::new("LIST_SESSIONS_ERROR", e.to_string()))?; + + let summaries: Vec<SessionSummary> = records + .into_iter() + .map(|r| SessionSummary { + session_id: r.id, + title: r.title, + channel_name: r.channel_name, + chat_id: r.chat_id, + message_count: r.message_count, + last_active_at: r.last_active_at, + archived_at: r.archived_at, + }) + .collect(); + + // 将会话列表序列化为 JSON 存储在 metadata 中 + let sessions_json = + serde_json::to_string(&summaries).map_err(|e| CommandError::new("SERIALIZE_ERROR", e.to_string()))?; + + // 构建可读的会话列表消息 + let message = if summaries.is_empty() { + "No sessions found.".to_string() + } else { + let mut lines = vec![format!("Found {} session(s):", summaries.len())]; + for summary in &summaries { + let archived_info = summary + .archived_at + .map(|_| " [archived]") + .unwrap_or(""); + lines.push(format!( + " - {}: {}{}", + summary.session_id, summary.title, archived_info + )); + } + lines.push("".to_string()); + lines.push("Use /use <session_id> to switch to a session".to_string()); + lines.join("\n") + }; + + Ok(CommandResponse::success(ctx.request_id) + .with_message(MessageKind::Notification, &message) + .with_metadata("sessions", &sessions_json) + .with_metadata("count", &summaries.len().to_string())) +} + +/// 处理加载会话命令 +async fn handle_load_session( + handler: &SessionQueryCommandHandler, + session_id: String, + ctx: CommandContext, +) -> Result<CommandResponse, CommandError> { + let record = handler + .cli_sessions + .get(&session_id) + .map_err(|e| CommandError::new("LOAD_SESSION_ERROR", e.to_string()))? + .ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", format!("Session not found: {}", session_id)))?; + + Ok(CommandResponse::success(ctx.request_id) + .with_message(MessageKind::Notification, &record.title) + .with_metadata("session_id", &record.id) + .with_metadata("title", &record.title) + .with_metadata("message_count", &record.message_count.to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::SessionStore; + use std::sync::Arc; + + fn create_test_service() -> CliSessionService { + let store = Arc::new(SessionStore::in_memory().unwrap()); + CliSessionService::new(store) + } + + #[tokio::test] + async fn test_list_sessions_empty() { + let service = create_test_service(); + let handler = SessionQueryCommandHandler::new(service); + let ctx = CommandContext::new("test", "test"); + let cmd = Command::ListSessions { + include_archived: false, + }; + + let result = handler.handle(cmd, ctx).await; + + assert!(result.is_ok()); + let resp = result.unwrap(); + assert!(resp.success); + assert!(resp.messages[0].content.contains("No sessions")); + } + + #[tokio::test] + async fn test_list_sessions_with_items() { + let service = create_test_service(); + let handler = SessionQueryCommandHandler::new(service.clone()); + + // 创建一些会话 + service.create(Some("test session")).unwrap(); + + let ctx = CommandContext::new("test", "test"); + let cmd = Command::ListSessions { + include_archived: false, + }; + + let result = handler.handle(cmd, ctx).await; + + assert!(result.is_ok()); + let resp = result.unwrap(); + assert!(resp.success); + assert!(resp.metadata.contains_key("sessions")); + } + + #[tokio::test] + async fn test_load_session_not_found() { + let service = create_test_service(); + let handler = SessionQueryCommandHandler::new(service); + let ctx = CommandContext::new("test", "test"); + let cmd = Command::LoadSession { + session_id: "nonexistent".to_string(), + }; + + let result = handler.handle(cmd, ctx).await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_load_session_success() { + let service = create_test_service(); + let handler = SessionQueryCommandHandler::new(service.clone()); + + // 创建会话 + let record = service.create(Some("test session")).unwrap(); + + let ctx = CommandContext::new("test", "test"); + let cmd = Command::LoadSession { + session_id: record.id.clone(), + }; + + let result = handler.handle(cmd, ctx).await; + + assert!(result.is_ok()); + let resp = result.unwrap(); + assert!(resp.success); + assert_eq!(resp.metadata.get("session_id").unwrap(), &record.id); + } +} diff --git a/src/command/mod.rs b/src/command/mod.rs index 81e344f..935abc5 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -18,6 +18,10 @@ pub enum Command { filepath: Option<String>, include_all: bool, }, + /// 列出会话 + ListSessions { include_archived: bool }, + /// 加载指定会话 + LoadSession { session_id: String }, } impl Command { @@ -26,6 +30,8 @@ impl Command { match self { Command::CreateSession { .. } => "create_session", Command::SaveSession { .. } => "save_session", + Command::ListSessions { .. } => "list_sessions", + Command::LoadSession { .. } => "load_session", } } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 2a7fb86..0bff93c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1472,7 +1472,7 @@ mod tests { assert!(config.scheduler.jobs.is_empty()); let effective_jobs = config.scheduler.effective_jobs(&config.time); - assert_eq!(effective_jobs.len(), 1); + assert_eq!(effective_jobs.len(), 2); assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID); assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent); assert_eq!( @@ -1481,6 +1481,8 @@ mod tests { expression: "0 */4 * * *".to_string(), } ); + // 第二个内置作业是会话清理 + assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID); } #[test] @@ -1516,7 +1518,8 @@ mod tests { let effective_jobs = scheduler.effective_jobs(&TimeConfig { timezone: "Asia/Shanghai".to_string(), }); - assert_eq!(effective_jobs.len(), 2); + assert_eq!(effective_jobs.len(), 3); // 2个内置 + 1个自定义 + // 第一个作业:内存维护(被覆盖为禁用) assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID); assert!(!effective_jobs[0].enabled); assert_eq!( @@ -1525,7 +1528,11 @@ mod tests { expression: "15 2 * * *".to_string(), } ); - assert_eq!(effective_jobs[1].id, "custom.reminder"); + // 第二个作业:会话清理(保持默认) + assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID); + assert!(effective_jobs[1].enabled); + // 第三个作业:自定义提醒 + assert_eq!(effective_jobs[2].id, "custom.reminder"); } #[test] diff --git a/src/gateway/cli_session.rs b/src/gateway/cli_session.rs index c07d8e1..35fb4ac 100644 --- a/src/gateway/cli_session.rs +++ b/src/gateway/cli_session.rs @@ -19,6 +19,17 @@ impl CliSessionService { .map_err(|err| AgentError::Other(format!("create session error: {}", err))) } + /// 创建指定通道的会话 + pub(crate) fn create_with_channel( + &self, + channel_name: &str, + title: Option<&str>, + ) -> Result<SessionRecord, AgentError> { + self.store + .create_session(channel_name, title) + .map_err(|err| AgentError::Other(format!("create session error: {}", err))) + } + pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> { self.store .get_session(session_id) diff --git a/src/gateway/command.rs b/src/gateway/command.rs index 03ad8f0..c0278b3 100644 --- a/src/gateway/command.rs +++ b/src/gateway/command.rs @@ -1,165 +1,2 @@ -use crate::agent::AgentError; - -use super::session::Session; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum InChatCommand { - FreshConversation, -} - -fn parse_in_chat_command(content: &str) -> Option<InChatCommand> { - match content.trim() { - "/new" | "/reset" => Some(InChatCommand::FreshConversation), - _ => None, - } -} - -pub(crate) fn handle_in_chat_command( - session: &mut Session, - chat_id: &str, - content: &str, -) -> Result<Option<String>, AgentError> { - match parse_in_chat_command(content) { - Some(InChatCommand::FreshConversation) => { - session.reset_chat_context(chat_id)?; - Ok(Some("Started a fresh conversation.".to_string())) - } - None => Ok(None), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::bus::ChatMessage; - use crate::config::LLMProviderConfig; - use crate::skills::SkillRuntime; - use crate::storage::SessionStore; - use crate::tools::ToolRegistry; - use std::collections::HashMap; - use std::sync::Arc; - use tokio::sync::mpsc; - - const TEST_CHANNEL: &str = "test-channel"; - - fn test_provider_config() -> LLMProviderConfig { - LLMProviderConfig { - provider_type: "openai".to_string(), - name: "test".to_string(), - base_url: "http://localhost".to_string(), - api_key: "test-key".to_string(), - extra_headers: HashMap::new(), - llm_timeout_secs: 120, - memory_maintenance_timeout_secs: 600, - model_id: "test-model".to_string(), - temperature: Some(0.0), - max_tokens: Some(32), - context_window_tokens: None, - model_extra: HashMap::new(), - max_tool_iterations: 1, - tool_result_max_chars: 20_000, - context_tool_result_trim_chars: 20_000, - } - } - - #[test] - fn test_parse_in_chat_command_aliases() { - assert_eq!( - parse_in_chat_command("/new"), - Some(InChatCommand::FreshConversation) - ); - assert_eq!( - parse_in_chat_command(" /reset \n"), - Some(InChatCommand::FreshConversation) - ); - assert_eq!(parse_in_chat_command("/new planning"), None); - assert_eq!(parse_in_chat_command("please /reset"), None); - } - - #[tokio::test] - async fn test_handle_in_chat_command_resets_active_history_only() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(ToolRegistry::new()); - let mut session = Session::new( - TEST_CHANNEL.to_string(), - test_provider_config(), - user_tx, - tools, - skills, - store.clone(), - 100, - Some(4), - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - let response = handle_in_chat_command(&mut session, "chat-1", "/reset") - .unwrap() - .unwrap(); - - assert_eq!(response, "Started a fresh conversation."); - assert!(session.get_history("chat-1").unwrap().is_empty()); - assert!( - store - .load_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .is_empty() - ); - assert_eq!( - store - .load_all_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .len(), - // 新设计:系统提示词不再持久化,只有 1 条用户消息 - 1, - ); - - session.ensure_chat_loaded("chat-1").unwrap(); - let history = session.get_history("chat-1").unwrap(); - // 新设计:系统提示词不再持久化到历史记录 - assert_eq!(history.len(), 0); - } - - #[tokio::test] - async fn test_reset_reinjects_agent_prompt_before_next_user_message() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(ToolRegistry::new()); - let mut session = Session::new( - TEST_CHANNEL.to_string(), - test_provider_config(), - user_tx, - tools, - skills, - store, - 100, - Some(4), - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap(); - session - .ensure_agent_prompt_before_user_message("chat-1") - .unwrap(); - - // 新设计:系统提示词不再持久化到历史记录 - let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 0); - } -} +// 此文件已废弃,InChatCommand 功能已合并到 Command 系统 +// 保留文件以避免破坏现有 import,但内容为空 diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index d1624ae..265527a 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -7,7 +7,6 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL use crate::config::LLMProviderConfig; use tokio::sync::Mutex; -use super::command::handle_in_chat_command; use super::compaction::schedule_background_history_compaction; use super::message_prepare::enrich_user_content_with_media_refs; use super::session::Session; @@ -138,18 +137,6 @@ impl AgentExecutionService { session_guard.ensure_persistent_session(request.chat_id)?; session_guard.ensure_chat_loaded(request.chat_id)?; - if let Some(command_response) = - handle_in_chat_command(&mut session_guard, request.chat_id, request.content)? - { - return Ok(vec![OutboundMessage::assistant( - request.channel_name.to_string(), - request.chat_id.to_string(), - command_response, - None, - HashMap::new(), - )]); - } - session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?; let media_refs: Vec<String> = request diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 5c21705..b945977 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -4,7 +4,12 @@ use tokio::sync::Semaphore; use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::{InboundMessage, MessageBus, OutboundMessage}; -use crate::command::handler::InChatCommandRouter; +use crate::command::adapter::InputAdapter; +use crate::command::adapters::channel::ChannelInputAdapter; +use crate::command::handler::CommandRouter; +use crate::command::handlers::save_session::SaveSessionCommandHandler; +use crate::command::handlers::session::SessionCommandHandler; +use crate::command::handlers::session_query::SessionQueryCommandHandler; use crate::command::Command; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::AgentPromptProvider; @@ -18,7 +23,7 @@ pub struct InboundProcessor { session_manager: SessionManager, semaphore: Arc<Semaphore>, _provider_config: LLMProviderConfig, - command_router: Arc<InChatCommandRouter>, + command_router: Arc<CommandRouter>, } impl InboundProcessor { @@ -29,7 +34,14 @@ impl InboundProcessor { provider_config: LLMProviderConfig, ) -> Self { // 创建命令路由器并注册处理器 - let mut command_router = InChatCommandRouter::new(); + let mut command_router = CommandRouter::new(); + + // 注册 Session 处理器 + let cli_sessions = session_manager.cli_sessions(); + command_router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone()))); + + // 注册 session_query 处理器 + command_router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions))); // 注册 save_session 处理器 let store = session_manager.store(); @@ -43,7 +55,7 @@ impl InboundProcessor { )), Box::new(SkillPromptProvider::new(skills)), ])); - command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new( + command_router.register(Box::new(SaveSessionCommandHandler::new( store, system_prompt_provider, ))); @@ -103,18 +115,43 @@ impl InboundProcessor { } async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> { - // 尝试解析为命令 - if let Some(cmd) = parse_in_chat_command(&inbound.content) { + // 使用 ChannelInputAdapter 尝试解析命令 + let adapter = ChannelInputAdapter::new(); + let ctx = crate::command::context::AdapterContext::new(&inbound.channel) + .with_session_id(&inbound.chat_id); + + if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) { // 使用命令路由器处理 - match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? { - Some(response_msg) => { - // 发送命令执行结果给用户 + let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel, &inbound.channel) + .with_session_id(&inbound.chat_id); + + // 记录是否是创建会话命令(用于后续自动切换) + let is_create_session = matches!(cmd, Command::CreateSession { .. }); + + let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await; + + // 发送响应给用户 + if response.success { + // 如果是创建会话,更新 chat_id 到新会话 + let target_chat_id = if let Some(session_id) = response.metadata.get("session_id") { + if is_create_session { + // 自动切换到新会话 + session_id.clone() + } else { + inbound.chat_id.clone() + } + } else { + inbound.chat_id.clone() + }; + + // 提取响应消息 + for msg in &response.messages { if let Err(error) = self .bus .publish_outbound(OutboundMessage::assistant( inbound.channel.clone(), - inbound.chat_id.clone(), - response_msg, + target_chat_id.clone(), + msg.content.clone(), None, inbound.forwarded_metadata.clone(), )) @@ -122,13 +159,23 @@ impl InboundProcessor { { tracing::error!(error = %error, "Failed to publish command response"); } - return Ok(()); } - None => { - // 命令已处理但没有返回消息 - return Ok(()); + } else if let Some(error) = response.error { + if let Err(e) = self + .bus + .publish_outbound(OutboundMessage::assistant( + inbound.channel.clone(), + inbound.chat_id.clone(), + format!("Error [{}]: {}", error.code, error.message), + None, + inbound.forwarded_metadata.clone(), + )) + .await + { + tracing::error!(error = %e, "Failed to publish error response"); } } + return Ok(()); } // 普通消息进入 AgentLoop @@ -183,41 +230,3 @@ impl InboundProcessor { Ok(()) } } - -/// 解析聊天中的命令 -/// -/// 支持格式: -/// - `/save` - 保存活跃会话消息(到 cutoff) -/// - `/save all` - 保存全部会话消息(包括 cutoff 之前) -/// - `/save <filepath>` - 保存活跃消息到指定路径 -/// - `/save all <filepath>` - 保存全部消息到指定路径 -/// -/// 返回 Some(Command) 如果是命令 -/// 返回 None 如果不是命令 -fn parse_in_chat_command(content: &str) -> Option<Command> { - let trimmed = content.trim(); - - if trimmed.starts_with("/save") { - let args = trimmed[5..].trim(); - - // 解析参数 - let (include_all, filepath) = if args.is_empty() { - // /save 无参数 - 只保存活跃消息 - (false, None) - } else if args == "all" { - // /save all - 保存全部消息 - (true, None) - } else if args.starts_with("all ") { - // /save all <filepath> - 保存全部消息到指定路径 - let path = args[4..].trim(); - (true, Some(path.to_string())) - } else { - // /save <filepath> - 保存活跃消息到指定路径 - (false, Some(args.to_string())) - }; - - Some(Command::SaveSession { filepath, include_all }) - } else { - None - } -} diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index ee7249b..a193a7e 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,14 +1,15 @@ use super::GatewayState; use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::InboundMessage; -use crate::command::adapter::OutputAdapter; +use crate::command::adapter::{InputAdapter, OutputAdapter}; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; +use crate::command::handlers::session_query::SessionQueryCommandHandler; use crate::gateway::agent_prompt_provider::AgentPromptProvider; -use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; +use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use crate::skills::SkillPromptProvider; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; @@ -125,17 +126,6 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } -fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { - SessionSummary { - session_id: record.id, - title: record.title, - channel_name: record.channel_name, - chat_id: record.chat_id, - message_count: record.message_count, - last_active_at: record.last_active_at, - archived_at: record.archived_at, - } -} async fn handle_inbound( state: &Arc<GatewayState>, @@ -143,9 +133,9 @@ async fn handle_inbound( runtime_session_id: &str, current_session_id: &mut String, inbound: WsInbound, -) -> Result<(), crate::agent::AgentError> { +) -> Result<(), AgentError> { match inbound { - WsInbound::UserInput { + WsInbound::Message { content, chat_id, sender_id, @@ -181,53 +171,74 @@ async fn handle_inbound( Ok(()) } - WsInbound::ClearHistory { - session_id, - chat_id, - } => { - let target = session_id - .or(chat_id) - .unwrap_or_else(|| current_session_id.clone()); - state - .session_manager - .cli_sessions() - .clear_messages(&target)?; - - if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { - session.lock().await.remove_history(&target); - } - - let _ = sender - .send(WsOutbound::HistoryCleared { session_id: target }) - .await; - Ok(()) - } - WsInbound::CreateSession { title } => { - // 使用新的命令层处理 - let _input_adapter = WebSocketInputAdapter::new(); + WsInbound::Command { payload } => { + // 使用 Command 系统处理命令 + let input_adapter = WebSocketInputAdapter::new(); let output_adapter = WebSocketOutputAdapter::new(); - let cli_sessions = state.session_manager.cli_sessions(); - let handler = SessionCommandHandler::new(cli_sessions); - let router = { - let mut r = CommandRouter::new(); - r.register(Box::new(handler)); - r + + // 解析命令 + let adapter_ctx = crate::command::context::AdapterContext::new("websocket") + .with_session_id(current_session_id.as_str()); + + let cmd = match input_adapter.try_parse(&payload, adapter_ctx) { + Ok(Some(cmd)) => cmd, + Ok(None) => { + // 不是命令,返回错误 + let _ = sender + .send(WsOutbound::Error { + code: "INVALID_COMMAND".to_string(), + message: "Invalid command payload".to_string(), + }) + .await; + return Ok(()); + } + Err(e) => { + let _ = sender + .send(WsOutbound::Error { + code: "PARSE_ERROR".to_string(), + message: e.to_string(), + }) + .await; + return Ok(()); + } }; - // 构建命令 - let cmd = crate::command::Command::CreateSession { title }; - let cmd_ctx = CommandContext::new("websocket") + // 创建命令路由器 + let cli_sessions = state.session_manager.cli_sessions(); + let store = state.session_manager.store(); + let skills = state.session_manager.skills(); + let provider_config = state.config.get_provider_config("default") + .map_err(|e| AgentError::Other(e.to_string()))?; + let prompt_repository = state.session_manager.store().clone(); + + let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![ + Box::new(AgentPromptProvider::new( + 0, + provider_config.clone(), + prompt_repository.clone(), + )), + Box::new(SkillPromptProvider::new(skills)), + ])); + + let mut router = CommandRouter::new(); + router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone()))); + router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions))); + router.register(Box::new(SaveSessionCommandHandler::new( + store, + system_prompt_provider, + ))); + + // 构建命令上下文 + let cmd_ctx = CommandContext::new("websocket", "cli") .with_session_id(current_session_id.as_str()); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; - // 适配输出 - let outbounds = output_adapter.adapt(response); - // 处理响应 - for msg in outbounds { - if let WsOutbound::SessionCreated { session_id, title: _ } = &msg { + if response.success { + // 更新当前会话 ID(如果是创建会话) + if let Some(session_id) = response.metadata.get("session_id") { *current_session_id = session_id.clone(); state .channel_manager @@ -239,178 +250,14 @@ async fn handle_inbound( ) .await; } + } + + // 适配并发送响应 + let outbounds = output_adapter.adapt(response); + for msg in outbounds { let _ = sender.send(msg).await; } - Ok(()) - } - WsInbound::ListSessions { include_archived } => { - let records = state - .session_manager - .cli_sessions() - .list(include_archived)?; - let summaries = records.into_iter().map(to_session_summary).collect(); - let _ = sender - .send(WsOutbound::SessionList { - sessions: summaries, - current_session_id: Some(current_session_id.clone()), - }) - .await; - Ok(()) - } - WsInbound::LoadSession { session_id } => { - let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else { - let _ = sender - .send(WsOutbound::Error { - code: "SESSION_NOT_FOUND".to_string(), - message: format!("Session not found: {}", session_id), - }) - .await; - return Ok(()); - }; - - *current_session_id = record.id.clone(); - state - .channel_manager - .cli_channel() - .register_connection( - record.id.clone(), - runtime_session_id.to_string(), - sender.clone(), - ) - .await; - let _ = sender - .send(WsOutbound::SessionLoaded { - session_id: record.id, - title: record.title, - message_count: record.message_count, - }) - .await; - Ok(()) - } - WsInbound::RenameSession { session_id, title } => { - let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state - .session_manager - .cli_sessions() - .rename(&target, &title)?; - let _ = sender - .send(WsOutbound::SessionRenamed { - session_id: target, - title, - }) - .await; - Ok(()) - } - WsInbound::ArchiveSession { session_id } => { - let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.cli_sessions().archive(&target)?; - let _ = sender - .send(WsOutbound::SessionArchived { session_id: target }) - .await; - Ok(()) - } - WsInbound::DeleteSession { session_id } => { - let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.cli_sessions().delete(&target)?; - - let replacement = if target == *current_session_id { - Some(state.session_manager.cli_sessions().create(None)?) - } else { - None - }; - - if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { - session.lock().await.remove_history(&target); - } - - let _ = sender - .send(WsOutbound::SessionDeleted { - session_id: target.clone(), - }) - .await; - - if let Some(record) = replacement { - *current_session_id = record.id.clone(); - state - .channel_manager - .cli_channel() - .register_connection( - record.id.clone(), - runtime_session_id.to_string(), - sender.clone(), - ) - .await; - let _ = sender - .send(WsOutbound::SessionCreated { - session_id: record.id, - title: record.title, - }) - .await; - } - - Ok(()) - } - WsInbound::SaveSession { filepath, session_id } => { - let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone()); - - // 获取所需依赖 - let store = state.session_manager.store(); - let skills = state.session_manager.skills(); - let provider_config = state.config.get_provider_config("default") - .map_err(|e| AgentError::Other(e.to_string()))?; - let prompt_repository = state.session_manager.store().clone(); - - // 构建组合系统提示词提供者(与运行时一致) - let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![ - Box::new(AgentPromptProvider::new( - 0, // save_session 不需要 reinject 逻辑 - provider_config.clone(), - prompt_repository, - )), - Box::new(SkillPromptProvider::new(skills)), - ])); - - // 构建处理器 - let handler = SaveSessionCommandHandler::new(store, system_prompt_provider); - let router = { - let mut r = CommandRouter::new(); - r.register(Box::new(handler)); - r - }; - - // 构建命令 - let cmd = crate::command::Command::SaveSession { filepath, include_all: true }; - let cmd_ctx = CommandContext::new("websocket") - .with_session_id(&target_session_id); - - // 执行命令 - let response = router.dispatch_with_response(cmd, cmd_ctx).await; - - // 处理响应 - if response.success { - let filepath = response - .metadata - .get("filepath") - .cloned() - .unwrap_or_default(); - let _ = sender - .send(WsOutbound::SessionSaved { - session_id: target_session_id, - filepath, - }) - .await; - } else { - let error = response.error.unwrap_or_else(|| { - crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error") - }); - let _ = sender - .send(WsOutbound::Error { - code: error.code, - message: error.message, - }) - .await; - } Ok(()) } WsInbound::Ping => { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index baad9d2..2882b06 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -17,8 +17,9 @@ pub struct SessionSummary { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum WsInbound { - #[serde(rename = "user_input")] - UserInput { + /// 普通用户消息 + #[serde(rename = "message")] + Message { content: String, #[serde(default, skip_serializing_if = "Option::is_none")] channel: Option<String>, @@ -27,48 +28,9 @@ pub enum WsInbound { #[serde(default, skip_serializing_if = "Option::is_none")] sender_id: Option<String>, }, - #[serde(rename = "clear_history")] - ClearHistory { - #[serde(default, skip_serializing_if = "Option::is_none")] - chat_id: Option<String>, - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, - #[serde(rename = "create_session")] - CreateSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - title: Option<String>, - }, - #[serde(rename = "list_sessions")] - ListSessions { - #[serde(default)] - include_archived: bool, - }, - #[serde(rename = "load_session")] - LoadSession { session_id: String }, - #[serde(rename = "rename_session")] - RenameSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - title: String, - }, - #[serde(rename = "archive_session")] - ArchiveSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, - #[serde(rename = "delete_session")] - DeleteSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, - #[serde(rename = "save_session")] - SaveSession { - #[serde(default, skip_serializing_if = "Option::is_none")] - filepath: Option<String>, - #[serde(default, skip_serializing_if = "Option::is_none")] - session_id: Option<String>, - }, + /// 命令(JSON 格式) + #[serde(rename = "command")] + Command { payload: String }, #[serde(rename = "ping")] Ping, } @@ -126,14 +88,6 @@ pub enum WsOutbound { title: String, message_count: i64, }, - #[serde(rename = "session_renamed")] - SessionRenamed { session_id: String, title: String }, - #[serde(rename = "session_archived")] - SessionArchived { session_id: String }, - #[serde(rename = "session_deleted")] - SessionDeleted { session_id: String }, - #[serde(rename = "history_cleared")] - HistoryCleared { session_id: String }, #[serde(rename = "session_saved")] SessionSaved { session_id: String, filepath: String }, #[serde(rename = "pong")] diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 3dded44..13f9515 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -204,14 +204,24 @@ impl SessionStore { Self::from_connection(Connection::open_in_memory()?) } - pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> { + pub fn create_session( + &self, + channel_name: &str, + title: Option<&str>, + ) -> Result<SessionRecord, StorageError> { let now = current_timestamp(); let id = uuid::Uuid::new_v4().to_string(); let title = title .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) - .unwrap_or_else(|| format!("CLI Session {}", &id[..8])); + .unwrap_or_else(|| { + if channel_name == "cli" { + format!("CLI Session {}", &id[..8]) + } else { + format!("Session {}", &id[..8]) + } + }); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute( @@ -220,9 +230,9 @@ impl SessionStore { id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, archived_at, deleted_at, message_count, reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count - ) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0) + ) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0) ", - params![id, title, id, now], + params![id, title, channel_name, id, now], )?; drop(conn); @@ -230,6 +240,10 @@ impl SessionStore { .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } + pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> { + self.create_session("cli", title) + } + pub fn ensure_channel_session( &self, channel_name: &str, diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index de665d5..fa0cdf1 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -1,4 +1,4 @@ -use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; +use picobot::protocol::{WsInbound, WsOutbound}; use picobot::providers::{ChatCompletionRequest, Message}; /// Test that message with special characters is properly escaped @@ -53,70 +53,65 @@ fn test_chat_request_serialization() { } #[test] -fn test_session_inbound_serialization() { - let msg = WsInbound::CreateSession { - title: Some("demo".to_string()), +fn test_command_inbound_serialization() { + // Command is now sent as payload in WsInbound::Command + let command_json = r#"{"type":"create_session","title":"demo"}"#; + let msg = WsInbound::Command { + payload: command_json.to_string(), }; let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains(r#""type":"create_session""#)); - assert!(json.contains(r#""title":"demo""#)); + assert!(json.contains(r#""type":"command""#)); + assert!(json.contains(r#""payload":""#)); + assert!(json.contains(r#"create_session"#)); +} + +#[test] +fn test_message_inbound_serialization() { + let msg = WsInbound::Message { + content: "Hello world".to_string(), + channel: None, + chat_id: Some("session-1".to_string()), + sender_id: Some("user-1".to_string()), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"message""#)); + assert!(json.contains(r#""content":"Hello world""#)); + assert!(json.contains(r#""chat_id":"session-1""#)); let decoded: WsInbound = serde_json::from_str(&json).unwrap(); match decoded { - WsInbound::CreateSession { title } => { - assert_eq!(title.as_deref(), Some("demo")); + WsInbound::Message { content, chat_id, .. } => { + assert_eq!(content, "Hello world"); + assert_eq!(chat_id.as_deref(), Some("session-1")); } other => panic!("unexpected decoded variant: {:?}", other), } } #[test] -fn test_session_list_outbound_serialization() { - let msg = WsOutbound::SessionList { - sessions: vec![SessionSummary { - session_id: "session-1".to_string(), - title: "demo".to_string(), - channel_name: "cli".to_string(), - chat_id: "session-1".to_string(), - message_count: 2, - last_active_at: 123, - archived_at: None, - }], - current_session_id: Some("session-1".to_string()), +fn test_session_created_outbound_serialization() { + let msg = WsOutbound::SessionCreated { + session_id: "session-1".to_string(), + title: "demo".to_string(), }; let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains(r#""type":"session_list""#)); + assert!(json.contains(r#""type":"session_created""#)); assert!(json.contains(r#""session_id":"session-1""#)); - assert!(json.contains(r#""message_count":2"#)); + assert!(json.contains(r#""title":"demo""#)); let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); match decoded { - WsOutbound::SessionList { - sessions, - current_session_id, - } => { - assert_eq!(sessions.len(), 1); - assert_eq!(sessions[0].title, "demo"); - assert_eq!(current_session_id.as_deref(), Some("session-1")); + WsOutbound::SessionCreated { session_id, title } => { + assert_eq!(session_id, "session-1"); + assert_eq!(title, "demo"); } other => panic!("unexpected decoded variant: {:?}", other), } } -#[test] -fn test_clear_history_with_session_id_serialization() { - let msg = WsInbound::ClearHistory { - chat_id: None, - session_id: Some("session-1".to_string()), - }; - - let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains(r#""type":"clear_history""#)); - assert!(json.contains(r#""session_id":"session-1""#)); -} - #[test] fn test_tool_call_outbound_serialization() { let msg = WsOutbound::ToolCall {