Compare commits

...

3 Commits

14 changed files with 438 additions and 65 deletions

View File

@ -44,13 +44,30 @@ impl InputAdapter for ChannelInputAdapter {
if trimmed == "/save" { if trimmed == "/save" {
return Ok(Some(Command::SaveTopic { return Ok(Some(Command::SaveTopic {
filepath: None, filepath: None,
include_subagents: false,
})); }));
} }
if let Some(filepath) = trimmed.strip_prefix("/save ") { if let Some(args) = trimmed.strip_prefix("/save ") {
let filepath = filepath.trim(); let args = args.trim();
let parts: Vec<&str> = args.split_whitespace().collect();
// 解析参数
let mut include_subagents = false;
let mut filepath = None;
for part in parts {
if part == "+sub" {
include_subagents = true;
} else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveTopic { return Ok(Some(Command::SaveTopic {
filepath: Some(filepath.to_string()), filepath,
include_subagents,
})); }));
} }
@ -59,24 +76,35 @@ impl InputAdapter for ChannelInputAdapter {
return Ok(Some(Command::SaveSession { return Ok(Some(Command::SaveSession {
filepath: None, filepath: None,
include_all: false, include_all: false,
include_subagents: false,
})); }));
} }
if let Some(args) = trimmed.strip_prefix("/save-session ") { if let Some(args) = trimmed.strip_prefix("/save-session ") {
let args = args.trim(); let args = args.trim();
// 解析参数:可能是 "all"、路径、或 "all 路径" let parts: Vec<&str> = args.split_whitespace().collect();
let (include_all, filepath) = if args == "all" {
// /save-session all - 保存全部消息 // 解析参数
(true, None) let mut include_all = false;
} else if args.starts_with("all ") { let mut include_subagents = false;
// /save-session all <filepath> - 保存全部消息到指定路径 let mut filepath = None;
let path = args[4..].trim();
(true, Some(path.to_string())) for part in parts {
} else { if part == "all" {
// /save-session <filepath> - 保存活跃消息到指定路径 include_all = true;
(false, Some(args.to_string())) } else if part == "+sub" {
}; include_subagents = true;
return Ok(Some(Command::SaveSession { filepath, include_all })); } else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveSession {
filepath,
include_all,
include_subagents,
}));
} }
// 解析 /list 命令 // 解析 /list 命令

View File

@ -45,13 +45,30 @@ impl InputAdapter for CliInputAdapter {
if trimmed == "/save" { if trimmed == "/save" {
return Ok(Some(Command::SaveTopic { return Ok(Some(Command::SaveTopic {
filepath: None, filepath: None,
include_subagents: false,
})); }));
} }
if let Some(filepath) = trimmed.strip_prefix("/save ") { if let Some(args) = trimmed.strip_prefix("/save ") {
let filepath = filepath.trim(); let args = args.trim();
let parts: Vec<&str> = args.split_whitespace().collect();
// 解析参数
let mut include_subagents = false;
let mut filepath = None;
for part in parts {
if part == "+sub" {
include_subagents = true;
} else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveTopic { return Ok(Some(Command::SaveTopic {
filepath: Some(filepath.to_string()), filepath,
include_subagents,
})); }));
} }
@ -60,24 +77,35 @@ impl InputAdapter for CliInputAdapter {
return Ok(Some(Command::SaveSession { return Ok(Some(Command::SaveSession {
filepath: None, filepath: None,
include_all: false, include_all: false,
include_subagents: false,
})); }));
} }
if let Some(args) = trimmed.strip_prefix("/save-session ") { if let Some(args) = trimmed.strip_prefix("/save-session ") {
let args = args.trim(); let args = args.trim();
// 解析参数:可能是 "all"、路径、或 "all 路径" let parts: Vec<&str> = args.split_whitespace().collect();
let (include_all, filepath) = if args == "all" {
// /save-session all - 保存全部消息 // 解析参数
(true, None) let mut include_all = false;
} else if args.starts_with("all ") { let mut include_subagents = false;
// /save-session all <filepath> - 保存全部消息到指定路径 let mut filepath = None;
let path = args[4..].trim();
(true, Some(path.to_string())) for part in parts {
} else { if part == "all" {
// /save-session <filepath> - 保存活跃消息到指定路径 include_all = true;
(false, Some(args.to_string())) } else if part == "+sub" {
}; include_subagents = true;
return Ok(Some(Command::SaveSession { filepath, include_all })); } else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveSession {
filepath,
include_all,
include_subagents,
}));
} }
// 解析 /list 命令 // 解析 /list 命令

View File

