feat: 添加话题管理功能,支持切换和持久化话题历史
This commit is contained in:
parent
2e13f6932c
commit
e709773464
@ -3,7 +3,7 @@ use crate::command::handler::CommandHandler;
|
||||
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||||
use crate::command::Command;
|
||||
use crate::gateway::session::SessionManager;
|
||||
use crate::storage::{SessionStore, TopicRecord};
|
||||
use crate::storage::SessionStore;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -219,9 +219,11 @@ async fn handle_switch_session(
|
||||
topic_id: String,
|
||||
ctx: CommandContext,
|
||||
) -> Result<CommandResponse, CommandError> {
|
||||
// 获取当前 session_id
|
||||
// 获取当前 session_id 和 chat_id
|
||||
let session_id = ctx.session_id.as_deref()
|
||||
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?;
|
||||
let chat_id = ctx.chat_id.as_deref()
|
||||
.ok_or_else(|| CommandError::new("NO_CHAT_ID", "No chat_id in context"))?;
|
||||
|
||||
// 尝试解析为序号
|
||||
let target_topic_id = if let Ok(index) = topic_id.parse::<usize>() {
|
||||
@ -249,6 +251,15 @@ async fn handle_switch_session(
|
||||
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?
|
||||
.ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?;
|
||||
|
||||
// 如果有 SessionManager,实际切换话题历史
|
||||
if let Some(ref session_manager) = handler.session_manager {
|
||||
if let Some(session) = session_manager.get(&ctx.channel_name).await {
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.switch_topic(chat_id, &target_topic_id)
|
||||
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
// 返回切换成功响应
|
||||
let message = format!(
|
||||
"✓ Switched to topic: {} ({} messages)",
|
||||
|
||||
@ -102,6 +102,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
agent_factory,
|
||||
conversations,
|
||||
skill_events,
|
||||
store.clone(),
|
||||
chat_history_ttl_hours,
|
||||
);
|
||||
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);
|
||||
|
||||
@ -41,6 +41,8 @@ pub struct Session {
|
||||
agent_factory: AgentFactory,
|
||||
compressor: ContextCompressor,
|
||||
history: SessionHistory,
|
||||
store: Arc<SessionStore>,
|
||||
current_topic_id: Option<String>,
|
||||
}
|
||||
|
||||
pub struct BusToolCallEmitter {
|
||||
@ -120,6 +122,7 @@ impl Session {
|
||||
conversations,
|
||||
skill_events,
|
||||
chat_history_ttl_hours,
|
||||
store,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@ -133,6 +136,7 @@ impl Session {
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
@ -148,6 +152,8 @@ impl Session {
|
||||
skill_events,
|
||||
chat_history_ttl_hours,
|
||||
),
|
||||
store,
|
||||
current_topic_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@ -155,6 +161,38 @@ impl Session {
|
||||
self.history.persistent_session_id(chat_id)
|
||||
}
|
||||
|
||||
/// 设置当前话题 ID
|
||||
pub fn set_current_topic(&mut self, topic_id: Option<String>) {
|
||||
self.current_topic_id = topic_id;
|
||||
}
|
||||
|
||||
/// 获取当前话题 ID
|
||||
pub fn current_topic(&self) -> Option<&str> {
|
||||
self.current_topic_id.as_deref()
|
||||
}
|
||||
|
||||
/// 切换话题 - 清除当前历史并加载新话题的历史
|
||||
pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> {
|
||||
// 清除当前历史
|
||||
self.history.remove_history(chat_id);
|
||||
|
||||
// 加载新话题的历史
|
||||
let messages = self
|
||||
.store
|
||||
.load_messages_for_topic(topic_id)
|
||||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||
|
||||
self.history.set_history(chat_id, messages);
|
||||
self.current_topic_id = Some(topic_id.to_string());
|
||||
|
||||
tracing::info!(
|
||||
topic_id = %topic_id,
|
||||
chat_id = %chat_id,
|
||||
"Switched to topic"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||||
self.history.ensure_persistent_session(chat_id)
|
||||
}
|
||||
@ -198,13 +236,28 @@ impl Session {
|
||||
self.history.reset_chat_context(chat_id)
|
||||
}
|
||||
|
||||
/// 将消息写入内存与持久化层
|
||||
/// 将消息写入内存与持久化层(使用当前 topic)
|
||||
pub fn append_persisted_message(
|
||||
&mut self,
|
||||
chat_id: &str,
|
||||
message: ChatMessage,
|
||||
) -> Result<(), AgentError> {
|
||||
self.history.append_persisted_message(chat_id, message)
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
self.store
|
||||
.append_message_with_topic(&session_id, self.current_topic_id.as_deref(), &message)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("append message persistence error: {}", err))
|
||||
})?;
|
||||
self.add_message(chat_id, message);
|
||||
|
||||
// 更新 topic 的最后活跃时间
|
||||
if let Some(topic_id) = &self.current_topic_id {
|
||||
if let Err(e) = self.store.touch_topic(topic_id) {
|
||||
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_persisted_messages<I>(
|
||||
@ -283,7 +336,18 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
self.history.reload_chat_history(chat_id)
|
||||
// 如果当前有 topic,加载该 topic 的消息
|
||||
if let Some(topic_id) = &self.current_topic_id {
|
||||
let messages = self
|
||||
.store
|
||||
.load_messages_for_topic(topic_id)
|
||||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||
self.history.set_history(chat_id, messages);
|
||||
} else {
|
||||
// 否则加载 session 的所有消息
|
||||
self.history.reload_chat_history(chat_id)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
|
||||
|
||||
@ -6,7 +6,7 @@ use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{ConversationRepository, SkillEventRepository};
|
||||
use crate::storage::{ConversationRepository, SessionStore, SkillEventRepository};
|
||||
|
||||
use super::agent_factory::AgentFactory;
|
||||
use super::session::Session;
|
||||
@ -18,6 +18,7 @@ pub(crate) struct SessionFactory {
|
||||
agent_factory: AgentFactory,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
store: Arc<SessionStore>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
}
|
||||
|
||||
@ -28,6 +29,7 @@ impl SessionFactory {
|
||||
agent_factory: AgentFactory,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
store: Arc<SessionStore>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
@ -36,6 +38,7 @@ impl SessionFactory {
|
||||
agent_factory,
|
||||
conversations,
|
||||
skill_events,
|
||||
store,
|
||||
chat_history_ttl_hours,
|
||||
}
|
||||
}
|
||||
@ -54,6 +57,7 @@ impl SessionFactory {
|
||||
self.conversations.clone(),
|
||||
self.skill_events.clone(),
|
||||
self.chat_history_ttl_hours,
|
||||
self.store.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@ -113,6 +113,10 @@ impl SessionHistory {
|
||||
self.chat_histories.get(chat_id)
|
||||
}
|
||||
|
||||
pub(crate) fn set_history(&mut self, chat_id: &str, history: Vec<ChatMessage>) {
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
}
|
||||
|
||||
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||
self.get_or_create_history(chat_id).push(message);
|
||||
}
|
||||
|
||||
@ -41,6 +41,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
|
||||
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut current_session_id = initial_record.id.clone();
|
||||
let mut current_topic_id: Option<String> = None;
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
@ -85,6 +86,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
&sender,
|
||||
&runtime_session_id,
|
||||
&mut current_session_id,
|
||||
&mut current_topic_id,
|
||||
inbound,
|
||||
)
|
||||
.await
|
||||
@ -132,6 +134,7 @@ async fn handle_inbound(
|
||||
sender: &mpsc::Sender<WsOutbound>,
|
||||
runtime_session_id: &str,
|
||||
current_session_id: &mut String,
|
||||
current_topic_id: &mut Option<String>,
|
||||
inbound: WsInbound,
|
||||
) -> Result<(), AgentError> {
|
||||
match inbound {
|
||||
@ -204,7 +207,7 @@ async fn handle_inbound(
|
||||
};
|
||||
|
||||
// 创建命令路由器
|
||||
let cli_sessions = state.session_manager.cli_sessions();
|
||||
let _cli_sessions = state.session_manager.cli_sessions();
|
||||
let store = state.session_manager.store();
|
||||
let skills = state.session_manager.skills();
|
||||
let provider_config = state.config.get_provider_config("default")
|
||||
@ -234,11 +237,13 @@ async fn handle_inbound(
|
||||
// 构建命令上下文
|
||||
tracing::debug!(
|
||||
current_session_id = %current_session_id,
|
||||
current_topic_id = ?current_topic_id,
|
||||
"Building CommandContext for WebSocket command"
|
||||
);
|
||||
let cmd_ctx = CommandContext::new("websocket", "cli")
|
||||
.with_session_id(current_session_id.as_str())
|
||||
.with_chat_id(current_session_id.as_str());
|
||||
.with_chat_id(current_session_id.as_str())
|
||||
.with_topic_id(current_topic_id.as_deref().unwrap_or(""));
|
||||
|
||||
// 执行命令
|
||||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||
@ -263,6 +268,15 @@ async fn handle_inbound(
|
||||
)
|
||||
.await;
|
||||
}
|
||||
// 更新当前话题 ID(如果是创建话题或切换话题)
|
||||
if let Some(topic_id) = response.metadata.get("topic_id") {
|
||||
tracing::info!(
|
||||
old_topic_id = ?current_topic_id,
|
||||
new_topic_id = %topic_id,
|
||||
"Updating current_topic_id"
|
||||
);
|
||||
*current_topic_id = Some(topic_id.clone());
|
||||
}
|
||||
} else if let Some(ref error) = response.error {
|
||||
tracing::warn!(
|
||||
error_code = %error.code,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user