feat: enhance WebSocket session management and storage

- Added SessionSummary struct for session metadata.
- Updated ws_handler to create and manage CLI sessions more robustly.
- Implemented session creation, loading, renaming, archiving, and deletion via WebSocket messages.
- Introduced SessionStore for persistent session storage using SQLite.
- Enhanced error handling and logging for session operations.
- Updated protocol definitions for new session-related WebSocket messages.
- Refactored tests to cover new session functionalities and ensure proper serialization.
This commit is contained in:
ooodc 2026-04-18 13:09:14 +08:00
parent c971bc3639
commit 8bb32fa066
14 changed files with 1204 additions and 186 deletions

View File

@ -27,3 +27,4 @@ mime_guess = "2.0"
base64 = "0.22" base64 = "0.22"
tempfile = "3" tempfile = "3"
meval = "0.2" meval = "0.2"
rusqlite = { version = "0.32", features = ["bundled"] }

View File

@ -2,6 +2,23 @@ use crate::bus::ChatMessage;
use super::channel::CliChannel; use super::channel::CliChannel;
pub enum InputEvent {
Message(ChatMessage),
Command(InputCommand),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InputCommand {
Exit,
Clear,
New(Option<String>),
Sessions,
Use(String),
Rename(String),
Archive,
Delete,
}
pub struct InputHandler { pub struct InputHandler {
channel: CliChannel, channel: CliChannel,
} }
@ -13,7 +30,7 @@ impl InputHandler {
} }
} }
pub async fn read_input(&mut self, prompt: &str) -> Result<Option<ChatMessage>, InputError> { pub async fn read_input(&mut self, prompt: &str) -> Result<Option<InputEvent>, InputError> {
match self.channel.read_line(prompt).await { match self.channel.read_line(prompt).await {
Ok(Some(line)) => { Ok(Some(line)) => {
if line.trim().is_empty() { if line.trim().is_empty() {
@ -21,10 +38,10 @@ impl InputHandler {
} }
if let Some(cmd) = self.handle_special_commands(&line) { if let Some(cmd) = self.handle_special_commands(&line) {
return Ok(Some(cmd)); return Ok(Some(InputEvent::Command(cmd)));
} }
Ok(Some(ChatMessage::user(line))) Ok(Some(InputEvent::Message(ChatMessage::user(line))))
} }
Ok(None) => Ok(None), Ok(None) => Ok(None),
Err(e) => Err(InputError::IoError(e)), Err(e) => Err(InputError::IoError(e)),
@ -39,10 +56,21 @@ impl InputHandler {
self.channel.write_response(content).await.map_err(InputError::IoError) self.channel.write_response(content).await.map_err(InputError::IoError)
} }
fn handle_special_commands(&self, line: &str) -> Option<ChatMessage> { fn handle_special_commands(&self, line: &str) -> Option<InputCommand> {
match line.trim() { let trimmed = line.trim();
"/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")), let mut parts = trimmed.splitn(2, char::is_whitespace);
"/clear" => Some(ChatMessage::system("__CLEAR__")), let command = parts.next()?;
let arg = parts.next().map(str::trim).filter(|value| !value.is_empty());
match command {
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
"/clear" => Some(InputCommand::Clear),
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
"/sessions" => Some(InputCommand::Sessions),
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
"/archive" => Some(InputCommand::Archive),
"/delete" => Some(InputCommand::Delete),
_ => None, _ => None,
} }
} }
@ -68,3 +96,34 @@ impl std::fmt::Display for InputError {
} }
impl std::error::Error for InputError {} impl std::error::Error for InputError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_special_command_parsing() {
let handler = InputHandler::new();
assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit));
assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear));
assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None)));
assert_eq!(
handler.handle_special_commands("/new planning"),
Some(InputCommand::New(Some("planning".to_string())))
);
assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions));
assert_eq!(
handler.handle_special_commands("/use abc123"),
Some(InputCommand::Use("abc123".to_string()))
);
assert_eq!(
handler.handle_special_commands("/rename project alpha"),
Some(InputCommand::Rename("project alpha".to_string()))
);
assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive));
assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete));
assert_eq!(handler.handle_special_commands("/unknown"), None);
assert_eq!(handler.handle_special_commands("/use"), None);
}
}

View File

@ -2,4 +2,4 @@ pub mod channel;
pub mod input; pub mod input;
pub use channel::CliChannel; pub use channel::CliChannel;
pub use input::InputHandler; pub use input::{InputCommand, InputEvent, InputHandler};

View File

