feat: 添加话题管理功能,支持切换和持久化话题历史

This commit is contained in:
oudecheng 2026-05-15 15:28:07 +08:00
parent 2e13f6932c
commit e709773464
6 changed files with 106 additions and 8 deletions

View File

@ -3,7 +3,7 @@ use crate::command::handler::CommandHandler;
use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command; use crate::command::Command;
use crate::gateway::session::SessionManager; use crate::gateway::session::SessionManager;
use crate::storage::{SessionStore, TopicRecord}; use crate::storage::SessionStore;
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::Arc; use std::sync::Arc;
@ -219,9 +219,11 @@ async fn handle_switch_session(
topic_id: String, topic_id: String,
ctx: CommandContext, ctx: CommandContext,
) -> Result<CommandResponse, CommandError> { ) -> Result<CommandResponse, CommandError> {
// 获取当前 session_id // 获取当前 session_id 和 chat_id
let session_id = ctx.session_id.as_deref() let session_id = ctx.session_id.as_deref()
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?; .ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?;
let chat_id = ctx.chat_id.as_deref()
.ok_or_else(|| CommandError::new("NO_CHAT_ID", "No chat_id in context"))?;
// 尝试解析为序号 // 尝试解析为序号
let target_topic_id = if let Ok(index) = topic_id.parse::<usize>() { let target_topic_id = if let Ok(index) = topic_id.parse::<usize>() {
@ -249,6 +251,15 @@ async fn handle_switch_session(
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))? .map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?
.ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?; .ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?;
// 如果有 SessionManager实际切换话题历史
if let Some(ref session_manager) = handler.session_manager {
if let Some(session) = session_manager.get(&ctx.channel_name).await {
let mut session_guard = session.lock().await;
session_guard.switch_topic(chat_id, &target_topic_id)
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?;
}
}
// 返回切换成功响应 // 返回切换成功响应
let message = format!( let message = format!(
"✓ Switched to topic: {} ({} messages)", "✓ Switched to topic: {} ({} messages)",

View File

@ -102,6 +102,7 @@ pub(crate) fn build_session_manager_with_sender(
agent_factory, agent_factory,
conversations, conversations,
skill_events, skill_events,
store.clone(),
chat_history_ttl_hours, chat_history_ttl_hours,
); );
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours); let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);

View File

@ -41,6 +41,8 @@ pub struct Session {
agent_factory: AgentFactory, agent_factory: AgentFactory,
compressor: ContextCompressor, compressor: ContextCompressor,
history: SessionHistory, history: SessionHistory,
store: Arc<SessionStore>,
current_topic_id: Option<String>,
} }
pub struct BusToolCallEmitter { pub struct BusToolCallEmitter {
@ -120,6 +122,7 @@ impl Session {
conversations, conversations,
skill_events, skill_events,
chat_history_ttl_hours, chat_history_ttl_hours,
store,
) )
.await .await
} }
@ -133,6 +136,7 @@ impl Session {
conversations: Arc<dyn ConversationRepository>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
store: Arc<SessionStore>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
Ok(Self { Ok(Self {
id: Uuid::new_v4(), id: Uuid::new_v4(),
@ -148,6 +152,8 @@ impl Session {
skill_events, skill_events,
chat_history_ttl_hours, chat_history_ttl_hours,
), ),
store,
current_topic_id: None,
}) })
} }
@ -155,6 +161,38 @@ impl Session {
self.history.persistent_session_id(chat_id) self.history.persistent_session_id(chat_id)
} }
/// 设置当前话题 ID
pub fn set_current_topic(&mut self, topic_id: Option<String>) {
self.current_topic_id = topic_id;
}
/// 获取当前话题 ID
pub fn current_topic(&self) -> Option<&str> {
self.current_topic_id.as_deref()
}
/// 切换话题 - 清除当前历史并加载新话题的历史
pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> {
// 清除当前历史
self.history.remove_history(chat_id);
// 加载新话题的历史
let messages = self
.store
.load_messages_for_topic(topic_id)
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
self.history.set_history(chat_id, messages);
self.current_topic_id = Some(topic_id.to_string());
tracing::info!(
topic_id = %topic_id,
chat_id = %chat_id,
"Switched to topic"
);
Ok(())
}
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> { pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
self.history.ensure_persistent_session(chat_id) self.history.ensure_persistent_session(chat_id)
} }
@ -198,13 +236,28 @@ impl Session {
self.history.reset_chat_context(chat_id) self.history.reset_chat_context(chat_id)
} }
/// 将消息写入内存与持久化层 /// 将消息写入内存与持久化层(使用当前 topic
pub fn append_persisted_message( pub fn append_persisted_message(
&mut self, &mut self,
chat_id: &str, chat_id: &str,
message: ChatMessage, message: ChatMessage,
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
self.history.append_persisted_message(chat_id, message) let session_id = self.persistent_session_id(chat_id);
self.store
.append_message_with_topic(&session_id, self.current_topic_id.as_deref(), &message)
.map_err(|err| {
AgentError::Other(format!("append message persistence error: {}", err))
})?;
self.add_message(chat_id, message);
// 更新 topic 的最后活跃时间
if let Some(topic_id) = &self.current_topic_id {
if let Err(e) = self.store.touch_topic(topic_id) {
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic");
}
}
Ok(())
} }
pub fn append_persisted_messages<I>( pub fn append_persisted_messages<I>(
@ -283,7 +336,18 @@ impl Session {
} }
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
self.history.reload_chat_history(chat_id) // 如果当前有 topic加载该 topic 的消息
if let Some(topic_id) = &self.current_topic_id {
let messages = self
.store
.load_messages_for_topic(topic_id)
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
self.history.set_history(chat_id, messages);
} else {
// 否则加载 session 的所有消息
self.history.reload_chat_history(chat_id)?;
}
Ok(())
} }
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> { pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {

View File

@ -6,7 +6,7 @@ use crate::agent::AgentError;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, SkillEventRepository}; use crate::storage::{ConversationRepository, SessionStore, SkillEventRepository};
use super::agent_factory::AgentFactory; use super::agent_factory::AgentFactory;
use super::session::Session; use super::session::Session;
@ -18,6 +18,7 @@ pub(crate) struct SessionFactory {
agent_factory: AgentFactory, agent_factory: AgentFactory,
conversations: Arc<dyn ConversationRepository>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
store: Arc<SessionStore>,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
} }
@ -28,6 +29,7 @@ impl SessionFactory {
agent_factory: AgentFactory, agent_factory: AgentFactory,
conversations: Arc<dyn ConversationRepository>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
store: Arc<SessionStore>,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
) -> Self { ) -> Self {
Self { Self {
@ -36,6 +38,7 @@ impl SessionFactory {
agent_factory, agent_factory,
conversations, conversations,
skill_events, skill_events,
store,
chat_history_ttl_hours, chat_history_ttl_hours,
} }
} }
@ -54,6 +57,7 @@ impl SessionFactory {
self.conversations.clone(), self.conversations.clone(),
self.skill_events.clone(), self.skill_events.clone(),
self.chat_history_ttl_hours, self.chat_history_ttl_hours,
self.store.clone(),
) )
.await .await
} }

