use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::agent::system_prompt::build_system_prompt; use crate::agent::context_compressor::ContextCompressionConfig; use crate::protocol::WsOutbound; use crate::providers::{create_provider, LLMProvider}; use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; use crate::session::events::DialogInfo; use crate::skills::SkillsLoader; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool, }; /// Generate a short ID (8 characters) from a UUID fn short_id() -> String { Uuid::new_v4().to_string()[..8].to_string() } /// Session = 一个 dialog /// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history pub struct Session { pub id: UnifiedSessionId, messages: Vec, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, provider: Arc, tools: Arc, compressor: ContextCompressor, } impl Session { pub async fn new( id: UnifiedSessionId, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, tools: Arc, ) -> Result { let provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; let provider: Arc = Arc::from(provider_box); let compressor_config = ContextCompressionConfig { protect_first_n: 2, ..Default::default() }; Ok(Self { id, messages: Vec::new(), user_tx, provider_config: provider_config.clone(), provider: provider.clone(), tools, compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config), }) } /// 获取 session ID pub fn session_id(&self) -> String { self.id.to_string() } /// 添加消息到历史 pub fn add_message(&mut self, message: ChatMessage) { self.messages.push(message); } /// 获取消息历史 pub fn get_history(&self) -> &[ChatMessage] { &self.messages } /// 清除历史消息 pub fn clear_history(&mut self) { let len = self.messages.len(); self.messages.clear(); #[cfg(debug_assertions)] tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared"); } /// 重置对话上下文 pub fn reset_context(&mut self) { let len = self.messages.len(); self.messages.clear(); #[cfg(debug_assertions)] tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory"); } pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { if media_refs.is_empty() { ChatMessage::user(content) } else { ChatMessage::user_with_media(content, media_refs) } } pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } /// 获取 provider_config 引用 pub fn provider_config(&self) -> &LLMProviderConfig { &self.provider_config } /// 获取 compressor 引用 pub fn compressor(&self) -> &ContextCompressor { &self.compressor } /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent(&self) -> Result { Ok(AgentLoop::with_provider_and_tools( self.provider.clone(), self.tools.clone(), self.provider_config.max_tool_iterations, self.provider_config.model_id.clone(), self.provider_config.workspace_dir.clone(), )) } /// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills) pub fn build_system_prompt(&self, skills_prompt: &str) -> String { let base_prompt = build_system_prompt( &self.provider_config.workspace_dir, &self.provider_config.model_id, &self.tools, ); if skills_prompt.trim().is_empty() { base_prompt } else { format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt) } } /// 将当前 session 导出为 markdown 文档 pub fn dump_as_markdown(&self) -> String { use chrono::{DateTime, Local}; let now = Local::now().format("%Y-%m-%d %H:%M:%S"); let mut md = String::new(); md.push_str(&format!("# Session Dump\n\n")); md.push_str(&format!("- **Session ID**: `{}`\n", self.id)); md.push_str(&format!("- **Channel**: `{}`\n", self.id.channel)); md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id)); md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id)); md.push_str(&format!("- **Message Count**: {}\n", self.messages.len())); md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id)); md.push_str(&format!("- **Exported At**: {}\n", now)); md.push_str("\n---\n\n"); md.push_str("## Conversation History\n\n"); for (i, msg) in self.messages.iter().enumerate() { let role = match msg.role.as_str() { "system" => "System", "user" => "User", "assistant" => "Assistant", "tool" => "Tool", r => r, }; let timestamp = if msg.timestamp > 0 { DateTime::from_timestamp_millis(msg.timestamp) .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string()) .unwrap_or_default() } else { String::new() }; md.push_str(&format!("### [{:03}] {} {}\n\n", i + 1, role, timestamp)); md.push_str("```\n"); if let Some(ref tool_calls) = msg.tool_calls { md.push_str(&format!("[Tool Calls]\n")); for tc in tool_calls { md.push_str(&format!("- {}: {:?}\n", tc.name, tc.arguments)); } } if let Some(ref tool_name) = msg.tool_name { md.push_str(&format!("[Tool: {}]\n", tool_name)); } if let Some(ref tool_call_id) = msg.tool_call_id { md.push_str(&format!("[Tool Call ID: {}]\n", tool_call_id)); } md.push_str(&msg.content); md.push_str("\n```\n\n"); if !msg.media_refs.is_empty() { md.push_str(&format!("**Media**: {:?}\n\n", msg.media_refs)); } } md } } /// SessionManager 管理所有 Session,按 channel_name 路由 #[derive(Clone)] pub struct SessionManager { inner: Arc>, provider_config: LLMProviderConfig, tools: Arc, skills_loader: Arc, } struct SessionManagerInner { /// Sessions keyed by UnifiedSessionId.to_string() sessions: HashMap>>, session_timestamps: HashMap, session_ttl: Duration, } fn create_default_tools(skills_loader: Arc) -> ToolRegistry { let mut registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); registry.register(FileReadTool::new()); registry.register(FileWriteTool::new()); registry.register(FileEditTool::new()); registry.register(BashTool::new()); registry.register(HttpRequestTool::new( vec!["*".to_string()], 1_000_000, 30, false, )); registry.register(WebFetchTool::new(50_000, 30)); registry.register(GetSkillTool::new(skills_loader)); registry } /// 斜杠命令定义 #[derive(Debug, Clone)] pub struct SlashCommand { /// 命令名称 pub name: &'static str, /// 命令描述 pub description: &'static str, /// 命令别名(触发词) pub aliases: &'static [&'static str], } impl SlashCommand { /// 检查给定内容是否匹配此命令 pub fn matches(&self, content: &str) -> bool { let trimmed = content.trim(); self.aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) } } /// Session 支持的斜杠命令列表 pub static SLASH_COMMANDS: &[SlashCommand] = &[ SlashCommand { name: "new", description: "Create a new conversation", aliases: &["/new"], }, SlashCommand { name: "delete", description: "Delete current conversation and start a new one", aliases: &["/delete"], }, SlashCommand { name: "compact", description: "Manually trigger context compression", aliases: &["/compact"], }, SlashCommand { name: "info", description: "Print current session information", aliases: &["/info"], }, SlashCommand { name: "dump", description: "Save current session as markdown document", aliases: &["/dump"], }, ]; impl SessionManager { pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result { let skills_loader = SkillsLoader::new(); skills_loader.load_skills(); let skills_loader = Arc::new(skills_loader); let tools = Arc::new(create_default_tools(skills_loader.clone())); Ok(Self { inner: Arc::new(Mutex::new(SessionManagerInner { sessions: HashMap::new(), session_timestamps: HashMap::new(), session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), provider_config, tools, skills_loader, }) } pub fn tools(&self) -> Arc { self.tools.clone() } /// 获取所有可用的斜杠命令 pub fn get_slash_commands(&self) -> &[SlashCommand] { SLASH_COMMANDS } /// 执行斜杠命令 /// 返回 (新session_id, 响应消息) pub async fn execute_slash_command( &self, command: &str, args: Option<&str>, channel: &str, chat_id: &str, current_session_id: Option<&UnifiedSessionId>, ) -> Result<(Option, String), AgentError> { let cmd = SLASH_COMMANDS .iter() .find(|c| c.name == command) .ok_or_else(|| AgentError::Other(format!("Unknown command: {}", command)))?; tracing::info!(cmd = %cmd.name, args = ?args, "Executing slash command"); match cmd.name { "new" => { let title = args.map(|s| s.to_string()); let (new_id, title) = self.create_session(channel, chat_id, title.as_deref()).await?; Ok((Some(new_id), format!("New conversation '{}' created.", title))) } "delete" => { let (new_id, _title) = self.create_session(channel, chat_id, None).await?; Ok((Some(new_id), "Conversation deleted. New conversation created.".to_string())) } "compact" => { if let Some(sid) = current_session_id { let session = self.get_or_create_session(sid).await?; let mut session_guard = session.lock().await; let original_count = session_guard.get_history().len(); let history = session_guard.get_history().to_vec(); let compressed = session_guard.compressor .compress_if_needed(history) .await?; let compressed_count = compressed.len(); session_guard.clear_history(); for msg in compressed { session_guard.add_message(msg); } Ok((None, format!( "Context compressed: {} → {} messages.", original_count, compressed_count ))) } else { Ok((None, "No active conversation to compress.".to_string())) } } "info" => { if let Some(sid) = current_session_id { let session = self.get_or_create_session(sid).await?; let session_guard = session.lock().await; let message_count = session_guard.get_history().len(); let session_id_str = session_guard.session_id(); Ok((None, format!( "Session ID: {}\nMessage count: {}", session_id_str, message_count ))) } else { Ok((None, "No active session.".to_string())) } } "dump" => { if let Some(sid) = current_session_id { let session = self.get_or_create_session(sid).await?; let session_guard = session.lock().await; let md = session_guard.dump_as_markdown(); Ok((None, md)) } else { Ok((None, "No active session.".to_string())) } } _ => Err(AgentError::Other(format!("Command not implemented: {}", cmd.name))), } } pub async fn create_session( &self, channel: &str, chat_id: &str, title: Option<&str>, ) -> Result<(UnifiedSessionId, String), AgentError> { let dialog_id = short_id(); let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id); let session_id_str = unified_id.to_string(); let title = title .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) .unwrap_or_else(|| format!("Dialog {}", &dialog_id)); let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), self.provider_config.clone(), user_tx, self.tools.clone(), ).await?; let arc = Arc::new(Mutex::new(session)); let inner = &mut *self.inner.lock().await; inner.sessions.insert(session_id_str.clone(), arc.clone()); inner.session_timestamps.insert(session_id_str, Instant::now()); Ok((unified_id, title)) } pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result>, AgentError> { let session_id_str = unified_id.to_string(); let inner = &mut *self.inner.lock().await; if let Some(session) = inner.sessions.get(&session_id_str) { inner.session_timestamps.insert(session_id_str, Instant::now()); return Ok(session.clone()); } // Create new session let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), self.provider_config.clone(), user_tx, self.tools.clone(), ).await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(session_id_str.clone(), arc.clone()); inner.session_timestamps.insert(session_id_str, Instant::now()); Ok(arc) } pub async fn create_dialog( &self, channel: &str, chat_id: &str, title: Option<&str>, ) -> Result<(UnifiedSessionId, String), AgentError> { self.create_session(channel, chat_id, title).await } pub async fn get_current_dialog( &self, _channel: &str, _chat_id: &str, ) -> Result, AgentError> { Ok(None) } pub async fn switch_dialog( &self, _channel: &str, _chat_id: &str, _dialog_id: &str, ) -> Result { Err(AgentError::Other("switch_dialog not applicable in new architecture".to_string())) } pub async fn list_dialogs( &self, _channel: &str, _chat_id: &str, _include_archived: bool, ) -> Result<(Vec, Option), AgentError> { Ok((vec![], None)) } pub fn rename_dialog(&self, _session_id: &UnifiedSessionId, _title: &str) -> Result<(), AgentError> { Err(AgentError::Other("rename_dialog not available".to_string())) } pub fn archive_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> { Err(AgentError::Other("archive_dialog not available".to_string())) } pub fn delete_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> { Err(AgentError::Other("delete_dialog not available".to_string())) } pub fn clear_dialog_history(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> { Err(AgentError::Other("clear_dialog_history not available".to_string())) } pub async fn handle_message( &self, channel: &str, _sender_id: &str, chat_id: &str, dialog_id: Option<&str>, content: &str, media: Vec, ) -> Result { let dialog_id = dialog_id.unwrap_or(DEFAULT_DIALOG_ID); let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id); let session = self.get_or_create_session(&unified_id).await?; let response: String = { let mut session_guard = session.lock().await; let media_refs: Vec = media.iter().map(|m| m.path.clone()).collect(); #[cfg(debug_assertions)] if !media_refs.is_empty() { tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); } let user_message = session_guard.create_user_message(content, media_refs); session_guard.add_message(user_message); let mut history = session_guard.get_history().to_vec(); // Build skills prompt let skills_prompt = self.skills_loader.build_skills_prompt(); // Build combined system prompt and inject at position 0 // This ensures AgentLoop.process() sees a system message and doesn't inject its own let system_prompt = session_guard.build_system_prompt(&skills_prompt); history.insert(0, ChatMessage::system(system_prompt)); let history = session_guard.compressor .compress_if_needed(history) .await?; let agent = session_guard.create_agent()?; let result = agent.process(history).await?; for msg in result.emitted_messages { session_guard.add_message(msg); } result.final_response.content }; #[cfg(debug_assertions)] tracing::debug!( channel = %channel, chat_id = %chat_id, response_len = %response.len(), "Agent response received" ); Ok(response) } pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { let session = self.get_or_create_session(unified_id).await?; let mut session_guard = session.lock().await; session_guard.clear_history(); Ok(()) } } #[cfg(test)] mod tests { use super::*; use std::collections::HashMap; use tokio::sync::mpsc; fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), model_extra: HashMap::new(), max_tool_iterations: 1, token_limit: 4096, workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"), } } }