feat: 添加话题管理功能,支持切换和持久化话题历史
This commit is contained in:
parent
2e13f6932c
commit
e709773464
@ -3,7 +3,7 @@ use crate::command::handler::CommandHandler;
|
|||||||
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||||||
use crate::command::Command;
|
use crate::command::Command;
|
||||||
use crate::gateway::session::SessionManager;
|
use crate::gateway::session::SessionManager;
|
||||||
use crate::storage::{SessionStore, TopicRecord};
|
use crate::storage::SessionStore;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -219,9 +219,11 @@ async fn handle_switch_session(
|
|||||||
topic_id: String,
|
topic_id: String,
|
||||||
ctx: CommandContext,
|
ctx: CommandContext,
|
||||||
) -> Result<CommandResponse, CommandError> {
|
) -> Result<CommandResponse, CommandError> {
|
||||||
// 获取当前 session_id
|
// 获取当前 session_id 和 chat_id
|
||||||
let session_id = ctx.session_id.as_deref()
|
let session_id = ctx.session_id.as_deref()
|
||||||
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?;
|
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?;
|
||||||
|
let chat_id = ctx.chat_id.as_deref()
|
||||||
|
.ok_or_else(|| CommandError::new("NO_CHAT_ID", "No chat_id in context"))?;
|
||||||
|
|
||||||
// 尝试解析为序号
|
// 尝试解析为序号
|
||||||
let target_topic_id = if let Ok(index) = topic_id.parse::<usize>() {
|
let target_topic_id = if let Ok(index) = topic_id.parse::<usize>() {
|
||||||
@ -249,6 +251,15 @@ async fn handle_switch_session(
|
|||||||
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?
|
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?
|
||||||
.ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?;
|
.ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?;
|
||||||
|
|
||||||
|
// 如果有 SessionManager,实际切换话题历史
|
||||||
|
if let Some(ref session_manager) = handler.session_manager {
|
||||||
|
if let Some(session) = session_manager.get(&ctx.channel_name).await {
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
session_guard.switch_topic(chat_id, &target_topic_id)
|
||||||
|
.map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 返回切换成功响应
|
// 返回切换成功响应
|
||||||
let message = format!(
|
let message = format!(
|
||||||
"✓ Switched to topic: {} ({} messages)",
|
"✓ Switched to topic: {} ({} messages)",
|
||||||
|
|||||||
@ -102,6 +102,7 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
agent_factory,
|
agent_factory,
|
||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
|
store.clone(),
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
);
|
);
|
||||||
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);
|
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);
|
||||||
|
|||||||
@ -41,6 +41,8 @@ pub struct Session {
|
|||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
compressor: ContextCompressor,
|
compressor: ContextCompressor,
|
||||||
history: SessionHistory,
|
history: SessionHistory,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
current_topic_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct BusToolCallEmitter {
|
pub struct BusToolCallEmitter {
|
||||||
@ -120,6 +122,7 @@ impl Session {
|
|||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
|
store,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@ -133,6 +136,7 @@ impl Session {
|
|||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
@ -148,6 +152,8 @@ impl Session {
|
|||||||
skill_events,
|
skill_events,
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
),
|
),
|
||||||
|
store,
|
||||||
|
current_topic_id: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,6 +161,38 @@ impl Session {
|
|||||||
self.history.persistent_session_id(chat_id)
|
self.history.persistent_session_id(chat_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 设置当前话题 ID
|
||||||
|
pub fn set_current_topic(&mut self, topic_id: Option<String>) {
|
||||||
|
self.current_topic_id = topic_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取当前话题 ID
|
||||||
|
pub fn current_topic(&self) -> Option<&str> {
|
||||||
|
self.current_topic_id.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 切换话题 - 清除当前历史并加载新话题的历史
|
||||||
|
pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> {
|
||||||
|
// 清除当前历史
|
||||||
|
self.history.remove_history(chat_id);
|
||||||
|
|
||||||
|
// 加载新话题的历史
|
||||||
|
let messages = self
|
||||||
|
.store
|
||||||
|
.load_messages_for_topic(topic_id)
|
||||||
|
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||||
|
|
||||||
|
self.history.set_history(chat_id, messages);
|
||||||
|
self.current_topic_id = Some(topic_id.to_string());
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
topic_id = %topic_id,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
"Switched to topic"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||||||
self.history.ensure_persistent_session(chat_id)
|
self.history.ensure_persistent_session(chat_id)
|
||||||
}
|
}
|
||||||
@ -198,13 +236,28 @@ impl Session {
|
|||||||
self.history.reset_chat_context(chat_id)
|
self.history.reset_chat_context(chat_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将消息写入内存与持久化层
|
/// 将消息写入内存与持久化层(使用当前 topic)
|
||||||
pub fn append_persisted_message(
|
pub fn append_persisted_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
message: ChatMessage,
|
message: ChatMessage,
|
||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
self.history.append_persisted_message(chat_id, message)
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
|
self.store
|
||||||
|
.append_message_with_topic(&session_id, self.current_topic_id.as_deref(), &message)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("append message persistence error: {}", err))
|
||||||
|
})?;
|
||||||
|
self.add_message(chat_id, message);
|
||||||
|
|
||||||
|
// 更新 topic 的最后活跃时间
|
||||||
|
if let Some(topic_id) = &self.current_topic_id {
|
||||||
|
if let Err(e) = self.store.touch_topic(topic_id) {
|
||||||
|
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn append_persisted_messages<I>(
|
pub fn append_persisted_messages<I>(
|
||||||
@ -283,7 +336,18 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
self.history.reload_chat_history(chat_id)
|
// 如果当前有 topic,加载该 topic 的消息
|
||||||
|
if let Some(topic_id) = &self.current_topic_id {
|
||||||
|
let messages = self
|
||||||
|
.store
|
||||||
|
.load_messages_for_topic(topic_id)
|
||||||
|
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||||
|
self.history.set_history(chat_id, messages);
|
||||||
|
} else {
|
||||||
|
// 否则加载 session 的所有消息
|
||||||
|
self.history.reload_chat_history(chat_id)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
|
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
|
||||||
|
|||||||
@ -6,7 +6,7 @@ use crate::agent::AgentError;
|
|||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{ConversationRepository, SkillEventRepository};
|
use crate::storage::{ConversationRepository, SessionStore, SkillEventRepository};
|
||||||
|
|
||||||
use super::agent_factory::AgentFactory;
|
use super::agent_factory::AgentFactory;
|
||||||
use super::session::Session;
|
use super::session::Session;
|
||||||
@ -18,6 +18,7 @@ pub(crate) struct SessionFactory {
|
|||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ impl SessionFactory {
|
|||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -36,6 +38,7 @@ impl SessionFactory {
|
|||||||
agent_factory,
|
agent_factory,
|
||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
|
store,
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,6 +57,7 @@ impl SessionFactory {
|
|||||||
self.conversations.clone(),
|
self.conversations.clone(),
|
||||||
self.skill_events.clone(),
|
self.skill_events.clone(),
|
||||||
self.chat_history_ttl_hours,
|
self.chat_history_ttl_hours,
|
||||||
|
self.store.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|||||||
@ -113,6 +113,10 @@ impl SessionHistory {
|
|||||||
self.chat_histories.get(chat_id)
|
self.chat_histories.get(chat_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_history(&mut self, chat_id: &str, history: Vec<ChatMessage>) {
|
||||||
|
self.chat_histories.insert(chat_id.to_string(), history);
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||||
self.get_or_create_history(chat_id).push(message);
|
self.get_or_create_history(chat_id).push(message);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,6 +41,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
|
|
||||||
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
||||||
let mut current_session_id = initial_record.id.clone();
|
let mut current_session_id = initial_record.id.clone();
|
||||||
|
let mut current_topic_id: Option<String> = None;
|
||||||
state
|
state
|
||||||
.channel_manager
|
.channel_manager
|
||||||
.cli_channel()
|
.cli_channel()
|
||||||
@ -85,6 +86,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
&sender,
|
&sender,
|
||||||
&runtime_session_id,
|
&runtime_session_id,
|
||||||
&mut current_session_id,
|
&mut current_session_id,
|
||||||
|
&mut current_topic_id,
|
||||||
inbound,
|
inbound,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -132,6 +134,7 @@ async fn handle_inbound(
|
|||||||
sender: &mpsc::Sender<WsOutbound>,
|
sender: &mpsc::Sender<WsOutbound>,
|
||||||
runtime_session_id: &str,
|
runtime_session_id: &str,
|
||||||
current_session_id: &mut String,
|
current_session_id: &mut String,
|
||||||
|
current_topic_id: &mut Option<String>,
|
||||||
inbound: WsInbound,
|
inbound: WsInbound,
|
||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
match inbound {
|
match inbound {
|
||||||
@ -204,7 +207,7 @@ async fn handle_inbound(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// 创建命令路由器
|
// 创建命令路由器
|
||||||
let cli_sessions = state.session_manager.cli_sessions();
|
let _cli_sessions = state.session_manager.cli_sessions();
|
||||||
let store = state.session_manager.store();
|
let store = state.session_manager.store();
|
||||||
let skills = state.session_manager.skills();
|
let skills = state.session_manager.skills();
|
||||||
let provider_config = state.config.get_provider_config("default")
|
let provider_config = state.config.get_provider_config("default")
|
||||||
@ -234,11 +237,13 @@ async fn handle_inbound(
|
|||||||
// 构建命令上下文
|
// 构建命令上下文
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
current_session_id = %current_session_id,
|
current_session_id = %current_session_id,
|
||||||
|
current_topic_id = ?current_topic_id,
|
||||||
"Building CommandContext for WebSocket command"
|
"Building CommandContext for WebSocket command"
|
||||||
);
|
);
|
||||||
let cmd_ctx = CommandContext::new("websocket", "cli")
|
let cmd_ctx = CommandContext::new("websocket", "cli")
|
||||||
.with_session_id(current_session_id.as_str())
|
.with_session_id(current_session_id.as_str())
|
||||||
.with_chat_id(current_session_id.as_str());
|
.with_chat_id(current_session_id.as_str())
|
||||||
|
.with_topic_id(current_topic_id.as_deref().unwrap_or(""));
|
||||||
|
|
||||||
// 执行命令
|
// 执行命令
|
||||||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||||
@ -263,6 +268,15 @@ async fn handle_inbound(
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
// 更新当前话题 ID(如果是创建话题或切换话题)
|
||||||
|
if let Some(topic_id) = response.metadata.get("topic_id") {
|
||||||
|
tracing::info!(
|
||||||
|
old_topic_id = ?current_topic_id,
|
||||||
|
new_topic_id = %topic_id,
|
||||||
|
"Updating current_topic_id"
|
||||||
|
);
|
||||||
|
*current_topic_id = Some(topic_id.clone());
|
||||||
|
}
|
||||||
} else if let Some(ref error) = response.error {
|
} else if let Some(ref error) = response.error {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
error_code = %error.code,
|
error_code = %error.code,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user