@ -11,6 +11,7 @@ pub mod switch_session;
pub use save_session::{ pub use save_session::{
escape_yaml_string, format_message_content, format_timestamp, escape_yaml_string, format_message_content, format_timestamp,
generate_messages_markdown, generate_system_prompt_markdown, generate_messages_markdown, generate_system_prompt_markdown,
generate_subagent_tasks_markdown, load_subagent_data, SubagentTaskData,
}; };
use crate::bus::ChatMessage; use crate::bus::ChatMessage;

View File

@ -5,6 +5,7 @@ use crate::command::handler::{CommandHandler, CommandMetadata, InChatCommandHand
use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command; use crate::command::Command;
use crate::storage::{SessionRecord, SessionStore}; use crate::storage::{SessionRecord, SessionStore};
use crate::tools::task::repository::TaskRepository;
use crate::agent::AgentError; use crate::agent::AgentError;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{Local, TimeZone}; use chrono::{Local, TimeZone};
@ -17,8 +18,10 @@ use std::sync::Arc;
/// * `session_id` - 会话ID /// * `session_id` - 会话ID
/// * `filepath` - 可选的文件路径 /// * `filepath` - 可选的文件路径
/// * `include_all` - 是否包含 cutoff 之前的所有消息 /// * `include_all` - 是否包含 cutoff 之前的所有消息
/// * `include_subagents` - 是否包含子智能体消息
/// * `store` - 会话存储 /// * `store` - 会话存储
/// * `provider_config` - LLM提供者配置 /// * `task_repository` - 任务存储(可选,用于查询子智能体)
/// * `system_prompt_provider` - 系统提示词提供者
/// ///
/// # Returns /// # Returns
/// 返回保存的文件路径 /// 返回保存的文件路径
@ -26,7 +29,9 @@ pub async fn save_session_to_file(
session_id: &str, session_id: &str,
filepath: Option<String>, filepath: Option<String>,
include_all: bool, include_all: bool,
include_subagents: bool,
store: &SessionStore, store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
system_prompt_provider: &dyn SystemPromptProvider, system_prompt_provider: &dyn SystemPromptProvider,
) -> Result<PathBuf, String> { ) -> Result<PathBuf, String> {
// 获取会话记录 // 获取会话记录
@ -46,6 +51,13 @@ pub async fn save_session_to_file(
.map_err(|e| format!("Failed to load messages: {}", e))? .map_err(|e| format!("Failed to load messages: {}", e))?
}; };
// 加载子智能体消息(如果启用)
let subagent_data = if include_subagents {
load_subagent_data(session_id, store, task_repository).await
} else {
Vec::new()
};
// 计算用户消息数(用于系统提示词构建) // 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count(); let user_message_count = messages.iter().filter(|m| m.role == "user").count();
@ -53,7 +65,7 @@ pub async fn save_session_to_file(
let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count); let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count);
// 生成 Markdown 内容 // 生成 Markdown 内容
let markdown = generate_markdown(&record, &system_prompt, &messages); let markdown = generate_markdown_with_subagents(&record, &system_prompt, &messages, &subagent_data);
// 确定输出路径 // 确定输出路径
let output_path = resolve_filepath(filepath, &record); let output_path = resolve_filepath(filepath, &record);
@ -78,6 +90,7 @@ pub async fn save_session_to_file(
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 /// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
pub struct SaveSessionCommandHandler { pub struct SaveSessionCommandHandler {
store: Arc<SessionStore>, store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>, system_prompt_provider: Arc<dyn SystemPromptProvider>,
} }
@ -86,10 +99,16 @@ impl SaveSessionCommandHandler {
/// ///
/// # Arguments /// # Arguments
/// * `store` - 会话存储 /// * `store` - 会话存储
/// * `system_prompt_provider` - 系统提示词提供者(负责构建完整的系统提示词) /// * `task_repository` - 任务存储(用于查询子智能体)
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self { /// * `system_prompt_provider` - 系统提示词提供者
pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self { Self {
store, store,
task_repository,
system_prompt_provider, system_prompt_provider,
} }
} }
@ -121,8 +140,8 @@ impl CommandHandler for SaveSessionCommandHandler {
ctx: CommandContext, ctx: CommandContext,
) -> Result<CommandResponse, CommandError> { ) -> Result<CommandResponse, CommandError> {
match cmd { match cmd {
Command::SaveSession { filepath, include_all } => { Command::SaveSession { filepath, include_all, include_subagents } => {
handle_save_session(self, filepath, include_all, ctx).await handle_save_session(self, filepath, include_all, include_subagents, ctx).await
} }
_ => unreachable!(), _ => unreachable!(),
} }
@ -134,12 +153,14 @@ async fn handle_save_session(
handler: &SaveSessionCommandHandler, handler: &SaveSessionCommandHandler,
filepath: Option<String>, filepath: Option<String>,
include_all: bool, include_all: bool,
include_subagents: bool,
ctx: CommandContext, ctx: CommandContext,
) -> Result<CommandResponse, CommandError> { ) -> Result<CommandResponse, CommandError> {
tracing::debug!( tracing::debug!(
ctx_session_id = ?ctx.session_id, ctx_session_id = ?ctx.session_id,
ctx_chat_id = ?ctx.chat_id, ctx_chat_id = ?ctx.chat_id,
channel = %ctx.channel_name, channel = %ctx.channel_name,
include_subagents = include_subagents,
"SaveSession command received" "SaveSession command received"
); );
@ -174,7 +195,9 @@ async fn handle_save_session(
session_id, session_id,
filepath, filepath,
include_all, include_all,
include_subagents,
&*handler.store, &*handler.store,
Some(handler.task_repository.as_ref()),
&*handler.system_prompt_provider, &*handler.system_prompt_provider,
) )
.await .await
@ -202,6 +225,174 @@ async fn handle_save_session(
.with_metadata("message_count", &message_count.to_string())) .with_metadata("message_count", &message_count.to_string()))
} }
/// 子智能体任务数据
#[derive(Debug)]
pub struct SubagentTaskData {
pub task_id: String,
pub session_id: String,
pub description: String,
pub subagent_type: String,
pub state: String,
pub created_at: i64,
pub messages: Vec<crate::bus::ChatMessage>,
}
/// 加载子智能体数据
pub async fn load_subagent_data(
parent_session_id: &str,
store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
) -> Vec<SubagentTaskData> {
let Some(repo) = task_repository else {
return Vec::new();
};
// 获取所有子任务
let tasks = match repo.list_tasks_for_session(parent_session_id).await {
Ok(tasks) => tasks,
Err(e) => {
tracing::warn!(error = %e, "Failed to list tasks for session");
return Vec::new();
}
};
let mut result = Vec::new();
for task in tasks {
// 加载子智能体的消息
let messages = match store.load_all_messages(&task.session_id) {
Ok(msgs) => msgs,
Err(e) => {
tracing::warn!(error = %e, task_id = %task.id, "Failed to load subagent messages");
Vec::new()
}
};
result.push(SubagentTaskData {
task_id: task.id,
session_id: task.session_id,
description: task.description,
subagent_type: task.subagent_type.as_str().to_string(),
state: format!("{:?}", task.state),
created_at: task.created_at,
messages,
});
}
result
}
/// 生成 Markdown 内容(包含子智能体)
pub fn generate_markdown_with_subagents(
record: &SessionRecord,
system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage],
subagent_data: &[SubagentTaskData],
) -> String {
let mut output = String::new();
// YAML frontmatter
output.push_str("---\n");
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
output.push_str(&format!("session_id: {}\n", record.id));
output.push_str(&format!("channel: {}\n", record.channel_name));
output.push_str(&format!("chat_id: {}\n", record.chat_id));
output.push_str(&format!(
"created_at: {}\n",
format_timestamp(record.created_at)
));
output.push_str(&format!(
"updated_at: {}\n",
format_timestamp(record.updated_at)
));
output.push_str(&format!(
"last_active_at: {}\n",
format_timestamp(record.last_active_at)
));
output.push_str(&format!("message_count: {}\n", messages.len()));
if !subagent_data.is_empty() {
output.push_str(&format!("subagent_count: {}\n", subagent_data.len()));
}
output.push_str("---\n\n");
// 系统提示词
output.push_str(&generate_system_prompt_markdown(system_prompt));
// 子智能体任务(如果有)
if !subagent_data.is_empty() {
output.push_str(&generate_subagent_tasks_markdown(subagent_data));
}
// 主会话消息历史
output.push_str(&generate_messages_markdown(messages));
output
}
/// 生成子智能体任务 Markdown
pub fn generate_subagent_tasks_markdown(subagent_data: &[SubagentTaskData]) -> String {
let mut output = String::new();
output.push_str("# Subagent Tasks\n\n");
for task in subagent_data {
output.push_str(&format!("## Task: {} ({})", task.description, task.subagent_type));
output.push('\n');
output.push_str(&format!("**Task ID:** `{}`\n\n", task.task_id));
output.push_str(&format!("**Session ID:** `{}`\n\n", task.session_id));
output.push_str(&format!("**Status:** {}\n\n", task.state));
output.push_str(&format!("**Created:** {}\n\n", format_timestamp(task.created_at)));
output.push_str(&format!("**Message Count:** {}\n\n", task.messages.len()));
// 子智能体消息
if !task.messages.is_empty() {
output.push_str("### Messages\n\n");
for (idx, msg) in task.messages.iter().enumerate() {
output.push_str(&format!("#### Message {}\n\n", idx + 1));
output.push_str(&format!("**Role:** {}\n\n", msg.role));
output.push_str(&format!("**Time:** {}\n\n", format_timestamp(msg.timestamp)));
if let Some(ref reasoning) = msg.reasoning_content {
output.push_str("**Reasoning:**\n");
output.push_str("```\n");
output.push_str(reasoning);
output.push_str("\n```\n\n");
}
output.push_str("**Content:**\n\n");
if msg.content.is_empty() {
output.push_str("*empty*\n\n");
} else {
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
}
// 工具调用
if let Some(ref calls) = msg.tool_calls {
if !calls.is_empty() {
output.push_str("**Tool Calls:**\n\n");
for call in calls {
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
output.push_str(" ```json\n");
let args_json = serde_json::to_string_pretty(&call.arguments)
.unwrap_or_else(|_| call.arguments.to_string());
for line in args_json.lines() {
output.push_str(&format!(" {}\n", line));
}
output.push_str(" ```\n");
}
output.push('\n');
}
}
output.push_str("---\n\n");
}
}
output.push_str("---\n\n");
}
output
}
/// 构建系统提示词 /// 构建系统提示词
fn build_system_prompt( fn build_system_prompt(
provider: &dyn SystemPromptProvider, provider: &dyn SystemPromptProvider,
@ -440,14 +631,20 @@ pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> Pat
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令 /// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
pub struct SaveSessionInChatHandler { pub struct SaveSessionInChatHandler {
store: Arc<SessionStore>, store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>, system_prompt_provider: Arc<dyn SystemPromptProvider>,
} }
impl SaveSessionInChatHandler { impl SaveSessionInChatHandler {
/// 创建新的 InChat 保存会话命令处理器 /// 创建新的 InChat 保存会话命令处理器
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self { pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self { Self {
store, store,
task_repository,
system_prompt_provider, system_prompt_provider,
} }
} }
@ -465,7 +662,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
inbound: &InboundMessage, inbound: &InboundMessage,
session_manager: &crate::gateway::session::SessionManager, session_manager: &crate::gateway::session::SessionManager,
) -> Result<Option<String>, AgentError> { ) -> Result<Option<String>, AgentError> {
let Command::SaveSession { filepath, include_all } = cmd else { let Command::SaveSession { filepath, include_all, include_subagents } = cmd else {
return Ok(None); return Ok(None);
}; };
@ -486,7 +683,9 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
&session_id, &session_id,
filepath, filepath,
include_all, include_all,
include_subagents,
&*self.store, &*self.store,
Some(self.task_repository.as_ref()),
&*self.system_prompt_provider, &*self.system_prompt_provider,
) )
.await; .await;
@ -623,11 +822,12 @@ mod tests {
#[test] #[test]
fn test_can_handle() { fn test_can_handle() {
let store = Arc::new(SessionStore::in_memory().unwrap()); let store = Arc::new(SessionStore::in_memory().unwrap());
let task_repository = Arc::new(crate::tools::task::repository::InMemoryTaskRepository::new());
let provider = Arc::new(TestSystemPromptProvider); let provider = Arc::new(TestSystemPromptProvider);
let handler = SaveSessionCommandHandler::new(store, provider); let handler = SaveSessionCommandHandler::new(store, task_repository, provider);
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false })); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true })); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false }));
assert!(!handler.can_handle(&Command::CreateSession { title: None })); assert!(!handler.can_handle(&Command::CreateSession { title: None }));
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None })); assert!(!handler.can_handle(&Command::SaveTopic { filepath: None }));
} }

