862 lines
29 KiB
Rust
862 lines
29 KiB
Rust
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||
use crate::bus::InboundMessage;
|
||
use crate::command::context::CommandContext;
|
||
use crate::command::handler::{CommandHandler, CommandMetadata, InChatCommandHandler};
|
||
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||
use crate::command::Command;
|
||
use crate::storage::{SessionRecord, SessionStore};
|
||
use crate::tools::task::repository::TaskRepository;
|
||
use crate::agent::AgentError;
|
||
use async_trait::async_trait;
|
||
use chrono::{Local, TimeZone};
|
||
use std::path::PathBuf;
|
||
use std::sync::Arc;
|
||
|
||
/// 保存会话到文件(公共函数,可被命令处理器和其他模块复用)
|
||
///
|
||
/// # Arguments
|
||
/// * `session_id` - 会话ID
|
||
/// * `filepath` - 可选的文件路径
|
||
/// * `include_all` - 是否包含 cutoff 之前的所有消息
|
||
/// * `include_subagents` - 是否包含子智能体消息
|
||
/// * `store` - 会话存储
|
||
/// * `task_repository` - 任务存储(可选,用于查询子智能体)
|
||
/// * `system_prompt_provider` - 系统提示词提供者
|
||
///
|
||
/// # Returns
|
||
/// 返回保存的文件路径
|
||
pub async fn save_session_to_file(
|
||
session_id: &str,
|
||
filepath: Option<String>,
|
||
include_all: bool,
|
||
include_subagents: bool,
|
||
store: &SessionStore,
|
||
task_repository: Option<&dyn TaskRepository>,
|
||
system_prompt_provider: &dyn SystemPromptProvider,
|
||
) -> Result<PathBuf, String> {
|
||
// 获取会话记录
|
||
let record = store
|
||
.get_session(session_id)
|
||
.map_err(|e| format!("Failed to get session: {}", e))?
|
||
.ok_or_else(|| "Session not found".to_string())?;
|
||
|
||
// 根据 include_all 决定加载消息范围
|
||
let messages = if include_all {
|
||
store
|
||
.load_all_messages(session_id)
|
||
.map_err(|e| format!("Failed to load messages: {}", e))?
|
||
} else {
|
||
store
|
||
.load_messages(session_id)
|
||
.map_err(|e| format!("Failed to load messages: {}", e))?
|
||
};
|
||
|
||
// 加载子智能体消息(如果启用)
|
||
let subagent_data = if include_subagents {
|
||
load_subagent_data(session_id, None, store, task_repository).await
|
||
} else {
|
||
Vec::new()
|
||
};
|
||
|
||
// 计算用户消息数(用于系统提示词构建)
|
||
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||
|
||
// 构建系统提示词(使用外部传入的提供者)
|
||
let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count);
|
||
|
||
// 生成 Markdown 内容
|
||
let markdown = generate_markdown_with_subagents(&record, &system_prompt, &messages, &subagent_data);
|
||
|
||
// 确定输出路径
|
||
let output_path = resolve_filepath(filepath, &record);
|
||
|
||
// 创建父目录
|
||
if let Some(parent) = output_path.parent() {
|
||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||
std::fs::create_dir_all(parent)
|
||
.map_err(|e| format!("Failed to create directory: {}", e))?;
|
||
}
|
||
}
|
||
|
||
// 写入文件
|
||
std::fs::write(&output_path, markdown)
|
||
.map_err(|e| format!("Failed to write file: {}", e))?;
|
||
|
||
Ok(output_path)
|
||
}
|
||
|
||
/// 保存会话命令处理器
|
||
///
|
||
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
|
||
pub struct SaveSessionCommandHandler {
|
||
store: Arc<SessionStore>,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||
}
|
||
|
||
impl SaveSessionCommandHandler {
|
||
/// 创建新的保存会话命令处理器
|
||
///
|
||
/// # Arguments
|
||
/// * `store` - 会话存储
|
||
/// * `task_repository` - 任务存储(用于查询子智能体)
|
||
/// * `system_prompt_provider` - 系统提示词提供者
|
||
pub fn new(
|
||
store: Arc<SessionStore>,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||
) -> Self {
|
||
Self {
|
||
store,
|
||
task_repository,
|
||
system_prompt_provider,
|
||
}
|
||
}
|
||
|
||
/// 从会话记录获取存储(用于测试)
|
||
#[cfg(test)]
|
||
fn store(&self) -> &Arc<SessionStore> {
|
||
&self.store
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl CommandHandler for SaveSessionCommandHandler {
|
||
fn can_handle(&self, cmd: &Command) -> bool {
|
||
matches!(cmd, Command::SaveSession { .. })
|
||
}
|
||
|
||
fn metadata(&self) -> Option<CommandMetadata> {
|
||
Some(CommandMetadata {
|
||
name: "save-session",
|
||
description: "保存当前会话到 Markdown 文件",
|
||
usage: "/save-session [all] [filepath]",
|
||
})
|
||
}
|
||
|
||
async fn handle(
|
||
&self,
|
||
cmd: Command,
|
||
ctx: CommandContext,
|
||
) -> Result<CommandResponse, CommandError> {
|
||
match cmd {
|
||
Command::SaveSession { filepath, include_all, include_subagents } => {
|
||
handle_save_session(self, filepath, include_all, include_subagents, ctx).await
|
||
}
|
||
_ => unreachable!(),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理保存会话命令
|
||
async fn handle_save_session(
|
||
handler: &SaveSessionCommandHandler,
|
||
filepath: Option<String>,
|
||
include_all: bool,
|
||
include_subagents: bool,
|
||
ctx: CommandContext,
|
||
) -> Result<CommandResponse, CommandError> {
|
||
tracing::debug!(
|
||
ctx_session_id = ?ctx.session_id,
|
||
ctx_chat_id = ?ctx.chat_id,
|
||
channel = %ctx.channel_name,
|
||
include_subagents = include_subagents,
|
||
"SaveSession command received"
|
||
);
|
||
|
||
let session_id = ctx
|
||
.session_id
|
||
.as_deref()
|
||
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?;
|
||
|
||
tracing::debug!(session_id = %session_id, "Attempting to save session");
|
||
|
||
// 先检查会话是否存在
|
||
match handler.store.get_session(session_id) {
|
||
Ok(Some(record)) => {
|
||
tracing::debug!(
|
||
session_id = %session_id,
|
||
title = %record.title,
|
||
chat_id = %record.chat_id,
|
||
message_count = record.message_count,
|
||
"Session found for saving"
|
||
);
|
||
}
|
||
Ok(None) => {
|
||
tracing::warn!(session_id = %session_id, "Session not found in store");
|
||
}
|
||
Err(e) => {
|
||
tracing::error!(session_id = %session_id, error = %e, "Error querying session");
|
||
}
|
||
}
|
||
|
||
// 调用公共函数
|
||
let output_path = save_session_to_file(
|
||
session_id,
|
||
filepath,
|
||
include_all,
|
||
include_subagents,
|
||
&*handler.store,
|
||
Some(handler.task_repository.as_ref()),
|
||
&*handler.system_prompt_provider,
|
||
)
|
||
.await
|
||
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
|
||
|
||
// 根据 include_all 获取消息数量
|
||
let message_count = if include_all {
|
||
handler
|
||
.store
|
||
.load_all_messages(session_id)
|
||
} else {
|
||
handler
|
||
.store
|
||
.load_messages(session_id)
|
||
}
|
||
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?
|
||
.len();
|
||
|
||
Ok(CommandResponse::success(ctx.request_id)
|
||
.with_message(
|
||
MessageKind::Notification,
|
||
&format!("Session saved to: {}", output_path.display()),
|
||
)
|
||
.with_metadata("filepath", output_path.to_string_lossy().as_ref())
|
||
.with_metadata("message_count", &message_count.to_string()))
|
||
}
|
||
|
||
/// 子智能体任务数据
|
||
#[derive(Debug)]
|
||
pub struct SubagentTaskData {
|
||
pub task_id: String,
|
||
pub session_id: String,
|
||
pub description: String,
|
||
pub subagent_type: String,
|
||
pub state: String,
|
||
pub created_at: i64,
|
||
pub messages: Vec<crate::bus::ChatMessage>,
|
||
}
|
||
|
||
/// 加载子智能体数据
|
||
///
|
||
/// # Arguments
|
||
/// * `parent_session_id` - 父会话 ID
|
||
/// * `parent_topic_id` - 可选的父话题 ID,如果提供则只加载该话题下的子智能体
|
||
/// * `store` - 会话存储
|
||
/// * `task_repository` - 任务存储(可选)
|
||
pub async fn load_subagent_data(
|
||
parent_session_id: &str,
|
||
parent_topic_id: Option<&str>,
|
||
store: &SessionStore,
|
||
task_repository: Option<&dyn TaskRepository>,
|
||
) -> Vec<SubagentTaskData> {
|
||
let Some(repo) = task_repository else {
|
||
return Vec::new();
|
||
};
|
||
|
||
// 获取子任务:如果提供了 topic_id,则按 topic 查询;否则按 session 查询
|
||
let tasks = match parent_topic_id {
|
||
Some(topic_id) => match repo.list_tasks_for_topic(topic_id).await {
|
||
Ok(tasks) => tasks,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Failed to list tasks for topic");
|
||
return Vec::new();
|
||
}
|
||
},
|
||
None => match repo.list_tasks_for_session(parent_session_id).await {
|
||
Ok(tasks) => tasks,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Failed to list tasks for session");
|
||
return Vec::new();
|
||
}
|
||
},
|
||
};
|
||
|
||
let mut result = Vec::new();
|
||
for task in tasks {
|
||
// 加载子智能体的消息
|
||
let messages = match store.load_all_messages(&task.session_id) {
|
||
Ok(msgs) => msgs,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, task_id = %task.id, "Failed to load subagent messages");
|
||
Vec::new()
|
||
}
|
||
};
|
||
|
||
result.push(SubagentTaskData {
|
||
task_id: task.id,
|
||
session_id: task.session_id,
|
||
description: task.description,
|
||
subagent_type: task.subagent_type.as_str().to_string(),
|
||
state: format!("{:?}", task.state),
|
||
created_at: task.created_at,
|
||
messages,
|
||
});
|
||
}
|
||
|
||
result
|
||
}
|
||
|
||
/// 生成 Markdown 内容(包含子智能体)
|
||
pub fn generate_markdown_with_subagents(
|
||
record: &SessionRecord,
|
||
system_prompt: &Option<SystemPrompt>,
|
||
messages: &[crate::bus::ChatMessage],
|
||
subagent_data: &[SubagentTaskData],
|
||
) -> String {
|
||
let mut output = String::new();
|
||
|
||
// YAML frontmatter
|
||
output.push_str("---\n");
|
||
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
|
||
output.push_str(&format!("session_id: {}\n", record.id));
|
||
output.push_str(&format!("channel: {}\n", record.channel_name));
|
||
output.push_str(&format!("chat_id: {}\n", record.chat_id));
|
||
output.push_str(&format!(
|
||
"created_at: {}\n",
|
||
format_timestamp(record.created_at)
|
||
));
|
||
output.push_str(&format!(
|
||
"updated_at: {}\n",
|
||
format_timestamp(record.updated_at)
|
||
));
|
||
output.push_str(&format!(
|
||
"last_active_at: {}\n",
|
||
format_timestamp(record.last_active_at)
|
||
));
|
||
output.push_str(&format!("message_count: {}\n", messages.len()));
|
||
if !subagent_data.is_empty() {
|
||
output.push_str(&format!("subagent_count: {}\n", subagent_data.len()));
|
||
}
|
||
output.push_str("---\n\n");
|
||
|
||
// 系统提示词
|
||
output.push_str(&generate_system_prompt_markdown(system_prompt));
|
||
|
||
// 子智能体任务(如果有)
|
||
if !subagent_data.is_empty() {
|
||
output.push_str(&generate_subagent_tasks_markdown(subagent_data));
|
||
}
|
||
|
||
// 主会话消息历史
|
||
output.push_str(&generate_messages_markdown(messages));
|
||
|
||
output
|
||
}
|
||
|
||
/// 生成子智能体任务 Markdown
|
||
pub fn generate_subagent_tasks_markdown(subagent_data: &[SubagentTaskData]) -> String {
|
||
let mut output = String::new();
|
||
|
||
output.push_str("# Subagent Tasks\n\n");
|
||
|
||
for task in subagent_data {
|
||
output.push_str(&format!("## Task: {} ({})", task.description, task.subagent_type));
|
||
output.push('\n');
|
||
output.push_str(&format!("**Task ID:** `{}`\n\n", task.task_id));
|
||
output.push_str(&format!("**Session ID:** `{}`\n\n", task.session_id));
|
||
output.push_str(&format!("**Status:** {}\n\n", task.state));
|
||
output.push_str(&format!("**Created:** {}\n\n", format_timestamp(task.created_at)));
|
||
output.push_str(&format!("**Message Count:** {}\n\n", task.messages.len()));
|
||
|
||
// 子智能体消息
|
||
if !task.messages.is_empty() {
|
||
output.push_str("### Messages\n\n");
|
||
for (idx, msg) in task.messages.iter().enumerate() {
|
||
output.push_str(&format!("#### Message {}\n\n", idx + 1));
|
||
output.push_str(&format!("**Role:** {}\n\n", msg.role));
|
||
output.push_str(&format!("**Time:** {}\n\n", format_timestamp(msg.timestamp)));
|
||
|
||
if let Some(ref reasoning) = msg.reasoning_content {
|
||
output.push_str("**Reasoning:**\n");
|
||
output.push_str("```\n");
|
||
output.push_str(reasoning);
|
||
output.push_str("\n```\n\n");
|
||
}
|
||
|
||
output.push_str("**Content:**\n\n");
|
||
if msg.content.is_empty() {
|
||
output.push_str("*empty*\n\n");
|
||
} else {
|
||
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
|
||
}
|
||
|
||
// 工具调用
|
||
if let Some(ref calls) = msg.tool_calls {
|
||
if !calls.is_empty() {
|
||
output.push_str("**Tool Calls:**\n\n");
|
||
for call in calls {
|
||
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
|
||
output.push_str(" ```json\n");
|
||
let args_json = serde_json::to_string_pretty(&call.arguments)
|
||
.unwrap_or_else(|_| call.arguments.to_string());
|
||
for line in args_json.lines() {
|
||
output.push_str(&format!(" {}\n", line));
|
||
}
|
||
output.push_str(" ```\n");
|
||
}
|
||
output.push('\n');
|
||
}
|
||
}
|
||
|
||
output.push_str("---\n\n");
|
||
}
|
||
}
|
||
|
||
output.push_str("---\n\n");
|
||
}
|
||
|
||
output
|
||
}
|
||
|
||
/// 构建系统提示词
|
||
fn build_system_prompt(
|
||
provider: &dyn SystemPromptProvider,
|
||
record: &SessionRecord,
|
||
user_message_count: usize,
|
||
) -> Option<SystemPrompt> {
|
||
let context = SystemPromptContext {
|
||
session_id: Some(record.id.clone()),
|
||
chat_id: record.chat_id.clone(),
|
||
user_message_count,
|
||
};
|
||
provider.build(&context)
|
||
}
|
||
|
||
/// 生成 Markdown 内容
|
||
pub fn generate_markdown(
|
||
record: &SessionRecord,
|
||
system_prompt: &Option<SystemPrompt>,
|
||
messages: &[crate::bus::ChatMessage],
|
||
) -> String {
|
||
let mut output = String::new();
|
||
|
||
// YAML frontmatter
|
||
output.push_str("---\n");
|
||
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
|
||
output.push_str(&format!("session_id: {}\n", record.id));
|
||
output.push_str(&format!("channel: {}\n", record.channel_name));
|
||
output.push_str(&format!("chat_id: {}\n", record.chat_id));
|
||
output.push_str(&format!(
|
||
"created_at: {}\n",
|
||
format_timestamp(record.created_at)
|
||
));
|
||
output.push_str(&format!(
|
||
"updated_at: {}\n",
|
||
format_timestamp(record.updated_at)
|
||
));
|
||
output.push_str(&format!(
|
||
"last_active_at: {}\n",
|
||
format_timestamp(record.last_active_at)
|
||
));
|
||
output.push_str(&format!("message_count: {}\n", messages.len()));
|
||
output.push_str("---\n\n");
|
||
|
||
// 系统提示词(复用公共函数)
|
||
output.push_str(&generate_system_prompt_markdown(system_prompt));
|
||
|
||
// 消息历史(复用公共函数)
|
||
output.push_str(&generate_messages_markdown(messages));
|
||
|
||
output
|
||
}
|
||
|
||
/// 格式化消息内容
|
||
///
|
||
/// 如果内容包含特殊字符,使用代码块包装
|
||
/// 使用比内容中最大连续反引号数量多1的反引号来包裹,避免嵌套冲突
|
||
pub fn format_message_content(content: &str) -> String {
|
||
// 如果内容包含表格标记或换行符,使用代码块包裹以保留格式
|
||
if content.contains("| ") || content.contains('\n') {
|
||
// 计算内容中连续反引号的最大数量
|
||
let max_backticks = content
|
||
.chars()
|
||
.fold((0, 0), |(max_count, current_count), c| {
|
||
if c == '`' {
|
||
(max_count, current_count + 1)
|
||
} else {
|
||
(max_count.max(current_count), 0)
|
||
}
|
||
})
|
||
.0;
|
||
// 使用比最大数量多1的反引号来包裹(至少3个)
|
||
let fence = "`".repeat(max_backticks.max(3) + 1);
|
||
format!("{}\n{}\n{}", fence, content, fence)
|
||
} else {
|
||
content.to_string()
|
||
}
|
||
}
|
||
|
||
/// 转义 YAML 字符串
|
||
pub fn escape_yaml_string(s: &str) -> String {
|
||
if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') {
|
||
// 使用双引号包裹并转义内部的双引号
|
||
format!("\"{}\"", s.replace('"', "\\\""))
|
||
} else {
|
||
s.to_string()
|
||
}
|
||
}
|
||
|
||
/// 格式化时间戳
|
||
pub fn format_timestamp(ts: i64) -> String {
|
||
Local
|
||
.timestamp_millis_opt(ts)
|
||
.single()
|
||
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
|
||
.unwrap_or_else(|| format!("{}", ts))
|
||
}
|
||
|
||
/// 生成消息历史的 Markdown 内容
|
||
///
|
||
/// 这是一个通用函数,可被 Session 和 Topic 的保存逻辑复用
|
||
pub fn generate_messages_markdown(messages: &[crate::bus::ChatMessage]) -> String {
|
||
let mut output = String::new();
|
||
|
||
output.push_str("# Message History\n\n");
|
||
|
||
for (idx, msg) in messages.iter().enumerate() {
|
||
output.push_str(&format!("## Message {}\n\n", idx + 1));
|
||
output.push_str(&format!("**Role:** {}\n\n", msg.role));
|
||
output.push_str(&format!("**ID:** {}\n\n", msg.id));
|
||
output.push_str(&format!(
|
||
"**Time:** {}\n\n",
|
||
format_timestamp(msg.timestamp)
|
||
));
|
||
|
||
if let Some(ref ctx) = msg.system_context {
|
||
output.push_str(&format!("**System Context:** `{}`\n\n", ctx));
|
||
}
|
||
|
||
if let Some(ref tool_name) = msg.tool_name {
|
||
output.push_str(&format!("**Tool Name:** `{}`\n\n", tool_name));
|
||
}
|
||
|
||
if let Some(ref tool_call_id) = msg.tool_call_id {
|
||
output.push_str(&format!("**Tool Call ID:** `{}`\n\n", tool_call_id));
|
||
}
|
||
|
||
if let Some(ref reasoning) = msg.reasoning_content {
|
||
output.push_str("### Reasoning\n\n");
|
||
output.push_str("```\n");
|
||
output.push_str(reasoning);
|
||
output.push_str("\n```\n\n");
|
||
}
|
||
|
||
// Content
|
||
output.push_str("### Content\n\n");
|
||
if msg.content.is_empty() {
|
||
output.push_str("*empty*\n\n");
|
||
} else {
|
||
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
|
||
}
|
||
|
||
// Tool calls
|
||
if let Some(ref calls) = msg.tool_calls {
|
||
if !calls.is_empty() {
|
||
output.push_str("### Tool Calls\n\n");
|
||
for call in calls {
|
||
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
|
||
output.push_str(" ```json\n");
|
||
let args_json = serde_json::to_string_pretty(&call.arguments)
|
||
.unwrap_or_else(|_| call.arguments.to_string());
|
||
for line in args_json.lines() {
|
||
output.push_str(&format!(" {}\n", line));
|
||
}
|
||
output.push_str(" ```\n");
|
||
}
|
||
output.push('\n');
|
||
}
|
||
}
|
||
|
||
// Media refs
|
||
if !msg.media_refs.is_empty() {
|
||
output.push_str("### Media References\n\n");
|
||
for media_ref in &msg.media_refs {
|
||
output.push_str(&format!("- `{}`\n", media_ref));
|
||
}
|
||
output.push('\n');
|
||
}
|
||
|
||
output.push_str("---\n\n");
|
||
}
|
||
|
||
output
|
||
}
|
||
|
||
/// 生成系统提示词部分的 Markdown
|
||
pub fn generate_system_prompt_markdown(system_prompt: &Option<SystemPrompt>) -> String {
|
||
let mut output = String::new();
|
||
|
||
output.push_str("# System Prompt\n\n");
|
||
if let Some(prompt) = system_prompt {
|
||
output.push_str("```\n");
|
||
output.push_str(&prompt.content);
|
||
output.push_str("\n```\n\n");
|
||
} else {
|
||
output.push_str("*No system prompt available*\n\n");
|
||
}
|
||
|
||
output
|
||
}
|
||
|
||
/// 解析文件路径
|
||
///
|
||
/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名,
|
||
/// 保存到用户主目录下的 .picobot/sessions/ 目录
|
||
pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
|
||
match filepath {
|
||
Some(path) => PathBuf::from(path),
|
||
None => {
|
||
// 生成安全标题(替换特殊字符)
|
||
let safe_title = record
|
||
.title
|
||
.replace(' ', "_")
|
||
.replace('/', "_")
|
||
.replace('\\', "_")
|
||
.replace(':', "_")
|
||
.replace('<', "_")
|
||
.replace('>', "_")
|
||
.replace('|', "_")
|
||
.replace('?', "_")
|
||
.replace('*', "_")
|
||
.replace('"', "_");
|
||
|
||
// 使用标题或 session_id 作为文件名
|
||
let base_name = if safe_title.is_empty() {
|
||
format!("session_{}", &record.id[..8.min(record.id.len())])
|
||
} else {
|
||
safe_title
|
||
};
|
||
|
||
// 添加时间戳
|
||
let timestamp = Local::now().format("%Y%m%d_%H%M%S");
|
||
let filename = format!("{}_{}.md", base_name, timestamp);
|
||
|
||
// 保存到用户主目录下的 .picobot/sessions/ 目录
|
||
dirs::home_dir()
|
||
.unwrap_or_else(|| PathBuf::from("."))
|
||
.join(".picobot")
|
||
.join("sessions")
|
||
.join(filename)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// InChat 保存会话命令处理器
|
||
///
|
||
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
|
||
pub struct SaveSessionInChatHandler {
|
||
store: Arc<SessionStore>,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||
}
|
||
|
||
impl SaveSessionInChatHandler {
|
||
/// 创建新的 InChat 保存会话命令处理器
|
||
pub fn new(
|
||
store: Arc<SessionStore>,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||
) -> Self {
|
||
Self {
|
||
store,
|
||
task_repository,
|
||
system_prompt_provider,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl InChatCommandHandler for SaveSessionInChatHandler {
|
||
fn can_handle(&self, cmd: &Command) -> bool {
|
||
matches!(cmd, Command::SaveSession { .. })
|
||
}
|
||
|
||
async fn handle(
|
||
&self,
|
||
cmd: Command,
|
||
inbound: &InboundMessage,
|
||
session_manager: &crate::gateway::session::SessionManager,
|
||
) -> Result<Option<String>, AgentError> {
|
||
let Command::SaveSession { filepath, include_all, include_subagents } = cmd else {
|
||
return Ok(None);
|
||
};
|
||
|
||
// 通过 session_manager 获取 session
|
||
let session = match session_manager.get(&inbound.channel).await {
|
||
Some(s) => s,
|
||
None => {
|
||
tracing::error!("Session not found for channel: {}", inbound.channel);
|
||
return Ok(Some("Session not found".to_string()));
|
||
}
|
||
};
|
||
|
||
let session_guard = session.lock().await;
|
||
let session_id = session_guard.persistent_session_id(&inbound.chat_id);
|
||
|
||
// 调用公共函数
|
||
let result = save_session_to_file(
|
||
&session_id,
|
||
filepath,
|
||
include_all,
|
||
include_subagents,
|
||
&*self.store,
|
||
Some(self.task_repository.as_ref()),
|
||
&*self.system_prompt_provider,
|
||
)
|
||
.await;
|
||
|
||
// 返回成功或失败消息
|
||
match result {
|
||
Ok(output_path) => {
|
||
let msg = format!("Session saved to: {}", output_path.display());
|
||
tracing::info!("{}", msg);
|
||
Ok(Some(msg))
|
||
}
|
||
Err(error) => {
|
||
let msg = format!("Failed to save session: {}", error);
|
||
tracing::error!("{}", msg);
|
||
Ok(Some(msg))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::storage::{SessionRecord, SessionStore};
|
||
|
||
fn create_test_record(id: &str, title: &str) -> SessionRecord {
|
||
SessionRecord {
|
||
id: id.to_string(),
|
||
title: title.to_string(),
|
||
channel_name: "cli".to_string(),
|
||
chat_id: id.to_string(),
|
||
summary: None,
|
||
created_at: 1705312800000, // 2024-01-15 10:00:00
|
||
updated_at: 1705316400000, // 2024-01-15 11:00:00
|
||
last_active_at: 1705316400000,
|
||
archived_at: None,
|
||
deleted_at: None,
|
||
message_count: 0,
|
||
user_turn_count: 0,
|
||
agent_prompt_reinjection_count: 0,
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_resolve_filepath_with_custom_path() {
|
||
let record = create_test_record("test-123", "My Session");
|
||
let path = resolve_filepath(Some("/custom/path/file.md".to_string()), &record);
|
||
assert_eq!(path, PathBuf::from("/custom/path/file.md"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_resolve_filepath_generates_filename_with_title() {
|
||
let record = create_test_record("test-123", "My Session");
|
||
let path = resolve_filepath(None, &record);
|
||
let filename = path.file_name().unwrap().to_str().unwrap();
|
||
|
||
assert!(filename.starts_with("My_Session_"));
|
||
assert!(filename.ends_with(".md"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_resolve_filepath_generates_filename_with_id_when_title_empty() {
|
||
let record = create_test_record("abc12345-xyz", "");
|
||
let path = resolve_filepath(None, &record);
|
||
let filename = path.file_name().unwrap().to_str().unwrap();
|
||
|
||
assert!(filename.starts_with("session_abc123"));
|
||
assert!(filename.ends_with(".md"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_escape_yaml_string() {
|
||
assert_eq!(escape_yaml_string("simple"), "simple");
|
||
assert_eq!(escape_yaml_string("with: colon"), "\"with: colon\"");
|
||
assert_eq!(escape_yaml_string("with \"quote\""), "\"with \\\"quote\\\"\"");
|
||
}
|
||
|
||
#[test]
|
||
fn test_format_message_content() {
|
||
// 普通单行文本 - 原样返回
|
||
assert_eq!(format_message_content("hello"), "hello");
|
||
|
||
// 单行包含反引号 - 原样返回(单行不需要包裹)
|
||
assert_eq!(format_message_content("`code`"), "`code`");
|
||
|
||
// 包含换行符 - 使用4个反引号包裹(最小)
|
||
assert_eq!(
|
||
format_message_content("line1\nline2\nline3"),
|
||
"````\nline1\nline2\nline3\n````"
|
||
);
|
||
|
||
// 包含表格标记 - 使用4个反引号包裹
|
||
assert_eq!(
|
||
format_message_content("| col1 | col2 |"),
|
||
"````\n| col1 | col2 |\n````"
|
||
);
|
||
|
||
// 多行内容包含3个反引号(代码块标记)- 使用4个反引号包裹
|
||
assert_eq!(
|
||
format_message_content("```code```\nmore"),
|
||
"````\n```code```\nmore\n````"
|
||
);
|
||
|
||
// 多行内容包含多行代码块
|
||
assert_eq!(
|
||
format_message_content("```\ncode\n```\nmore"),
|
||
"````\n```\ncode\n```\nmore\n````"
|
||
);
|
||
|
||
// 多行内容包含4个反引号 - 使用5个反引号包裹
|
||
assert_eq!(
|
||
format_message_content("````code````\nmore"),
|
||
"`````\n````code````\nmore\n`````"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_generate_markdown_structure() {
|
||
let record = create_test_record("test-123", "Test Session");
|
||
let messages = vec![crate::bus::ChatMessage::system("System prompt here")];
|
||
|
||
let markdown = generate_markdown(&record, &None, &messages);
|
||
|
||
assert!(markdown.contains("---"));
|
||
assert!(markdown.contains("title:"));
|
||
assert!(markdown.contains("session_id: test-123"));
|
||
assert!(markdown.contains("# System Prompt"));
|
||
assert!(markdown.contains("# Message History"));
|
||
assert!(markdown.contains("## Message 1"));
|
||
assert!(markdown.contains("**Role:** system"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_can_handle() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let task_repository = Arc::new(crate::tools::task::repository::InMemoryTaskRepository::new());
|
||
let provider = Arc::new(TestSystemPromptProvider);
|
||
let handler = SaveSessionCommandHandler::new(store, task_repository, provider);
|
||
|
||
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false }));
|
||
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false }));
|
||
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
|
||
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None, include_subagents: false }));
|
||
}
|
||
|
||
/// 测试用的系统提示词提供者
|
||
struct TestSystemPromptProvider;
|
||
|
||
impl SystemPromptProvider for TestSystemPromptProvider {
|
||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||
Some(SystemPrompt {
|
||
content: "Test system prompt".to_string(),
|
||
context: Some("test".to_string()),
|
||
})
|
||
}
|
||
}
|
||
}
|