From 393d98074249b9f42795c318012522a339d5d955 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Mon, 20 Apr 2026 22:35:13 +0800 Subject: [PATCH] =?UTF-8?q?feat(session):=20=E6=B7=BB=E5=8A=A0=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E9=87=8D=E7=BD=AE=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=BC=9A=E8=AF=9D=E5=8E=86=E5=8F=B2=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PERSISTENCE.md | 28 ++++- src/client/mod.rs | 2 +- src/gateway/session.rs | 114 ++++++++++++++++++ src/gateway/ws.rs | 13 ++- src/storage/mod.rs | 260 ++++++++++++++++++++++++++++++++--------- 5 files changed, 359 insertions(+), 58 deletions(-) diff --git a/PERSISTENCE.md b/PERSISTENCE.md index 70e5007..5a15a8b 100644 --- a/PERSISTENCE.md +++ b/PERSISTENCE.md @@ -85,6 +85,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 | `archived_at` | `INTEGER` | 归档时间 | 非空表示会话已归档 | | `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 | | `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 | +| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 | 索引: @@ -172,9 +173,14 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 ### 6.3 读取历史 -`load_messages(session_id)` 会按 `seq ASC` 读取整个消息历史,并把 JSON 字段反序列化回 `ChatMessage`。 +`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是: -因此它恢复的是“逻辑顺序”,而不是简单按创建时间排序。只要 `seq` 连续,重放顺序就稳定。 +- 只返回 `seq > sessions.reset_cutoff_seq` 的消息 +- 因此 `/reset` 之后,旧消息仍然保留在数据库中,但不会默认回灌到运行时上下文 + +如果需要审计、导出或查看完整历史,应使用全量读取接口 `load_all_messages(session_id)`。 + +因此运行态恢复的是“当前活动段的逻辑顺序”,而不是简单按创建时间排序。只要 `seq` 连续,重放顺序就稳定。 ## 7. 典型时序 @@ -229,12 +235,24 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 - 删除该会话在 `messages` 中的所有记录 - 将 `sessions.message_count` 重置为 0 +- 将 `sessions.reset_cutoff_seq` 重置为 0 - 更新 `updated_at` 和 `last_active_at` - 保留会话本身 这适合“保留会话入口,但丢弃聊天内容”的场景。 -### 8.4 删除会话 +### 8.4 逻辑重置 + +`reset_session(session_id)`: + +- 不删除 `messages` 中的任何记录 +- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq` +- 更新 `updated_at` 和 `last_active_at` +- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息 + +这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。 + +### 8.5 删除会话 `delete_session(session_id)`: @@ -276,6 +294,9 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 - `sessions.deleted_at` - 当前查询逻辑兼容软删除 - 当前删除实现仍然是物理删除 +- `sessions.reset_cutoff_seq` + - 当前已用于实现 `/reset` 的非破坏性逻辑重置 + - 只影响默认恢复的活动段,不影响数据库中的全量历史 这说明当前 schema 已经为“会话摘要”和“软删除”预留了演进空间,但并未完全落地。 @@ -285,6 +306,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 - 会话查不到:先看 `persistent_session_id` 是否和实际 `channel_name/chat_id` 一致 - 重启后没历史:检查 `ensure_chat_loaded()` 调用链,以及数据库文件路径是否正确 +- `/reset` 后重启又带回旧上下文:检查 `sessions.reset_cutoff_seq` 是否已写入,以及恢复路径是否走了活动段读取而不是全量读取 - 消息顺序不对:检查 `messages.seq` - 工具调用上下文异常:同时检查 `tool_calls_json` 和 `tool_call_id` - 会话列表里看不到记录:检查 `archived_at` 和 `include_archived` 参数 diff --git a/src/client/mod.rs b/src/client/mod.rs index c4e7730..a3178f9 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -48,7 +48,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { let mut input = InputHandler::new(); let mut current_session_id: Option = None; - input.write_output("picobot CLI - Commands: /new [title], /sessions, /use , /rename , /archive, /delete, /clear, /quit\n").await?; + input.write_output("picobot CLI - Commands: /new [title], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?; // Main loop: poll both stdin and WebSocket loop { diff --git a/src/gateway/session.rs b/src/gateway/session.rs index c1b2a26..b5bf8f3 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -105,6 +105,19 @@ impl Session { .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) } + pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> { + if let Some(history) = self.chat_histories.get_mut(chat_id) { + let len = history.len(); + history.clear(); + #[cfg(debug_assertions)] + tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory"); + } + + self.store + .reset_session(&self.persistent_session_id(chat_id)) + .map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err))) + } + /// 将消息写入内存与持久化层 pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> { let session_id = self.persistent_session_id(chat_id); @@ -202,6 +215,32 @@ fn default_tools() -> ToolRegistry { registry } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InChatCommand { + FreshConversation, +} + +fn parse_in_chat_command(content: &str) -> Option<InChatCommand> { + match content.trim() { + "/new" | "/reset" => Some(InChatCommand::FreshConversation), + _ => None, + } +} + +pub(crate) fn handle_in_chat_command( + session: &mut Session, + chat_id: &str, + content: &str, +) -> Result<Option<String>, AgentError> { + match parse_in_chat_command(content) { + Some(InChatCommand::FreshConversation) => { + session.reset_chat_context(chat_id)?; + Ok(Some("Started a fresh conversation.".to_string())) + } + None => Ok(None), + } +} + impl SessionManager { pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> { let store = Arc::new( @@ -372,6 +411,10 @@ impl SessionManager { session_guard.ensure_persistent_session(chat_id)?; session_guard.ensure_chat_loaded(chat_id)?; + if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? { + return Ok(command_response); + } + // 添加用户消息到历史 let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect(); #[cfg(debug_assertions)] @@ -419,3 +462,74 @@ impl SessionManager { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use tokio::sync::mpsc; + + fn test_provider_config() -> LLMProviderConfig { + LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test".to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + model_id: "test-model".to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + model_extra: HashMap::new(), + max_tool_iterations: 1, + token_limit: 4096, + } + } + + #[test] + fn test_parse_in_chat_command_aliases() { + assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation)); + assert_eq!(parse_in_chat_command(" /reset \n"), Some(InChatCommand::FreshConversation)); + assert_eq!(parse_in_chat_command("/new planning"), None); + assert_eq!(parse_in_chat_command("please /reset"), None); + } + + #[tokio::test] + async fn test_handle_in_chat_command_resets_active_history_only() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let tools = Arc::new(default_tools()); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + store.clone(), + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + session + .append_persisted_message("chat-1", ChatMessage::user("hello")) + .unwrap(); + + let response = handle_in_chat_command(&mut session, "chat-1", "/reset") + .unwrap() + .unwrap(); + + assert_eq!(response, "Started a fresh conversation."); + assert!(session.get_history("chat-1").unwrap().is_empty()); + assert!(store + .load_messages(&session.persistent_session_id("chat-1")) + .unwrap() + .is_empty()); + assert_eq!( + store + .load_all_messages(&session.persistent_session_id("chat-1")) + .unwrap() + .len(), + 1, + ); + } +} diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 18c979d..7c545ff 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -5,7 +5,7 @@ use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; -use super::{GatewayState, session::Session}; +use super::{GatewayState, session::{Session, handle_in_chat_command}}; pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { ws.on_upgrade(|socket| async { @@ -153,6 +153,17 @@ async fn handle_inbound( session_guard.ensure_persistent_session(&chat_id)?; session_guard.ensure_chat_loaded(&chat_id)?; + if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? { + let _ = session_guard + .send(WsOutbound::AssistantResponse { + id: uuid::Uuid::new_v4().to_string(), + content: command_response, + role: "assistant".to_string(), + }) + .await; + return Ok(()); + } + let user_message = session_guard.create_user_message(&content, Vec::new()); session_guard.append_persisted_message(&chat_id, user_message)?; diff --git a/src/storage/mod.rs b/src/storage/mod.rs index cf5c58e..b8a2bfa 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -29,6 +29,7 @@ pub struct SessionRecord { pub archived_at: Option<i64>, pub deleted_at: Option<i64>, pub message_count: i64, + pub reset_cutoff_seq: i64, } #[derive(Clone)] @@ -68,7 +69,8 @@ impl SessionStore { last_active_at INTEGER NOT NULL, archived_at INTEGER, deleted_at INTEGER, - message_count INTEGER NOT NULL DEFAULT 0 + message_count INTEGER NOT NULL DEFAULT 0, + reset_cutoff_seq INTEGER NOT NULL DEFAULT 0 ); CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived @@ -98,13 +100,15 @@ impl SessionStore { ", )?; + ensure_sessions_schema(&conn)?; + Ok(Self { conn: Arc::new(Mutex::new(conn)), }) } #[cfg(test)] - fn in_memory() -> Result<Self, StorageError> { + pub(crate) fn in_memory() -> Result<Self, StorageError> { Self::from_connection(Connection::open_in_memory()?) } @@ -165,7 +169,7 @@ impl SessionStore { " SELECT id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, - archived_at, deleted_at, message_count + archived_at, deleted_at, message_count, reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL ", @@ -186,7 +190,7 @@ impl SessionStore { " SELECT id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, - archived_at, deleted_at, message_count + archived_at, deleted_at, message_count, reset_cutoff_seq FROM sessions WHERE channel_name = ?1 AND deleted_at IS NULL @@ -242,7 +246,7 @@ impl SessionStore { conn.execute( " UPDATE sessions - SET message_count = 0, updated_at = ?2, last_active_at = ?2 + SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0 WHERE id = ?1 AND deleted_at IS NULL ", params![session_id, now], @@ -250,6 +254,33 @@ impl SessionStore { Ok(()) } + pub fn reset_session(&self, session_id: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + let tx = conn.unchecked_transaction()?; + + let cutoff_seq: i64 = tx.query_row( + "SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1", + params![session_id], + |row| row.get(0), + )?; + + tx.execute( + " + UPDATE sessions + SET reset_cutoff_seq = ?2, + updated_at = ?3, + last_active_at = ?3, + archived_at = NULL + WHERE id = ?1 AND deleted_at IS NULL + ", + params![session_id, cutoff_seq, now], + )?; + + tx.commit()?; + Ok(()) + } + pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let tx = conn.unchecked_transaction()?; @@ -302,55 +333,13 @@ impl SessionStore { pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); - let mut stmt = conn.prepare( - " - SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json - FROM messages - WHERE session_id = ?1 - ORDER BY seq ASC - ", - )?; + let cutoff_seq = active_reset_cutoff(&conn, session_id)?; + load_messages_after(&conn, session_id, cutoff_seq) + } - let rows = stmt.query_map(params![session_id], |row| { - let media_refs_json: String = row.get(3)?; - let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - media_refs_json.len(), - rusqlite::types::Type::Text, - Box::new(err), - ) - })?; - - let tool_calls_json: Option<String> = row.get(7)?; - let tool_calls = tool_calls_json - .as_deref() - .map(serde_json::from_str) - .transpose() - .map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - 7, - rusqlite::types::Type::Text, - Box::new(err), - ) - })?; - - Ok(ChatMessage { - id: row.get(0)?, - role: row.get(1)?, - content: row.get(2)?, - media_refs, - timestamp: row.get(4)?, - tool_call_id: row.get(5)?, - tool_name: row.get(6)?, - tool_calls, - }) - })?; - - let mut messages = Vec::new(); - for row in rows { - messages.push(row?); - } - Ok(messages) + pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + load_messages_after(&conn, session_id, 0) } } @@ -380,9 +369,104 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord archived_at: row.get(8)?, deleted_at: row.get(9)?, message_count: row.get(10)?, + reset_cutoff_seq: row.get(11)?, }) } +fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> { + if !has_column(conn, "sessions", "reset_cutoff_seq")? { + conn.execute( + "ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0", + [], + )?; + } + + Ok(()) +} + +fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<bool, StorageError> { + let pragma = format!("PRAGMA table_info({})", table_name); + let mut stmt = conn.prepare(&pragma)?; + let mut rows = stmt.query([])?; + + while let Some(row) = rows.next()? { + let existing_name: String = row.get(1)?; + if existing_name == column_name { + return Ok(true); + } + } + + Ok(false) +} + +fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> { + let cutoff = conn + .query_row( + "SELECT reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL", + params![session_id], + |row| row.get(0), + ) + .optional()?; + + Ok(cutoff.unwrap_or(0)) +} + +fn load_messages_after( + conn: &Connection, + session_id: &str, + cutoff_seq: i64, +) -> Result<Vec<ChatMessage>, StorageError> { + let mut stmt = conn.prepare( + " + SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json + FROM messages + WHERE session_id = ?1 AND seq > ?2 + ORDER BY seq ASC + ", + )?; + + let rows = stmt.query_map(params![session_id, cutoff_seq], |row| { + let media_refs_json: String = row.get(3)?; + let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + media_refs_json.len(), + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; + + let tool_calls_json: Option<String> = row.get(7)?; + let tool_calls = tool_calls_json + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + 7, + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; + + Ok(ChatMessage { + id: row.get(0)?, + role: row.get(1)?, + content: row.get(2)?, + media_refs, + timestamp: row.get(4)?, + tool_call_id: row.get(5)?, + tool_name: row.get(6)?, + tool_calls, + }) + })?; + + let mut messages = Vec::new(); + for row in rows { + messages.push(row?); + } + Ok(messages) +} + fn current_timestamp() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -410,6 +494,7 @@ mod tests { assert_eq!(session.channel_name, "cli"); assert_eq!(session.chat_id, session.id); assert_eq!(session.message_count, 0); + assert_eq!(session.reset_cutoff_seq, 0); let first = ChatMessage::user("hello"); let second = ChatMessage::assistant("world"); @@ -419,6 +504,7 @@ mod tests { let stored = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(stored.message_count, 2); assert!(stored.archived_at.is_none()); + assert_eq!(stored.reset_cutoff_seq, 0); let messages = store.load_messages(&session.id).unwrap(); assert_eq!(messages.len(), 2); @@ -487,6 +573,74 @@ mod tests { assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator"); } + #[test] + fn test_reset_session_preserves_full_history_and_hides_active_history() { + let store = SessionStore::in_memory().unwrap(); + let session = store.create_cli_session(Some("reset")).unwrap(); + + store.append_message(&session.id, &ChatMessage::user("before")).unwrap(); + store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap(); + store.reset_session(&session.id).unwrap(); + + let stored = store.get_session(&session.id).unwrap().unwrap(); + assert_eq!(stored.reset_cutoff_seq, 2); + + let active_messages = store.load_messages(&session.id).unwrap(); + assert!(active_messages.is_empty()); + + let all_messages = store.load_all_messages(&session.id).unwrap(); + assert_eq!(all_messages.len(), 2); + assert_eq!(all_messages[0].content, "before"); + assert_eq!(all_messages[1].content, "context"); + + store.append_message(&session.id, &ChatMessage::user("after")).unwrap(); + let active_messages = store.load_messages(&session.id).unwrap(); + assert_eq!(active_messages.len(), 1); + assert_eq!(active_messages[0].content, "after"); + } + + #[test] + fn test_schema_migration_adds_reset_cutoff_column() { + let conn = Connection::open_in_memory().unwrap(); + conn.execute_batch( + " + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + channel_name TEXT NOT NULL, + chat_id TEXT NOT NULL, + summary TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + last_active_at INTEGER NOT NULL, + archived_at INTEGER, + deleted_at INTEGER, + message_count INTEGER NOT NULL DEFAULT 0 + ); + + CREATE TABLE messages ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + media_refs_json TEXT NOT NULL, + tool_call_id TEXT, + tool_name TEXT, + tool_calls_json TEXT, + created_at INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, + UNIQUE(session_id, seq) + ); + ", + ) + .unwrap(); + + let store = SessionStore::from_connection(conn).unwrap(); + let session = store.create_cli_session(Some("migrated")).unwrap(); + assert_eq!(session.reset_cutoff_seq, 0); + } + #[test] fn test_tool_result_roundtrip() { let store = SessionStore::in_memory().unwrap();