@ -3,7 +3,38 @@ pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_ou
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{connect_async, tungstenite::Message}; use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::cli::InputHandler; use crate::cli::{InputCommand, InputEvent, InputHandler};
fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String {
if sessions.is_empty() {
return "No sessions found.".to_string();
}
let mut lines = Vec::with_capacity(sessions.len() + 1);
lines.push("Sessions:".to_string());
for session in sessions {
let marker = if current_session_id == Some(session.session_id.as_str()) {
"*"
} else {
"-"
};
let archived = if session.archived_at.is_some() {
" [archived]"
} else {
""
};
lines.push(format!(
"{} {} | {} | {} messages{}",
marker,
session.session_id,
session.title,
session.message_count,
archived,
));
}
lines.join("\n")
}
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> { fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
serde_json::from_str(raw) serde_json::from_str(raw)
@ -16,7 +47,8 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let (mut sender, mut receiver) = ws_stream.split(); let (mut sender, mut receiver) = ws_stream.split();
let mut input = InputHandler::new(); let mut input = InputHandler::new();
input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?; let mut current_session_id: Option<String> = None;
input.write_output("picobot CLI - Commands: /new [title], /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
// Main loop: poll both stdin and WebSocket // Main loop: poll both stdin and WebSocket
loop { loop {
@ -35,10 +67,38 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
input.write_output(&format!("Error: {}", message)).await?; input.write_output(&format!("Error: {}", message)).await?;
} }
WsOutbound::SessionEstablished { session_id } => { WsOutbound::SessionEstablished { session_id } => {
current_session_id = Some(session_id.clone());
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id, "Session established"); tracing::debug!(session_id = %session_id, "Session established");
input.write_output(&format!("Session: {}\n", session_id)).await?; input.write_output(&format!("Session: {}\n", session_id)).await?;
} }
WsOutbound::SessionCreated { session_id, title } => {
current_session_id = Some(session_id.clone());
input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?;
}
WsOutbound::SessionList { sessions, current_session_id: listed_current } => {
let display = format_session_list(&sessions, listed_current.as_deref());
input.write_output(&format!("{}\n", display)).await?;
}
WsOutbound::SessionLoaded { session_id, title, message_count } => {
current_session_id = Some(session_id.clone());
input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?;
}
WsOutbound::SessionRenamed { session_id, title } => {
input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?;
}
WsOutbound::SessionArchived { session_id } => {
input.write_output(&format!("Archived session: {}\n", session_id)).await?;
}
WsOutbound::SessionDeleted { session_id } => {
if current_session_id.as_deref() == Some(session_id.as_str()) {
current_session_id = None;
}
input.write_output(&format!("Deleted session: {}\n", session_id)).await?;
}
WsOutbound::HistoryCleared { session_id } => {
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
}
_ => {} _ => {}
} }
} }
@ -54,32 +114,86 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
// Handle stdin input // Handle stdin input
result = input.read_input("> ") => { result = input.read_input("> ") => {
match result { match result {
Ok(Some(msg)) => { Ok(Some(event)) => {
match msg.content.as_str() { match event {
"__EXIT__" => { InputEvent::Command(InputCommand::Exit) => {
input.write_output("Goodbye!").await?; input.write_output("Goodbye!").await?;
break; break;
} }
"__CLEAR__" => { InputEvent::Command(InputCommand::Clear) => {
let inbound = WsInbound::ClearHistory { chat_id: None }; let inbound = WsInbound::ClearHistory {
chat_id: None,
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) { if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await; let _ = sender.send(Message::Text(text.into())).await;
} }
continue; continue;
} }
_ => {} InputEvent::Command(InputCommand::New(title)) => {
} let inbound = WsInbound::CreateSession { title };
if let Ok(text) = serialize_inbound(&inbound) {
let inbound = WsInbound::UserInput { let _ = sender.send(Message::Text(text.into())).await;
content: msg.content, }
channel: None, continue;
chat_id: None, }
sender_id: None, InputEvent::Command(InputCommand::Sessions) => {
}; let inbound = WsInbound::ListSessions {
if let Ok(text) = serialize_inbound(&inbound) { include_archived: true,
if sender.send(Message::Text(text.into())).await.is_err() { };
tracing::error!("Failed to send message to gateway"); if let Ok(text) = serialize_inbound(&inbound) {
break; let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Use(session_id)) => {
let inbound = WsInbound::LoadSession { session_id };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Rename(title)) => {
let inbound = WsInbound::RenameSession {
session_id: current_session_id.clone(),
title,
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Archive) => {
let inbound = WsInbound::ArchiveSession {
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Delete) => {
let inbound = WsInbound::DeleteSession {
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Message(msg) => {
let inbound = WsInbound::UserInput {
content: msg.content,
channel: None,
chat_id: current_session_id.clone(),
sender_id: None,
};
if let Ok(text) = serialize_inbound(&inbound) {
if sender.send(Message::Text(text.into())).await.is_err() {
tracing::error!("Failed to send message to gateway");
break;
}
}
} }
} }
} }

View File

@ -267,9 +267,54 @@ fn resolve_env_placeholders(content: &str) -> String {
mod tests { mod tests {
use super::*; use super::*;
fn write_test_config() -> tempfile::NamedTempFile {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
},
"volcengine": {
"type": "openai",
"base_url": "https://example.invalid/volc",
"api_key": "test-key-2",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"temperature": 0.0
},
"doubao-seed-2-0-lite-260215": {
"model_id": "doubao-seed-2-0-lite-260215"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"gateway": {
"host": "0.0.0.0",
"port": 19876
}
}"#,
)
.unwrap();
file
}
#[test] #[test]
fn test_config_load() { fn test_config_load() {
let config = Config::load("config.json").unwrap(); let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
// Check providers // Check providers
assert!(config.providers.contains_key("volcengine")); assert!(config.providers.contains_key("volcengine"));
@ -285,7 +330,8 @@ mod tests {
#[test] #[test]
fn test_get_provider_config() { fn test_get_provider_config() {
let config = Config::load("config.json").unwrap(); let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap(); let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.provider_type, "openai");
@ -296,7 +342,8 @@ mod tests {
#[test] #[test]
fn test_default_gateway_config() { fn test_default_gateway_config() {
let config = Config::load("config.json").unwrap(); let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.host, "0.0.0.0");
assert_eq!(config.gateway.port, 19876); assert_eq!(config.gateway.port, 19876);
} }

View File

@ -29,7 +29,7 @@ impl GatewayState {
// Session TTL from config (default 4 hours) // Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let session_manager = SessionManager::new(session_ttl_hours, provider_config); let session_manager = SessionManager::new(session_ttl_hours, provider_config)?;
let channel_manager = ChannelManager::new(); let channel_manager = ChannelManager::new();
let bus = channel_manager.bus(); let bus = channel_manager.bus();

View File

@ -7,6 +7,7 @@ use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{ use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, ToolRegistry, WebFetchTool, HttpRequestTool, ToolRegistry, WebFetchTool,
@ -23,6 +24,7 @@ pub struct Session {
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
compressor: ContextCompressor, compressor: ContextCompressor,
store: Arc<SessionStore>,
} }
impl Session { impl Session {
@ -31,6 +33,7 @@ impl Session {
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>, user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
store: Arc<SessionStore>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
Ok(Self { Ok(Self {
id: Uuid::new_v4(), id: Uuid::new_v4(),
@ -40,9 +43,33 @@ impl Session {
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
tools, tools,
compressor: ContextCompressor::new(provider_config.token_limit), compressor: ContextCompressor::new(provider_config.token_limit),
store,
}) })
} }
pub fn persistent_session_id(&self, chat_id: &str) -> String {
persistent_session_id(&self.channel_name, chat_id)
}
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
self.store
.ensure_channel_session(&self.channel_name, chat_id)
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
}
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
if self.chat_histories.contains_key(chat_id) {
return Ok(());
}
let history = self
.store
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
Ok(())
}
/// 获取或创建指定 chat_id 的会话历史 /// 获取或创建指定 chat_id 的会话历史
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> { pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
self.chat_histories self.chat_histories
@ -55,41 +82,62 @@ impl Session {
self.chat_histories.get(chat_id) self.chat_histories.get(chat_id)
} }
/// 添加用户消息到指定 chat_id 的历史 /// 使用完整消息追加到历史
pub fn add_user_message(&mut self, chat_id: &str, content: &str) { pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
let history = self.get_or_create_history(chat_id); let history = self.get_or_create_history(chat_id);
history.push(ChatMessage::user(content)); history.push(message);
} }
/// 添加带媒体的用户消息到指定 chat_id 的历史 pub fn remove_history(&mut self, chat_id: &str) {
pub fn add_user_message_with_media(&mut self, chat_id: &str, content: &str, media_refs: Vec<String>) { self.chat_histories.remove(chat_id);
let history = self.get_or_create_history(chat_id);
history.push(ChatMessage::user_with_media(content, media_refs));
} }
/// 添加助手响应到指定 chat_id 的历史 pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
history.push(message);
}
}
/// 清除指定 chat_id 的历史
pub fn clear_chat_history(&mut self, chat_id: &str) {
if let Some(history) = self.chat_histories.get_mut(chat_id) { if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len(); let len = history.len();
history.clear(); history.clear();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
} }
self.store
.clear_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
}
/// 将消息写入内存与持久化层
pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id);
self.store
.append_message(&session_id, &message)
.map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?;
self.add_message(chat_id, message);
Ok(())
}
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
if media_refs.is_empty() {
ChatMessage::user(content)
} else {
ChatMessage::user_with_media(content, media_refs)
}
} }
/// 清除所有历史 /// 清除所有历史
pub fn clear_all_history(&mut self) { pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
let total: usize = self.chat_histories.values().map(|h| h.len()).sum(); let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear(); self.chat_histories.clear();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(previous_total = total, "All chat histories cleared"); tracing::debug!(previous_total = total, "All chat histories cleared");
for chat_id in chat_ids {
self.store
.clear_messages(&self.persistent_session_id(&chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?;
}
Ok(())
} }
pub async fn send(&self, msg: WsOutbound) { pub async fn send(&self, msg: WsOutbound) {
@ -118,6 +166,7 @@ pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>, inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
store: Arc<SessionStore>,
} }
struct SessionManagerInner { struct SessionManagerInner {
@ -144,8 +193,13 @@ fn default_tools() -> ToolRegistry {
} }
impl SessionManager { impl SessionManager {
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self { pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
Self { let store = Arc::new(
SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner { inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(), sessions: HashMap::new(),
session_timestamps: HashMap::new(), session_timestamps: HashMap::new(),
@ -153,13 +207,66 @@ impl SessionManager {
})), })),
provider_config, provider_config,
tools: Arc::new(default_tools()), tools: Arc::new(default_tools()),
} store,
})
} }
pub fn tools(&self) -> Arc<ToolRegistry> { pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone() self.tools.clone()
} }
pub fn store(&self) -> Arc<SessionStore> {
self.store.clone()
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub fn get_session_record(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub fn list_cli_sessions(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
self.store
.load_messages(session_id)
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
}
/// 确保 session 存在且未超时,超时则重建 /// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await; let mut inner = self.inner.lock().await;
@ -189,6 +296,7 @@ impl SessionManager {
self.provider_config.clone(), self.provider_config.clone(),
user_tx, user_tx,
self.tools.clone(), self.tools.clone(),
self.store.clone(),
) )
.await?; .await?;
let arc = Arc::new(Mutex::new(session)); let arc = Arc::new(Mutex::new(session));
@ -251,15 +359,17 @@ impl SessionManager {
let response = { let response = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(chat_id)?;
session_guard.ensure_chat_loaded(chat_id)?;
// 添加用户消息到历史 // 添加用户消息到历史
if media.is_empty() { let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
session_guard.add_user_message(chat_id, content); #[cfg(debug_assertions)]
} else { if !media_refs.is_empty() {
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
session_guard.add_user_message_with_media(chat_id, content, media_refs);
} }
let user_message = session_guard.create_user_message(content, media_refs);
session_guard.append_persisted_message(chat_id, user_message)?;
// 获取完整历史 // 获取完整历史
let history = session_guard.get_or_create_history(chat_id).clone(); let history = session_guard.get_or_create_history(chat_id).clone();
@ -274,7 +384,7 @@ impl SessionManager {
let response = agent.process(history).await?; let response = agent.process(history).await?;
// 添加助手响应到历史 // 添加助手响应到历史
session_guard.add_assistant_message(chat_id, response.clone()); session_guard.append_persisted_message(chat_id, response.clone())?;
response response
}; };
@ -294,7 +404,7 @@ impl SessionManager {
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
if let Some(session) = self.get(channel_name).await { if let Some(session) = self.get(channel_name).await {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
session_guard.clear_all_history(); session_guard.clear_all_history()?;
} }
Ok(()) Ok(())
} }

View File

@ -4,7 +4,7 @@ use axum::extract::State;
use axum::response::Response; use axum::response::Response;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound}; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::Session}; use super::{GatewayState, session::Session};
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
@ -24,8 +24,15 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
}; };
// CLI 使用独立的 sessionchannel_name = "cli-{uuid}" let initial_record = match state.session_manager.create_cli_session(None) {
let channel_name = format!("cli-{}", uuid::Uuid::new_v4()); Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
return;
}
};
let channel_name = "cli".to_string();
// 创建 CLI session // 创建 CLI session
let session = match Session::new( let session = match Session::new(
@ -33,6 +40,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
provider_config, provider_config,
sender, sender,
state.session_manager.tools(), state.session_manager.tools(),
state.session_manager.store(),
) )
.await .await
{ {
@ -43,21 +51,27 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
}; };
let session_id = session.lock().await.id; if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
tracing::info!(session_id = %session_id, "CLI session established"); tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
return;
}
let runtime_session_id = session.lock().await.id;
let mut current_session_id = initial_record.id.clone();
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = session let _ = session
.lock() .lock()
.await .await
.send(WsOutbound::SessionEstablished { .send(WsOutbound::SessionEstablished {
session_id: session_id.to_string(), session_id: current_session_id.clone(),
}) })
.await; .await;
let (mut ws_sender, mut ws_receiver) = ws.split(); let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver; let mut receiver = receiver;
let session_id_for_sender = session_id; let session_id_for_sender = runtime_session_id;
tokio::spawn(async move { tokio::spawn(async move {
while let Some(msg) = receiver.recv().await { while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) { if let Ok(text) = serialize_outbound(&msg) {
@ -76,7 +90,17 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let text = text.to_string(); let text = text.to_string();
match parse_inbound(&text) { match parse_inbound(&text) {
Ok(inbound) => { Ok(inbound) => {
handle_inbound(&session, inbound).await; if let Err(e) = handle_inbound(&state, &session, &mut current_session_id, inbound).await {
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
} }
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message"); tracing::warn!(error = %e, "Failed to parse inbound message");
@ -93,92 +117,203 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
Ok(WsMessage::Close(_)) | Err(_) => { Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id, "WebSocket closed"); tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
break; break;
} }
_ => {} _ => {}
} }
} }
tracing::info!(session_id = %session_id, "CLI session ended"); tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
} }
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) { fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
let inbound_clone = inbound.clone(); SessionSummary {
session_id: record.id,
title: record.title,
channel_name: record.channel_name,
chat_id: record.chat_id,
message_count: record.message_count,
last_active_at: record.last_active_at,
archived_at: record.archived_at,
}
}
// 提取 content 和 chat_idCLI 使用 session id 作为 chat_id async fn handle_inbound(
let (content, chat_id) = match inbound_clone { state: &Arc<GatewayState>,
WsInbound::UserInput { session: &Arc<Mutex<Session>>,
content, current_session_id: &mut String,
channel: _, inbound: WsInbound,
chat_id, ) -> Result<(), crate::agent::AgentError> {
sender_id: _, match inbound {
} => { WsInbound::UserInput { content, chat_id, .. } => {
// CLI 使用 session 中的 channel_name 作为标识 let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
// chat_id 使用传入的或使用默认 let mut session_guard = session.lock().await;
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
(content, chat_id) session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
session_guard.append_persisted_message(&chat_id, user_message)?;
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
let history = match session_guard
.compressor()
.compress_if_needed(raw_history, session_guard.provider_config())
.await
{
Ok(history) => history,
Err(error) => {
tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history");
session_guard.get_or_create_history(&chat_id).clone()
}
};
let agent = session_guard.create_agent()?;
match agent.process(history).await {
Ok(response) => {
session_guard.append_persisted_message(&chat_id, response.clone())?;
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
})
.await;
}
Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
let _ = session_guard
.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: error.to_string(),
})
.await;
}
}
Ok(())
} }
_ => return, WsInbound::ClearHistory { session_id, chat_id } => {
}; let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?;
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
// 添加用户消息到历史
session_guard.add_user_message(&chat_id, &content);
// 获取完整历史
let history = session_guard.get_or_create_history(&chat_id).clone();
// 压缩历史(如果需要)
let history = match session_guard.compressor()
.compress_if_needed(history, session_guard.provider_config())
.await
{
Ok(h) => h,
Err(e) => {
tracing::warn!(chat_id = %chat_id, error = %e, "Compression failed, using original history");
session_guard.get_or_create_history(&chat_id).clone()
}
};
// 创建 agent 并处理
let agent = match session_guard.create_agent() {
Ok(a) => a,
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Failed to create agent");
let _ = session_guard let _ = session_guard
.send(WsOutbound::Error { .send(WsOutbound::HistoryCleared {
code: "AGENT_ERROR".to_string(), session_id: target,
message: e.to_string(),
}) })
.await; .await;
return; Ok(())
} }
}; WsInbound::CreateSession { title } => {
let record = state.session_manager.create_cli_session(title.as_deref())?;
*current_session_id = record.id.clone();
match agent.process(history).await { let mut session_guard = session.lock().await;
Ok(response) => { session_guard.ensure_chat_loaded(&record.id)?;
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, "Agent response sent");
// 添加助手响应到历史
session_guard.add_assistant_message(&chat_id, response.clone());
let _ = session_guard let _ = session_guard
.send(WsOutbound::AssistantResponse { .send(WsOutbound::SessionCreated {
id: response.id, session_id: record.id,
content: response.content, title: record.title,
role: response.role,
}) })
.await; .await;
Ok(())
} }
Err(e) => { WsInbound::ListSessions { include_archived } => {
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error"); let records = state.session_manager.list_cli_sessions(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect();
let session_guard = session.lock().await;
let _ = session_guard let _ = session_guard
.send(WsOutbound::Error { .send(WsOutbound::SessionList {
code: "LLM_ERROR".to_string(), sessions: summaries,
message: e.to_string(), current_session_id: Some(current_session_id.clone()),
}) })
.await; .await;
Ok(())
}
WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::Error {
code: "SESSION_NOT_FOUND".to_string(),
message: format!("Session not found: {}", session_id),
})
.await;
return Ok(());
};
*current_session_id = record.id.clone();
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionLoaded {
session_id: record.id,
title: record.title,
message_count: record.message_count,
})
.await;
Ok(())
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.rename_session(&target, &title)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionRenamed {
session_id: target,
title,
})
.await;
Ok(())
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.archive_session(&target)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target })
.await;
Ok(())
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.delete_session(&target)?;
let replacement = if target == *current_session_id {
Some(state.session_manager.create_cli_session(None)?)
} else {
None
};
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::SessionDeleted {
session_id: target.clone(),
})
.await;
if let Some(record) = replacement {
*current_session_id = record.id.clone();
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
})
.await;
}
Ok(())
}
WsInbound::Ping => {
let session_guard = session.lock().await;
let _ = session_guard.send(WsOutbound::Pong).await;
Ok(())
} }
} }
} }