View File

@ -113,6 +113,10 @@ impl SessionHistory {
self.chat_histories.get(chat_id) self.chat_histories.get(chat_id)
} }
pub(crate) fn set_history(&mut self, chat_id: &str, history: Vec<ChatMessage>) {
self.chat_histories.insert(chat_id.to_string(), history);
}
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) { pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
self.get_or_create_history(chat_id).push(message); self.get_or_create_history(chat_id).push(message);
} }

View File

@ -41,6 +41,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let runtime_session_id = uuid::Uuid::new_v4().to_string(); let runtime_session_id = uuid::Uuid::new_v4().to_string();
let mut current_session_id = initial_record.id.clone(); let mut current_session_id = initial_record.id.clone();
let mut current_topic_id: Option<String> = None;
state state
.channel_manager .channel_manager
.cli_channel() .cli_channel()
@ -85,6 +86,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
&sender, &sender,
&runtime_session_id, &runtime_session_id,
&mut current_session_id, &mut current_session_id,
&mut current_topic_id,
inbound, inbound,
) )
.await .await
@ -132,6 +134,7 @@ async fn handle_inbound(
sender: &mpsc::Sender<WsOutbound>, sender: &mpsc::Sender<WsOutbound>,
runtime_session_id: &str, runtime_session_id: &str,
current_session_id: &mut String, current_session_id: &mut String,
current_topic_id: &mut Option<String>,
inbound: WsInbound, inbound: WsInbound,
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
match inbound { match inbound {
@ -204,7 +207,7 @@ async fn handle_inbound(
}; };
// 创建命令路由器 // 创建命令路由器
let cli_sessions = state.session_manager.cli_sessions(); let _cli_sessions = state.session_manager.cli_sessions();
let store = state.session_manager.store(); let store = state.session_manager.store();
let skills = state.session_manager.skills(); let skills = state.session_manager.skills();
let provider_config = state.config.get_provider_config("default") let provider_config = state.config.get_provider_config("default")
@ -234,11 +237,13 @@ async fn handle_inbound(
// 构建命令上下文 // 构建命令上下文
tracing::debug!( tracing::debug!(
current_session_id = %current_session_id, current_session_id = %current_session_id,
current_topic_id = ?current_topic_id,
"Building CommandContext for WebSocket command" "Building CommandContext for WebSocket command"
); );
let cmd_ctx = CommandContext::new("websocket", "cli") let cmd_ctx = CommandContext::new("websocket", "cli")
.with_session_id(current_session_id.as_str()) .with_session_id(current_session_id.as_str())
.with_chat_id(current_session_id.as_str()); .with_chat_id(current_session_id.as_str())
.with_topic_id(current_topic_id.as_deref().unwrap_or(""));
// 执行命令 // 执行命令
let response = router.dispatch_with_response(cmd, cmd_ctx).await; let response = router.dispatch_with_response(cmd, cmd_ctx).await;
@ -263,6 +268,15 @@ async fn handle_inbound(
) )
.await; .await;
} }
// 更新当前话题 ID如果是创建话题或切换话题
if let Some(topic_id) = response.metadata.get("topic_id") {
tracing::info!(
old_topic_id = ?current_topic_id,
new_topic_id = %topic_id,
"Updating current_topic_id"
);
*current_topic_id = Some(topic_id.clone());
}
} else if let Some(ref error) = response.error { } else if let Some(ref error) = response.error {
tracing::warn!( tracing::warn!(
error_code = %error.code, error_code = %error.code,