View File

@ -4,12 +4,14 @@ use crate::command::context::CommandContext;
use crate::command::handler::{CommandHandler, CommandMetadata}; use crate::command::handler::{CommandHandler, CommandMetadata};
use crate::command::handlers::{ use crate::command::handlers::{
escape_yaml_string, format_timestamp, generate_messages_markdown, escape_yaml_string, format_timestamp, generate_messages_markdown,
generate_system_prompt_markdown, get_messages_from_session, generate_subagent_tasks_markdown, generate_system_prompt_markdown,
get_messages_from_session, load_subagent_data, SubagentTaskData,
}; };
use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command; use crate::command::Command;
use crate::gateway::session::SessionManager; use crate::gateway::session::SessionManager;
use crate::storage::{SessionStore, TopicRecord}; use crate::storage::{SessionStore, TopicRecord};
use crate::tools::task::repository::TaskRepository;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Local; use chrono::Local;
use std::path::PathBuf; use std::path::PathBuf;
@ -19,9 +21,11 @@ use std::sync::Arc;
pub async fn save_topic_to_file( pub async fn save_topic_to_file(
topic_id: &str, topic_id: &str,
filepath: Option<String>, filepath: Option<String>,
include_subagents: bool,
store: &SessionStore, store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
system_prompt_provider: &dyn SystemPromptProvider, system_prompt_provider: &dyn SystemPromptProvider,
messages: &[ChatMessage], // ← 从外部传入的消息(已压缩的 active history messages: &[ChatMessage],
) -> Result<PathBuf, String> { ) -> Result<PathBuf, String> {
// 获取话题记录 // 获取话题记录
let topic = store let topic = store
@ -38,8 +42,15 @@ pub async fn save_topic_to_file(
let user_message_count = messages.iter().filter(|m| m.role == "user").count(); let user_message_count = messages.iter().filter(|m| m.role == "user").count();
let system_prompt = build_system_prompt(system_prompt_provider, &session, user_message_count); let system_prompt = build_system_prompt(system_prompt_provider, &session, user_message_count);
// 加载子智能体消息(如果启用)
let subagent_data = if include_subagents {
load_subagent_data(&topic.session_id, store, task_repository).await
} else {
Vec::new()
};
// 生成 Markdown 内容 // 生成 Markdown 内容
let markdown = generate_topic_markdown(&topic, &system_prompt, messages); let markdown = generate_topic_markdown(&topic, &system_prompt, messages, &subagent_data);
// 确定输出路径 // 确定输出路径
let output_path = resolve_topic_filepath(filepath, &topic); let output_path = resolve_topic_filepath(filepath, &topic);
@ -79,6 +90,7 @@ fn generate_topic_markdown(
topic: &TopicRecord, topic: &TopicRecord,
system_prompt: &Option<SystemPrompt>, system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage], messages: &[crate::bus::ChatMessage],
subagent_data: &[SubagentTaskData],
) -> String { ) -> String {
let mut output = String::new(); let mut output = String::new();
@ -103,11 +115,19 @@ fn generate_topic_markdown(
format_timestamp(topic.last_active_at) format_timestamp(topic.last_active_at)
)); ));
output.push_str(&format!("message_count: {}\n", messages.len())); output.push_str(&format!("message_count: {}\n", messages.len()));
if !subagent_data.is_empty() {
output.push_str(&format!("subagent_count: {}\n", subagent_data.len()));
}
output.push_str("---\n\n"); output.push_str("---\n\n");
// 系统提示词(复用公共函数) // 系统提示词(复用公共函数)
output.push_str(&generate_system_prompt_markdown(system_prompt)); output.push_str(&generate_system_prompt_markdown(system_prompt));
// 子智能体任务(如果有)
if !subagent_data.is_empty() {
output.push_str(&generate_subagent_tasks_markdown(subagent_data));
}
// 消息历史(复用公共函数) // 消息历史(复用公共函数)
output.push_str(&generate_messages_markdown(messages)); output.push_str(&generate_messages_markdown(messages));
@ -153,6 +173,7 @@ fn resolve_topic_filepath(filepath: Option<String>, topic: &TopicRecord) -> Path
/// 保存话题命令处理器 /// 保存话题命令处理器
pub struct SaveTopicCommandHandler { pub struct SaveTopicCommandHandler {
store: Arc<SessionStore>, store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>, system_prompt_provider: Arc<dyn SystemPromptProvider>,
session_manager: Option<SessionManager>, session_manager: Option<SessionManager>,
} }
@ -160,10 +181,12 @@ pub struct SaveTopicCommandHandler {
impl SaveTopicCommandHandler { impl SaveTopicCommandHandler {
pub fn new( pub fn new(
store: Arc<SessionStore>, store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>, system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self { ) -> Self {
Self { Self {
store, store,
task_repository,
system_prompt_provider, system_prompt_provider,
session_manager: None, session_manager: None,
} }
@ -195,7 +218,9 @@ impl CommandHandler for SaveTopicCommandHandler {
ctx: CommandContext, ctx: CommandContext,
) -> Result<CommandResponse, CommandError> { ) -> Result<CommandResponse, CommandError> {
match cmd { match cmd {
Command::SaveTopic { filepath } => handle_save_topic(self, filepath, ctx).await, Command::SaveTopic { filepath, include_subagents } => {
handle_save_topic(self, filepath, include_subagents, ctx).await
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -204,12 +229,14 @@ impl CommandHandler for SaveTopicCommandHandler {
async fn handle_save_topic( async fn handle_save_topic(
handler: &SaveTopicCommandHandler, handler: &SaveTopicCommandHandler,
filepath: Option<String>, filepath: Option<String>,
include_subagents: bool,
ctx: CommandContext, ctx: CommandContext,
) -> Result<CommandResponse, CommandError> { ) -> Result<CommandResponse, CommandError> {
tracing::debug!( tracing::debug!(
ctx_topic_id = ?ctx.topic_id, ctx_topic_id = ?ctx.topic_id,
ctx_session_id = ?ctx.session_id, ctx_session_id = ?ctx.session_id,
channel = %ctx.channel_name, channel = %ctx.channel_name,
include_subagents = include_subagents,
"SaveTopic command received" "SaveTopic command received"
); );
@ -238,7 +265,9 @@ async fn handle_save_topic(
let output_path = save_topic_to_file( let output_path = save_topic_to_file(
topic_id, topic_id,
filepath, filepath,
include_subagents,
&*handler.store, &*handler.store,
Some(handler.task_repository.as_ref()),
&*handler.system_prompt_provider, &*handler.system_prompt_provider,
&messages, &messages,
) )

View File

@ -14,11 +14,15 @@ pub enum Command {
/// 创建新话题(在同一个 Session 内) /// 创建新话题(在同一个 Session 内)
CreateSession { title: Option<String> }, CreateSession { title: Option<String> },
/// 保存当前话题内容到 Markdown 文件 /// 保存当前话题内容到 Markdown 文件
SaveTopic { filepath: Option<String> }, SaveTopic {
filepath: Option<String>,
include_subagents: bool,
},
/// 保存会话内容到 Markdown 文件 /// 保存会话内容到 Markdown 文件
SaveSession { SaveSession {
filepath: Option<String>, filepath: Option<String>,
include_all: bool, include_all: bool,
include_subagents: bool,
}, },
/// 列出当前 Session 的所有话题 /// 列出当前 Session 的所有话题
ListSessions { include_archived: bool }, ListSessions { include_archived: bool },

View File

@ -38,6 +38,7 @@ use crate::config::LLMProviderConfig;
use crate::logging; use crate::logging;
use crate::scheduler::Scheduler; use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::tools::task::repository::TaskRepository;
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
use outbound_dispatcher::OutboundDispatcher; use outbound_dispatcher::OutboundDispatcher;
use processor::InboundProcessor; use processor::InboundProcessor;
@ -50,6 +51,7 @@ pub struct GatewayState {
pub session_manager: SessionManager, pub session_manager: SessionManager,
pub channel_manager: ChannelManager, pub channel_manager: ChannelManager,
pub bus: Arc<MessageBus>, pub bus: Arc<MessageBus>,
pub task_repository: Arc<dyn TaskRepository>,
} }
impl GatewayState { impl GatewayState {
@ -72,7 +74,7 @@ impl GatewayState {
let channel_manager = ChannelManager::new(); let channel_manager = ChannelManager::new();
let bus = channel_manager.bus(); let bus = channel_manager.bus();
let session_manager = build_session_manager_with_sender( let (session_manager, task_repository) = build_session_manager_with_sender(
agent_prompt_reinject_every, agent_prompt_reinject_every,
show_tool_results, show_tool_results,
config.time.timezone.clone(), config.time.timezone.clone(),
@ -91,6 +93,7 @@ impl GatewayState {
session_manager, session_manager,
channel_manager, channel_manager,
bus, bus,
task_repository,
}) })
} }

View File

@ -80,12 +80,14 @@ impl InboundProcessor {
// 注册 save_session 处理器 // 注册 save_session 处理器
command_router.register(Box::new(SaveSessionCommandHandler::new( command_router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(), store.clone(),
session_manager.task_repository(),
system_prompt_provider.clone(), system_prompt_provider.clone(),
))); )));
// 注册 save_topic 处理器 // 注册 save_topic 处理器
command_router.register(Box::new(SaveTopicCommandHandler::new( command_router.register(Box::new(SaveTopicCommandHandler::new(
store.clone(), store.clone(),
session_manager.task_repository(),
system_prompt_provider, system_prompt_provider,
).with_session_manager(session_manager.clone()))); ).with_session_manager(session_manager.clone())));

View File

@ -13,6 +13,7 @@ use crate::tools::{
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender, DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
SessionMessageSender, SubAgentRuntimeConfig, ToolRegistry, SessionMessageSender, SubAgentRuntimeConfig, ToolRegistry,
}; };
use crate::tools::task::repository::TaskRepository;
use super::agent_factory::AgentFactory; use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService; use super::cli_session::CliSessionService;
@ -35,7 +36,7 @@ pub(crate) fn build_session_manager(
task_config: TaskConfig, task_config: TaskConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>, session_ttl_hours: Option<u64>,
) -> Result<SessionManager, AgentError> { ) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
build_session_manager_with_sender( build_session_manager_with_sender(
agent_prompt_reinject_every, agent_prompt_reinject_every,
show_tool_results, show_tool_results,
@ -63,7 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
task_config: TaskConfig, task_config: TaskConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>, session_ttl_hours: Option<u64>,
) -> Result<SessionManager, AgentError> { ) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
let store = Arc::new( let store = Arc::new(
SessionStore::new() SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
@ -96,7 +97,7 @@ pub(crate) fn build_session_manager_with_sender(
); );
// 创建 SubAgentRuntime如果 task 工具启用) // 创建 SubAgentRuntime如果 task 工具启用)
let factory = if task_config.enabled { let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
let task_repository = Arc::new(InMemoryTaskRepository::new()); let task_repository = Arc::new(InMemoryTaskRepository::new());
let subagent_tools = Arc::new(factory.build_subagent_tools()); let subagent_tools = Arc::new(factory.build_subagent_tools());
@ -111,15 +112,16 @@ pub(crate) fn build_session_manager_with_sender(
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new( let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
runtime_config, runtime_config,
task_repository, task_repository.clone(),
conversations.clone(), conversations.clone(),
subagent_tools, subagent_tools,
provider_config.clone(), provider_config.clone(),
)); ));
factory.with_subagent_runtime(subagent_runtime) (factory.with_subagent_runtime(subagent_runtime), task_repository)
} else { } else {
factory // 如果 task 工具未启用,创建一个空的内存仓库
(factory, Arc::new(InMemoryTaskRepository::new()))
}; };
let tools = Arc::new(factory.build()); let tools = Arc::new(factory.build());
@ -151,7 +153,7 @@ pub(crate) fn build_session_manager_with_sender(
let memory_maintenance = let memory_maintenance =
MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone()); MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone());
Ok(SessionManager::from_services(SessionManagerServices { Ok((SessionManager::from_services(SessionManagerServices {
tools: tools as Arc<ToolRegistry>, tools: tools as Arc<ToolRegistry>,
skills, skills,
store, store,
@ -161,5 +163,6 @@ pub(crate) fn build_session_manager_with_sender(
messages, messages,
scheduled_tasks, scheduled_tasks,
memory_maintenance, memory_maintenance,
})) task_repository: task_repository.clone(),
}), task_repository))
} }

View File

@ -8,6 +8,7 @@ use crate::scheduler::ScheduledAgentTaskOptions;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository}; use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use crate::tools::task::repository::TaskRepository;
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -458,6 +459,7 @@ pub struct SessionManager {
messages: SessionMessageService, messages: SessionMessageService,
scheduled_tasks: ScheduledAgentTaskService, scheduled_tasks: ScheduledAgentTaskService,
memory_maintenance: MemoryMaintenanceCoordinator, memory_maintenance: MemoryMaintenanceCoordinator,
task_repository: Arc<dyn TaskRepository>,
} }
pub(crate) struct SessionManagerServices { pub(crate) struct SessionManagerServices {
@ -470,6 +472,7 @@ pub(crate) struct SessionManagerServices {
pub(crate) messages: SessionMessageService, pub(crate) messages: SessionMessageService,
pub(crate) scheduled_tasks: ScheduledAgentTaskService, pub(crate) scheduled_tasks: ScheduledAgentTaskService,
pub(crate) memory_maintenance: MemoryMaintenanceCoordinator, pub(crate) memory_maintenance: MemoryMaintenanceCoordinator,
pub(crate) task_repository: Arc<dyn TaskRepository>,
} }
impl SessionManager { impl SessionManager {
@ -484,6 +487,7 @@ impl SessionManager {
messages: services.messages, messages: services.messages,
scheduled_tasks: services.scheduled_tasks, scheduled_tasks: services.scheduled_tasks,
memory_maintenance: services.memory_maintenance, memory_maintenance: services.memory_maintenance,
task_repository: services.task_repository,
} }
} }
@ -511,6 +515,7 @@ impl SessionManager {
chat_history_ttl_hours, chat_history_ttl_hours,
session_ttl_hours, session_ttl_hours,
) )
.map(|(session_manager, _)| session_manager)
} }
pub fn tools(&self) -> Arc<ToolRegistry> { pub fn tools(&self) -> Arc<ToolRegistry> {
@ -525,6 +530,10 @@ impl SessionManager {
self.show_tool_results self.show_tool_results
} }
pub fn task_repository(&self) -> Arc<dyn TaskRepository> {
self.task_repository.clone()
}
pub fn skills(&self) -> Arc<SkillRuntime> { pub fn skills(&self) -> Arc<SkillRuntime> {
self.skills.clone() self.skills.clone()
} }

View File

@ -244,6 +244,7 @@ async fn handle_inbound(
router.register(Box::new(LoadSessionCommandHandler::new(store.clone()))); router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
router.register(Box::new(SaveSessionCommandHandler::new( router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(), store.clone(),
state.task_repository.clone(),
system_prompt_provider.clone(), system_prompt_provider.clone(),
))); )));
// 注册 help 处理器 // 注册 help 处理器

View File

@ -269,12 +269,22 @@ impl SessionStore {
chat_id: &str, chat_id: &str,
) -> Result<SessionRecord, StorageError> { ) -> Result<SessionRecord, StorageError> {
let session_id = persistent_session_id(channel_name, chat_id); let session_id = persistent_session_id(channel_name, chat_id);
if let Some(record) = self.get_session(&session_id)? { self.ensure_session(&session_id, channel_name, chat_id, &format!("{}:{}", channel_name, chat_id))
}
/// 确保指定 session_id 的会话存在(如果不存在则创建)
pub fn ensure_session(
&self,
session_id: &str,
channel_name: &str,
chat_id: &str,
title: &str,
) -> Result<SessionRecord, StorageError> {
if let Some(record) = self.get_session(session_id)? {
return Ok(record); return Ok(record);
} }
let now = current_timestamp(); let now = current_timestamp();
let title = format!("{}:{}", channel_name, chat_id);
let conn = self.conn.lock().expect("session db mutex poisoned"); let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute( conn.execute(
" "
@ -288,7 +298,7 @@ impl SessionStore {
)?; )?;
drop(conn); drop(conn);
self.get_session(&session_id)? self.get_session(session_id)?
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
} }

View File

@ -11,6 +11,15 @@ pub trait ConversationRepository: Send + Sync + 'static {
chat_id: &str, chat_id: &str,
) -> Result<SessionRecord, StorageError>; ) -> Result<SessionRecord, StorageError>;
/// 确保指定 session_id 的会话存在(如果不存在则创建)
fn ensure_session(
&self,
session_id: &str,
channel_name: &str,
chat_id: &str,
title: &str,
) -> Result<SessionRecord, StorageError>;
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>; fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>; fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
@ -139,6 +148,16 @@ impl ConversationRepository for super::SessionStore {
super::SessionStore::ensure_channel_session(self, channel_name, chat_id) super::SessionStore::ensure_channel_session(self, channel_name, chat_id)
} }
fn ensure_session(
&self,
session_id: &str,
channel_name: &str,
chat_id: &str,
title: &str,
) -> Result<SessionRecord, StorageError> {
super::SessionStore::ensure_session(self, session_id, channel_name, chat_id, title)
}
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> { fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
super::SessionStore::load_messages(self, session_id) super::SessionStore::load_messages(self, session_id)
} }