View File

@ -9,4 +9,5 @@ pub mod protocol;
pub mod channels; pub mod channels;
pub mod logging; pub mod logging;
pub mod observability; pub mod observability;
pub mod storage;
pub mod tools; pub mod tools;

View File

@ -1,5 +1,17 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSummary {
pub session_id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub message_count: i64,
pub last_active_at: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub archived_at: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum WsInbound { pub enum WsInbound {
@ -17,6 +29,38 @@ pub enum WsInbound {
ClearHistory { ClearHistory {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
chat_id: Option<String>, chat_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
},
#[serde(rename = "create_session")]
CreateSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
title: Option<String>,
},
#[serde(rename = "list_sessions")]
ListSessions {
#[serde(default)]
include_archived: bool,
},
#[serde(rename = "load_session")]
LoadSession {
session_id: String,
},
#[serde(rename = "rename_session")]
RenameSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
title: String,
},
#[serde(rename = "archive_session")]
ArchiveSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
},
#[serde(rename = "delete_session")]
DeleteSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
}, },
#[serde(rename = "ping")] #[serde(rename = "ping")]
Ping, Ping,
@ -31,6 +75,28 @@ pub enum WsOutbound {
Error { code: String, message: String }, Error { code: String, message: String },
#[serde(rename = "session_established")] #[serde(rename = "session_established")]
SessionEstablished { session_id: String }, SessionEstablished { session_id: String },
#[serde(rename = "session_created")]
SessionCreated { session_id: String, title: String },
#[serde(rename = "session_list")]
SessionList {
sessions: Vec<SessionSummary>,
#[serde(default, skip_serializing_if = "Option::is_none")]
current_session_id: Option<String>,
},
#[serde(rename = "session_loaded")]
SessionLoaded {
session_id: String,
title: String,
message_count: i64,
},
#[serde(rename = "session_renamed")]
SessionRenamed { session_id: String, title: String },
#[serde(rename = "session_archived")]
SessionArchived { session_id: String },
#[serde(rename = "session_deleted")]
SessionDeleted { session_id: String },
#[serde(rename = "history_cleared")]
HistoryCleared { session_id: String },
#[serde(rename = "pong")] #[serde(rename = "pong")]
Pong, Pong,
} }

