use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use rusqlite::{Connection, OptionalExtension, params}; use serde::{Deserialize, Serialize}; use crate::bus::ChatMessage; #[derive(Debug, thiserror::Error)] pub enum StorageError { #[error("database error: {0}")] Database(#[from] rusqlite::Error), #[error("io error: {0}")] Io(#[from] std::io::Error), #[error("serialization error: {0}")] Serialization(#[from] serde_json::Error), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SessionRecord { pub id: String, pub title: String, pub channel_name: String, pub chat_id: String, pub summary: Option, pub created_at: i64, pub updated_at: i64, pub last_active_at: i64, pub archived_at: Option, pub deleted_at: Option, pub message_count: i64, pub reset_cutoff_seq: i64, } #[derive(Clone)] pub struct SessionStore { conn: Arc>, } impl SessionStore { pub fn new() -> Result { let db_path = default_session_db_path()?; Self::open_at_path(&db_path) } fn open_at_path(path: &Path) -> Result { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent)?; } let conn = Connection::open(path)?; Self::from_connection(conn) } fn from_connection(conn: Connection) -> Result { conn.execute_batch( " PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON; CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, title TEXT NOT NULL, channel_name TEXT NOT NULL, chat_id TEXT NOT NULL, summary TEXT, created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL, last_active_at INTEGER NOT NULL, archived_at INTEGER, deleted_at INTEGER, message_count INTEGER NOT NULL DEFAULT 0, reset_cutoff_seq INTEGER NOT NULL DEFAULT 0 ); CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived ON sessions(channel_name, archived_at, last_active_at DESC); CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC); CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, seq INTEGER NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, media_refs_json TEXT NOT NULL, tool_call_id TEXT, tool_name TEXT, tool_calls_json TEXT, created_at INTEGER NOT NULL, FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, UNIQUE(session_id, seq) ); CREATE INDEX IF NOT EXISTS idx_messages_session_seq ON messages(session_id, seq); CREATE INDEX IF NOT EXISTS idx_messages_session_created ON messages(session_id, created_at); ", )?; ensure_sessions_schema(&conn)?; Ok(Self { conn: Arc::new(Mutex::new(conn)), }) } #[cfg(test)] pub(crate) fn in_memory() -> Result { Self::from_connection(Connection::open_in_memory()?) } pub fn create_cli_session(&self, title: Option<&str>) -> Result { let now = crate::bus::message::current_timestamp(); let id = uuid::Uuid::new_v4().to_string(); let title = title .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) .unwrap_or_else(|| format!("CLI Session {}", &id[..8])); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute( " INSERT INTO sessions ( id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, archived_at, deleted_at, message_count ) VALUES (?1, ?2, 'cli_chat', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0) ", params![id, title, id, now], )?; drop(conn); self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } pub fn ensure_channel_session( &self, channel_name: &str, chat_id: &str, dialog_id: &str, ) -> Result { let session_id = persistent_session_id(channel_name, chat_id, dialog_id); if let Some(record) = self.get_session(&session_id)? { return Ok(record); } let now = crate::bus::message::current_timestamp(); let title = format!("{}:{}", channel_name, chat_id); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute( " INSERT INTO sessions ( id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, archived_at, deleted_at, message_count ) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0) ", params![session_id, title, channel_name, chat_id, now], )?; drop(conn); self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } pub fn get_session(&self, session_id: &str) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let mut stmt = conn.prepare( " SELECT id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, archived_at, deleted_at, message_count, reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL ", )?; stmt.query_row(params![session_id], map_session_record) .optional() .map_err(StorageError::from) } pub fn list_sessions( &self, channel_name: &str, include_archived: bool, ) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let mut sql = String::from( " SELECT id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, archived_at, deleted_at, message_count, reset_cutoff_seq FROM sessions WHERE channel_name = ?1 AND deleted_at IS NULL ", ); if !include_archived { sql.push_str(" AND archived_at IS NULL"); } sql.push_str(" ORDER BY last_active_at DESC, created_at DESC"); let mut stmt = conn.prepare(&sql)?; let rows = stmt.query_map(params![channel_name], map_session_record)?; let mut sessions = Vec::new(); for row in rows { sessions.push(row?); } Ok(sessions) } pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), StorageError> { let now = crate::bus::message::current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute( "UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1 AND deleted_at IS NULL", params![session_id, title.trim(), now], )?; Ok(()) } pub fn archive_session(&self, session_id: &str) -> Result<(), StorageError> { let now = crate::bus::message::current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute( "UPDATE sessions SET archived_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL", params![session_id, now], )?; Ok(()) } pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?; conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?; Ok(()) } pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { let now = crate::bus::message::current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?; conn.execute( " UPDATE sessions SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0 WHERE id = ?1 AND deleted_at IS NULL ", params![session_id, now], )?; Ok(()) } pub fn reset_session(&self, session_id: &str) -> Result<(), StorageError> { let now = crate::bus::message::current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); let tx = conn.unchecked_transaction()?; let cutoff_seq: i64 = tx.query_row( "SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1", params![session_id], |row| row.get(0), )?; tx.execute( " UPDATE sessions SET reset_cutoff_seq = ?2, updated_at = ?3, last_active_at = ?3, archived_at = NULL WHERE id = ?1 AND deleted_at IS NULL ", params![session_id, cutoff_seq, now], )?; tx.commit()?; Ok(()) } pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let tx = conn.unchecked_transaction()?; let seq: i64 = tx.query_row( "SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1", params![session_id], |row| row.get(0), )?; let media_refs_json = serde_json::to_string(&message.media_refs)?; let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?; tx.execute( " INSERT INTO messages ( id, session_id, seq, role, content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) ", params![ message.id, session_id, seq, message.role, message.content, media_refs_json, message.tool_call_id, message.tool_name, tool_calls_json, message.timestamp, ], )?; let now = crate::bus::message::current_timestamp(); tx.execute( " UPDATE sessions SET message_count = message_count + 1, updated_at = ?2, last_active_at = ?2, archived_at = NULL WHERE id = ?1 AND deleted_at IS NULL ", params![session_id, now], )?; tx.commit()?; Ok(()) } pub fn load_messages(&self, session_id: &str) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let cutoff_seq = active_reset_cutoff(&conn, session_id)?; load_messages_after(&conn, session_id, cutoff_seq) } pub fn load_all_messages(&self, session_id: &str) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); load_messages_after(&conn, session_id, 0) } } pub fn persistent_session_id(channel_name: &str, chat_id: &str, dialog_id: &str) -> String { format!("{}:{}:{}", channel_name, chat_id, dialog_id) } fn default_session_db_path() -> Result { let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); Ok(home.join(".picobot").join("storage").join("sessions.db")) } fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result { Ok(SessionRecord { id: row.get(0)?, title: row.get(1)?, channel_name: row.get(2)?, chat_id: row.get(3)?, summary: row.get(4)?, created_at: row.get(5)?, updated_at: row.get(6)?, last_active_at: row.get(7)?, archived_at: row.get(8)?, deleted_at: row.get(9)?, message_count: row.get(10)?, reset_cutoff_seq: row.get(11)?, }) } fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> { if !has_column(conn, "sessions", "reset_cutoff_seq")? { conn.execute( "ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0", [], )?; } Ok(()) } fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result { let pragma = format!("PRAGMA table_info({})", table_name); let mut stmt = conn.prepare(&pragma)?; let mut rows = stmt.query([])?; while let Some(row) = rows.next()? { let existing_name: String = row.get(1)?; if existing_name == column_name { return Ok(true); } } Ok(false) } fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result { let cutoff = conn .query_row( "SELECT reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL", params![session_id], |row| row.get(0), ) .optional()?; Ok(cutoff.unwrap_or(0)) } fn load_messages_after( conn: &Connection, session_id: &str, cutoff_seq: i64, ) -> Result, StorageError> { let mut stmt = conn.prepare( " SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json FROM messages WHERE session_id = ?1 AND seq > ?2 ORDER BY seq ASC ", )?; let rows = stmt.query_map(params![session_id, cutoff_seq], |row| { let media_refs_json: String = row.get(3)?; let media_refs: Vec = serde_json::from_str(&media_refs_json).map_err(|err| { rusqlite::Error::FromSqlConversionFailure( media_refs_json.len(), rusqlite::types::Type::Text, Box::new(err), ) })?; let tool_calls_json: Option = row.get(7)?; let tool_calls = tool_calls_json .as_deref() .map(serde_json::from_str) .transpose() .map_err(|err| { rusqlite::Error::FromSqlConversionFailure( 7, rusqlite::types::Type::Text, Box::new(err), ) })?; Ok(ChatMessage { id: row.get(0)?, role: row.get(1)?, content: row.get(2)?, media_refs, timestamp: row.get(4)?, tool_call_id: row.get(5)?, tool_name: row.get(6)?, tool_calls, }) })?; let mut messages = Vec::new(); for row in rows { messages.push(row?); } Ok(messages) } #[cfg(test)] mod tests { use super::*; use crate::providers::ToolCall; #[test] fn test_persistent_session_id_for_cli_and_channel() { assert_eq!(persistent_session_id("cli", "abc", "default"), "cli:abc:default"); assert_eq!(persistent_session_id("cli_chat", "abc", "default"), "cli_chat:abc:default"); assert_eq!(persistent_session_id("feishu", "abc", "default"), "feishu:abc:default"); } #[test] fn test_session_store_roundtrip_and_lifecycle() { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("demo")).unwrap(); assert_eq!(session.title, "demo"); assert_eq!(session.channel_name, "cli_chat"); assert_eq!(session.chat_id, session.id); assert_eq!(session.message_count, 0); assert_eq!(session.reset_cutoff_seq, 0); let first = ChatMessage::user("hello"); let second = ChatMessage::assistant("world"); store.append_message(&session.id, &first).unwrap(); store.append_message(&session.id, &second).unwrap(); let stored = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(stored.message_count, 2); assert!(stored.archived_at.is_none()); assert_eq!(stored.reset_cutoff_seq, 0); let messages = store.load_messages(&session.id).unwrap(); assert_eq!(messages.len(), 2); assert_eq!(messages[0].role, "user"); assert_eq!(messages[0].content, "hello"); assert_eq!(messages[1].role, "assistant"); assert_eq!(messages[1].content, "world"); store.rename_session(&session.id, "renamed").unwrap(); let renamed = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(renamed.title, "renamed"); store.archive_session(&session.id).unwrap(); let archived = store.get_session(&session.id).unwrap().unwrap(); assert!(archived.archived_at.is_some()); let active_only = store.list_sessions("cli_chat", false).unwrap(); assert!(active_only.is_empty()); let including_archived = store.list_sessions("cli_chat", true).unwrap(); assert_eq!(including_archived.len(), 1); store.clear_messages(&session.id).unwrap(); let cleared = store.load_messages(&session.id).unwrap(); assert!(cleared.is_empty()); let cleared_session = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(cleared_session.message_count, 0); store.delete_session(&session.id).unwrap(); assert!(store.get_session(&session.id).unwrap().is_none()); } #[test] fn test_ensure_channel_session_is_stable() { let store = SessionStore::in_memory().unwrap(); let first = store.ensure_channel_session("feishu", "chat-1", "default").unwrap(); let second = store.ensure_channel_session("feishu", "chat-1", "default").unwrap(); assert_eq!(first.id, second.id); assert_eq!(first.chat_id, "chat-1"); assert_eq!(second.channel_name, "feishu"); } #[test] fn test_assistant_tool_calls_roundtrip() { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("tools")).unwrap(); let assistant = ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({ "expression": "3*7" }), }], ); store.append_message(&session.id, &assistant).unwrap(); let messages = store.load_messages(&session.id).unwrap(); assert_eq!(messages.len(), 1); assert_eq!(messages[0].role, "assistant"); assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1); assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1"); assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator"); } #[test] fn test_reset_session_preserves_full_history_and_hides_active_history() { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("reset")).unwrap(); store.append_message(&session.id, &ChatMessage::user("before")).unwrap(); store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap(); store.reset_session(&session.id).unwrap(); let stored = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(stored.reset_cutoff_seq, 2); let active_messages = store.load_messages(&session.id).unwrap(); assert!(active_messages.is_empty()); let all_messages = store.load_all_messages(&session.id).unwrap(); assert_eq!(all_messages.len(), 2); assert_eq!(all_messages[0].content, "before"); assert_eq!(all_messages[1].content, "context"); store.append_message(&session.id, &ChatMessage::user("after")).unwrap(); let active_messages = store.load_messages(&session.id).unwrap(); assert_eq!(active_messages.len(), 1); assert_eq!(active_messages[0].content, "after"); } #[test] fn test_schema_migration_adds_reset_cutoff_column() { let conn = Connection::open_in_memory().unwrap(); conn.execute_batch( " CREATE TABLE sessions ( id TEXT PRIMARY KEY, title TEXT NOT NULL, channel_name TEXT NOT NULL, chat_id TEXT NOT NULL, summary TEXT, created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL, last_active_at INTEGER NOT NULL, archived_at INTEGER, deleted_at INTEGER, message_count INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE messages ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, seq INTEGER NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, media_refs_json TEXT NOT NULL, tool_call_id TEXT, tool_name TEXT, tool_calls_json TEXT, created_at INTEGER NOT NULL, FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, UNIQUE(session_id, seq) ); ", ) .unwrap(); let store = SessionStore::from_connection(conn).unwrap(); let session = store.create_cli_session(Some("migrated")).unwrap(); assert_eq!(session.reset_cutoff_seq, 0); } #[test] fn test_tool_result_roundtrip() { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("tool-result")).unwrap(); let tool_message = ChatMessage::tool("call_9", "file_write", "saved to /tmp/output.txt"); store.append_message(&session.id, &tool_message).unwrap(); let messages = store.load_messages(&session.id).unwrap(); assert_eq!(messages.len(), 1); assert_eq!(messages[0].role, "tool"); assert_eq!(messages[0].content, "saved to /tmp/output.txt"); assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9")); assert_eq!(messages[0].tool_name.as_deref(), Some("file_write")); assert!(messages[0].tool_calls.is_none()); } }