use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::protocol::WsOutbound; use crate::providers::{create_provider, LLMProvider}; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, ToolRegistry, WebFetchTool, }; /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 pub struct Session { pub id: Uuid, pub channel_name: String, /// 按 chat_id 路由到不同会话历史,支持多用户多会话 chat_histories: HashMap>, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, provider: Arc, tools: Arc, compressor: ContextCompressor, store: Arc, } impl Session { pub async fn new( channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, tools: Arc, store: Arc, ) -> Result { let provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; let provider: Arc = Arc::from(provider_box); Ok(Self { id: Uuid::new_v4(), channel_name, chat_histories: HashMap::new(), user_tx, provider_config: provider_config.clone(), provider: provider.clone(), tools, compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit), store, }) } pub fn persistent_session_id(&self, chat_id: &str) -> String { persistent_session_id(&self.channel_name, chat_id) } pub fn ensure_persistent_session(&self, chat_id: &str) -> Result { self.store .ensure_channel_session(&self.channel_name, chat_id) .map_err(|err| AgentError::Other(format!("session persistence error: {}", err))) } pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { if self.chat_histories.contains_key(chat_id) { return Ok(()); } let history = self .store .load_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; self.chat_histories.insert(chat_id.to_string(), history); Ok(()) } /// 获取或创建指定 chat_id 的会话历史 pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec { self.chat_histories .entry(chat_id.to_string()) .or_insert_with(Vec::new) } /// 获取指定 chat_id 的会话历史(不创建) pub fn get_history(&self, chat_id: &str) -> Option<&Vec> { self.chat_histories.get(chat_id) } /// 使用完整消息追加到历史 pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) { let history = self.get_or_create_history(chat_id); history.push(message); } pub fn remove_history(&mut self, chat_id: &str) { self.chat_histories.remove(chat_id); } pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { if let Some(history) = self.chat_histories.get_mut(chat_id) { let len = history.len(); history.clear(); #[cfg(debug_assertions)] tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); } self.store .clear_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) } pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> { if let Some(history) = self.chat_histories.get_mut(chat_id) { let len = history.len(); history.clear(); #[cfg(debug_assertions)] tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory"); } self.store .reset_session(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err))) } /// 将消息写入内存与持久化层 pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> { let session_id = self.persistent_session_id(chat_id); self.store .append_message(&session_id, &message) .map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?; self.add_message(chat_id, message); Ok(()) } pub fn append_persisted_messages(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError> where I: IntoIterator, { for message in messages { self.append_persisted_message(chat_id, message)?; } Ok(()) } pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { if media_refs.is_empty() { ChatMessage::user(content) } else { ChatMessage::user_with_media(content, media_refs) } } /// 清除所有历史 pub fn clear_all_history(&mut self) -> Result<(), AgentError> { let chat_ids: Vec = self.chat_histories.keys().cloned().collect(); let total: usize = self.chat_histories.values().map(|h| h.len()).sum(); self.chat_histories.clear(); #[cfg(debug_assertions)] tracing::debug!(previous_total = total, "All chat histories cleared"); for chat_id in chat_ids { self.store .clear_messages(&self.persistent_session_id(&chat_id)) .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?; } Ok(()) } pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } /// 获取 provider_config 引用 pub fn provider_config(&self) -> &LLMProviderConfig { &self.provider_config } /// 获取 compressor 引用 pub fn compressor(&self) -> &ContextCompressor { &self.compressor } /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent(&self) -> Result { Ok(AgentLoop::with_provider_and_tools( self.provider.clone(), self.tools.clone(), self.provider_config.max_tool_iterations, )) } } /// SessionManager 管理所有 Session,按 channel_name 路由 #[derive(Clone)] pub struct SessionManager { inner: Arc>, provider_config: LLMProviderConfig, tools: Arc, store: Arc, } struct SessionManagerInner { sessions: HashMap>>, session_timestamps: HashMap, session_ttl: Duration, } fn default_tools() -> ToolRegistry { let mut registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); registry.register(FileReadTool::new()); registry.register(FileWriteTool::new()); registry.register(FileEditTool::new()); registry.register(BashTool::new()); registry.register(HttpRequestTool::new( vec!["*".to_string()], // 允许所有域名,实际使用时建议限制 1_000_000, // max_response_size 30, // timeout_secs false, // allow_private_hosts )); registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs registry } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum InChatCommand { FreshConversation, } fn parse_in_chat_command(content: &str) -> Option { match content.trim() { "/new" | "/reset" => Some(InChatCommand::FreshConversation), _ => None, } } pub(crate) fn handle_in_chat_command( session: &mut Session, chat_id: &str, content: &str, ) -> Result, AgentError> { match parse_in_chat_command(content) { Some(InChatCommand::FreshConversation) => { session.reset_chat_context(chat_id)?; Ok(Some("Started a fresh conversation.".to_string())) } None => Ok(None), } } impl SessionManager { pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result { let store = Arc::new( SessionStore::new() .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, ); Ok(Self { inner: Arc::new(Mutex::new(SessionManagerInner { sessions: HashMap::new(), session_timestamps: HashMap::new(), session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), provider_config, tools: Arc::new(default_tools()), store, }) } pub fn tools(&self) -> Arc { self.tools.clone() } pub fn store(&self) -> Arc { self.store.clone() } pub fn create_cli_session(&self, title: Option<&str>) -> Result { self.store .create_cli_session(title) .map_err(|err| AgentError::Other(format!("create session error: {}", err))) } pub fn get_session_record(&self, session_id: &str) -> Result, AgentError> { self.store .get_session(session_id) .map_err(|err| AgentError::Other(format!("get session error: {}", err))) } pub fn list_cli_sessions(&self, include_archived: bool) -> Result, AgentError> { self.store .list_sessions("cli", include_archived) .map_err(|err| AgentError::Other(format!("list sessions error: {}", err))) } pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> { self.store .rename_session(session_id, title) .map_err(|err| AgentError::Other(format!("rename session error: {}", err))) } pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> { self.store .archive_session(session_id) .map_err(|err| AgentError::Other(format!("archive session error: {}", err))) } pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> { self.store .delete_session(session_id) .map_err(|err| AgentError::Other(format!("delete session error: {}", err))) } pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> { self.store .clear_messages(session_id) .map_err(|err| AgentError::Other(format!("clear session error: {}", err))) } pub fn load_session_messages(&self, session_id: &str) -> Result, AgentError> { self.store .load_messages(session_id) .map_err(|err| AgentError::Other(format!("load messages error: {}", err))) } /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { let mut inner = self.inner.lock().await; let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) { let elapsed = last_active.elapsed(); if elapsed > inner.session_ttl { tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating"); true } else { false } } else { #[cfg(debug_assertions)] tracing::debug!(channel = %channel_name, "Creating new session"); true }; if should_recreate { // 移除旧 session inner.sessions.remove(channel_name); // 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS) let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone(), self.store.clone(), ) .await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(channel_name.to_string(), arc.clone()); inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); } Ok(()) } /// 获取 session(不检查超时) pub async fn get(&self, channel_name: &str) -> Option>> { let inner = self.inner.lock().await; inner.sessions.get(channel_name).cloned() } /// 更新最后活跃时间 pub async fn touch(&self, channel_name: &str) { let mut inner = self.inner.lock().await; inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); } /// 处理消息:路由到对应 session 的 agent pub async fn handle_message( &self, channel_name: &str, _sender_id: &str, chat_id: &str, content: &str, media: Vec, ) -> Result { #[cfg(debug_assertions)] { tracing::debug!( channel = %channel_name, chat_id = %chat_id, content_len = content.len(), media_count = %media.len(), "Routing message to agent" ); for (i, m) in media.iter().enumerate() { tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message"); } } // 确保 session 存在(可能需要重建) self.ensure_session(channel_name).await?; // 更新活跃时间 self.touch(channel_name).await; // 获取 session let session = self .get(channel_name) .await .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; // 处理消息 let response = { let mut session_guard = session.lock().await; session_guard.ensure_persistent_session(chat_id)?; session_guard.ensure_chat_loaded(chat_id)?; if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? { return Ok(command_response); } // 添加用户消息到历史 let media_refs: Vec = media.iter().map(|m| m.path.clone()).collect(); #[cfg(debug_assertions)] if !media_refs.is_empty() { tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); } let user_message = session_guard.create_user_message(content, media_refs); session_guard.append_persisted_message(chat_id, user_message)?; // 获取完整历史 let history = session_guard.get_or_create_history(chat_id).clone(); // 压缩历史(如果需要) let history = session_guard.compressor .compress_if_needed(history) .await?; // 创建 agent 并处理 let agent = session_guard.create_agent()?; let result = agent.process(history).await?; // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; result.final_response }; #[cfg(debug_assertions)] tracing::debug!( channel = %channel_name, chat_id = %chat_id, response_len = response.content.len(), "Agent response received" ); Ok(response.content) } /// 清除指定 session 的所有历史 pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { if let Some(session) = self.get(channel_name).await { let mut session_guard = session.lock().await; session_guard.clear_all_history()?; } Ok(()) } } #[cfg(test)] mod tests { use super::*; use std::collections::HashMap; use tokio::sync::mpsc; fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), model_extra: HashMap::new(), max_tool_iterations: 1, token_limit: 4096, } } #[test] fn test_parse_in_chat_command_aliases() { assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation)); assert_eq!(parse_in_chat_command(" /reset \n"), Some(InChatCommand::FreshConversation)); assert_eq!(parse_in_chat_command("/new planning"), None); assert_eq!(parse_in_chat_command("please /reset"), None); } #[tokio::test] async fn test_handle_in_chat_command_resets_active_history_only() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let tools = Arc::new(default_tools()); let mut session = Session::new( "feishu".to_string(), test_provider_config(), user_tx, tools, store.clone(), ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); session .append_persisted_message("chat-1", ChatMessage::user("hello")) .unwrap(); let response = handle_in_chat_command(&mut session, "chat-1", "/reset") .unwrap() .unwrap(); assert_eq!(response, "Started a fresh conversation."); assert!(session.get_history("chat-1").unwrap().is_empty()); assert!(store .load_messages(&session.persistent_session_id("chat-1")) .unwrap() .is_empty()); assert_eq!( store .load_all_messages(&session.persistent_session_id("chat-1")) .unwrap() .len(), 1, ); } }