447
src/storage/mod.rs Normal file
View File

@ -0,0 +1,447 @@
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use rusqlite::{Connection, OptionalExtension, params};
use serde::{Deserialize, Serialize};
use crate::bus::ChatMessage;
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRecord {
pub id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub summary: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub last_active_at: i64,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub message_count: i64,
}
#[derive(Clone)]
pub struct SessionStore {
conn: Arc<Mutex<Connection>>,
}
impl SessionStore {
pub fn new() -> Result<Self, StorageError> {
let db_path = default_session_db_path()?;
Self::open_at_path(&db_path)
}
fn open_at_path(path: &Path) -> Result<Self, StorageError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(path)?;
Self::from_connection(conn)
}
fn from_connection(conn: Connection) -> Result<Self, StorageError> {
conn.execute_batch(
"
PRAGMA journal_mode = WAL;
PRAGMA foreign_keys = ON;
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
channel_name TEXT NOT NULL,
chat_id TEXT NOT NULL,
summary TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
archived_at INTEGER,
deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
ON sessions(channel_name, archived_at, last_active_at DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at
ON sessions(updated_at DESC);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs_json TEXT NOT NULL,
tool_call_id TEXT,
tool_name TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
UNIQUE(session_id, seq)
);
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
ON messages(session_id, seq);
CREATE INDEX IF NOT EXISTS idx_messages_session_created
ON messages(session_id, created_at);
",
)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
#[cfg(test)]
fn in_memory() -> Result<Self, StorageError> {
Self::from_connection(Connection::open_in_memory()?)
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
let now = current_timestamp();
let id = uuid::Uuid::new_v4().to_string();
let title = title
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| format!("CLI Session {}", &id[..8]));
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0)
",
params![id, title, id, now],
)?;
drop(conn);
self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError> {
let session_id = persistent_session_id(channel_name, chat_id);
if let Some(record) = self.get_session(&session_id)? {
return Ok(record);
}
let now = current_timestamp();
let title = format!("{}:{}", channel_name, chat_id);
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0)
",
params![session_id, title, channel_name, chat_id, now],
)?;
drop(conn);
self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut stmt = conn.prepare(
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count
FROM sessions
WHERE id = ?1 AND deleted_at IS NULL
",
)?;
stmt.query_row(params![session_id], map_session_record)
.optional()
.map_err(StorageError::from)
}
pub fn list_sessions(
&self,
channel_name: &str,
include_archived: bool,
) -> Result<Vec<SessionRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut sql = String::from(
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count
FROM sessions
WHERE channel_name = ?1
AND deleted_at IS NULL
",
);
if !include_archived {
sql.push_str(" AND archived_at IS NULL");
}
sql.push_str(" ORDER BY last_active_at DESC, created_at DESC");
let mut stmt = conn.prepare(&sql)?;
let rows = stmt.query_map(params![channel_name], map_session_record)?;
let mut sessions = Vec::new();
for row in rows {
sessions.push(row?);
}
Ok(sessions)
}
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1 AND deleted_at IS NULL",
params![session_id, title.trim(), now],
)?;
Ok(())
}
pub fn archive_session(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"UPDATE sessions SET archived_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL",
params![session_id, now],
)?;
Ok(())
}
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
Ok(())
}
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
conn.execute(
"
UPDATE sessions
SET message_count = 0, updated_at = ?2, last_active_at = ?2
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now],
)?;
Ok(())
}
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let tx = conn.unchecked_transaction()?;
let seq: i64 = tx.query_row(
"SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1",
params![session_id],
|row| row.get(0),
)?;
let media_refs_json = serde_json::to_string(&message.media_refs)?;
tx.execute(
"
INSERT INTO messages (
id, session_id, seq, role, content,
media_refs_json, tool_call_id, tool_name, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
",
params![
message.id,
session_id,
seq,
message.role,
message.content,
media_refs_json,
message.tool_call_id,
message.tool_name,
message.timestamp,
],
)?;
let now = current_timestamp();
tx.execute(
"
UPDATE sessions
SET message_count = message_count + 1,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now],
)?;
tx.commit()?;
Ok(())
}
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut stmt = conn.prepare(
"
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name
FROM messages
WHERE session_id = ?1
ORDER BY seq ASC
",
)?;
let rows = stmt.query_map(params![session_id], |row| {
let media_refs_json: String = row.get(3)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
media_refs,
timestamp: row.get(4)?,
tool_call_id: row.get(5)?,
tool_name: row.get(6)?,
})
})?;
let mut messages = Vec::new();
for row in rows {
messages.push(row?);
}
Ok(messages)
}
}
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
if channel_name == "cli" {
chat_id.to_string()
} else {
format!("{}:{}", channel_name, chat_id)
}
}
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
Ok(home.join(".picobot").join("storage").join("sessions.db"))
}
fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
Ok(SessionRecord {
id: row.get(0)?,
title: row.get(1)?,
channel_name: row.get(2)?,
chat_id: row.get(3)?,
summary: row.get(4)?,
created_at: row.get(5)?,
updated_at: row.get(6)?,
last_active_at: row.get(7)?,
archived_at: row.get(8)?,
deleted_at: row.get(9)?,
message_count: row.get(10)?,
})
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_millis() as i64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_persistent_session_id_for_cli_and_channel() {
assert_eq!(persistent_session_id("cli", "abc"), "abc");
assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc");
}
#[test]
fn test_session_store_roundtrip_and_lifecycle() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("demo")).unwrap();
assert_eq!(session.title, "demo");
assert_eq!(session.channel_name, "cli");
assert_eq!(session.chat_id, session.id);
assert_eq!(session.message_count, 0);
let first = ChatMessage::user("hello");
let second = ChatMessage::assistant("world");
store.append_message(&session.id, &first).unwrap();
store.append_message(&session.id, &second).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.message_count, 2);
assert!(stored.archived_at.is_none());
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "hello");
assert_eq!(messages[1].role, "assistant");
assert_eq!(messages[1].content, "world");
store.rename_session(&session.id, "renamed").unwrap();
let renamed = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(renamed.title, "renamed");
store.archive_session(&session.id).unwrap();
let archived = store.get_session(&session.id).unwrap().unwrap();
assert!(archived.archived_at.is_some());
let active_only = store.list_sessions("cli", false).unwrap();
assert!(active_only.is_empty());
let including_archived = store.list_sessions("cli", true).unwrap();
assert_eq!(including_archived.len(), 1);
store.clear_messages(&session.id).unwrap();
let cleared = store.load_messages(&session.id).unwrap();
assert!(cleared.is_empty());
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(cleared_session.message_count, 0);
store.delete_session(&session.id).unwrap();
assert!(store.get_session(&session.id).unwrap().is_none());
}
#[test]
fn test_ensure_channel_session_is_stable() {
let store = SessionStore::in_memory().unwrap();
let first = store.ensure_channel_session("feishu", "chat-1").unwrap();
let second = store.ensure_channel_session("feishu", "chat-1").unwrap();
assert_eq!(first.id, second.id);
assert_eq!(first.chat_id, "chat-1");
assert_eq!(second.channel_name, "feishu");
}
}

