PicoBot/src/command/handlers/save_session.rs

546 lines
18 KiB
Rust

use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::bus::InboundMessage;
use crate::command::context::CommandContext;
use crate::command::handler::{CommandHandler, InChatCommandHandler};
use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command;
use crate::storage::{SessionRecord, SessionStore};
use crate::agent::AgentError;
use async_trait::async_trait;
use chrono::{Local, TimeZone};
use std::path::PathBuf;
use std::sync::Arc;
/// 保存会话到文件(公共函数,可被命令处理器和其他模块复用)
///
/// # Arguments
/// * `session_id` - 会话ID
/// * `filepath` - 可选的文件路径
/// * `include_all` - 是否包含 cutoff 之前的所有消息
/// * `store` - 会话存储
/// * `provider_config` - LLM提供者配置
///
/// # Returns
/// 返回保存的文件路径
pub async fn save_session_to_file(
session_id: &str,
filepath: Option<String>,
include_all: bool,
store: &SessionStore,
system_prompt_provider: &dyn SystemPromptProvider,
) -> Result<PathBuf, String> {
// 获取会话记录
let record = store
.get_session(session_id)
.map_err(|e| format!("Failed to get session: {}", e))?
.ok_or_else(|| "Session not found".to_string())?;
// 根据 include_all 决定加载消息范围
let messages = if include_all {
store
.load_all_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?
} else {
store
.load_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?
};
// 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
// 构建系统提示词(使用外部传入的提供者)
let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count);
// 生成 Markdown 内容
let markdown = generate_markdown(&record, &system_prompt, &messages);
// 确定输出路径
let output_path = resolve_filepath(filepath, &record);
// 创建父目录
if let Some(parent) = output_path.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
}
// 写入文件
std::fs::write(&output_path, markdown)
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok(output_path)
}
/// 保存会话命令处理器
///
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
pub struct SaveSessionCommandHandler {
store: Arc<SessionStore>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
}
impl SaveSessionCommandHandler {
/// 创建新的保存会话命令处理器
///
/// # Arguments
/// * `store` - 会话存储
/// * `system_prompt_provider` - 系统提示词提供者(负责构建完整的系统提示词)
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self {
Self {
store,
system_prompt_provider,
}
}
/// 从会话记录获取存储(用于测试)
#[cfg(test)]
fn store(&self) -> &Arc<SessionStore> {
&self.store
}
}
#[async_trait]
impl CommandHandler for SaveSessionCommandHandler {
fn can_handle(&self, cmd: &Command) -> bool {
matches!(cmd, Command::SaveSession { .. })
}
async fn handle(
&self,
cmd: Command,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
match cmd {
Command::SaveSession { filepath, include_all } => {
handle_save_session(self, filepath, include_all, ctx).await
}
_ => unreachable!(),
}
}
}
/// 处理保存会话命令
async fn handle_save_session(
handler: &SaveSessionCommandHandler,
filepath: Option<String>,
include_all: bool,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
let session_id = ctx
.session_id
.as_deref()
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?;
// 调用公共函数
let output_path = save_session_to_file(
session_id,
filepath,
include_all,
&*handler.store,
&*handler.system_prompt_provider,
)
.await
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
// 根据 include_all 获取消息数量
let message_count = if include_all {
handler
.store
.load_all_messages(session_id)
} else {
handler
.store
.load_messages(session_id)
}
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?
.len();
Ok(CommandResponse::success(ctx.request_id)
.with_message(
MessageKind::Notification,
&format!("Session saved to: {}", output_path.display()),
)
.with_metadata("filepath", output_path.to_string_lossy().as_ref())
.with_metadata("message_count", &message_count.to_string()))
}
/// 构建系统提示词
fn build_system_prompt(
provider: &dyn SystemPromptProvider,
record: &SessionRecord,
user_message_count: usize,
) -> Option<SystemPrompt> {
let context = SystemPromptContext {
session_id: Some(record.id.clone()),
chat_id: record.chat_id.clone(),
user_message_count,
};
provider.build(&context)
}
/// 生成 Markdown 内容
pub fn generate_markdown(
record: &SessionRecord,
system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage],
) -> String {
let mut output = String::new();
// YAML frontmatter
output.push_str("---\n");
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
output.push_str(&format!("session_id: {}\n", record.id));
output.push_str(&format!("channel: {}\n", record.channel_name));
output.push_str(&format!("chat_id: {}\n", record.chat_id));
output.push_str(&format!(
"created_at: {}\n",
format_timestamp(record.created_at)
));
output.push_str(&format!(
"updated_at: {}\n",
format_timestamp(record.updated_at)
));
output.push_str(&format!(
"last_active_at: {}\n",
format_timestamp(record.last_active_at)
));
output.push_str(&format!("message_count: {}\n", messages.len()));
output.push_str("---\n\n");
// 系统提示词
output.push_str("# System Prompt\n\n");
if let Some(prompt) = system_prompt {
output.push_str("```\n");
output.push_str(&prompt.content);
output.push_str("\n```\n\n");
} else {
output.push_str("*No system prompt available*\n\n");
}
// 消息历史
output.push_str("# Message History\n\n");
for (idx, msg) in messages.iter().enumerate() {
output.push_str(&format!("## Message {}\n\n", idx + 1));
output.push_str(&format!("**Role:** {}\n\n", msg.role));
output.push_str(&format!("**ID:** {}\n\n", msg.id));
output.push_str(&format!(
"**Time:** {}\n\n",
format_timestamp(msg.timestamp)
));
if let Some(ref ctx) = msg.system_context {
output.push_str(&format!("**System Context:** `{}`\n\n", ctx));
}
if let Some(ref tool_name) = msg.tool_name {
output.push_str(&format!("**Tool Name:** `{}`\n\n", tool_name));
}
if let Some(ref tool_call_id) = msg.tool_call_id {
output.push_str(&format!("**Tool Call ID:** `{}`\n\n", tool_call_id));
}
if let Some(ref reasoning) = msg.reasoning_content {
output.push_str("### Reasoning\n\n");
output.push_str("```\n");
output.push_str(reasoning);
output.push_str("\n```\n\n");
}
// Content
output.push_str("### Content\n\n");
if msg.content.is_empty() {
output.push_str("*empty*\n\n");
} else {
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
}
// Tool calls
if let Some(ref calls) = msg.tool_calls {
if !calls.is_empty() {
output.push_str("### Tool Calls\n\n");
for call in calls {
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
output.push_str(" ```json\n");
let args_json = serde_json::to_string_pretty(&call.arguments)
.unwrap_or_else(|_| call.arguments.to_string());
for line in args_json.lines() {
output.push_str(&format!(" {}\n", line));
}
output.push_str(" ```\n");
}
output.push('\n');
}
}
// Media refs
if !msg.media_refs.is_empty() {
output.push_str("### Media References\n\n");
for media_ref in &msg.media_refs {
output.push_str(&format!("- `{}`\n", media_ref));
}
output.push('\n');
}
output.push_str("---\n\n");
}
output
}
/// 格式化消息内容
///
/// 如果内容包含特殊字符,使用代码块包装
fn format_message_content(content: &str) -> String {
// 如果内容包含代码块标记或表格标记,使用原始格式
if content.contains("```") || content.contains("| ") {
format!("```\n{}\n```", content)
} else {
content.to_string()
}
}
/// 转义 YAML 字符串
fn escape_yaml_string(s: &str) -> String {
if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') {
// 使用双引号包裹并转义内部的双引号
format!("\"{}\"", s.replace('"', "\\\""))
} else {
s.to_string()
}
}
/// 格式化时间戳
pub fn format_timestamp(ts: i64) -> String {
Local
.timestamp_millis_opt(ts)
.single()
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
.unwrap_or_else(|| format!("{}", ts))
}
/// 解析文件路径
///
/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名,
/// 保存到用户主目录下的 .picobot/sessions/ 目录
pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
match filepath {
Some(path) => PathBuf::from(path),
None => {
// 生成安全标题(替换特殊字符)
let safe_title = record
.title
.replace(' ', "_")
.replace('/', "_")
.replace('\\', "_")
.replace(':', "_")
.replace('<', "_")
.replace('>', "_")
.replace('|', "_")
.replace('?', "_")
.replace('*', "_")
.replace('"', "_");
// 使用标题或 session_id 作为文件名
let base_name = if safe_title.is_empty() {
format!("session_{}", &record.id[..8.min(record.id.len())])
} else {
safe_title
};
// 添加时间戳
let timestamp = Local::now().format("%Y%m%d_%H%M%S");
let filename = format!("{}_{}.md", base_name, timestamp);
// 保存到用户主目录下的 .picobot/sessions/ 目录
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".picobot")
.join("sessions")
.join(filename)
}
}
}
/// InChat 保存会话命令处理器
///
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
pub struct SaveSessionInChatHandler {
store: Arc<SessionStore>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
}
impl SaveSessionInChatHandler {
/// 创建新的 InChat 保存会话命令处理器
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self {
Self {
store,
system_prompt_provider,
}
}
}
#[async_trait]
impl InChatCommandHandler for SaveSessionInChatHandler {
fn can_handle(&self, cmd: &Command) -> bool {
matches!(cmd, Command::SaveSession { .. })
}
async fn handle(
&self,
cmd: Command,
inbound: &InboundMessage,
session_manager: &crate::gateway::session::SessionManager,
) -> Result<Option<String>, AgentError> {
let Command::SaveSession { filepath, include_all } = cmd else {
return Ok(None);
};
// 通过 session_manager 获取 session
let session = match session_manager.get(&inbound.channel).await {
Some(s) => s,
None => {
tracing::error!("Session not found for channel: {}", inbound.channel);
return Ok(Some("Session not found".to_string()));
}
};
let session_guard = session.lock().await;
let session_id = session_guard.persistent_session_id(&inbound.chat_id);
// 调用公共函数
let result = save_session_to_file(
&session_id,
filepath,
include_all,
&*self.store,
&*self.system_prompt_provider,
)
.await;
// 返回成功或失败消息
match result {
Ok(output_path) => {
let msg = format!("Session saved to: {}", output_path.display());
tracing::info!("{}", msg);
Ok(Some(msg))
}
Err(error) => {
let msg = format!("Failed to save session: {}", error);
tracing::error!("{}", msg);
Ok(Some(msg))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::{SessionRecord, SessionStore};
fn create_test_record(id: &str, title: &str) -> SessionRecord {
SessionRecord {
id: id.to_string(),
title: title.to_string(),
channel_name: "cli".to_string(),
chat_id: id.to_string(),
summary: None,
created_at: 1705312800000, // 2024-01-15 10:00:00
updated_at: 1705316400000, // 2024-01-15 11:00:00
last_active_at: 1705316400000,
archived_at: None,
deleted_at: None,
message_count: 0,
reset_cutoff_seq: 0,
user_turn_count: 0,
agent_prompt_reinjection_count: 0,
}
}
#[test]
fn test_resolve_filepath_with_custom_path() {
let record = create_test_record("test-123", "My Session");
let path = resolve_filepath(Some("/custom/path/file.md".to_string()), &record);
assert_eq!(path, PathBuf::from("/custom/path/file.md"));
}
#[test]
fn test_resolve_filepath_generates_filename_with_title() {
let record = create_test_record("test-123", "My Session");
let path = resolve_filepath(None, &record);
let filename = path.file_name().unwrap().to_str().unwrap();
assert!(filename.starts_with("My_Session_"));
assert!(filename.ends_with(".md"));
}
#[test]
fn test_resolve_filepath_generates_filename_with_id_when_title_empty() {
let record = create_test_record("abc12345-xyz", "");
let path = resolve_filepath(None, &record);
let filename = path.file_name().unwrap().to_str().unwrap();
assert!(filename.starts_with("session_abc123"));
assert!(filename.ends_with(".md"));
}
#[test]
fn test_escape_yaml_string() {
assert_eq!(escape_yaml_string("simple"), "simple");
assert_eq!(escape_yaml_string("with: colon"), "\"with: colon\"");
assert_eq!(escape_yaml_string("with \"quote\""), "\"with \\\"quote\\\"\"");
}
#[test]
fn test_format_message_content() {
assert_eq!(format_message_content("hello"), "hello");
assert_eq!(
format_message_content("```code```"),
"```\n```code```\n```"
);
}
#[test]
fn test_generate_markdown_structure() {
let record = create_test_record("test-123", "Test Session");
let messages = vec![crate::bus::ChatMessage::system("System prompt here")];
let markdown = generate_markdown(&record, &None, &messages);
assert!(markdown.contains("---"));
assert!(markdown.contains("title:"));
assert!(markdown.contains("session_id: test-123"));
assert!(markdown.contains("# System Prompt"));
assert!(markdown.contains("# Message History"));
assert!(markdown.contains("## Message 1"));
assert!(markdown.contains("**Role:** system"));
}
#[test]
fn test_can_handle() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let provider = Arc::new(TestSystemPromptProvider);
let handler = SaveSessionCommandHandler::new(store, provider);
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true }));
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
}
/// 测试用的系统提示词提供者
struct TestSystemPromptProvider;
impl SystemPromptProvider for TestSystemPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
Some(SystemPrompt {
content: "Test system prompt".to_string(),
context: Some("test".to_string()),
})
}
}
}