From 20f32a3f96030c13521ab9f03c2759d58119758c Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 16 May 2026 19:33:42 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E8=AF=9D=E9=A2=98=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=B0=86=E5=BD=93=E5=89=8D=E8=AF=9D=E9=A2=98=E5=86=85=E5=AE=B9?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E4=B8=BA=20Markdown=20=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/command/adapters/channel.rs | 24 ++- src/command/adapters/cli.rs | 73 +++++++-- src/command/handlers/mod.rs | 7 + src/command/handlers/save_session.rs | 91 +++++++---- src/command/handlers/save_topic.rs | 234 +++++++++++++++++++++++++++ src/command/mod.rs | 5 +- src/gateway/processor.rs | 7 + 7 files changed, 385 insertions(+), 56 deletions(-) create mode 100644 src/command/handlers/save_topic.rs diff --git a/src/command/adapters/channel.rs b/src/command/adapters/channel.rs index 44fccf1..1a4cd04 100644 --- a/src/command/adapters/channel.rs +++ b/src/command/adapters/channel.rs @@ -40,26 +40,40 @@ impl InputAdapter for ChannelInputAdapter { })); } - // 解析 /save 命令 + // 解析 /save 命令 - 保存当前话题 if trimmed == "/save" { + return Ok(Some(Command::SaveTopic { + filepath: None, + })); + } + + if let Some(filepath) = trimmed.strip_prefix("/save ") { + let filepath = filepath.trim(); + return Ok(Some(Command::SaveTopic { + filepath: Some(filepath.to_string()), + })); + } + + // 解析 /save-session 命令 - 保存整个会话 + if trimmed == "/save-session" { return Ok(Some(Command::SaveSession { filepath: None, include_all: false, })); } - if let Some(args) = trimmed.strip_prefix("/save ") { + if let Some(args) = trimmed.strip_prefix("/save-session ") { let args = args.trim(); // 解析参数:可能是 "all"、路径、或 "all 路径" let (include_all, filepath) = if args == "all" { - // /save all - 保存全部消息 + // /save-session all - 保存全部消息 (true, None) } else if args.starts_with("all ") { - // /save all - 保存全部消息到指定路径 + // /save-session all - 保存全部消息到指定路径 let path = args[4..].trim(); (true, Some(path.to_string())) } else { - // /save - 保存活跃消息到指定路径 + // /save-session - 保存活跃消息到指定路径 (false, Some(args.to_string())) }; return Ok(Some(Command::SaveSession { filepath, include_all })); diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index f279bff..4a22f2e 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -41,26 +41,40 @@ impl InputAdapter for CliInputAdapter { })); } - // 解析 /save 命令 + // 解析 /save 命令 - 保存当前话题 if trimmed == "/save" { + return Ok(Some(Command::SaveTopic { + filepath: None, + })); + } + + if let Some(filepath) = trimmed.strip_prefix("/save ") { + let filepath = filepath.trim(); + return Ok(Some(Command::SaveTopic { + filepath: Some(filepath.to_string()), + })); + } + + // 解析 /save-session 命令 - 保存整个会话 + if trimmed == "/save-session" { return Ok(Some(Command::SaveSession { filepath: None, include_all: false, })); } - if let Some(args) = trimmed.strip_prefix("/save ") { + if let Some(args) = trimmed.strip_prefix("/save-session ") { let args = args.trim(); // 解析参数:可能是 "all"、路径、或 "all 路径" let (include_all, filepath) = if args == "all" { - // /save all - 保存全部消息 + // /save-session all - 保存全部消息 (true, None) } else if args.starts_with("all ") { - // /save all - 保存全部消息到指定路径 + // /save-session all - 保存全部消息到指定路径 let path = args[4..].trim(); (true, Some(path.to_string())) } else { - // /save - 保存活跃消息到指定路径 + // /save-session - 保存活跃消息到指定路径 (false, Some(args.to_string())) }; return Ok(Some(Command::SaveSession { filepath, include_all })); @@ -223,23 +237,52 @@ mod tests { } #[test] - fn test_cli_input_adapter_save_without_path() { + fn test_cli_input_adapter_save_topic_without_path() { let adapter = CliInputAdapter::new(); let ctx = AdapterContext::new("test"); let result = adapter.try_parse("/save", ctx).unwrap(); + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!(cmd, Command::SaveTopic { filepath: None })); + } + + #[test] + fn test_cli_input_adapter_save_topic_with_path() { + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("test"); + + let result = adapter.try_parse("/save ./debug/topic.md", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveTopic { + filepath: Some(ref p), + } if p == "./debug/topic.md" + )); + } + + #[test] + fn test_cli_input_adapter_save_session_without_path() { + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("test"); + + let result = adapter.try_parse("/save-session", ctx).unwrap(); + assert!(result.is_some()); let cmd = result.unwrap(); assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false })); } #[test] - fn test_cli_input_adapter_save_with_path() { + fn test_cli_input_adapter_save_session_with_path() { let adapter = CliInputAdapter::new(); let ctx = AdapterContext::new("test"); - let result = adapter.try_parse("/save ./debug/session.md", ctx).unwrap(); + let result = adapter.try_parse("/save-session ./debug/session.md", ctx).unwrap(); assert!(result.is_some()); let cmd = result.unwrap(); @@ -253,11 +296,11 @@ mod tests { } #[test] - fn test_cli_input_adapter_save_all() { + fn test_cli_input_adapter_save_session_all() { let adapter = CliInputAdapter::new(); let ctx = AdapterContext::new("test"); - let result = adapter.try_parse("/save all", ctx).unwrap(); + let result = adapter.try_parse("/save-session all", ctx).unwrap(); assert!(result.is_some()); let cmd = result.unwrap(); @@ -265,11 +308,11 @@ mod tests { } #[test] - fn test_cli_input_adapter_save_all_with_path() { + fn test_cli_input_adapter_save_session_all_with_path() { let adapter = CliInputAdapter::new(); let ctx = AdapterContext::new("test"); - let result = adapter.try_parse("/save all ./debug/session.md", ctx).unwrap(); + let result = adapter.try_parse("/save-session all ./debug/session.md", ctx).unwrap(); assert!(result.is_some()); let cmd = result.unwrap(); @@ -287,11 +330,11 @@ mod tests { let adapter = CliOutputAdapter::new(); let request_id = uuid::Uuid::new_v4(); let response = CommandResponse::success(request_id) - .with_message(MessageKind::Notification, "Session saved to: session.md") - .with_metadata("filepath", "session.md"); + .with_message(MessageKind::Notification, "Topic saved to: topic.md") + .with_metadata("filepath", "topic.md"); let output = adapter.adapt(response); - assert!(output.contains("Session saved to: session.md")); + assert!(output.contains("Topic saved to: topic.md")); } } diff --git a/src/command/handlers/mod.rs b/src/command/handlers/mod.rs index 3f4f14f..93a0d6e 100644 --- a/src/command/handlers/mod.rs +++ b/src/command/handlers/mod.rs @@ -1,3 +1,10 @@ pub mod save_session; +pub mod save_topic; pub mod session; pub mod session_query; + +// 导出公共函数供其他模块复用 +pub use save_session::{ + escape_yaml_string, format_message_content, format_timestamp, + generate_messages_markdown, generate_system_prompt_markdown, +}; diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index 2f77192..20397dd 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -237,17 +237,52 @@ pub fn generate_markdown( output.push_str(&format!("message_count: {}\n", messages.len())); output.push_str("---\n\n"); - // 系统提示词 - output.push_str("# System Prompt\n\n"); - if let Some(prompt) = system_prompt { - output.push_str("```\n"); - output.push_str(&prompt.content); - output.push_str("\n```\n\n"); - } else { - output.push_str("*No system prompt available*\n\n"); - } + // 系统提示词(复用公共函数) + output.push_str(&generate_system_prompt_markdown(system_prompt)); + + // 消息历史(复用公共函数) + output.push_str(&generate_messages_markdown(messages)); + + output +} + +/// 格式化消息内容 +/// +/// 如果内容包含特殊字符,使用代码块包装 +pub fn format_message_content(content: &str) -> String { + // 如果内容包含代码块标记或表格标记,使用原始格式 + if content.contains("```") || content.contains("| ") { + format!("```\n{}\n```", content) + } else { + content.to_string() + } +} + +/// 转义 YAML 字符串 +pub fn escape_yaml_string(s: &str) -> String { + if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') { + // 使用双引号包裹并转义内部的双引号 + format!("\"{}\"", s.replace('"', "\\\"")) + } else { + s.to_string() + } +} + +/// 格式化时间戳 +pub fn format_timestamp(ts: i64) -> String { + Local + .timestamp_millis_opt(ts) + .single() + .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string()) + .unwrap_or_else(|| format!("{}", ts)) +} + +/// 生成消息历史的 Markdown 内容 +/// +/// 这是一个通用函数,可被 Session 和 Topic 的保存逻辑复用 +pub fn generate_messages_markdown(messages: &[crate::bus::ChatMessage]) -> String { + let mut output = String::new(); - // 消息历史 output.push_str("# Message History\n\n"); for (idx, msg) in messages.iter().enumerate() { @@ -319,35 +354,20 @@ pub fn generate_markdown( output } -/// 格式化消息内容 -/// -/// 如果内容包含特殊字符,使用代码块包装 -fn format_message_content(content: &str) -> String { - // 如果内容包含代码块标记或表格标记,使用原始格式 - if content.contains("```") || content.contains("| ") { - format!("```\n{}\n```", content) - } else { - content.to_string() - } -} +/// 生成系统提示词部分的 Markdown +pub fn generate_system_prompt_markdown(system_prompt: &Option) -> String { + let mut output = String::new(); -/// 转义 YAML 字符串 -fn escape_yaml_string(s: &str) -> String { - if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') { - // 使用双引号包裹并转义内部的双引号 - format!("\"{}\"", s.replace('"', "\\\"")) + output.push_str("# System Prompt\n\n"); + if let Some(prompt) = system_prompt { + output.push_str("```\n"); + output.push_str(&prompt.content); + output.push_str("\n```\n\n"); } else { - s.to_string() + output.push_str("*No system prompt available*\n\n"); } -} -/// 格式化时间戳 -pub fn format_timestamp(ts: i64) -> String { - Local - .timestamp_millis_opt(ts) - .single() - .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string()) - .unwrap_or_else(|| format!("{}", ts)) + output } /// 解析文件路径 @@ -557,6 +577,7 @@ mod tests { assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false })); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true })); assert!(!handler.can_handle(&Command::CreateSession { title: None })); + assert!(!handler.can_handle(&Command::SaveTopic { filepath: None })); } /// 测试用的系统提示词提供者 diff --git a/src/command/handlers/save_topic.rs b/src/command/handlers/save_topic.rs new file mode 100644 index 0000000..2d0fa7a --- /dev/null +++ b/src/command/handlers/save_topic.rs @@ -0,0 +1,234 @@ +use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; +use crate::command::context::CommandContext; +use crate::command::handler::CommandHandler; +use crate::command::handlers::{ + escape_yaml_string, format_timestamp, generate_messages_markdown, + generate_system_prompt_markdown, +}; +use crate::command::response::{CommandError, CommandResponse, MessageKind}; +use crate::command::Command; +use crate::storage::{SessionStore, TopicRecord}; +use async_trait::async_trait; +use chrono::Local; +use std::path::PathBuf; +use std::sync::Arc; + +/// 保存话题到文件 +pub async fn save_topic_to_file( + topic_id: &str, + filepath: Option, + store: &SessionStore, + system_prompt_provider: &dyn SystemPromptProvider, +) -> Result { + // 获取话题记录 + let topic = store + .get_topic(topic_id) + .map_err(|e| format!("Failed to get topic: {}", e))? + .ok_or_else(|| "Topic not found".to_string())?; + + // 加载话题消息 + let messages = store + .load_messages_for_topic(topic_id) + .map_err(|e| format!("Failed to load messages: {}", e))?; + + // 获取 session 信息(用于系统提示词) + let session = store + .get_session(&topic.session_id) + .map_err(|e| format!("Failed to get session: {}", e))?; + + // 构建系统提示词 + let user_message_count = messages.iter().filter(|m| m.role == "user").count(); + let system_prompt = build_system_prompt(system_prompt_provider, &session, user_message_count); + + // 生成 Markdown 内容 + let markdown = generate_topic_markdown(&topic, &system_prompt, &messages); + + // 确定输出路径 + let output_path = resolve_topic_filepath(filepath, &topic); + + // 创建父目录 + if let Some(parent) = output_path.parent() { + if !parent.as_os_str().is_empty() && !parent.exists() { + std::fs::create_dir_all(parent) + .map_err(|e| format!("Failed to create directory: {}", e))?; + } + } + + // 写入文件 + std::fs::write(&output_path, markdown) + .map_err(|e| format!("Failed to write file: {}", e))?; + + Ok(output_path) +} + +/// 构建系统提示词 +fn build_system_prompt( + provider: &dyn SystemPromptProvider, + session: &Option, + user_message_count: usize, +) -> Option { + let session = session.as_ref()?; + let context = SystemPromptContext { + session_id: Some(session.id.clone()), + chat_id: session.chat_id.clone(), + user_message_count, + }; + provider.build(&context) +} + +/// 生成话题 Markdown 内容(复用公共函数) +fn generate_topic_markdown( + topic: &TopicRecord, + system_prompt: &Option, + messages: &[crate::bus::ChatMessage], +) -> String { + let mut output = String::new(); + + // YAML frontmatter(Topic 特有) + output.push_str("---\n"); + output.push_str(&format!("title: {}\n", escape_yaml_string(&topic.title))); + output.push_str(&format!("topic_id: {}\n", topic.id)); + output.push_str(&format!("session_id: {}\n", topic.session_id)); + if let Some(ref desc) = topic.description { + output.push_str(&format!("description: {}\n", escape_yaml_string(desc))); + } + output.push_str(&format!( + "created_at: {}\n", + format_timestamp(topic.created_at) + )); + output.push_str(&format!( + "updated_at: {}\n", + format_timestamp(topic.updated_at) + )); + output.push_str(&format!( + "last_active_at: {}\n", + format_timestamp(topic.last_active_at) + )); + output.push_str(&format!("message_count: {}\n", messages.len())); + output.push_str("---\n\n"); + + // 系统提示词(复用公共函数) + output.push_str(&generate_system_prompt_markdown(system_prompt)); + + // 消息历史(复用公共函数) + output.push_str(&generate_messages_markdown(messages)); + + output +} + +/// 解析话题文件路径(Topic 特有) +fn resolve_topic_filepath(filepath: Option, topic: &TopicRecord) -> PathBuf { + match filepath { + Some(path) => PathBuf::from(path), + None => { + let safe_title = topic + .title + .replace(' ', "_") + .replace('/', "_") + .replace('\\', "_") + .replace(':', "_") + .replace('<', "_") + .replace('>', "_") + .replace('|', "_") + .replace('?', "_") + .replace('*', "_") + .replace('"', "_"); + + let base_name = if safe_title.is_empty() { + format!("topic_{}", &topic.id[..8.min(topic.id.len())]) + } else { + safe_title + }; + + let timestamp = Local::now().format("%Y%m%d_%H%M%S"); + let filename = format!("{}_{}.md", base_name, timestamp); + + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".picobot") + .join("topics") + .join(filename) + } + } +} + +/// 保存话题命令处理器 +pub struct SaveTopicCommandHandler { + store: Arc, + system_prompt_provider: Arc, +} + +impl SaveTopicCommandHandler { + pub fn new( + store: Arc, + system_prompt_provider: Arc, + ) -> Self { + Self { + store, + system_prompt_provider, + } + } +} + +#[async_trait] +impl CommandHandler for SaveTopicCommandHandler { + fn can_handle(&self, cmd: &Command) -> bool { + matches!(cmd, Command::SaveTopic { .. }) + } + + async fn handle( + &self, + cmd: Command, + ctx: CommandContext, + ) -> Result { + match cmd { + Command::SaveTopic { filepath } => handle_save_topic(self, filepath, ctx).await, + _ => unreachable!(), + } + } +} + +async fn handle_save_topic( + handler: &SaveTopicCommandHandler, + filepath: Option, + ctx: CommandContext, +) -> Result { + tracing::debug!( + ctx_topic_id = ?ctx.topic_id, + ctx_session_id = ?ctx.session_id, + channel = %ctx.channel_name, + "SaveTopic command received" + ); + + let topic_id = ctx + .topic_id + .as_deref() + .ok_or_else(|| CommandError::new("NO_TOPIC", "No active topic".to_string()))?; + + tracing::debug!(topic_id = %topic_id, "Attempting to save topic"); + + // 调用保存函数 + let output_path = save_topic_to_file( + topic_id, + filepath, + &*handler.store, + &*handler.system_prompt_provider, + ) + .await + .map_err(|e| CommandError::new("SAVE_ERROR", e))?; + + // 获取消息数量 + let message_count = handler + .store + .load_messages_for_topic(topic_id) + .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))? + .len(); + + Ok(CommandResponse::success(ctx.request_id) + .with_message( + MessageKind::Notification, + &format!("Topic saved to: {}", output_path.display()), + ) + .with_metadata("filepath", output_path.to_string_lossy().as_ref()) + .with_metadata("message_count", &message_count.to_string())) +} \ No newline at end of file diff --git a/src/command/mod.rs b/src/command/mod.rs index f2867d1..e1a2b54 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -13,7 +13,9 @@ use serde::{Deserialize, Serialize}; pub enum Command { /// 创建新话题(在同一个 Session 内) CreateSession { title: Option }, - /// 保存话题内容到 Markdown 文件 + /// 保存当前话题内容到 Markdown 文件 + SaveTopic { filepath: Option }, + /// 保存会话内容到 Markdown 文件 SaveSession { filepath: Option, include_all: bool, @@ -33,6 +35,7 @@ impl Command { pub fn name(&self) -> &'static str { match self { Command::CreateSession { .. } => "create_session", + Command::SaveTopic { .. } => "save_topic", Command::SaveSession { .. } => "save_session", Command::ListSessions { .. } => "list_sessions", Command::LoadSession { .. } => "load_session", diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index d6f9aed..abc63ae 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -8,6 +8,7 @@ use crate::command::adapter::InputAdapter; use crate::command::adapters::channel::ChannelInputAdapter; use crate::command::handler::CommandRouter; use crate::command::handlers::save_session::SaveSessionCommandHandler; +use crate::command::handlers::save_topic::SaveTopicCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::command::handlers::session_query::SessionQueryCommandHandler; use crate::config::LLMProviderConfig; @@ -60,6 +61,12 @@ impl InboundProcessor { Box::new(SkillPromptProvider::new(skills)), ])); command_router.register(Box::new(SaveSessionCommandHandler::new( + store.clone(), + system_prompt_provider.clone(), + ))); + + // 注册 save_topic 处理器 + command_router.register(Box::new(SaveTopicCommandHandler::new( store, system_prompt_provider, )));