View File

@ -1,6 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message}; use picobot::providers::{create_provider, ChatCompletionRequest, Message};
use PicoBot::config::{Config, LLMProviderConfig}; use picobot::config::{Config, LLMProviderConfig};
fn load_config() -> Option<LLMProviderConfig> { fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?; dotenv::from_filename("tests/test.env").ok()?;
@ -24,15 +24,13 @@ fn load_config() -> Option<LLMProviderConfig> {
max_tokens: Some(100), max_tokens: Some(100),
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 20, max_tool_iterations: 20,
token_limit: 128_000,
}) })
} }
fn create_request(content: &str) -> ChatCompletionRequest { fn create_request(content: &str) -> ChatCompletionRequest {
ChatCompletionRequest { ChatCompletionRequest {
messages: vec![Message { messages: vec![Message::user(content)],
role: "user".to_string(),
content: content.to_string(),
}],
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(100), max_tokens: Some(100),
tools: None, tools: None,
@ -64,9 +62,9 @@ async fn test_openai_conversation() {
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![ messages: vec![
Message { role: "user".to_string(), content: "My name is Alice".to_string() }, Message::user("My name is Alice"),
Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() }, Message::assistant("Hello Alice!"),
Message { role: "user".to_string(), content: "What is my name?".to_string() }, Message::user("What is my name?"),
], ],
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(50), max_tokens: Some(50),

View File

@ -1,31 +1,26 @@
use PicoBot::providers::{ChatCompletionRequest, Message}; use picobot::providers::{ChatCompletionRequest, Message};
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
/// Test that message with special characters is properly escaped /// Test that message with special characters is properly escaped
#[test] #[test]
fn test_message_special_characters() { fn test_message_special_characters() {
let msg = Message { let msg = Message::user("Hello \"world\"\nNew line\tTab");
role: "user".to_string(),
content: "Hello \"world\"\nNew line\tTab".to_string(),
};
let json = serde_json::to_string(&msg).unwrap(); let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap(); let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.content, "Hello \"world\"\nNew line\tTab"); assert_eq!(deserialized.role, "user");
assert_eq!(deserialized.content.len(), 1);
let encoded = serde_json::to_string(&deserialized.content).unwrap();
assert!(encoded.contains("Hello \\\"world\\\"\\nNew line\\tTab"));
} }
/// Test that multi-line system prompt is preserved /// Test that multi-line system prompt is preserved
#[test] #[test]
fn test_multiline_system_prompt() { fn test_multiline_system_prompt() {
let messages = vec![ let messages = vec![
Message { Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
role: "system".to_string(), Message::user("Hi"),
content: "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate".to_string(),
},
Message {
role: "user".to_string(),
content: "Hi".to_string(),
},
]; ];
let json = serde_json::to_string(&messages[0]).unwrap(); let json = serde_json::to_string(&messages[0]).unwrap();
@ -39,14 +34,8 @@ fn test_multiline_system_prompt() {
fn test_chat_request_serialization() { fn test_chat_request_serialization() {
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![ messages: vec![
Message { Message::system("You are helpful"),
role: "system".to_string(), Message::user("Hello"),
content: "You are helpful".to_string(),
},
Message {
role: "user".to_string(),
content: "Hello".to_string(),
},
], ],
temperature: Some(0.7), temperature: Some(0.7),
max_tokens: Some(100), max_tokens: Some(100),
@ -58,8 +47,73 @@ fn test_chat_request_serialization() {
// Verify structure // Verify structure
assert!(json.contains(r#""role":"system""#)); assert!(json.contains(r#""role":"system""#));
assert!(json.contains(r#""role":"user""#)); assert!(json.contains(r#""role":"user""#));
assert!(json.contains(r#""content":"You are helpful""#)); assert!(json.contains("You are helpful"));
assert!(json.contains(r#""content":"Hello""#)); assert!(json.contains("Hello"));
assert!(json.contains(r#""temperature":0.7"#)); assert!(json.contains(r#""temperature":0.7"#));
assert!(json.contains(r#""max_tokens":100"#)); assert!(json.contains(r#""max_tokens":100"#));
} }
#[test]
fn test_session_inbound_serialization() {
let msg = WsInbound::CreateSession {
title: Some("demo".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"create_session""#));
assert!(json.contains(r#""title":"demo""#));
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
match decoded {
WsInbound::CreateSession { title } => {
assert_eq!(title.as_deref(), Some("demo"));
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_session_list_outbound_serialization() {
let msg = WsOutbound::SessionList {
sessions: vec![SessionSummary {
session_id: "session-1".to_string(),
title: "demo".to_string(),
channel_name: "cli".to_string(),
chat_id: "session-1".to_string(),
message_count: 2,
last_active_at: 123,
archived_at: None,
}],
current_session_id: Some("session-1".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"session_list""#));
assert!(json.contains(r#""session_id":"session-1""#));
assert!(json.contains(r#""message_count":2"#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::SessionList {
sessions,
current_session_id,
} => {
assert_eq!(sessions.len(), 1);
assert_eq!(sessions[0].title, "demo");
assert_eq!(current_session_id.as_deref(), Some("session-1"));
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_clear_history_with_session_id_serialization() {
let msg = WsInbound::ClearHistory {
chat_id: None,
session_id: Some("session-1".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"clear_history""#));
assert!(json.contains(r#""session_id":"session-1""#));
}

View File

@ -1,6 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction}; use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use PicoBot::config::LLMProviderConfig; use picobot::config::LLMProviderConfig;
fn load_openai_config() -> Option<LLMProviderConfig> { fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?; dotenv::from_filename("tests/test.env").ok()?;
@ -24,6 +24,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
max_tokens: Some(100), max_tokens: Some(100),
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 20, max_tool_iterations: 20,
token_limit: 128_000,
}) })
} }
@ -56,10 +57,7 @@ async fn test_openai_tool_call() {
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![Message { messages: vec![Message::user("What is the weather in Tokyo?")],
role: "user".to_string(),
content: "What is the weather in Tokyo?".to_string(),
}],
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(200), max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]), tools: Some(vec![make_weather_tool()]),
@ -85,10 +83,7 @@ async fn test_openai_tool_call_with_manual_execution() {
// First request with tool // First request with tool
let request1 = ChatCompletionRequest { let request1 = ChatCompletionRequest {
messages: vec![Message { messages: vec![Message::user("What is the weather in Tokyo?")],
role: "user".to_string(),
content: "What is the weather in Tokyo?".to_string(),
}],
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(200), max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]), tools: Some(vec![make_weather_tool()]),
@ -102,14 +97,8 @@ async fn test_openai_tool_call_with_manual_execution() {
// Second request with tool result // Second request with tool result
let request2 = ChatCompletionRequest { let request2 = ChatCompletionRequest {
messages: vec![ messages: vec![
Message { Message::user("What is the weather in Tokyo?"),
role: "user".to_string(), Message::assistant(r#"I'll check the weather for you using the get_weather tool."#),
content: "What is the weather in Tokyo?".to_string(),
},
Message {
role: "assistant".to_string(),
content: r#"I'll check the weather for you using the get_weather tool."#.to_string(),
},
], ],
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(200), max_tokens: Some(200),
@ -131,10 +120,7 @@ async fn test_openai_no_tool_when_not_provided() {
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![Message { messages: vec![Message::user("Say hello in one word.")],
role: "user".to_string(),
content: "Say hello in one word.".to_string(),
}],
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(10), max_tokens: Some(10),
tools: None, tools: None,