View File

@ -188,6 +188,13 @@ impl DefaultSubAgentRuntime {
match result { match result {
Ok(Ok(process_result)) => { Ok(Ok(process_result)) => {
// 保存子智能体产生的所有消息到数据库
for message in &process_result.emitted_messages {
if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message");
}
}
let final_message = process_result.final_response; let final_message = process_result.final_response;
Ok(TaskToolResult { Ok(TaskToolResult {
status: "success".to_string(), status: "success".to_string(),
@ -232,6 +239,13 @@ impl DefaultSubAgentRuntime {
match result { match result {
Ok(Ok(process_result)) => { Ok(Ok(process_result)) => {
// 保存子智能体产生的所有消息到数据库
for message in &process_result.emitted_messages {
if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message");
}
}
let final_message = process_result.final_response; let final_message = process_result.final_response;
Ok(TaskToolResult { Ok(TaskToolResult {
status: "success".to_string(), status: "success".to_string(),
@ -276,7 +290,18 @@ impl SubAgentRuntime for DefaultSubAgentRuntime {
task.subagent_type, task.subagent_type,
); );
// 3. 保存会话 // 3. 在 sessions 表中创建子智能体会话(确保外键约束满足)
let session_title = format!("Subagent: {}", task.description);
if let Err(e) = self.conversation_repository.ensure_session(
&session.session_id,
&session.parent_channel_name,
&session.parent_chat_id,
&session_title,
) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session");
}
// 4. 保存任务会话
self.task_repository.save_task_session(&session).await?; self.task_repository.save_task_session(&session).await?;
// 4. 构建子代理系统提示词 // 4. 构建子代理系统提示词
@ -340,21 +365,32 @@ impl SubAgentRuntime for DefaultSubAgentRuntime {
return Err(TaskError::InvalidParentSession); return Err(TaskError::InvalidParentSession);
} }
// 3. 构建恢复提示词 // 3. 确保 sessions 表中存在子智能体会话记录
let session_title = format!("Subagent: {}", session.description);
if let Err(e) = self.conversation_repository.ensure_session(
&session.session_id,
&session.parent_channel_name,
&session.parent_chat_id,
&session_title,
) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session on resume");
}
// 4. 构建恢复提示词
let system_prompt = SubagentPromptBuilder::build_resume_prompt( let system_prompt = SubagentPromptBuilder::build_resume_prompt(
&session.description, &session.description,
&additional_prompt, &additional_prompt,
); );
// 4. 创建子代理 // 5. 创建子代理
let agent = self.create_subagent(&session, system_prompt)?; let agent = self.create_subagent(&session, system_prompt)?;
// 5. 使用历史继续执行 // 6. 使用历史继续执行
let result = self let result = self
.execute_task_with_history(agent, &session, additional_prompt) .execute_task_with_history(agent, &session, additional_prompt)
.await; .await;
// 6. 更新会话状态 // 7. 更新会话状态
match result { match result {
Ok(tool_result) => { Ok(tool_result) => {
let mut session = session; let mut session = session;