From 35b9c42d073a9078dee743849f92f04d6dfb7eb0 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Wed, 13 May 2026 21:46:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20/save=20=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E4=BB=A5=E4=BF=9D=E5=AD=98=E4=BC=9A=E8=AF=9D=E5=86=85?= =?UTF-8?q?=E5=AE=B9=E5=88=B0=20Markdown=20=E6=96=87=E4=BB=B6=EF=BC=9B?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20SaveSessionCommandHandler=20=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cli/input.rs | 10 + src/client/mod.rs | 15 +- src/command/adapters/cli.rs | 54 ++++ src/command/handlers/mod.rs | 1 + src/command/handlers/save_session.rs | 433 +++++++++++++++++++++++++++ src/command/handlers/session.rs | 1 + src/command/mod.rs | 5 +- src/gateway/ws.rs | 51 ++++ src/protocol/mod.rs | 9 + 9 files changed, 577 insertions(+), 2 deletions(-) create mode 100644 src/command/handlers/save_session.rs diff --git a/src/cli/input.rs b/src/cli/input.rs index 10bf07b..7a0820d 100644 --- a/src/cli/input.rs +++ b/src/cli/input.rs @@ -12,6 +12,7 @@ pub enum InputCommand { Exit, Clear, New(Option), + Save(Option), Sessions, Use(String), Rename(String), @@ -75,6 +76,7 @@ impl InputHandler { "/quit" | "/exit" | "/q" => Some(InputCommand::Exit), "/clear" => Some(InputCommand::Clear), "/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))), + "/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))), "/sessions" => Some(InputCommand::Sessions), "/use" => arg.map(|value| InputCommand::Use(value.to_string())), "/rename" => arg.map(|value| InputCommand::Rename(value.to_string())), @@ -130,6 +132,14 @@ mod tests { handler.handle_special_commands("/new planning"), Some(InputCommand::New(Some("planning".to_string()))) ); + assert_eq!( + handler.handle_special_commands("/save"), + Some(InputCommand::Save(None)) + ); + assert_eq!( + handler.handle_special_commands("/save ./debug/session.md"), + Some(InputCommand::Save(Some("./debug/session.md".to_string()))) + ); assert_eq!( handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions) diff --git a/src/client/mod.rs b/src/client/mod.rs index e16aef0..a86b259 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -54,7 +54,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { let mut input = InputHandler::new(); let mut current_session_id: Option = None; - input.write_output("picobot CLI - Commands: /new [title], /reset, /sessions, /use , /rename , /archive, /delete, /clear, /quit\n").await?; + input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?; // Main loop: poll both stdin and WebSocket loop { @@ -114,6 +114,9 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { WsOutbound::HistoryCleared { session_id } => { input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?; } + WsOutbound::SessionSaved { session_id, filepath } => { + input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?; + } _ => {} } } @@ -225,6 +228,16 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> { } continue; } + InputEvent::Command(InputCommand::Save(filepath)) => { + let inbound = WsInbound::SaveSession { + filepath, + session_id: current_session_id.clone(), + }; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } InputEvent::Message(msg) => { let inbound = WsInbound::UserInput { content: msg.content, diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 4d7adc8..9c256d8 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -41,6 +41,18 @@ impl InputAdapter for CliInputAdapter { })); } + // 解析 /save 命令 + if trimmed == "/save" { + return Ok(Some(Command::SaveSession { filepath: None })); + } + + if let Some(path) = trimmed.strip_prefix("/save ") { + let path = path.trim(); + return Ok(Some(Command::SaveSession { + filepath: Some(path.to_string()), + })); + } + // 不是命令,返回 None Ok(None) } @@ -170,4 +182,46 @@ mod tests { assert!(output.contains("Error [TEST_ERROR]")); assert!(output.contains("something failed")); } + + #[test] + fn test_cli_input_adapter_save_without_path() { + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("test"); + + let result = adapter.try_parse("/save", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!(cmd, Command::SaveSession { filepath: None })); + } + + #[test] + fn test_cli_input_adapter_save_with_path() { + let adapter = CliInputAdapter::new(); + let ctx = AdapterContext::new("test"); + + let result = adapter.try_parse("/save ./debug/session.md", ctx).unwrap(); + + assert!(result.is_some()); + let cmd = result.unwrap(); + assert!(matches!( + cmd, + Command::SaveSession { + filepath: Some(ref p) + } if p == "./debug/session.md" + )); + } + + #[test] + fn test_cli_output_adapter_save_success() { + let adapter = CliOutputAdapter::new(); + let request_id = uuid::Uuid::new_v4(); + let response = CommandResponse::success(request_id) + .with_message(MessageKind::Notification, "Session saved to: session.md") + .with_metadata("filepath", "session.md"); + + let output = adapter.adapt(response); + + assert!(output.contains("Session saved to: session.md")); + } } diff --git a/src/command/handlers/mod.rs b/src/command/handlers/mod.rs index f52f1c4..0df37c6 100644 --- a/src/command/handlers/mod.rs +++ b/src/command/handlers/mod.rs @@ -1 +1,2 @@ +pub mod save_session; pub mod session; diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs new file mode 100644 index 0000000..5d038a8 --- /dev/null +++ b/src/command/handlers/save_session.rs @@ -0,0 +1,433 @@ +use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; +use crate::command::context::CommandContext; +use crate::command::handler::CommandHandler; +use crate::command::response::{CommandError, CommandResponse, MessageKind}; +use crate::command::Command; +use crate::config::LLMProviderConfig; +use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider; +use crate::storage::{SessionRecord, SessionStore}; +use async_trait::async_trait; +use chrono::{Local, TimeZone}; +use std::path::PathBuf; +use std::sync::Arc; + +/// 保存会话命令处理器 +/// +/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 +pub struct SaveSessionCommandHandler { + store: Arc<SessionStore>, + provider_config: LLMProviderConfig, +} + +impl SaveSessionCommandHandler { + /// 创建新的保存会话命令处理器 + /// + /// # Arguments + /// * `store` - 会话存储 + /// * `provider_config` - LLM 提供者配置(用于构建系统提示词) + pub fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self { + Self { + store, + provider_config, + } + } + + /// 从会话记录获取存储(用于测试) + #[cfg(test)] + fn store(&self) -> &Arc<SessionStore> { + &self.store + } +} + +#[async_trait] +impl CommandHandler for SaveSessionCommandHandler { + fn can_handle(&self, cmd: &Command) -> bool { + matches!(cmd, Command::SaveSession { .. }) + } + + async fn handle( + &self, + cmd: Command, + ctx: CommandContext, + ) -> Result<CommandResponse, CommandError> { + match cmd { + Command::SaveSession { filepath } => { + handle_save_session(self, filepath, ctx).await + } + _ => unreachable!(), + } + } +} + +/// 处理保存会话命令 +async fn handle_save_session( + handler: &SaveSessionCommandHandler, + filepath: Option<String>, + ctx: CommandContext, +) -> Result<CommandResponse, CommandError> { + let session_id = ctx + .session_id + .as_deref() + .ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?; + + // 获取会话记录 + let record = handler + .store + .get_session(session_id) + .map_err(|e| CommandError::new("SESSION_ERROR", e.to_string()))? + .ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", "Session not found".to_string()))?; + + // 获取所有消息(包括历史) + let messages = handler + .store + .load_all_messages(session_id) + .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?; + + // 计算用户消息数(用于系统提示词构建) + let user_message_count = messages.iter().filter(|m| m.role == "user").count(); + + // 构建系统提示词 + let system_prompt = build_system_prompt(&handler.provider_config, &record, user_message_count); + + // 生成 Markdown 内容 + let markdown = generate_markdown(&record, &system_prompt, &messages); + + // 确定输出路径 + let output_path = resolve_filepath(filepath, &record); + + // 创建父目录 + if let Some(parent) = output_path.parent() { + if !parent.as_os_str().is_empty() && !parent.exists() { + std::fs::create_dir_all(parent).map_err(|e| { + CommandError::new( + "CREATE_DIR_ERROR", + format!("Failed to create directory: {}", e), + ) + })?; + } + } + + // 写入文件 + std::fs::write(&output_path, markdown).map_err(|e| { + CommandError::new( + "WRITE_FILE_ERROR", + format!("Failed to write file: {}", e), + ) + })?; + + Ok(CommandResponse::success(ctx.request_id) + .with_message( + MessageKind::Notification, + &format!("Session saved to: {}", output_path.display()), + ) + .with_metadata("filepath", output_path.to_string_lossy().as_ref()) + .with_metadata("message_count", &messages.len().to_string())) +} + +/// 构建系统提示词 +fn build_system_prompt( + provider_config: &LLMProviderConfig, + record: &SessionRecord, + user_message_count: usize, +) -> Option<SystemPrompt> { + let provider = SimpleAgentPromptProvider::new(provider_config.clone()); + let context = SystemPromptContext { + session_id: Some(record.id.clone()), + chat_id: record.chat_id.clone(), + user_message_count, + }; + provider.build(&context) +} + +/// 生成 Markdown 内容 +fn generate_markdown( + record: &SessionRecord, + system_prompt: &Option<SystemPrompt>, + messages: &[crate::bus::ChatMessage], +) -> String { + let mut output = String::new(); + + // YAML frontmatter + output.push_str("---\n"); + output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title))); + output.push_str(&format!("session_id: {}\n", record.id)); + output.push_str(&format!("channel: {}\n", record.channel_name)); + output.push_str(&format!("chat_id: {}\n", record.chat_id)); + output.push_str(&format!( + "created_at: {}\n", + format_timestamp(record.created_at) + )); + output.push_str(&format!( + "updated_at: {}\n", + format_timestamp(record.updated_at) + )); + output.push_str(&format!( + "last_active_at: {}\n", + format_timestamp(record.last_active_at) + )); + output.push_str(&format!("message_count: {}\n", messages.len())); + output.push_str("---\n\n"); + + // 系统提示词 + output.push_str("# System Prompt\n\n"); + if let Some(prompt) = system_prompt { + output.push_str("```\n"); + output.push_str(&prompt.content); + output.push_str("\n```\n\n"); + } else { + output.push_str("*No system prompt available*\n\n"); + } + + // 消息历史 + output.push_str("# Message History\n\n"); + + for (idx, msg) in messages.iter().enumerate() { + output.push_str(&format!("## Message {}\n\n", idx + 1)); + output.push_str(&format!("**Role:** {}\n\n", msg.role)); + output.push_str(&format!("**ID:** {}\n\n", msg.id)); + output.push_str(&format!( + "**Time:** {}\n\n", + format_timestamp(msg.timestamp) + )); + + if let Some(ref ctx) = msg.system_context { + output.push_str(&format!("**System Context:** `{}`\n\n", ctx)); + } + + if let Some(ref tool_name) = msg.tool_name { + output.push_str(&format!("**Tool Name:** `{}`\n\n", tool_name)); + } + + if let Some(ref tool_call_id) = msg.tool_call_id { + output.push_str(&format!("**Tool Call ID:** `{}`\n\n", tool_call_id)); + } + + if let Some(ref reasoning) = msg.reasoning_content { + output.push_str("### Reasoning\n\n"); + output.push_str("```\n"); + output.push_str(reasoning); + output.push_str("\n```\n\n"); + } + + // Content + output.push_str("### Content\n\n"); + if msg.content.is_empty() { + output.push_str("*empty*\n\n"); + } else { + output.push_str(&format!("{}\n\n", format_message_content(&msg.content))); + } + + // Tool calls + if let Some(ref calls) = msg.tool_calls { + if !calls.is_empty() { + output.push_str("### Tool Calls\n\n"); + for call in calls { + output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id)); + output.push_str(" ```json\n"); + let args_json = serde_json::to_string_pretty(&call.arguments) + .unwrap_or_else(|_| call.arguments.to_string()); + for line in args_json.lines() { + output.push_str(&format!(" {}\n", line)); + } + output.push_str(" ```\n"); + } + output.push('\n'); + } + } + + // Media refs + if !msg.media_refs.is_empty() { + output.push_str("### Media References\n\n"); + for media_ref in &msg.media_refs { + output.push_str(&format!("- `{}`\n", media_ref)); + } + output.push('\n'); + } + + output.push_str("---\n\n"); + } + + output +} + +/// 格式化消息内容 +/// +/// 如果内容包含特殊字符,使用代码块包装 +fn format_message_content(content: &str) -> String { + // 如果内容包含代码块标记或表格标记,使用原始格式 + if content.contains("```") || content.contains("| ") { + format!("```\n{}\n```", content) + } else { + content.to_string() + } +} + +/// 转义 YAML 字符串 +fn escape_yaml_string(s: &str) -> String { + if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') { + // 使用双引号包裹并转义内部的双引号 + format!("\"{}\"", s.replace('"', "\\\"")) + } else { + s.to_string() + } +} + +/// 格式化时间戳 +fn format_timestamp(ts: i64) -> String { + Local + .timestamp_millis_opt(ts) + .single() + .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string()) + .unwrap_or_else(|| format!("{}", ts)) +} + +/// 解析文件路径 +/// +/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名 +fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf { + match filepath { + Some(path) => PathBuf::from(path), + None => { + // 生成安全标题(替换特殊字符) + let safe_title = record + .title + .replace(' ', "_") + .replace('/', "_") + .replace('\\', "_") + .replace(':', "_") + .replace('<', "_") + .replace('>', "_") + .replace('|', "_") + .replace('?', "_") + .replace('*', "_") + .replace('"', "_"); + + // 使用标题或 session_id 作为文件名 + let base_name = if safe_title.is_empty() { + format!("session_{}", &record.id[..8.min(record.id.len())]) + } else { + safe_title + }; + + // 添加时间戳 + let timestamp = Local::now().format("%Y%m%d_%H%M%S"); + let filename = format!("{}_{}.md", base_name, timestamp); + + PathBuf::from(filename) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::{SessionRecord, SessionStore}; + use std::collections::HashMap; + + fn test_config() -> LLMProviderConfig { + LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test".to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, + model_id: "test-model".to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + context_window_tokens: None, + tool_result_max_chars: 20_000, + context_tool_result_trim_chars: 20_000, + model_extra: HashMap::new(), + max_tool_iterations: 1, + } + } + + fn create_test_record(id: &str, title: &str) -> SessionRecord { + SessionRecord { + id: id.to_string(), + title: title.to_string(), + channel_name: "cli".to_string(), + chat_id: id.to_string(), + summary: None, + created_at: 1705312800000, // 2024-01-15 10:00:00 + updated_at: 1705316400000, // 2024-01-15 11:00:00 + last_active_at: 1705316400000, + archived_at: None, + deleted_at: None, + message_count: 0, + reset_cutoff_seq: 0, + user_turn_count: 0, + agent_prompt_reinjection_count: 0, + } + } + + #[test] + fn test_resolve_filepath_with_custom_path() { + let record = create_test_record("test-123", "My Session"); + let path = resolve_filepath(Some("/custom/path/file.md".to_string()), &record); + assert_eq!(path, PathBuf::from("/custom/path/file.md")); + } + + #[test] + fn test_resolve_filepath_generates_filename_with_title() { + let record = create_test_record("test-123", "My Session"); + let path = resolve_filepath(None, &record); + let filename = path.file_name().unwrap().to_str().unwrap(); + + assert!(filename.starts_with("My_Session_")); + assert!(filename.ends_with(".md")); + } + + #[test] + fn test_resolve_filepath_generates_filename_with_id_when_title_empty() { + let record = create_test_record("abc12345-xyz", ""); + let path = resolve_filepath(None, &record); + let filename = path.file_name().unwrap().to_str().unwrap(); + + assert!(filename.starts_with("session_abc123")); + assert!(filename.ends_with(".md")); + } + + #[test] + fn test_escape_yaml_string() { + assert_eq!(escape_yaml_string("simple"), "simple"); + assert_eq!(escape_yaml_string("with: colon"), "\"with: colon\""); + assert_eq!(escape_yaml_string("with \"quote\""), "\"with \\\"quote\\\"\""); + } + + #[test] + fn test_format_message_content() { + assert_eq!(format_message_content("hello"), "hello"); + assert_eq!( + format_message_content("```code```"), + "```\n```code```\n```" + ); + } + + #[test] + fn test_generate_markdown_structure() { + let record = create_test_record("test-123", "Test Session"); + let messages = vec![crate::bus::ChatMessage::system("System prompt here")]; + + let markdown = generate_markdown(&record, &None, &messages); + + assert!(markdown.contains("---")); + assert!(markdown.contains("title:")); + assert!(markdown.contains("session_id: test-123")); + assert!(markdown.contains("# System Prompt")); + assert!(markdown.contains("# Message History")); + assert!(markdown.contains("## Message 1")); + assert!(markdown.contains("**Role:** system")); + } + + #[test] + fn test_can_handle() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let handler = SaveSessionCommandHandler::new(store, test_config()); + + assert!(handler.can_handle(&Command::SaveSession { filepath: None })); + assert!(!handler.can_handle(&Command::CreateSession { title: None })); + } +} diff --git a/src/command/handlers/session.rs b/src/command/handlers/session.rs index 9a32d47..0e44ce0 100644 --- a/src/command/handlers/session.rs +++ b/src/command/handlers/session.rs @@ -35,6 +35,7 @@ impl CommandHandler for SessionCommandHandler { ) -> Result<CommandResponse, CommandError> { match cmd { Command::CreateSession { title } => handle_create_session(self, title, ctx).await, + Command::SaveSession { .. } => unreachable!("SaveSession should be handled by SaveSessionCommandHandler"), } } } diff --git a/src/command/mod.rs b/src/command/mod.rs index 2faa02b..e4b02fb 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -11,8 +11,10 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Command { - /// 目前仅实现 /new 命令 + /// 创建新会话 CreateSession { title: Option<String> }, + /// 保存会话内容到 Markdown 文件 + SaveSession { filepath: Option<String> }, } impl Command { @@ -20,6 +22,7 @@ impl Command { pub fn name(&self) -> &'static str { match self { Command::CreateSession { .. } => "create_session", + Command::SaveSession { .. } => "save_session", } } } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 8faeb70..9d9839d 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -5,6 +5,7 @@ use crate::command::adapter::OutputAdapter; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; +use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use axum::extract::State; @@ -348,6 +349,56 @@ async fn handle_inbound( Ok(()) } + WsInbound::SaveSession { filepath, session_id } => { + let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone()); + + // 获取所需依赖 + let store = state.session_manager.store(); + let provider_config = state.config.get_provider_config("default") + .map_err(|e| AgentError::Other(e.to_string()))?; + + // 构建处理器 + let handler = SaveSessionCommandHandler::new(store, provider_config); + let router = { + let mut r = CommandRouter::new(); + r.register(Box::new(handler)); + r + }; + + // 构建命令 + let cmd = crate::command::Command::SaveSession { filepath }; + let cmd_ctx = CommandContext::new("websocket") + .with_session_id(&target_session_id); + + // 执行命令 + let response = router.dispatch_with_response(cmd, cmd_ctx).await; + + // 处理响应 + if response.success { + let filepath = response + .metadata + .get("filepath") + .cloned() + .unwrap_or_default(); + let _ = sender + .send(WsOutbound::SessionSaved { + session_id: target_session_id, + filepath, + }) + .await; + } else { + let error = response.error.unwrap_or_else(|| { + crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error") + }); + let _ = sender + .send(WsOutbound::Error { + code: error.code, + message: error.message, + }) + .await; + } + Ok(()) + } WsInbound::Ping => { let _ = sender.send(WsOutbound::Pong).await; Ok(()) diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 3382d12..baad9d2 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -62,6 +62,13 @@ pub enum WsInbound { #[serde(default, skip_serializing_if = "Option::is_none")] session_id: Option<String>, }, + #[serde(rename = "save_session")] + SaveSession { + #[serde(default, skip_serializing_if = "Option::is_none")] + filepath: Option<String>, + #[serde(default, skip_serializing_if = "Option::is_none")] + session_id: Option<String>, + }, #[serde(rename = "ping")] Ping, } @@ -127,6 +134,8 @@ pub enum WsOutbound { SessionDeleted { session_id: String }, #[serde(rename = "history_cleared")] HistoryCleared { session_id: String }, + #[serde(rename = "session_saved")] + SessionSaved { session_id: String, filepath: String }, #[serde(rename = "pong")] Pong, }