use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; use crate::bus::InboundMessage; use crate::command::context::CommandContext; use crate::command::handler::{CommandHandler, CommandMetadata, InChatCommandHandler}; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; use crate::storage::{SessionRecord, SessionStore}; use crate::tools::task::repository::TaskRepository; use crate::agent::AgentError; use async_trait::async_trait; use chrono::{Local, TimeZone}; use std::path::PathBuf; use std::sync::Arc; /// 保存会话到文件(公共函数,可被命令处理器和其他模块复用) /// /// # Arguments /// * `session_id` - 会话ID /// * `filepath` - 可选的文件路径 /// * `include_all` - 是否包含 cutoff 之前的所有消息 /// * `include_subagents` - 是否包含子智能体消息 /// * `store` - 会话存储 /// * `task_repository` - 任务存储(可选,用于查询子智能体) /// * `system_prompt_provider` - 系统提示词提供者 /// /// # Returns /// 返回保存的文件路径 pub async fn save_session_to_file( session_id: &str, filepath: Option, include_all: bool, include_subagents: bool, store: &SessionStore, task_repository: Option<&dyn TaskRepository>, system_prompt_provider: &dyn SystemPromptProvider, ) -> Result { // 获取会话记录 let record = store .get_session(session_id) .map_err(|e| format!("Failed to get session: {}", e))? .ok_or_else(|| "Session not found".to_string())?; // 根据 include_all 决定加载消息范围 let messages = if include_all { store .load_all_messages(session_id) .map_err(|e| format!("Failed to load messages: {}", e))? } else { store .load_messages(session_id) .map_err(|e| format!("Failed to load messages: {}", e))? }; // 加载子智能体消息(如果启用) let subagent_data = if include_subagents { load_subagent_data(session_id, None, store, task_repository).await } else { Vec::new() }; // 计算用户消息数(用于系统提示词构建) let user_message_count = messages.iter().filter(|m| m.role == "user").count(); // 构建系统提示词(使用外部传入的提供者) let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count); // 生成 Markdown 内容 let markdown = generate_markdown_with_subagents(&record, &system_prompt, &messages, &subagent_data); // 确定输出路径 let output_path = resolve_filepath(filepath, &record); // 创建父目录 if let Some(parent) = output_path.parent() { if !parent.as_os_str().is_empty() && !parent.exists() { std::fs::create_dir_all(parent) .map_err(|e| format!("Failed to create directory: {}", e))?; } } // 写入文件 std::fs::write(&output_path, markdown) .map_err(|e| format!("Failed to write file: {}", e))?; Ok(output_path) } /// 保存会话命令处理器 /// /// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 pub struct SaveSessionCommandHandler { store: Arc, task_repository: Arc, system_prompt_provider: Arc, } impl SaveSessionCommandHandler { /// 创建新的保存会话命令处理器 /// /// # Arguments /// * `store` - 会话存储 /// * `task_repository` - 任务存储(用于查询子智能体) /// * `system_prompt_provider` - 系统提示词提供者 pub fn new( store: Arc, task_repository: Arc, system_prompt_provider: Arc, ) -> Self { Self { store, task_repository, system_prompt_provider, } } /// 从会话记录获取存储(用于测试) #[cfg(test)] fn store(&self) -> &Arc { &self.store } } #[async_trait] impl CommandHandler for SaveSessionCommandHandler { fn can_handle(&self, cmd: &Command) -> bool { matches!(cmd, Command::SaveSession { .. }) } fn metadata(&self) -> Option { Some(CommandMetadata { name: "save-session", description: "保存当前会话到 Markdown 文件", usage: "/save-session [all] [filepath]", }) } async fn handle( &self, cmd: Command, ctx: CommandContext, ) -> Result { match cmd { Command::SaveSession { filepath, include_all, include_subagents } => { handle_save_session(self, filepath, include_all, include_subagents, ctx).await } _ => unreachable!(), } } } /// 处理保存会话命令 async fn handle_save_session( handler: &SaveSessionCommandHandler, filepath: Option, include_all: bool, include_subagents: bool, ctx: CommandContext, ) -> Result { tracing::debug!( ctx_session_id = ?ctx.session_id, ctx_chat_id = ?ctx.chat_id, channel = %ctx.channel_name, include_subagents = include_subagents, "SaveSession command received" ); let session_id = ctx .session_id .as_deref() .ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?; tracing::debug!(session_id = %session_id, "Attempting to save session"); // 先检查会话是否存在 match handler.store.get_session(session_id) { Ok(Some(record)) => { tracing::debug!( session_id = %session_id, title = %record.title, chat_id = %record.chat_id, message_count = record.message_count, "Session found for saving" ); } Ok(None) => { tracing::warn!(session_id = %session_id, "Session not found in store"); } Err(e) => { tracing::error!(session_id = %session_id, error = %e, "Error querying session"); } } // 调用公共函数 let output_path = save_session_to_file( session_id, filepath, include_all, include_subagents, &*handler.store, Some(handler.task_repository.as_ref()), &*handler.system_prompt_provider, ) .await .map_err(|e| CommandError::new("SAVE_ERROR", e))?; // 根据 include_all 获取消息数量 let message_count = if include_all { handler .store .load_all_messages(session_id) } else { handler .store .load_messages(session_id) } .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))? .len(); Ok(CommandResponse::success(ctx.request_id) .with_message( MessageKind::Notification, &format!("Session saved to: {}", output_path.display()), ) .with_metadata("filepath", output_path.to_string_lossy().as_ref()) .with_metadata("message_count", &message_count.to_string())) } /// 子智能体任务数据 #[derive(Debug)] pub struct SubagentTaskData { pub task_id: String, pub session_id: String, pub description: String, pub subagent_type: String, pub state: String, pub created_at: i64, pub messages: Vec, } /// 加载子智能体数据 /// /// # Arguments /// * `parent_session_id` - 父会话 ID /// * `parent_topic_id` - 可选的父话题 ID,如果提供则只加载该话题下的子智能体 /// * `store` - 会话存储 /// * `task_repository` - 任务存储(可选) pub async fn load_subagent_data( parent_session_id: &str, parent_topic_id: Option<&str>, store: &SessionStore, task_repository: Option<&dyn TaskRepository>, ) -> Vec { let Some(repo) = task_repository else { return Vec::new(); }; // 获取子任务:如果提供了 topic_id,则按 topic 查询;否则按 session 查询 let tasks = match parent_topic_id { Some(topic_id) => match repo.list_tasks_for_topic(topic_id).await { Ok(tasks) => tasks, Err(e) => { tracing::warn!(error = %e, "Failed to list tasks for topic"); return Vec::new(); } }, None => match repo.list_tasks_for_session(parent_session_id).await { Ok(tasks) => tasks, Err(e) => { tracing::warn!(error = %e, "Failed to list tasks for session"); return Vec::new(); } }, }; let mut result = Vec::new(); for task in tasks { // 加载子智能体的消息 let messages = match store.load_all_messages(&task.session_id) { Ok(msgs) => msgs, Err(e) => { tracing::warn!(error = %e, task_id = %task.id, "Failed to load subagent messages"); Vec::new() } }; result.push(SubagentTaskData { task_id: task.id, session_id: task.session_id, description: task.description, subagent_type: task.subagent_type.as_str().to_string(), state: format!("{:?}", task.state), created_at: task.created_at, messages, }); } result } /// 生成 Markdown 内容(包含子智能体) pub fn generate_markdown_with_subagents( record: &SessionRecord, system_prompt: &Option, messages: &[crate::bus::ChatMessage], subagent_data: &[SubagentTaskData], ) -> String { let mut output = String::new(); // YAML frontmatter output.push_str("---\n"); output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title))); output.push_str(&format!("session_id: {}\n", record.id)); output.push_str(&format!("channel: {}\n", record.channel_name)); output.push_str(&format!("chat_id: {}\n", record.chat_id)); output.push_str(&format!( "created_at: {}\n", format_timestamp(record.created_at) )); output.push_str(&format!( "updated_at: {}\n", format_timestamp(record.updated_at) )); output.push_str(&format!( "last_active_at: {}\n", format_timestamp(record.last_active_at) )); output.push_str(&format!("message_count: {}\n", messages.len())); if !subagent_data.is_empty() { output.push_str(&format!("subagent_count: {}\n", subagent_data.len())); } output.push_str("---\n\n"); // 系统提示词 output.push_str(&generate_system_prompt_markdown(system_prompt)); // 子智能体任务(如果有) if !subagent_data.is_empty() { output.push_str(&generate_subagent_tasks_markdown(subagent_data)); } // 主会话消息历史 output.push_str(&generate_messages_markdown(messages)); output } /// 生成子智能体任务 Markdown pub fn generate_subagent_tasks_markdown(subagent_data: &[SubagentTaskData]) -> String { let mut output = String::new(); output.push_str("# Subagent Tasks\n\n"); for task in subagent_data { output.push_str(&format!("## Task: {} ({})", task.description, task.subagent_type)); output.push('\n'); output.push_str(&format!("**Task ID:** `{}`\n\n", task.task_id)); output.push_str(&format!("**Session ID:** `{}`\n\n", task.session_id)); output.push_str(&format!("**Status:** {}\n\n", task.state)); output.push_str(&format!("**Created:** {}\n\n", format_timestamp(task.created_at))); output.push_str(&format!("**Message Count:** {}\n\n", task.messages.len())); // 子智能体消息 if !task.messages.is_empty() { output.push_str("### Messages\n\n"); for (idx, msg) in task.messages.iter().enumerate() { output.push_str(&format!("#### Message {}\n\n", idx + 1)); output.push_str(&format!("**Role:** {}\n\n", msg.role)); output.push_str(&format!("**Time:** {}\n\n", format_timestamp(msg.timestamp))); if let Some(ref reasoning) = msg.reasoning_content { output.push_str("**Reasoning:**\n"); output.push_str("```\n"); output.push_str(reasoning); output.push_str("\n```\n\n"); } output.push_str("**Content:**\n\n"); if msg.content.is_empty() { output.push_str("*empty*\n\n"); } else { output.push_str(&format!("{}\n\n", format_message_content(&msg.content))); } // 工具调用 if let Some(ref calls) = msg.tool_calls { if !calls.is_empty() { output.push_str("**Tool Calls:**\n\n"); for call in calls { output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id)); output.push_str(" ```json\n"); let args_json = serde_json::to_string_pretty(&call.arguments) .unwrap_or_else(|_| call.arguments.to_string()); for line in args_json.lines() { output.push_str(&format!(" {}\n", line)); } output.push_str(" ```\n"); } output.push('\n'); } } output.push_str("---\n\n"); } } output.push_str("---\n\n"); } output } /// 构建系统提示词 fn build_system_prompt( provider: &dyn SystemPromptProvider, record: &SessionRecord, user_message_count: usize, ) -> Option { let context = SystemPromptContext { session_id: Some(record.id.clone()), chat_id: record.chat_id.clone(), user_message_count, }; provider.build(&context) } /// 生成 Markdown 内容 pub fn generate_markdown( record: &SessionRecord, system_prompt: &Option, messages: &[crate::bus::ChatMessage], ) -> String { let mut output = String::new(); // YAML frontmatter output.push_str("---\n"); output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title))); output.push_str(&format!("session_id: {}\n", record.id)); output.push_str(&format!("channel: {}\n", record.channel_name)); output.push_str(&format!("chat_id: {}\n", record.chat_id)); output.push_str(&format!( "created_at: {}\n", format_timestamp(record.created_at) )); output.push_str(&format!( "updated_at: {}\n", format_timestamp(record.updated_at) )); output.push_str(&format!( "last_active_at: {}\n", format_timestamp(record.last_active_at) )); output.push_str(&format!("message_count: {}\n", messages.len())); output.push_str("---\n\n"); // 系统提示词(复用公共函数) output.push_str(&generate_system_prompt_markdown(system_prompt)); // 消息历史(复用公共函数) output.push_str(&generate_messages_markdown(messages)); output } /// 格式化消息内容 /// /// 如果内容包含特殊字符,使用代码块包装 /// 使用比内容中最大连续反引号数量多1的反引号来包裹,避免嵌套冲突 pub fn format_message_content(content: &str) -> String { // 如果内容包含表格标记或换行符,使用代码块包裹以保留格式 if content.contains("| ") || content.contains('\n') { // 计算内容中连续反引号的最大数量 let max_backticks = content .chars() .fold((0, 0), |(max_count, current_count), c| { if c == '`' { (max_count, current_count + 1) } else { (max_count.max(current_count), 0) } }) .0; // 使用比最大数量多1的反引号来包裹(至少3个) let fence = "`".repeat(max_backticks.max(3) + 1); format!("{}\n{}\n{}", fence, content, fence) } else { content.to_string() } } /// 转义 YAML 字符串 pub fn escape_yaml_string(s: &str) -> String { if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') { // 使用双引号包裹并转义内部的双引号 format!("\"{}\"", s.replace('"', "\\\"")) } else { s.to_string() } } /// 格式化时间戳 pub fn format_timestamp(ts: i64) -> String { Local .timestamp_millis_opt(ts) .single() .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string()) .unwrap_or_else(|| format!("{}", ts)) } /// 生成消息历史的 Markdown 内容 /// /// 这是一个通用函数,可被 Session 和 Topic 的保存逻辑复用 pub fn generate_messages_markdown(messages: &[crate::bus::ChatMessage]) -> String { let mut output = String::new(); output.push_str("# Message History\n\n"); for (idx, msg) in messages.iter().enumerate() { output.push_str(&format!("## Message {}\n\n", idx + 1)); output.push_str(&format!("**Role:** {}\n\n", msg.role)); output.push_str(&format!("**ID:** {}\n\n", msg.id)); output.push_str(&format!( "**Time:** {}\n\n", format_timestamp(msg.timestamp) )); if let Some(ref ctx) = msg.system_context { output.push_str(&format!("**System Context:** `{}`\n\n", ctx)); } if let Some(ref tool_name) = msg.tool_name { output.push_str(&format!("**Tool Name:** `{}`\n\n", tool_name)); } if let Some(ref tool_call_id) = msg.tool_call_id { output.push_str(&format!("**Tool Call ID:** `{}`\n\n", tool_call_id)); } if let Some(ref reasoning) = msg.reasoning_content { output.push_str("### Reasoning\n\n"); output.push_str("```\n"); output.push_str(reasoning); output.push_str("\n```\n\n"); } // Content output.push_str("### Content\n\n"); if msg.content.is_empty() { output.push_str("*empty*\n\n"); } else { output.push_str(&format!("{}\n\n", format_message_content(&msg.content))); } // Tool calls if let Some(ref calls) = msg.tool_calls { if !calls.is_empty() { output.push_str("### Tool Calls\n\n"); for call in calls { output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id)); output.push_str(" ```json\n"); let args_json = serde_json::to_string_pretty(&call.arguments) .unwrap_or_else(|_| call.arguments.to_string()); for line in args_json.lines() { output.push_str(&format!(" {}\n", line)); } output.push_str(" ```\n"); } output.push('\n'); } } // Media refs if !msg.media_refs.is_empty() { output.push_str("### Media References\n\n"); for media_ref in &msg.media_refs { output.push_str(&format!("- `{}`\n", media_ref)); } output.push('\n'); } output.push_str("---\n\n"); } output } /// 生成系统提示词部分的 Markdown pub fn generate_system_prompt_markdown(system_prompt: &Option) -> String { let mut output = String::new(); output.push_str("# System Prompt\n\n"); if let Some(prompt) = system_prompt { output.push_str("```\n"); output.push_str(&prompt.content); output.push_str("\n```\n\n"); } else { output.push_str("*No system prompt available*\n\n"); } output } /// 解析文件路径 /// /// 如果未提供路径,自动生成基于会话标题和时间戳的文件名, /// 保存到用户主目录下的 .picobot/sessions/ 目录 pub fn resolve_filepath(filepath: Option, record: &SessionRecord) -> PathBuf { match filepath { Some(path) => PathBuf::from(path), None => { // 生成安全标题(替换特殊字符) let safe_title = record .title .replace(' ', "_") .replace('/', "_") .replace('\\', "_") .replace(':', "_") .replace('<', "_") .replace('>', "_") .replace('|', "_") .replace('?', "_") .replace('*', "_") .replace('"', "_"); // 使用标题或 session_id 作为文件名 let base_name = if safe_title.is_empty() { format!("session_{}", &record.id[..8.min(record.id.len())]) } else { safe_title }; // 添加时间戳 let timestamp = Local::now().format("%Y%m%d_%H%M%S"); let filename = format!("{}_{}.md", base_name, timestamp); // 保存到用户主目录下的 .picobot/sessions/ 目录 dirs::home_dir() .unwrap_or_else(|| PathBuf::from(".")) .join(".picobot") .join("sessions") .join(filename) } } } /// InChat 保存会话命令处理器 /// /// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令 pub struct SaveSessionInChatHandler { store: Arc, task_repository: Arc, system_prompt_provider: Arc, } impl SaveSessionInChatHandler { /// 创建新的 InChat 保存会话命令处理器 pub fn new( store: Arc, task_repository: Arc, system_prompt_provider: Arc, ) -> Self { Self { store, task_repository, system_prompt_provider, } } } #[async_trait] impl InChatCommandHandler for SaveSessionInChatHandler { fn can_handle(&self, cmd: &Command) -> bool { matches!(cmd, Command::SaveSession { .. }) } async fn handle( &self, cmd: Command, inbound: &InboundMessage, session_manager: &crate::gateway::session::SessionManager, ) -> Result, AgentError> { let Command::SaveSession { filepath, include_all, include_subagents } = cmd else { return Ok(None); }; // 通过 session_manager 获取 session let session = match session_manager.get(&inbound.channel).await { Some(s) => s, None => { tracing::error!("Session not found for channel: {}", inbound.channel); return Ok(Some("Session not found".to_string())); } }; let session_guard = session.lock().await; let session_id = session_guard.persistent_session_id(&inbound.chat_id); // 调用公共函数 let result = save_session_to_file( &session_id, filepath, include_all, include_subagents, &*self.store, Some(self.task_repository.as_ref()), &*self.system_prompt_provider, ) .await; // 返回成功或失败消息 match result { Ok(output_path) => { let msg = format!("Session saved to: {}", output_path.display()); tracing::info!("{}", msg); Ok(Some(msg)) } Err(error) => { let msg = format!("Failed to save session: {}", error); tracing::error!("{}", msg); Ok(Some(msg)) } } } } #[cfg(test)] mod tests { use super::*; use crate::storage::{SessionRecord, SessionStore}; fn create_test_record(id: &str, title: &str) -> SessionRecord { SessionRecord { id: id.to_string(), title: title.to_string(), channel_name: "cli".to_string(), chat_id: id.to_string(), summary: None, created_at: 1705312800000, // 2024-01-15 10:00:00 updated_at: 1705316400000, // 2024-01-15 11:00:00 last_active_at: 1705316400000, archived_at: None, deleted_at: None, message_count: 0, user_turn_count: 0, agent_prompt_reinjection_count: 0, } } #[test] fn test_resolve_filepath_with_custom_path() { let record = create_test_record("test-123", "My Session"); let path = resolve_filepath(Some("/custom/path/file.md".to_string()), &record); assert_eq!(path, PathBuf::from("/custom/path/file.md")); } #[test] fn test_resolve_filepath_generates_filename_with_title() { let record = create_test_record("test-123", "My Session"); let path = resolve_filepath(None, &record); let filename = path.file_name().unwrap().to_str().unwrap(); assert!(filename.starts_with("My_Session_")); assert!(filename.ends_with(".md")); } #[test] fn test_resolve_filepath_generates_filename_with_id_when_title_empty() { let record = create_test_record("abc12345-xyz", ""); let path = resolve_filepath(None, &record); let filename = path.file_name().unwrap().to_str().unwrap(); assert!(filename.starts_with("session_abc123")); assert!(filename.ends_with(".md")); } #[test] fn test_escape_yaml_string() { assert_eq!(escape_yaml_string("simple"), "simple"); assert_eq!(escape_yaml_string("with: colon"), "\"with: colon\""); assert_eq!(escape_yaml_string("with \"quote\""), "\"with \\\"quote\\\"\""); } #[test] fn test_format_message_content() { // 普通单行文本 - 原样返回 assert_eq!(format_message_content("hello"), "hello"); // 单行包含反引号 - 原样返回(单行不需要包裹) assert_eq!(format_message_content("`code`"), "`code`"); // 包含换行符 - 使用4个反引号包裹(最小) assert_eq!( format_message_content("line1\nline2\nline3"), "````\nline1\nline2\nline3\n````" ); // 包含表格标记 - 使用4个反引号包裹 assert_eq!( format_message_content("| col1 | col2 |"), "````\n| col1 | col2 |\n````" ); // 多行内容包含3个反引号(代码块标记)- 使用4个反引号包裹 assert_eq!( format_message_content("```code```\nmore"), "````\n```code```\nmore\n````" ); // 多行内容包含多行代码块 assert_eq!( format_message_content("```\ncode\n```\nmore"), "````\n```\ncode\n```\nmore\n````" ); // 多行内容包含4个反引号 - 使用5个反引号包裹 assert_eq!( format_message_content("````code````\nmore"), "`````\n````code````\nmore\n`````" ); } #[test] fn test_generate_markdown_structure() { let record = create_test_record("test-123", "Test Session"); let messages = vec![crate::bus::ChatMessage::system("System prompt here")]; let markdown = generate_markdown(&record, &None, &messages); assert!(markdown.contains("---")); assert!(markdown.contains("title:")); assert!(markdown.contains("session_id: test-123")); assert!(markdown.contains("# System Prompt")); assert!(markdown.contains("# Message History")); assert!(markdown.contains("## Message 1")); assert!(markdown.contains("**Role:** system")); } #[test] fn test_can_handle() { let store = Arc::new(SessionStore::in_memory().unwrap()); let task_repository = Arc::new(crate::tools::task::repository::InMemoryTaskRepository::new()); let provider = Arc::new(TestSystemPromptProvider); let handler = SaveSessionCommandHandler::new(store, task_repository, provider); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false })); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false })); assert!(!handler.can_handle(&Command::CreateSession { title: None })); assert!(!handler.can_handle(&Command::SaveTopic { filepath: None, include_subagents: false })); } /// 测试用的系统提示词提供者 struct TestSystemPromptProvider; impl SystemPromptProvider for TestSystemPromptProvider { fn build(&self, _context: &SystemPromptContext) -> Option { Some(SystemPrompt { content: "Test system prompt".to_string(), context: Some("test".to_string()), }) } } }