PicoBot/src/command/handlers/save_session.rs

862 lines
29 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::bus::InboundMessage;
use crate::command::context::CommandContext;
use crate::command::handler::{CommandHandler, CommandMetadata, InChatCommandHandler};
use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command;
use crate::storage::{SessionRecord, SessionStore};
use crate::tools::task::repository::TaskRepository;
use crate::agent::AgentError;
use async_trait::async_trait;
use chrono::{Local, TimeZone};
use std::path::PathBuf;
use std::sync::Arc;
/// 保存会话到文件(公共函数,可被命令处理器和其他模块复用)
///
/// # Arguments
/// * `session_id` - 会话ID
/// * `filepath` - 可选的文件路径
/// * `include_all` - 是否包含 cutoff 之前的所有消息
/// * `include_subagents` - 是否包含子智能体消息
/// * `store` - 会话存储
/// * `task_repository` - 任务存储(可选,用于查询子智能体)
/// * `system_prompt_provider` - 系统提示词提供者
///
/// # Returns
/// 返回保存的文件路径
pub async fn save_session_to_file(
session_id: &str,
filepath: Option<String>,
include_all: bool,
include_subagents: bool,
store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
system_prompt_provider: &dyn SystemPromptProvider,
) -> Result<PathBuf, String> {
// 获取会话记录
let record = store
.get_session(session_id)
.map_err(|e| format!("Failed to get session: {}", e))?
.ok_or_else(|| "Session not found".to_string())?;
// 根据 include_all 决定加载消息范围
let messages = if include_all {
store
.load_all_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?
} else {
store
.load_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?
};
// 加载子智能体消息(如果启用)
let subagent_data = if include_subagents {
load_subagent_data(session_id, None, store, task_repository).await
} else {
Vec::new()
};
// 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
// 构建系统提示词(使用外部传入的提供者)
let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count);
// 生成 Markdown 内容
let markdown = generate_markdown_with_subagents(&record, &system_prompt, &messages, &subagent_data);
// 确定输出路径
let output_path = resolve_filepath(filepath, &record);
// 创建父目录
if let Some(parent) = output_path.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
}
// 写入文件
std::fs::write(&output_path, markdown)
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok(output_path)
}
/// 保存会话命令处理器
///
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
pub struct SaveSessionCommandHandler {
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
}
impl SaveSessionCommandHandler {
/// 创建新的保存会话命令处理器
///
/// # Arguments
/// * `store` - 会话存储
/// * `task_repository` - 任务存储(用于查询子智能体)
/// * `system_prompt_provider` - 系统提示词提供者
pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self {
store,
task_repository,
system_prompt_provider,
}
}
/// 从会话记录获取存储(用于测试)
#[cfg(test)]
fn store(&self) -> &Arc<SessionStore> {
&self.store
}
}
#[async_trait]
impl CommandHandler for SaveSessionCommandHandler {
fn can_handle(&self, cmd: &Command) -> bool {
matches!(cmd, Command::SaveSession { .. })
}
fn metadata(&self) -> Option<CommandMetadata> {
Some(CommandMetadata {
name: "save-session",
description: "保存当前会话到 Markdown 文件",
usage: "/save-session [all] [filepath]",
})
}
async fn handle(
&self,
cmd: Command,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
match cmd {
Command::SaveSession { filepath, include_all, include_subagents } => {
handle_save_session(self, filepath, include_all, include_subagents, ctx).await
}
_ => unreachable!(),
}
}
}
/// 处理保存会话命令
async fn handle_save_session(
handler: &SaveSessionCommandHandler,
filepath: Option<String>,
include_all: bool,
include_subagents: bool,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
tracing::debug!(
ctx_session_id = ?ctx.session_id,
ctx_chat_id = ?ctx.chat_id,
channel = %ctx.channel_name,
include_subagents = include_subagents,
"SaveSession command received"
);
let session_id = ctx
.session_id
.as_deref()
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?;
tracing::debug!(session_id = %session_id, "Attempting to save session");
// 先检查会话是否存在
match handler.store.get_session(session_id) {
Ok(Some(record)) => {
tracing::debug!(
session_id = %session_id,
title = %record.title,
chat_id = %record.chat_id,
message_count = record.message_count,
"Session found for saving"
);
}
Ok(None) => {
tracing::warn!(session_id = %session_id, "Session not found in store");
}
Err(e) => {
tracing::error!(session_id = %session_id, error = %e, "Error querying session");
}
}
// 调用公共函数
let output_path = save_session_to_file(
session_id,
filepath,
include_all,
include_subagents,
&*handler.store,
Some(handler.task_repository.as_ref()),
&*handler.system_prompt_provider,
)
.await
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
// 根据 include_all 获取消息数量
let message_count = if include_all {
handler
.store
.load_all_messages(session_id)
} else {
handler
.store
.load_messages(session_id)
}
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?
.len();
Ok(CommandResponse::success(ctx.request_id)
.with_message(
MessageKind::Notification,
&format!("Session saved to: {}", output_path.display()),
)
.with_metadata("filepath", output_path.to_string_lossy().as_ref())
.with_metadata("message_count", &message_count.to_string()))
}
/// 子智能体任务数据
#[derive(Debug)]
pub struct SubagentTaskData {
pub task_id: String,
pub session_id: String,
pub description: String,
pub subagent_type: String,
pub state: String,
pub created_at: i64,
pub messages: Vec<crate::bus::ChatMessage>,
}
/// 加载子智能体数据
///
/// # Arguments
/// * `parent_session_id` - 父会话 ID
/// * `parent_topic_id` - 可选的父话题 ID如果提供则只加载该话题下的子智能体
/// * `store` - 会话存储
/// * `task_repository` - 任务存储(可选)
pub async fn load_subagent_data(
parent_session_id: &str,
parent_topic_id: Option<&str>,
store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
) -> Vec<SubagentTaskData> {
let Some(repo) = task_repository else {
return Vec::new();
};
// 获取子任务:如果提供了 topic_id则按 topic 查询;否则按 session 查询
let tasks = match parent_topic_id {
Some(topic_id) => match repo.list_tasks_for_topic(topic_id).await {
Ok(tasks) => tasks,
Err(e) => {
tracing::warn!(error = %e, "Failed to list tasks for topic");
return Vec::new();
}
},
None => match repo.list_tasks_for_session(parent_session_id).await {
Ok(tasks) => tasks,
Err(e) => {
tracing::warn!(error = %e, "Failed to list tasks for session");
return Vec::new();
}
},
};
let mut result = Vec::new();
for task in tasks {
// 加载子智能体的消息
let messages = match store.load_all_messages(&task.session_id) {
Ok(msgs) => msgs,
Err(e) => {
tracing::warn!(error = %e, task_id = %task.id, "Failed to load subagent messages");
Vec::new()
}
};
result.push(SubagentTaskData {
task_id: task.id,
session_id: task.session_id,
description: task.description,
subagent_type: task.subagent_type.as_str().to_string(),
state: format!("{:?}", task.state),
created_at: task.created_at,
messages,
});
}
result
}
/// 生成 Markdown 内容(包含子智能体)
pub fn generate_markdown_with_subagents(
record: &SessionRecord,
system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage],
subagent_data: &[SubagentTaskData],
) -> String {
let mut output = String::new();
// YAML frontmatter
output.push_str("---\n");
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
output.push_str(&format!("session_id: {}\n", record.id));
output.push_str(&format!("channel: {}\n", record.channel_name));
output.push_str(&format!("chat_id: {}\n", record.chat_id));
output.push_str(&format!(
"created_at: {}\n",
format_timestamp(record.created_at)
));
output.push_str(&format!(
"updated_at: {}\n",
format_timestamp(record.updated_at)
));
output.push_str(&format!(
"last_active_at: {}\n",
format_timestamp(record.last_active_at)
));
output.push_str(&format!("message_count: {}\n", messages.len()));
if !subagent_data.is_empty() {
output.push_str(&format!("subagent_count: {}\n", subagent_data.len()));
}
output.push_str("---\n\n");
// 系统提示词
output.push_str(&generate_system_prompt_markdown(system_prompt));
// 子智能体任务(如果有)
if !subagent_data.is_empty() {
output.push_str(&generate_subagent_tasks_markdown(subagent_data));
}
// 主会话消息历史
output.push_str(&generate_messages_markdown(messages));
output
}
/// 生成子智能体任务 Markdown
pub fn generate_subagent_tasks_markdown(subagent_data: &[SubagentTaskData]) -> String {
let mut output = String::new();
output.push_str("# Subagent Tasks\n\n");
for task in subagent_data {
output.push_str(&format!("## Task: {} ({})", task.description, task.subagent_type));
output.push('\n');
output.push_str(&format!("**Task ID:** `{}`\n\n", task.task_id));
output.push_str(&format!("**Session ID:** `{}`\n\n", task.session_id));
output.push_str(&format!("**Status:** {}\n\n", task.state));
output.push_str(&format!("**Created:** {}\n\n", format_timestamp(task.created_at)));
output.push_str(&format!("**Message Count:** {}\n\n", task.messages.len()));
// 子智能体消息
if !task.messages.is_empty() {
output.push_str("### Messages\n\n");
for (idx, msg) in task.messages.iter().enumerate() {
output.push_str(&format!("#### Message {}\n\n", idx + 1));
output.push_str(&format!("**Role:** {}\n\n", msg.role));
output.push_str(&format!("**Time:** {}\n\n", format_timestamp(msg.timestamp)));
if let Some(ref reasoning) = msg.reasoning_content {
output.push_str("**Reasoning:**\n");
output.push_str("```\n");
output.push_str(reasoning);
output.push_str("\n```\n\n");
}
output.push_str("**Content:**\n\n");
if msg.content.is_empty() {
output.push_str("*empty*\n\n");
} else {
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
}
// 工具调用
if let Some(ref calls) = msg.tool_calls {
if !calls.is_empty() {
output.push_str("**Tool Calls:**\n\n");
for call in calls {
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
output.push_str(" ```json\n");
let args_json = serde_json::to_string_pretty(&call.arguments)
.unwrap_or_else(|_| call.arguments.to_string());
for line in args_json.lines() {
output.push_str(&format!(" {}\n", line));
}
output.push_str(" ```\n");
}
output.push('\n');
}
}
output.push_str("---\n\n");
}
}
output.push_str("---\n\n");
}
output
}
/// 构建系统提示词
fn build_system_prompt(
provider: &dyn SystemPromptProvider,
record: &SessionRecord,
user_message_count: usize,
) -> Option<SystemPrompt> {
let context = SystemPromptContext {
session_id: Some(record.id.clone()),
chat_id: record.chat_id.clone(),
user_message_count,
};
provider.build(&context)
}
/// 生成 Markdown 内容
pub fn generate_markdown(
record: &SessionRecord,
system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage],
) -> String {
let mut output = String::new();
// YAML frontmatter
output.push_str("---\n");
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
output.push_str(&format!("session_id: {}\n", record.id));
output.push_str(&format!("channel: {}\n", record.channel_name));
output.push_str(&format!("chat_id: {}\n", record.chat_id));
output.push_str(&format!(
"created_at: {}\n",
format_timestamp(record.created_at)
));
output.push_str(&format!(
"updated_at: {}\n",
format_timestamp(record.updated_at)
));
output.push_str(&format!(
"last_active_at: {}\n",
format_timestamp(record.last_active_at)
));
output.push_str(&format!("message_count: {}\n", messages.len()));
output.push_str("---\n\n");
// 系统提示词(复用公共函数)
output.push_str(&generate_system_prompt_markdown(system_prompt));
// 消息历史(复用公共函数)
output.push_str(&generate_messages_markdown(messages));
output
}
/// 格式化消息内容
///
/// 如果内容包含特殊字符,使用代码块包装
/// 使用比内容中最大连续反引号数量多1的反引号来包裹避免嵌套冲突
pub fn format_message_content(content: &str) -> String {
// 如果内容包含表格标记或换行符,使用代码块包裹以保留格式
if content.contains("| ") || content.contains('\n') {
// 计算内容中连续反引号的最大数量
let max_backticks = content
.chars()
.fold((0, 0), |(max_count, current_count), c| {
if c == '`' {
(max_count, current_count + 1)
} else {
(max_count.max(current_count), 0)
}
})
.0;
// 使用比最大数量多1的反引号来包裹至少3个
let fence = "`".repeat(max_backticks.max(3) + 1);
format!("{}\n{}\n{}", fence, content, fence)
} else {
content.to_string()
}
}
/// 转义 YAML 字符串
pub fn escape_yaml_string(s: &str) -> String {
if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') {
// 使用双引号包裹并转义内部的双引号
format!("\"{}\"", s.replace('"', "\\\""))
} else {
s.to_string()
}
}
/// 格式化时间戳
pub fn format_timestamp(ts: i64) -> String {
Local
.timestamp_millis_opt(ts)
.single()
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
.unwrap_or_else(|| format!("{}", ts))
}
/// 生成消息历史的 Markdown 内容
///
/// 这是一个通用函数,可被 Session 和 Topic 的保存逻辑复用
pub fn generate_messages_markdown(messages: &[crate::bus::ChatMessage]) -> String {
let mut output = String::new();
output.push_str("# Message History\n\n");
for (idx, msg) in messages.iter().enumerate() {
output.push_str(&format!("## Message {}\n\n", idx + 1));
output.push_str(&format!("**Role:** {}\n\n", msg.role));
output.push_str(&format!("**ID:** {}\n\n", msg.id));
output.push_str(&format!(
"**Time:** {}\n\n",
format_timestamp(msg.timestamp)
));
if let Some(ref ctx) = msg.system_context {
output.push_str(&format!("**System Context:** `{}`\n\n", ctx));
}
if let Some(ref tool_name) = msg.tool_name {
output.push_str(&format!("**Tool Name:** `{}`\n\n", tool_name));
}
if let Some(ref tool_call_id) = msg.tool_call_id {
output.push_str(&format!("**Tool Call ID:** `{}`\n\n", tool_call_id));
}
if let Some(ref reasoning) = msg.reasoning_content {
output.push_str("### Reasoning\n\n");
output.push_str("```\n");
output.push_str(reasoning);
output.push_str("\n```\n\n");
}
// Content
output.push_str("### Content\n\n");
if msg.content.is_empty() {
output.push_str("*empty*\n\n");
} else {
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
}
// Tool calls
if let Some(ref calls) = msg.tool_calls {
if !calls.is_empty() {
output.push_str("### Tool Calls\n\n");
for call in calls {
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
output.push_str(" ```json\n");
let args_json = serde_json::to_string_pretty(&call.arguments)
.unwrap_or_else(|_| call.arguments.to_string());
for line in args_json.lines() {
output.push_str(&format!(" {}\n", line));
}
output.push_str(" ```\n");
}
output.push('\n');
}
}
// Media refs
if !msg.media_refs.is_empty() {
output.push_str("### Media References\n\n");
for media_ref in &msg.media_refs {
output.push_str(&format!("- `{}`\n", media_ref));
}
output.push('\n');
}
output.push_str("---\n\n");
}
output
}
/// 生成系统提示词部分的 Markdown
pub fn generate_system_prompt_markdown(system_prompt: &Option<SystemPrompt>) -> String {
let mut output = String::new();
output.push_str("# System Prompt\n\n");
if let Some(prompt) = system_prompt {
output.push_str("```\n");
output.push_str(&prompt.content);
output.push_str("\n```\n\n");
} else {
output.push_str("*No system prompt available*\n\n");
}
output
}
/// 解析文件路径
///
/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名,
/// 保存到用户主目录下的 .picobot/sessions/ 目录
pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
match filepath {
Some(path) => PathBuf::from(path),
None => {
// 生成安全标题(替换特殊字符)
let safe_title = record
.title
.replace(' ', "_")
.replace('/', "_")
.replace('\\', "_")
.replace(':', "_")
.replace('<', "_")
.replace('>', "_")
.replace('|', "_")
.replace('?', "_")
.replace('*', "_")
.replace('"', "_");
// 使用标题或 session_id 作为文件名
let base_name = if safe_title.is_empty() {
format!("session_{}", &record.id[..8.min(record.id.len())])
} else {
safe_title
};
// 添加时间戳
let timestamp = Local::now().format("%Y%m%d_%H%M%S");
let filename = format!("{}_{}.md", base_name, timestamp);
// 保存到用户主目录下的 .picobot/sessions/ 目录
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".picobot")
.join("sessions")
.join(filename)
}
}
}
/// InChat 保存会话命令处理器
///
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
pub struct SaveSessionInChatHandler {
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
}
impl SaveSessionInChatHandler {
/// 创建新的 InChat 保存会话命令处理器
pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self {
store,
task_repository,
system_prompt_provider,
}
}
}
#[async_trait]
impl InChatCommandHandler for SaveSessionInChatHandler {
fn can_handle(&self, cmd: &Command) -> bool {
matches!(cmd, Command::SaveSession { .. })
}
async fn handle(
&self,
cmd: Command,
inbound: &InboundMessage,
session_manager: &crate::gateway::session::SessionManager,
) -> Result<Option<String>, AgentError> {
let Command::SaveSession { filepath, include_all, include_subagents } = cmd else {
return Ok(None);
};
// 通过 session_manager 获取 session
let session = match session_manager.get(&inbound.channel).await {
Some(s) => s,
None => {
tracing::error!("Session not found for channel: {}", inbound.channel);
return Ok(Some("Session not found".to_string()));
}
};
let session_guard = session.lock().await;
let session_id = session_guard.persistent_session_id(&inbound.chat_id);
// 调用公共函数
let result = save_session_to_file(
&session_id,
filepath,
include_all,
include_subagents,
&*self.store,
Some(self.task_repository.as_ref()),
&*self.system_prompt_provider,
)
.await;
// 返回成功或失败消息
match result {
Ok(output_path) => {
let msg = format!("Session saved to: {}", output_path.display());
tracing::info!("{}", msg);
Ok(Some(msg))
}
Err(error) => {
let msg = format!("Failed to save session: {}", error);
tracing::error!("{}", msg);
Ok(Some(msg))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::{SessionRecord, SessionStore};
fn create_test_record(id: &str, title: &str) -> SessionRecord {
SessionRecord {
id: id.to_string(),
title: title.to_string(),
channel_name: "cli".to_string(),
chat_id: id.to_string(),
summary: None,
created_at: 1705312800000, // 2024-01-15 10:00:00
updated_at: 1705316400000, // 2024-01-15 11:00:00
last_active_at: 1705316400000,
archived_at: None,
deleted_at: None,
message_count: 0,
user_turn_count: 0,
agent_prompt_reinjection_count: 0,
}
}
#[test]
fn test_resolve_filepath_with_custom_path() {
let record = create_test_record("test-123", "My Session");
let path = resolve_filepath(Some("/custom/path/file.md".to_string()), &record);
assert_eq!(path, PathBuf::from("/custom/path/file.md"));
}
#[test]
fn test_resolve_filepath_generates_filename_with_title() {
let record = create_test_record("test-123", "My Session");
let path = resolve_filepath(None, &record);
let filename = path.file_name().unwrap().to_str().unwrap();
assert!(filename.starts_with("My_Session_"));
assert!(filename.ends_with(".md"));
}
#[test]
fn test_resolve_filepath_generates_filename_with_id_when_title_empty() {
let record = create_test_record("abc12345-xyz", "");
let path = resolve_filepath(None, &record);
let filename = path.file_name().unwrap().to_str().unwrap();
assert!(filename.starts_with("session_abc123"));
assert!(filename.ends_with(".md"));
}
#[test]
fn test_escape_yaml_string() {
assert_eq!(escape_yaml_string("simple"), "simple");
assert_eq!(escape_yaml_string("with: colon"), "\"with: colon\"");
assert_eq!(escape_yaml_string("with \"quote\""), "\"with \\\"quote\\\"\"");
}
#[test]
fn test_format_message_content() {
// 普通单行文本 - 原样返回
assert_eq!(format_message_content("hello"), "hello");
// 单行包含反引号 - 原样返回(单行不需要包裹)
assert_eq!(format_message_content("`code`"), "`code`");
// 包含换行符 - 使用4个反引号包裹最小
assert_eq!(
format_message_content("line1\nline2\nline3"),
"````\nline1\nline2\nline3\n````"
);
// 包含表格标记 - 使用4个反引号包裹
assert_eq!(
format_message_content("| col1 | col2 |"),
"````\n| col1 | col2 |\n````"
);
// 多行内容包含3个反引号代码块标记- 使用4个反引号包裹
assert_eq!(
format_message_content("```code```\nmore"),
"````\n```code```\nmore\n````"
);
// 多行内容包含多行代码块
assert_eq!(
format_message_content("```\ncode\n```\nmore"),
"````\n```\ncode\n```\nmore\n````"
);
// 多行内容包含4个反引号 - 使用5个反引号包裹
assert_eq!(
format_message_content("````code````\nmore"),
"`````\n````code````\nmore\n`````"
);
}
#[test]
fn test_generate_markdown_structure() {
let record = create_test_record("test-123", "Test Session");
let messages = vec![crate::bus::ChatMessage::system("System prompt here")];
let markdown = generate_markdown(&record, &None, &messages);
assert!(markdown.contains("---"));
assert!(markdown.contains("title:"));
assert!(markdown.contains("session_id: test-123"));
assert!(markdown.contains("# System Prompt"));
assert!(markdown.contains("# Message History"));
assert!(markdown.contains("## Message 1"));
assert!(markdown.contains("**Role:** system"));
}
#[test]
fn test_can_handle() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let task_repository = Arc::new(crate::tools::task::repository::InMemoryTaskRepository::new());
let provider = Arc::new(TestSystemPromptProvider);
let handler = SaveSessionCommandHandler::new(store, task_repository, provider);
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false }));
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None, include_subagents: false }));
}
/// 测试用的系统提示词提供者
struct TestSystemPromptProvider;
impl SystemPromptProvider for TestSystemPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
Some(SystemPrompt {
content: "Test system prompt".to_string(),
context: Some("test".to_string()),
})
}
}
}