From e7097734647bb5a5fa01627358b045293f9dfb67 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Fri, 15 May 2026 15:28:07 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=AF=9D=E9=A2=98?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=88=87=E6=8D=A2=E5=92=8C=E6=8C=81=E4=B9=85=E5=8C=96=E8=AF=9D?= =?UTF-8?q?=E9=A2=98=E5=8E=86=E5=8F=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/command/handlers/session_query.rs | 15 +++++- src/gateway/runtime.rs | 1 + src/gateway/session.rs | 70 +++++++++++++++++++++++++-- src/gateway/session_factory.rs | 6 ++- src/gateway/session_history.rs | 4 ++ src/gateway/ws.rs | 18 ++++++- 6 files changed, 106 insertions(+), 8 deletions(-) diff --git a/src/command/handlers/session_query.rs b/src/command/handlers/session_query.rs index 704c639..94ed6fe 100644 --- a/src/command/handlers/session_query.rs +++ b/src/command/handlers/session_query.rs @@ -3,7 +3,7 @@ use crate::command::handler::CommandHandler; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; use crate::gateway::session::SessionManager; -use crate::storage::{SessionStore, TopicRecord}; +use crate::storage::SessionStore; use async_trait::async_trait; use std::sync::Arc; @@ -219,9 +219,11 @@ async fn handle_switch_session( topic_id: String, ctx: CommandContext, ) -> Result { - // 获取当前 session_id + // 获取当前 session_id 和 chat_id let session_id = ctx.session_id.as_deref() .ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?; + let chat_id = ctx.chat_id.as_deref() + .ok_or_else(|| CommandError::new("NO_CHAT_ID", "No chat_id in context"))?; // 尝试解析为序号 let target_topic_id = if let Ok(index) = topic_id.parse::() { @@ -249,6 +251,15 @@ async fn handle_switch_session( .map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))? .ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?; + // 如果有 SessionManager,实际切换话题历史 + if let Some(ref session_manager) = handler.session_manager { + if let Some(session) = session_manager.get(&ctx.channel_name).await { + let mut session_guard = session.lock().await; + session_guard.switch_topic(chat_id, &target_topic_id) + .map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?; + } + } + // 返回切换成功响应 let message = format!( "✓ Switched to topic: {} ({} messages)", diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 5c0f81b..6664af1 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -102,6 +102,7 @@ pub(crate) fn build_session_manager_with_sender( agent_factory, conversations, skill_events, + store.clone(), chat_history_ttl_hours, ); let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours); diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 60426de..b4260d5 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -41,6 +41,8 @@ pub struct Session { agent_factory: AgentFactory, compressor: ContextCompressor, history: SessionHistory, + store: Arc, + current_topic_id: Option, } pub struct BusToolCallEmitter { @@ -120,6 +122,7 @@ impl Session { conversations, skill_events, chat_history_ttl_hours, + store, ) .await } @@ -133,6 +136,7 @@ impl Session { conversations: Arc, skill_events: Arc, chat_history_ttl_hours: Option, + store: Arc, ) -> Result { Ok(Self { id: Uuid::new_v4(), @@ -148,6 +152,8 @@ impl Session { skill_events, chat_history_ttl_hours, ), + store, + current_topic_id: None, }) } @@ -155,6 +161,38 @@ impl Session { self.history.persistent_session_id(chat_id) } + /// 设置当前话题 ID + pub fn set_current_topic(&mut self, topic_id: Option) { + self.current_topic_id = topic_id; + } + + /// 获取当前话题 ID + pub fn current_topic(&self) -> Option<&str> { + self.current_topic_id.as_deref() + } + + /// 切换话题 - 清除当前历史并加载新话题的历史 + pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> { + // 清除当前历史 + self.history.remove_history(chat_id); + + // 加载新话题的历史 + let messages = self + .store + .load_messages_for_topic(topic_id) + .map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?; + + self.history.set_history(chat_id, messages); + self.current_topic_id = Some(topic_id.to_string()); + + tracing::info!( + topic_id = %topic_id, + chat_id = %chat_id, + "Switched to topic" + ); + Ok(()) + } + pub fn ensure_persistent_session(&self, chat_id: &str) -> Result { self.history.ensure_persistent_session(chat_id) } @@ -198,13 +236,28 @@ impl Session { self.history.reset_chat_context(chat_id) } - /// 将消息写入内存与持久化层 + /// 将消息写入内存与持久化层(使用当前 topic) pub fn append_persisted_message( &mut self, chat_id: &str, message: ChatMessage, ) -> Result<(), AgentError> { - self.history.append_persisted_message(chat_id, message) + let session_id = self.persistent_session_id(chat_id); + self.store + .append_message_with_topic(&session_id, self.current_topic_id.as_deref(), &message) + .map_err(|err| { + AgentError::Other(format!("append message persistence error: {}", err)) + })?; + self.add_message(chat_id, message); + + // 更新 topic 的最后活跃时间 + if let Some(topic_id) = &self.current_topic_id { + if let Err(e) = self.store.touch_topic(topic_id) { + tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic"); + } + } + + Ok(()) } pub fn append_persisted_messages( @@ -283,7 +336,18 @@ impl Session { } pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { - self.history.reload_chat_history(chat_id) + // 如果当前有 topic,加载该 topic 的消息 + if let Some(topic_id) = &self.current_topic_id { + let messages = self + .store + .load_messages_for_topic(topic_id) + .map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?; + self.history.set_history(chat_id, messages); + } else { + // 否则加载 session 的所有消息 + self.history.reload_chat_history(chat_id)?; + } + Ok(()) } pub(crate) fn store(&self) -> Arc { diff --git a/src/gateway/session_factory.rs b/src/gateway/session_factory.rs index e9670d3..c03eeb5 100644 --- a/src/gateway/session_factory.rs +++ b/src/gateway/session_factory.rs @@ -6,7 +6,7 @@ use crate::agent::AgentError; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::skills::SkillRuntime; -use crate::storage::{ConversationRepository, SkillEventRepository}; +use crate::storage::{ConversationRepository, SessionStore, SkillEventRepository}; use super::agent_factory::AgentFactory; use super::session::Session; @@ -18,6 +18,7 @@ pub(crate) struct SessionFactory { agent_factory: AgentFactory, conversations: Arc, skill_events: Arc, + store: Arc, chat_history_ttl_hours: Option, } @@ -28,6 +29,7 @@ impl SessionFactory { agent_factory: AgentFactory, conversations: Arc, skill_events: Arc, + store: Arc, chat_history_ttl_hours: Option, ) -> Self { Self { @@ -36,6 +38,7 @@ impl SessionFactory { agent_factory, conversations, skill_events, + store, chat_history_ttl_hours, } } @@ -54,6 +57,7 @@ impl SessionFactory { self.conversations.clone(), self.skill_events.clone(), self.chat_history_ttl_hours, + self.store.clone(), ) .await } diff --git a/src/gateway/session_history.rs b/src/gateway/session_history.rs index e52af3c..a997488 100644 --- a/src/gateway/session_history.rs +++ b/src/gateway/session_history.rs @@ -113,6 +113,10 @@ impl SessionHistory { self.chat_histories.get(chat_id) } + pub(crate) fn set_history(&mut self, chat_id: &str, history: Vec) { + self.chat_histories.insert(chat_id.to_string(), history); + } + pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) { self.get_or_create_history(chat_id).push(message); } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 6310dc9..ec4cbb4 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -41,6 +41,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { let runtime_session_id = uuid::Uuid::new_v4().to_string(); let mut current_session_id = initial_record.id.clone(); + let mut current_topic_id: Option = None; state .channel_manager .cli_channel() @@ -85,6 +86,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { &sender, &runtime_session_id, &mut current_session_id, + &mut current_topic_id, inbound, ) .await @@ -132,6 +134,7 @@ async fn handle_inbound( sender: &mpsc::Sender, runtime_session_id: &str, current_session_id: &mut String, + current_topic_id: &mut Option, inbound: WsInbound, ) -> Result<(), AgentError> { match inbound { @@ -204,7 +207,7 @@ async fn handle_inbound( }; // 创建命令路由器 - let cli_sessions = state.session_manager.cli_sessions(); + let _cli_sessions = state.session_manager.cli_sessions(); let store = state.session_manager.store(); let skills = state.session_manager.skills(); let provider_config = state.config.get_provider_config("default") @@ -234,11 +237,13 @@ async fn handle_inbound( // 构建命令上下文 tracing::debug!( current_session_id = %current_session_id, + current_topic_id = ?current_topic_id, "Building CommandContext for WebSocket command" ); let cmd_ctx = CommandContext::new("websocket", "cli") .with_session_id(current_session_id.as_str()) - .with_chat_id(current_session_id.as_str()); + .with_chat_id(current_session_id.as_str()) + .with_topic_id(current_topic_id.as_deref().unwrap_or("")); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; @@ -263,6 +268,15 @@ async fn handle_inbound( ) .await; } + // 更新当前话题 ID(如果是创建话题或切换话题) + if let Some(topic_id) = response.metadata.get("topic_id") { + tracing::info!( + old_topic_id = ?current_topic_id, + new_topic_id = %topic_id, + "Updating current_topic_id" + ); + *current_topic_id = Some(topic_id.clone()); + } } else if let Some(ref error) = response.error { tracing::warn!( error_code = %error.code,