From 8bb32fa06641f894a0d15abb4721e567c0f018cd Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 18 Apr 2026 13:09:14 +0800 Subject: [PATCH] feat: enhance WebSocket session management and storage - Added SessionSummary struct for session metadata. - Updated ws_handler to create and manage CLI sessions more robustly. - Implemented session creation, loading, renaming, archiving, and deletion via WebSocket messages. - Introduced SessionStore for persistent session storage using SQLite. - Enhanced error handling and logging for session operations. - Updated protocol definitions for new session-related WebSocket messages. - Refactored tests to cover new session functionalities and ensure proper serialization. --- Cargo.toml | 1 + src/cli/input.rs | 73 +++++- src/cli/mod.rs | 2 +- src/client/mod.rs | 154 ++++++++++-- src/config/mod.rs | 53 ++++- src/gateway/mod.rs | 2 +- src/gateway/session.rs | 166 ++++++++++--- src/gateway/ws.rs | 277 ++++++++++++++++------ src/lib.rs | 1 + src/protocol.rs | 66 ++++++ src/storage/mod.rs | 447 +++++++++++++++++++++++++++++++++++ tests/test_integration.rs | 16 +- tests/test_request_format.rs | 102 ++++++-- tests/test_tool_calling.rs | 30 +-- 14 files changed, 1204 insertions(+), 186 deletions(-) create mode 100644 src/storage/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 8dd2225..a77d084 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,3 +27,4 @@ mime_guess = "2.0" base64 = "0.22" tempfile = "3" meval = "0.2" +rusqlite = { version = "0.32", features = ["bundled"] } diff --git a/src/cli/input.rs b/src/cli/input.rs index 2e5fa91..d024374 100644 --- a/src/cli/input.rs +++ b/src/cli/input.rs @@ -2,6 +2,23 @@ use crate::bus::ChatMessage; use super::channel::CliChannel; +pub enum InputEvent { + Message(ChatMessage), + Command(InputCommand), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InputCommand { + Exit, + Clear, + New(Option), + Sessions, + Use(String), + Rename(String), + Archive, + Delete, +} + pub struct InputHandler { channel: CliChannel, } @@ -13,7 +30,7 @@ impl InputHandler { } } - pub async fn read_input(&mut self, prompt: &str) -> Result, InputError> { + pub async fn read_input(&mut self, prompt: &str) -> Result, InputError> { match self.channel.read_line(prompt).await { Ok(Some(line)) => { if line.trim().is_empty() { @@ -21,10 +38,10 @@ impl InputHandler { } if let Some(cmd) = self.handle_special_commands(&line) { - return Ok(Some(cmd)); + return Ok(Some(InputEvent::Command(cmd))); } - Ok(Some(ChatMessage::user(line))) + Ok(Some(InputEvent::Message(ChatMessage::user(line)))) } Ok(None) => Ok(None), Err(e) => Err(InputError::IoError(e)), @@ -39,10 +56,21 @@ impl InputHandler { self.channel.write_response(content).await.map_err(InputError::IoError) } - fn handle_special_commands(&self, line: &str) -> Option { - match line.trim() { - "/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")), - "/clear" => Some(ChatMessage::system("__CLEAR__")), + fn handle_special_commands(&self, line: &str) -> Option { + let trimmed = line.trim(); + let mut parts = trimmed.splitn(2, char::is_whitespace); + let command = parts.next()?; + let arg = parts.next().map(str::trim).filter(|value| !value.is_empty()); + + match command { + "/quit" | "/exit" | "/q" => Some(InputCommand::Exit), + "/clear" => Some(InputCommand::Clear), + "/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))), + "/sessions" => Some(InputCommand::Sessions), + "/use" => arg.map(|value| InputCommand::Use(value.to_string())), + "/rename" => arg.map(|value| InputCommand::Rename(value.to_string())), + "/archive" => Some(InputCommand::Archive), + "/delete" => Some(InputCommand::Delete), _ => None, } } @@ -68,3 +96,34 @@ impl std::fmt::Display for InputError { } impl std::error::Error for InputError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_special_command_parsing() { + let handler = InputHandler::new(); + + assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit)); + assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear)); + assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None))); + assert_eq!( + handler.handle_special_commands("/new planning"), + Some(InputCommand::New(Some("planning".to_string()))) + ); + assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions)); + assert_eq!( + handler.handle_special_commands("/use abc123"), + Some(InputCommand::Use("abc123".to_string())) + ); + assert_eq!( + handler.handle_special_commands("/rename project alpha"), + Some(InputCommand::Rename("project alpha".to_string())) + ); + assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive)); + assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete)); + assert_eq!(handler.handle_special_commands("/unknown"), None); + assert_eq!(handler.handle_special_commands("/use"), None); + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index fe628b9..e582818 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -2,4 +2,4 @@ pub mod channel; pub mod input; pub use channel::CliChannel; -pub use input::InputHandler; +pub use input::{InputCommand, InputEvent, InputHandler}; diff --git a/src/client/mod.rs b/src/client/mod.rs index c663cad..c4e7730 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,7 +3,38 @@ pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_ou use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::{connect_async, tungstenite::Message}; -use crate::cli::InputHandler; +use crate::cli::{InputCommand, InputEvent, InputHandler}; + +fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String { + if sessions.is_empty() { + return "No sessions found.".to_string(); + } + + let mut lines = Vec::with_capacity(sessions.len() + 1); + lines.push("Sessions:".to_string()); + for session in sessions { + let marker = if current_session_id == Some(session.session_id.as_str()) { + "*" + } else { + "-" + }; + let archived = if session.archived_at.is_some() { + " [archived]" + } else { + "" + }; + lines.push(format!( + "{} {} | {} | {} messages{}", + marker, + session.session_id, + session.title, + session.message_count, + archived, + )); + } + + lines.join("\n") +} fn parse_message(raw: &str) -> Result { serde_json::from_str(raw) @@ -16,7 +47,8 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { let (mut sender, mut receiver) = ws_stream.split(); let mut input = InputHandler::new(); - input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?; + let mut current_session_id: Option = None; + input.write_output("picobot CLI - Commands: /new [title], /sessions, /use , /rename , /archive, /delete, /clear, /quit\n").await?; // Main loop: poll both stdin and WebSocket loop { @@ -35,10 +67,38 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { input.write_output(&format!("Error: {}", message)).await?; } WsOutbound::SessionEstablished { session_id } => { + current_session_id = Some(session_id.clone()); #[cfg(debug_assertions)] tracing::debug!(session_id = %session_id, "Session established"); input.write_output(&format!("Session: {}\n", session_id)).await?; } + WsOutbound::SessionCreated { session_id, title } => { + current_session_id = Some(session_id.clone()); + input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?; + } + WsOutbound::SessionList { sessions, current_session_id: listed_current } => { + let display = format_session_list(&sessions, listed_current.as_deref()); + input.write_output(&format!("{}\n", display)).await?; + } + WsOutbound::SessionLoaded { session_id, title, message_count } => { + current_session_id = Some(session_id.clone()); + input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?; + } + WsOutbound::SessionRenamed { session_id, title } => { + input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?; + } + WsOutbound::SessionArchived { session_id } => { + input.write_output(&format!("Archived session: {}\n", session_id)).await?; + } + WsOutbound::SessionDeleted { session_id } => { + if current_session_id.as_deref() == Some(session_id.as_str()) { + current_session_id = None; + } + input.write_output(&format!("Deleted session: {}\n", session_id)).await?; + } + WsOutbound::HistoryCleared { session_id } => { + input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?; + } _ => {} } } @@ -54,32 +114,86 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { // Handle stdin input result = input.read_input("> ") => { match result { - Ok(Some(msg)) => { - match msg.content.as_str() { - "__EXIT__" => { + Ok(Some(event)) => { + match event { + InputEvent::Command(InputCommand::Exit) => { input.write_output("Goodbye!").await?; break; } - "__CLEAR__" => { - let inbound = WsInbound::ClearHistory { chat_id: None }; + InputEvent::Command(InputCommand::Clear) => { + let inbound = WsInbound::ClearHistory { + chat_id: None, + session_id: current_session_id.clone(), + }; if let Ok(text) = serialize_inbound(&inbound) { let _ = sender.send(Message::Text(text.into())).await; } continue; } - _ => {} - } - - let inbound = WsInbound::UserInput { - content: msg.content, - channel: None, - chat_id: None, - sender_id: None, - }; - if let Ok(text) = serialize_inbound(&inbound) { - if sender.send(Message::Text(text.into())).await.is_err() { - tracing::error!("Failed to send message to gateway"); - break; + InputEvent::Command(InputCommand::New(title)) => { + let inbound = WsInbound::CreateSession { title }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + InputEvent::Command(InputCommand::Sessions) => { + let inbound = WsInbound::ListSessions { + include_archived: true, + }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + InputEvent::Command(InputCommand::Use(session_id)) => { + let inbound = WsInbound::LoadSession { session_id }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + InputEvent::Command(InputCommand::Rename(title)) => { + let inbound = WsInbound::RenameSession { + session_id: current_session_id.clone(), + title, + }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + InputEvent::Command(InputCommand::Archive) => { + let inbound = WsInbound::ArchiveSession { + session_id: current_session_id.clone(), + }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + InputEvent::Command(InputCommand::Delete) => { + let inbound = WsInbound::DeleteSession { + session_id: current_session_id.clone(), + }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + InputEvent::Message(msg) => { + let inbound = WsInbound::UserInput { + content: msg.content, + channel: None, + chat_id: current_session_id.clone(), + sender_id: None, + }; + if let Ok(text) = serialize_inbound(&inbound) { + if sender.send(Message::Text(text.into())).await.is_err() { + tracing::error!("Failed to send message to gateway"); + break; + } + } } } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 7268084..e2fecce 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -267,9 +267,54 @@ fn resolve_env_placeholders(content: &str) -> String { mod tests { use super::*; + fn write_test_config() -> tempfile::NamedTempFile { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + }, + "volcengine": { + "type": "openai", + "base_url": "https://example.invalid/volc", + "api_key": "test-key-2", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus", + "temperature": 0.0 + }, + "doubao-seed-2-0-lite-260215": { + "model_id": "doubao-seed-2-0-lite-260215" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + }, + "gateway": { + "host": "0.0.0.0", + "port": 19876 + } +}"#, + ) + .unwrap(); + file + } + #[test] fn test_config_load() { - let config = Config::load("config.json").unwrap(); + let file = write_test_config(); + let config = Config::load(file.path().to_str().unwrap()).unwrap(); // Check providers assert!(config.providers.contains_key("volcengine")); @@ -285,7 +330,8 @@ mod tests { #[test] fn test_get_provider_config() { - let config = Config::load("config.json").unwrap(); + let file = write_test_config(); + let config = Config::load(file.path().to_str().unwrap()).unwrap(); let provider_config = config.get_provider_config("default").unwrap(); assert_eq!(provider_config.provider_type, "openai"); @@ -296,7 +342,8 @@ mod tests { #[test] fn test_default_gateway_config() { - let config = Config::load("config.json").unwrap(); + let file = write_test_config(); + let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.port, 19876); } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 5ab521e..ba8331c 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -29,7 +29,7 @@ impl GatewayState { // Session TTL from config (default 4 hours) let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); - let session_manager = SessionManager::new(session_ttl_hours, provider_config); + let session_manager = SessionManager::new(session_ttl_hours, provider_config)?; let channel_manager = ChannelManager::new(); let bus = channel_manager.bus(); diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 42ca962..6b9876a 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -7,6 +7,7 @@ use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::protocol::WsOutbound; +use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, ToolRegistry, WebFetchTool, @@ -23,6 +24,7 @@ pub struct Session { provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>, compressor: ContextCompressor, + store: Arc<SessionStore>, } impl Session { @@ -31,6 +33,7 @@ impl Session { provider_config: LLMProviderConfig, user_tx: mpsc::Sender<WsOutbound>, tools: Arc<ToolRegistry>, + store: Arc<SessionStore>, ) -> Result<Self, AgentError> { Ok(Self { id: Uuid::new_v4(), @@ -40,9 +43,33 @@ impl Session { provider_config: provider_config.clone(), tools, compressor: ContextCompressor::new(provider_config.token_limit), + store, }) } + pub fn persistent_session_id(&self, chat_id: &str) -> String { + persistent_session_id(&self.channel_name, chat_id) + } + + pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> { + self.store + .ensure_channel_session(&self.channel_name, chat_id) + .map_err(|err| AgentError::Other(format!("session persistence error: {}", err))) + } + + pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { + if self.chat_histories.contains_key(chat_id) { + return Ok(()); + } + + let history = self + .store + .load_messages(&self.persistent_session_id(chat_id)) + .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; + self.chat_histories.insert(chat_id.to_string(), history); + Ok(()) + } + /// 获取或创建指定 chat_id 的会话历史 pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> { self.chat_histories @@ -55,41 +82,62 @@ impl Session { self.chat_histories.get(chat_id) } - /// 添加用户消息到指定 chat_id 的历史 - pub fn add_user_message(&mut self, chat_id: &str, content: &str) { + /// 使用完整消息追加到历史 + pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) { let history = self.get_or_create_history(chat_id); - history.push(ChatMessage::user(content)); + history.push(message); } - /// 添加带媒体的用户消息到指定 chat_id 的历史 - pub fn add_user_message_with_media(&mut self, chat_id: &str, content: &str, media_refs: Vec<String>) { - let history = self.get_or_create_history(chat_id); - history.push(ChatMessage::user_with_media(content, media_refs)); + pub fn remove_history(&mut self, chat_id: &str) { + self.chat_histories.remove(chat_id); } - /// 添加助手响应到指定 chat_id 的历史 - pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) { - if let Some(history) = self.chat_histories.get_mut(chat_id) { - history.push(message); - } - } - - /// 清除指定 chat_id 的历史 - pub fn clear_chat_history(&mut self, chat_id: &str) { + pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { if let Some(history) = self.chat_histories.get_mut(chat_id) { let len = history.len(); history.clear(); #[cfg(debug_assertions)] tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); } + + self.store + .clear_messages(&self.persistent_session_id(chat_id)) + .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) + } + + /// 将消息写入内存与持久化层 + pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> { + let session_id = self.persistent_session_id(chat_id); + self.store + .append_message(&session_id, &message) + .map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?; + self.add_message(chat_id, message); + Ok(()) + } + + pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage { + if media_refs.is_empty() { + ChatMessage::user(content) + } else { + ChatMessage::user_with_media(content, media_refs) + } } /// 清除所有历史 - pub fn clear_all_history(&mut self) { + pub fn clear_all_history(&mut self) -> Result<(), AgentError> { + let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect(); let total: usize = self.chat_histories.values().map(|h| h.len()).sum(); self.chat_histories.clear(); #[cfg(debug_assertions)] tracing::debug!(previous_total = total, "All chat histories cleared"); + + for chat_id in chat_ids { + self.store + .clear_messages(&self.persistent_session_id(&chat_id)) + .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?; + } + + Ok(()) } pub async fn send(&self, msg: WsOutbound) { @@ -118,6 +166,7 @@ pub struct SessionManager { inner: Arc<Mutex<SessionManagerInner>>, provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>, + store: Arc<SessionStore>, } struct SessionManagerInner { @@ -144,8 +193,13 @@ fn default_tools() -> ToolRegistry { } impl SessionManager { - pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self { - Self { + pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> { + let store = Arc::new( + SessionStore::new() + .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, + ); + + Ok(Self { inner: Arc::new(Mutex::new(SessionManagerInner { sessions: HashMap::new(), session_timestamps: HashMap::new(), @@ -153,13 +207,66 @@ impl SessionManager { })), provider_config, tools: Arc::new(default_tools()), - } + store, + }) } pub fn tools(&self) -> Arc<ToolRegistry> { self.tools.clone() } + pub fn store(&self) -> Arc<SessionStore> { + self.store.clone() + } + + pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> { + self.store + .create_cli_session(title) + .map_err(|err| AgentError::Other(format!("create session error: {}", err))) + } + + pub fn get_session_record(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> { + self.store + .get_session(session_id) + .map_err(|err| AgentError::Other(format!("get session error: {}", err))) + } + + pub fn list_cli_sessions(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> { + self.store + .list_sessions("cli", include_archived) + .map_err(|err| AgentError::Other(format!("list sessions error: {}", err))) + } + + pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> { + self.store + .rename_session(session_id, title) + .map_err(|err| AgentError::Other(format!("rename session error: {}", err))) + } + + pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> { + self.store + .archive_session(session_id) + .map_err(|err| AgentError::Other(format!("archive session error: {}", err))) + } + + pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> { + self.store + .delete_session(session_id) + .map_err(|err| AgentError::Other(format!("delete session error: {}", err))) + } + + pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> { + self.store + .clear_messages(session_id) + .map_err(|err| AgentError::Other(format!("clear session error: {}", err))) + } + + pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> { + self.store + .load_messages(session_id) + .map_err(|err| AgentError::Other(format!("load messages error: {}", err))) + } + /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { let mut inner = self.inner.lock().await; @@ -189,6 +296,7 @@ impl SessionManager { self.provider_config.clone(), user_tx, self.tools.clone(), + self.store.clone(), ) .await?; let arc = Arc::new(Mutex::new(session)); @@ -251,15 +359,17 @@ impl SessionManager { let response = { let mut session_guard = session.lock().await; + session_guard.ensure_persistent_session(chat_id)?; + session_guard.ensure_chat_loaded(chat_id)?; + // 添加用户消息到历史 - if media.is_empty() { - session_guard.add_user_message(chat_id, content); - } else { - let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect(); - #[cfg(debug_assertions)] + let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect(); + #[cfg(debug_assertions)] + if !media_refs.is_empty() { tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); - session_guard.add_user_message_with_media(chat_id, content, media_refs); } + let user_message = session_guard.create_user_message(content, media_refs); + session_guard.append_persisted_message(chat_id, user_message)?; // 获取完整历史 let history = session_guard.get_or_create_history(chat_id).clone(); @@ -274,7 +384,7 @@ impl SessionManager { let response = agent.process(history).await?; // 添加助手响应到历史 - session_guard.add_assistant_message(chat_id, response.clone()); + session_guard.append_persisted_message(chat_id, response.clone())?; response }; @@ -294,7 +404,7 @@ impl SessionManager { pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { if let Some(session) = self.get(channel_name).await { let mut session_guard = session.lock().await; - session_guard.clear_all_history(); + session_guard.clear_all_history()?; } Ok(()) } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 5507d62..18ba69f 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -4,7 +4,7 @@ use axum::extract::State; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; -use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound}; +use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use super::{GatewayState, session::Session}; pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { @@ -24,8 +24,15 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { } }; - // CLI 使用独立的 session,channel_name = "cli-{uuid}" - let channel_name = format!("cli-{}", uuid::Uuid::new_v4()); + let initial_record = match state.session_manager.create_cli_session(None) { + Ok(record) => record, + Err(e) => { + tracing::error!(error = %e, "Failed to create initial CLI session"); + return; + } + }; + + let channel_name = "cli".to_string(); // 创建 CLI session let session = match Session::new( @@ -33,6 +40,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { provider_config, sender, state.session_manager.tools(), + state.session_manager.store(), ) .await { @@ -43,21 +51,27 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { } }; - let session_id = session.lock().await.id; - tracing::info!(session_id = %session_id, "CLI session established"); + if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) { + tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history"); + return; + } + + let runtime_session_id = session.lock().await.id; + let mut current_session_id = initial_record.id.clone(); + tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established"); let _ = session .lock() .await .send(WsOutbound::SessionEstablished { - session_id: session_id.to_string(), + session_id: current_session_id.clone(), }) .await; let (mut ws_sender, mut ws_receiver) = ws.split(); let mut receiver = receiver; - let session_id_for_sender = session_id; + let session_id_for_sender = runtime_session_id; tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) { @@ -76,7 +90,17 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { let text = text.to_string(); match parse_inbound(&text) { Ok(inbound) => { - handle_inbound(&session, inbound).await; + if let Err(e) = handle_inbound(&state, &session, &mut current_session_id, inbound).await { + tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); + let _ = session + .lock() + .await + .send(WsOutbound::Error { + code: "SESSION_ERROR".to_string(), + message: e.to_string(), + }) + .await; + } } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); @@ -93,92 +117,203 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { } Ok(WsMessage::Close(_)) | Err(_) => { #[cfg(debug_assertions)] - tracing::debug!(session_id = %session_id, "WebSocket closed"); + tracing::debug!(session_id = %runtime_session_id, "WebSocket closed"); break; } _ => {} } } - tracing::info!(session_id = %session_id, "CLI session ended"); + tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } -async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) { - let inbound_clone = inbound.clone(); +fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { + SessionSummary { + session_id: record.id, + title: record.title, + channel_name: record.channel_name, + chat_id: record.chat_id, + message_count: record.message_count, + last_active_at: record.last_active_at, + archived_at: record.archived_at, + } +} - // 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id) - let (content, chat_id) = match inbound_clone { - WsInbound::UserInput { - content, - channel: _, - chat_id, - sender_id: _, - } => { - // CLI 使用 session 中的 channel_name 作为标识 - // chat_id 使用传入的或使用默认 - let chat_id = chat_id.unwrap_or_else(|| "default".to_string()); - (content, chat_id) +async fn handle_inbound( + state: &Arc<GatewayState>, + session: &Arc<Mutex<Session>>, + current_session_id: &mut String, + inbound: WsInbound, +) -> Result<(), crate::agent::AgentError> { + match inbound { + WsInbound::UserInput { content, chat_id, .. } => { + let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); + let mut session_guard = session.lock().await; + + session_guard.ensure_persistent_session(&chat_id)?; + session_guard.ensure_chat_loaded(&chat_id)?; + + let user_message = session_guard.create_user_message(&content, Vec::new()); + session_guard.append_persisted_message(&chat_id, user_message)?; + + let raw_history = session_guard.get_or_create_history(&chat_id).clone(); + let history = match session_guard + .compressor() + .compress_if_needed(raw_history, session_guard.provider_config()) + .await + { + Ok(history) => history, + Err(error) => { + tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history"); + session_guard.get_or_create_history(&chat_id).clone() + } + }; + + let agent = session_guard.create_agent()?; + match agent.process(history).await { + Ok(response) => { + session_guard.append_persisted_message(&chat_id, response.clone())?; + let _ = session_guard + .send(WsOutbound::AssistantResponse { + id: response.id, + content: response.content, + role: response.role, + }) + .await; + } + Err(error) => { + tracing::error!(chat_id = %chat_id, error = %error, "Agent process error"); + let _ = session_guard + .send(WsOutbound::Error { + code: "LLM_ERROR".to_string(), + message: error.to_string(), + }) + .await; + } + } + + Ok(()) } - _ => return, - }; + WsInbound::ClearHistory { session_id, chat_id } => { + let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone()); + state.session_manager.clear_session_messages(&target)?; - let mut session_guard = session.lock().await; - - // 添加用户消息到历史 - session_guard.add_user_message(&chat_id, &content); - - // 获取完整历史 - let history = session_guard.get_or_create_history(&chat_id).clone(); - - // 压缩历史(如果需要) - let history = match session_guard.compressor() - .compress_if_needed(history, session_guard.provider_config()) - .await - { - Ok(h) => h, - Err(e) => { - tracing::warn!(chat_id = %chat_id, error = %e, "Compression failed, using original history"); - session_guard.get_or_create_history(&chat_id).clone() - } - }; - - // 创建 agent 并处理 - let agent = match session_guard.create_agent() { - Ok(a) => a, - Err(e) => { - tracing::error!(chat_id = %chat_id, error = %e, "Failed to create agent"); + let mut session_guard = session.lock().await; + session_guard.remove_history(&target); let _ = session_guard - .send(WsOutbound::Error { - code: "AGENT_ERROR".to_string(), - message: e.to_string(), + .send(WsOutbound::HistoryCleared { + session_id: target, }) .await; - return; + Ok(()) } - }; + WsInbound::CreateSession { title } => { + let record = state.session_manager.create_cli_session(title.as_deref())?; + *current_session_id = record.id.clone(); - match agent.process(history).await { - Ok(response) => { - #[cfg(debug_assertions)] - tracing::debug!(chat_id = %chat_id, "Agent response sent"); - // 添加助手响应到历史 - session_guard.add_assistant_message(&chat_id, response.clone()); + let mut session_guard = session.lock().await; + session_guard.ensure_chat_loaded(&record.id)?; let _ = session_guard - .send(WsOutbound::AssistantResponse { - id: response.id, - content: response.content, - role: response.role, + .send(WsOutbound::SessionCreated { + session_id: record.id, + title: record.title, }) .await; + Ok(()) } - Err(e) => { - tracing::error!(chat_id = %chat_id, error = %e, "Agent process error"); + WsInbound::ListSessions { include_archived } => { + let records = state.session_manager.list_cli_sessions(include_archived)?; + let summaries = records.into_iter().map(to_session_summary).collect(); + + let session_guard = session.lock().await; let _ = session_guard - .send(WsOutbound::Error { - code: "LLM_ERROR".to_string(), - message: e.to_string(), + .send(WsOutbound::SessionList { + sessions: summaries, + current_session_id: Some(current_session_id.clone()), }) .await; + Ok(()) + } + WsInbound::LoadSession { session_id } => { + let Some(record) = state.session_manager.get_session_record(&session_id)? else { + let session_guard = session.lock().await; + let _ = session_guard + .send(WsOutbound::Error { + code: "SESSION_NOT_FOUND".to_string(), + message: format!("Session not found: {}", session_id), + }) + .await; + return Ok(()); + }; + + *current_session_id = record.id.clone(); + let mut session_guard = session.lock().await; + session_guard.ensure_chat_loaded(&record.id)?; + let _ = session_guard + .send(WsOutbound::SessionLoaded { + session_id: record.id, + title: record.title, + message_count: record.message_count, + }) + .await; + Ok(()) + } + WsInbound::RenameSession { session_id, title } => { + let target = session_id.unwrap_or_else(|| current_session_id.clone()); + state.session_manager.rename_session(&target, &title)?; + let session_guard = session.lock().await; + let _ = session_guard + .send(WsOutbound::SessionRenamed { + session_id: target, + title, + }) + .await; + Ok(()) + } + WsInbound::ArchiveSession { session_id } => { + let target = session_id.unwrap_or_else(|| current_session_id.clone()); + state.session_manager.archive_session(&target)?; + let session_guard = session.lock().await; + let _ = session_guard + .send(WsOutbound::SessionArchived { session_id: target }) + .await; + Ok(()) + } + WsInbound::DeleteSession { session_id } => { + let target = session_id.unwrap_or_else(|| current_session_id.clone()); + state.session_manager.delete_session(&target)?; + + let replacement = if target == *current_session_id { + Some(state.session_manager.create_cli_session(None)?) + } else { + None + }; + + let mut session_guard = session.lock().await; + session_guard.remove_history(&target); + let _ = session_guard + .send(WsOutbound::SessionDeleted { + session_id: target.clone(), + }) + .await; + + if let Some(record) = replacement { + *current_session_id = record.id.clone(); + session_guard.ensure_chat_loaded(&record.id)?; + let _ = session_guard + .send(WsOutbound::SessionCreated { + session_id: record.id, + title: record.title, + }) + .await; + } + + Ok(()) + } + WsInbound::Ping => { + let session_guard = session.lock().await; + let _ = session_guard.send(WsOutbound::Pong).await; + Ok(()) } } } diff --git a/src/lib.rs b/src/lib.rs index d7778e3..8f78d5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,4 +9,5 @@ pub mod protocol; pub mod channels; pub mod logging; pub mod observability; +pub mod storage; pub mod tools; diff --git a/src/protocol.rs b/src/protocol.rs index 8d24404..b301c0e 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,5 +1,17 @@ use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionSummary { + pub session_id: String, + pub title: String, + pub channel_name: String, + pub chat_id: String, + pub message_count: i64, + pub last_active_at: i64, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub archived_at: Option<i64>, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum WsInbound { @@ -17,6 +29,38 @@ pub enum WsInbound { ClearHistory { #[serde(default, skip_serializing_if = "Option::is_none")] chat_id: Option<String>, + #[serde(default, skip_serializing_if = "Option::is_none")] + session_id: Option<String>, + }, + #[serde(rename = "create_session")] + CreateSession { + #[serde(default, skip_serializing_if = "Option::is_none")] + title: Option<String>, + }, + #[serde(rename = "list_sessions")] + ListSessions { + #[serde(default)] + include_archived: bool, + }, + #[serde(rename = "load_session")] + LoadSession { + session_id: String, + }, + #[serde(rename = "rename_session")] + RenameSession { + #[serde(default, skip_serializing_if = "Option::is_none")] + session_id: Option<String>, + title: String, + }, + #[serde(rename = "archive_session")] + ArchiveSession { + #[serde(default, skip_serializing_if = "Option::is_none")] + session_id: Option<String>, + }, + #[serde(rename = "delete_session")] + DeleteSession { + #[serde(default, skip_serializing_if = "Option::is_none")] + session_id: Option<String>, }, #[serde(rename = "ping")] Ping, @@ -31,6 +75,28 @@ pub enum WsOutbound { Error { code: String, message: String }, #[serde(rename = "session_established")] SessionEstablished { session_id: String }, + #[serde(rename = "session_created")] + SessionCreated { session_id: String, title: String }, + #[serde(rename = "session_list")] + SessionList { + sessions: Vec<SessionSummary>, + #[serde(default, skip_serializing_if = "Option::is_none")] + current_session_id: Option<String>, + }, + #[serde(rename = "session_loaded")] + SessionLoaded { + session_id: String, + title: String, + message_count: i64, + }, + #[serde(rename = "session_renamed")] + SessionRenamed { session_id: String, title: String }, + #[serde(rename = "session_archived")] + SessionArchived { session_id: String }, + #[serde(rename = "session_deleted")] + SessionDeleted { session_id: String }, + #[serde(rename = "history_cleared")] + HistoryCleared { session_id: String }, #[serde(rename = "pong")] Pong, } diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..f341f80 --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,447 @@ +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; + +use rusqlite::{Connection, OptionalExtension, params}; +use serde::{Deserialize, Serialize}; + +use crate::bus::ChatMessage; + +#[derive(Debug, thiserror::Error)] +pub enum StorageError { + #[error("database error: {0}")] + Database(#[from] rusqlite::Error), + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("serialization error: {0}")] + Serialization(#[from] serde_json::Error), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionRecord { + pub id: String, + pub title: String, + pub channel_name: String, + pub chat_id: String, + pub summary: Option<String>, + pub created_at: i64, + pub updated_at: i64, + pub last_active_at: i64, + pub archived_at: Option<i64>, + pub deleted_at: Option<i64>, + pub message_count: i64, +} + +#[derive(Clone)] +pub struct SessionStore { + conn: Arc<Mutex<Connection>>, +} + +impl SessionStore { + pub fn new() -> Result<Self, StorageError> { + let db_path = default_session_db_path()?; + Self::open_at_path(&db_path) + } + + fn open_at_path(path: &Path) -> Result<Self, StorageError> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + + let conn = Connection::open(path)?; + Self::from_connection(conn) + } + + fn from_connection(conn: Connection) -> Result<Self, StorageError> { + conn.execute_batch( + " + PRAGMA journal_mode = WAL; + PRAGMA foreign_keys = ON; + + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + channel_name TEXT NOT NULL, + chat_id TEXT NOT NULL, + summary TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + last_active_at INTEGER NOT NULL, + archived_at INTEGER, + deleted_at INTEGER, + message_count INTEGER NOT NULL DEFAULT 0 + ); + + CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived + ON sessions(channel_name, archived_at, last_active_at DESC); + CREATE INDEX IF NOT EXISTS idx_sessions_updated_at + ON sessions(updated_at DESC); + + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + media_refs_json TEXT NOT NULL, + tool_call_id TEXT, + tool_name TEXT, + created_at INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, + UNIQUE(session_id, seq) + ); + + CREATE INDEX IF NOT EXISTS idx_messages_session_seq + ON messages(session_id, seq); + CREATE INDEX IF NOT EXISTS idx_messages_session_created + ON messages(session_id, created_at); + ", + )?; + + Ok(Self { + conn: Arc::new(Mutex::new(conn)), + }) + } + + #[cfg(test)] + fn in_memory() -> Result<Self, StorageError> { + Self::from_connection(Connection::open_in_memory()?) + } + + pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> { + let now = current_timestamp(); + let id = uuid::Uuid::new_v4().to_string(); + let title = title + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) + .unwrap_or_else(|| format!("CLI Session {}", &id[..8])); + + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + " + INSERT INTO sessions ( + id, title, channel_name, chat_id, summary, + created_at, updated_at, last_active_at, archived_at, deleted_at, message_count + ) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0) + ", + params![id, title, id, now], + )?; + + drop(conn); + self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) + } + + pub fn ensure_channel_session( + &self, + channel_name: &str, + chat_id: &str, + ) -> Result<SessionRecord, StorageError> { + let session_id = persistent_session_id(channel_name, chat_id); + if let Some(record) = self.get_session(&session_id)? { + return Ok(record); + } + + let now = current_timestamp(); + let title = format!("{}:{}", channel_name, chat_id); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + " + INSERT INTO sessions ( + id, title, channel_name, chat_id, summary, + created_at, updated_at, last_active_at, archived_at, deleted_at, message_count + ) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0) + ", + params![session_id, title, channel_name, chat_id, now], + )?; + drop(conn); + + self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) + } + + pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut stmt = conn.prepare( + " + SELECT id, title, channel_name, chat_id, summary, + created_at, updated_at, last_active_at, + archived_at, deleted_at, message_count + FROM sessions + WHERE id = ?1 AND deleted_at IS NULL + ", + )?; + + stmt.query_row(params![session_id], map_session_record) + .optional() + .map_err(StorageError::from) + } + + pub fn list_sessions( + &self, + channel_name: &str, + include_archived: bool, + ) -> Result<Vec<SessionRecord>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut sql = String::from( + " + SELECT id, title, channel_name, chat_id, summary, + created_at, updated_at, last_active_at, + archived_at, deleted_at, message_count + FROM sessions + WHERE channel_name = ?1 + AND deleted_at IS NULL + ", + ); + + if !include_archived { + sql.push_str(" AND archived_at IS NULL"); + } + + sql.push_str(" ORDER BY last_active_at DESC, created_at DESC"); + + let mut stmt = conn.prepare(&sql)?; + let rows = stmt.query_map(params![channel_name], map_session_record)?; + let mut sessions = Vec::new(); + for row in rows { + sessions.push(row?); + } + Ok(sessions) + } + + pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + "UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1 AND deleted_at IS NULL", + params![session_id, title.trim(), now], + )?; + Ok(()) + } + + pub fn archive_session(&self, session_id: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + "UPDATE sessions SET archived_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL", + params![session_id, now], + )?; + Ok(()) + } + + pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?; + conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?; + Ok(()) + } + + pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?; + conn.execute( + " + UPDATE sessions + SET message_count = 0, updated_at = ?2, last_active_at = ?2 + WHERE id = ?1 AND deleted_at IS NULL + ", + params![session_id, now], + )?; + Ok(()) + } + + pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let tx = conn.unchecked_transaction()?; + + let seq: i64 = tx.query_row( + "SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1", + params![session_id], + |row| row.get(0), + )?; + + let media_refs_json = serde_json::to_string(&message.media_refs)?; + tx.execute( + " + INSERT INTO messages ( + id, session_id, seq, role, content, + media_refs_json, tool_call_id, tool_name, created_at + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) + ", + params![ + message.id, + session_id, + seq, + message.role, + message.content, + media_refs_json, + message.tool_call_id, + message.tool_name, + message.timestamp, + ], + )?; + + let now = current_timestamp(); + tx.execute( + " + UPDATE sessions + SET message_count = message_count + 1, + updated_at = ?2, + last_active_at = ?2, + archived_at = NULL + WHERE id = ?1 AND deleted_at IS NULL + ", + params![session_id, now], + )?; + + tx.commit()?; + Ok(()) + } + + pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut stmt = conn.prepare( + " + SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name + FROM messages + WHERE session_id = ?1 + ORDER BY seq ASC + ", + )?; + + let rows = stmt.query_map(params![session_id], |row| { + let media_refs_json: String = row.get(3)?; + let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + media_refs_json.len(), + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; + + Ok(ChatMessage { + id: row.get(0)?, + role: row.get(1)?, + content: row.get(2)?, + media_refs, + timestamp: row.get(4)?, + tool_call_id: row.get(5)?, + tool_name: row.get(6)?, + }) + })?; + + let mut messages = Vec::new(); + for row in rows { + messages.push(row?); + } + Ok(messages) + } +} + +pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String { + if channel_name == "cli" { + chat_id.to_string() + } else { + format!("{}:{}", channel_name, chat_id) + } +} + +fn default_session_db_path() -> Result<PathBuf, std::io::Error> { + let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); + Ok(home.join(".picobot").join("storage").join("sessions.db")) +} + +fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> { + Ok(SessionRecord { + id: row.get(0)?, + title: row.get(1)?, + channel_name: row.get(2)?, + chat_id: row.get(3)?, + summary: row.get(4)?, + created_at: row.get(5)?, + updated_at: row.get(6)?, + last_active_at: row.get(7)?, + archived_at: row.get(8)?, + deleted_at: row.get(9)?, + message_count: row.get(10)?, + }) +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_millis() as i64 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_persistent_session_id_for_cli_and_channel() { + assert_eq!(persistent_session_id("cli", "abc"), "abc"); + assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc"); + } + + #[test] + fn test_session_store_roundtrip_and_lifecycle() { + let store = SessionStore::in_memory().unwrap(); + + let session = store.create_cli_session(Some("demo")).unwrap(); + assert_eq!(session.title, "demo"); + assert_eq!(session.channel_name, "cli"); + assert_eq!(session.chat_id, session.id); + assert_eq!(session.message_count, 0); + + let first = ChatMessage::user("hello"); + let second = ChatMessage::assistant("world"); + store.append_message(&session.id, &first).unwrap(); + store.append_message(&session.id, &second).unwrap(); + + let stored = store.get_session(&session.id).unwrap().unwrap(); + assert_eq!(stored.message_count, 2); + assert!(stored.archived_at.is_none()); + + let messages = store.load_messages(&session.id).unwrap(); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, "user"); + assert_eq!(messages[0].content, "hello"); + assert_eq!(messages[1].role, "assistant"); + assert_eq!(messages[1].content, "world"); + + store.rename_session(&session.id, "renamed").unwrap(); + let renamed = store.get_session(&session.id).unwrap().unwrap(); + assert_eq!(renamed.title, "renamed"); + + store.archive_session(&session.id).unwrap(); + let archived = store.get_session(&session.id).unwrap().unwrap(); + assert!(archived.archived_at.is_some()); + + let active_only = store.list_sessions("cli", false).unwrap(); + assert!(active_only.is_empty()); + + let including_archived = store.list_sessions("cli", true).unwrap(); + assert_eq!(including_archived.len(), 1); + + store.clear_messages(&session.id).unwrap(); + let cleared = store.load_messages(&session.id).unwrap(); + assert!(cleared.is_empty()); + let cleared_session = store.get_session(&session.id).unwrap().unwrap(); + assert_eq!(cleared_session.message_count, 0); + + store.delete_session(&session.id).unwrap(); + assert!(store.get_session(&session.id).unwrap().is_none()); + } + + #[test] + fn test_ensure_channel_session_is_stable() { + let store = SessionStore::in_memory().unwrap(); + + let first = store.ensure_channel_session("feishu", "chat-1").unwrap(); + let second = store.ensure_channel_session("feishu", "chat-1").unwrap(); + + assert_eq!(first.id, second.id); + assert_eq!(first.chat_id, "chat-1"); + assert_eq!(second.channel_name, "feishu"); + } +} \ No newline at end of file diff --git a/tests/test_integration.rs b/tests/test_integration.rs index 5a942e5..09f705e 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message}; -use PicoBot::config::{Config, LLMProviderConfig}; +use picobot::providers::{create_provider, ChatCompletionRequest, Message}; +use picobot::config::{Config, LLMProviderConfig}; fn load_config() -> Option<LLMProviderConfig> { dotenv::from_filename("tests/test.env").ok()?; @@ -24,15 +24,13 @@ fn load_config() -> Option<LLMProviderConfig> { max_tokens: Some(100), model_extra: HashMap::new(), max_tool_iterations: 20, + token_limit: 128_000, }) } fn create_request(content: &str) -> ChatCompletionRequest { ChatCompletionRequest { - messages: vec![Message { - role: "user".to_string(), - content: content.to_string(), - }], + messages: vec![Message::user(content)], temperature: Some(0.0), max_tokens: Some(100), tools: None, @@ -64,9 +62,9 @@ async fn test_openai_conversation() { let request = ChatCompletionRequest { messages: vec![ - Message { role: "user".to_string(), content: "My name is Alice".to_string() }, - Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() }, - Message { role: "user".to_string(), content: "What is my name?".to_string() }, + Message::user("My name is Alice"), + Message::assistant("Hello Alice!"), + Message::user("What is my name?"), ], temperature: Some(0.0), max_tokens: Some(50), diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index 58f34f4..d73ce37 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -1,31 +1,26 @@ -use PicoBot::providers::{ChatCompletionRequest, Message}; +use picobot::providers::{ChatCompletionRequest, Message}; +use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; /// Test that message with special characters is properly escaped #[test] fn test_message_special_characters() { - let msg = Message { - role: "user".to_string(), - content: "Hello \"world\"\nNew line\tTab".to_string(), - }; + let msg = Message::user("Hello \"world\"\nNew line\tTab"); let json = serde_json::to_string(&msg).unwrap(); let deserialized: Message = serde_json::from_str(&json).unwrap(); - assert_eq!(deserialized.content, "Hello \"world\"\nNew line\tTab"); + assert_eq!(deserialized.role, "user"); + assert_eq!(deserialized.content.len(), 1); + let encoded = serde_json::to_string(&deserialized.content).unwrap(); + assert!(encoded.contains("Hello \\\"world\\\"\\nNew line\\tTab")); } /// Test that multi-line system prompt is preserved #[test] fn test_multiline_system_prompt() { let messages = vec![ - Message { - role: "system".to_string(), - content: "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate".to_string(), - }, - Message { - role: "user".to_string(), - content: "Hi".to_string(), - }, + Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"), + Message::user("Hi"), ]; let json = serde_json::to_string(&messages[0]).unwrap(); @@ -39,14 +34,8 @@ fn test_multiline_system_prompt() { fn test_chat_request_serialization() { let request = ChatCompletionRequest { messages: vec![ - Message { - role: "system".to_string(), - content: "You are helpful".to_string(), - }, - Message { - role: "user".to_string(), - content: "Hello".to_string(), - }, + Message::system("You are helpful"), + Message::user("Hello"), ], temperature: Some(0.7), max_tokens: Some(100), @@ -58,8 +47,73 @@ fn test_chat_request_serialization() { // Verify structure assert!(json.contains(r#""role":"system""#)); assert!(json.contains(r#""role":"user""#)); - assert!(json.contains(r#""content":"You are helpful""#)); - assert!(json.contains(r#""content":"Hello""#)); + assert!(json.contains("You are helpful")); + assert!(json.contains("Hello")); assert!(json.contains(r#""temperature":0.7"#)); assert!(json.contains(r#""max_tokens":100"#)); } + +#[test] +fn test_session_inbound_serialization() { + let msg = WsInbound::CreateSession { + title: Some("demo".to_string()), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"create_session""#)); + assert!(json.contains(r#""title":"demo""#)); + + let decoded: WsInbound = serde_json::from_str(&json).unwrap(); + match decoded { + WsInbound::CreateSession { title } => { + assert_eq!(title.as_deref(), Some("demo")); + } + other => panic!("unexpected decoded variant: {:?}", other), + } +} + +#[test] +fn test_session_list_outbound_serialization() { + let msg = WsOutbound::SessionList { + sessions: vec![SessionSummary { + session_id: "session-1".to_string(), + title: "demo".to_string(), + channel_name: "cli".to_string(), + chat_id: "session-1".to_string(), + message_count: 2, + last_active_at: 123, + archived_at: None, + }], + current_session_id: Some("session-1".to_string()), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"session_list""#)); + assert!(json.contains(r#""session_id":"session-1""#)); + assert!(json.contains(r#""message_count":2"#)); + + let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); + match decoded { + WsOutbound::SessionList { + sessions, + current_session_id, + } => { + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].title, "demo"); + assert_eq!(current_session_id.as_deref(), Some("session-1")); + } + other => panic!("unexpected decoded variant: {:?}", other), + } +} + +#[test] +fn test_clear_history_with_session_id_serialization() { + let msg = WsInbound::ClearHistory { + chat_id: None, + session_id: Some("session-1".to_string()), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"clear_history""#)); + assert!(json.contains(r#""session_id":"session-1""#)); +} diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 1421891..39ead1d 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction}; -use PicoBot::config::LLMProviderConfig; +use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction}; +use picobot::config::LLMProviderConfig; fn load_openai_config() -> Option<LLMProviderConfig> { dotenv::from_filename("tests/test.env").ok()?; @@ -24,6 +24,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> { max_tokens: Some(100), model_extra: HashMap::new(), max_tool_iterations: 20, + token_limit: 128_000, }) } @@ -56,10 +57,7 @@ async fn test_openai_tool_call() { let provider = create_provider(config).expect("Failed to create provider"); let request = ChatCompletionRequest { - messages: vec![Message { - role: "user".to_string(), - content: "What is the weather in Tokyo?".to_string(), - }], + messages: vec![Message::user("What is the weather in Tokyo?")], temperature: Some(0.0), max_tokens: Some(200), tools: Some(vec![make_weather_tool()]), @@ -85,10 +83,7 @@ async fn test_openai_tool_call_with_manual_execution() { // First request with tool let request1 = ChatCompletionRequest { - messages: vec![Message { - role: "user".to_string(), - content: "What is the weather in Tokyo?".to_string(), - }], + messages: vec![Message::user("What is the weather in Tokyo?")], temperature: Some(0.0), max_tokens: Some(200), tools: Some(vec![make_weather_tool()]), @@ -102,14 +97,8 @@ async fn test_openai_tool_call_with_manual_execution() { // Second request with tool result let request2 = ChatCompletionRequest { messages: vec![ - Message { - role: "user".to_string(), - content: "What is the weather in Tokyo?".to_string(), - }, - Message { - role: "assistant".to_string(), - content: r#"I'll check the weather for you using the get_weather tool."#.to_string(), - }, + Message::user("What is the weather in Tokyo?"), + Message::assistant(r#"I'll check the weather for you using the get_weather tool."#), ], temperature: Some(0.0), max_tokens: Some(200), @@ -131,10 +120,7 @@ async fn test_openai_no_tool_when_not_provided() { let provider = create_provider(config).expect("Failed to create provider"); let request = ChatCompletionRequest { - messages: vec![Message { - role: "user".to_string(), - content: "Say hello in one word.".to_string(), - }], + messages: vec![Message::user("Say hello in one word.")], temperature: Some(0.0), max_tokens: Some(10), tools: None,