From 49475783a2085361481cc54f9b5048edf0e0d7e8 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Wed, 20 May 2026 17:52:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=AD=90=E6=99=BA?= =?UTF-8?q?=E8=83=BD=E4=BD=93=E6=94=AF=E6=8C=81=E5=88=B0=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E8=AF=9D=E9=A2=98=E5=92=8C=E4=BC=9A=E8=AF=9D=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E6=8C=81=E4=B9=85?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/command/adapters/channel.rs | 60 ++++++-- src/command/adapters/cli.rs | 60 ++++++-- src/command/handlers/mod.rs | 1 + src/command/handlers/save_session.rs | 222 +++++++++++++++++++++++++-- src/command/handlers/save_topic.rs | 37 ++++- src/command/mod.rs | 6 +- src/gateway/mod.rs | 5 +- src/gateway/processor.rs | 2 + src/gateway/runtime.rs | 19 ++- src/gateway/session.rs | 9 ++ src/gateway/ws.rs | 1 + 11 files changed, 365 insertions(+), 57 deletions(-) diff --git a/src/command/adapters/channel.rs b/src/command/adapters/channel.rs index 27724e7..54f00da 100644 --- a/src/command/adapters/channel.rs +++ b/src/command/adapters/channel.rs @@ -44,13 +44,30 @@ impl InputAdapter for ChannelInputAdapter { if trimmed == "/save" { return Ok(Some(Command::SaveTopic { filepath: None, + include_subagents: false, })); } - if let Some(filepath) = trimmed.strip_prefix("/save ") { - let filepath = filepath.trim(); + if let Some(args) = trimmed.strip_prefix("/save ") { + let args = args.trim(); + let parts: Vec<&str> = args.split_whitespace().collect(); + + // 解析参数 + let mut include_subagents = false; + let mut filepath = None; + + for part in parts { + if part == "+sub" { + include_subagents = true; + } else if !part.is_empty() { + // 非特殊参数视为文件路径 + filepath = Some(part.to_string()); + } + } + return Ok(Some(Command::SaveTopic { - filepath: Some(filepath.to_string()), + filepath, + include_subagents, })); } @@ -59,24 +76,35 @@ impl InputAdapter for ChannelInputAdapter { return Ok(Some(Command::SaveSession { filepath: None, include_all: false, + include_subagents: false, })); } if let Some(args) = trimmed.strip_prefix("/save-session ") { let args = args.trim(); - // 解析参数:可能是 "all"、路径、或 "all 路径" - let (include_all, filepath) = if args == "all" { - // /save-session all - 保存全部消息 - (true, None) - } else if args.starts_with("all ") { - // /save-session all - 保存全部消息到指定路径 - let path = args[4..].trim(); - (true, Some(path.to_string())) - } else { - // /save-session - 保存活跃消息到指定路径 - (false, Some(args.to_string())) - }; - return Ok(Some(Command::SaveSession { filepath, include_all })); + let parts: Vec<&str> = args.split_whitespace().collect(); + + // 解析参数 + let mut include_all = false; + let mut include_subagents = false; + let mut filepath = None; + + for part in parts { + if part == "all" { + include_all = true; + } else if part == "+sub" { + include_subagents = true; + } else if !part.is_empty() { + // 非特殊参数视为文件路径 + filepath = Some(part.to_string()); + } + } + + return Ok(Some(Command::SaveSession { + filepath, + include_all, + include_subagents, + })); } // 解析 /list 命令 diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 4f88bae..2ab4628 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -45,13 +45,30 @@ impl InputAdapter for CliInputAdapter { if trimmed == "/save" { return Ok(Some(Command::SaveTopic { filepath: None, + include_subagents: false, })); } - if let Some(filepath) = trimmed.strip_prefix("/save ") { - let filepath = filepath.trim(); + if let Some(args) = trimmed.strip_prefix("/save ") { + let args = args.trim(); + let parts: Vec<&str> = args.split_whitespace().collect(); + + // 解析参数 + let mut include_subagents = false; + let mut filepath = None; + + for part in parts { + if part == "+sub" { + include_subagents = true; + } else if !part.is_empty() { + // 非特殊参数视为文件路径 + filepath = Some(part.to_string()); + } + } + return Ok(Some(Command::SaveTopic { - filepath: Some(filepath.to_string()), + filepath, + include_subagents, })); } @@ -60,24 +77,35 @@ impl InputAdapter for CliInputAdapter { return Ok(Some(Command::SaveSession { filepath: None, include_all: false, + include_subagents: false, })); } if let Some(args) = trimmed.strip_prefix("/save-session ") { let args = args.trim(); - // 解析参数:可能是 "all"、路径、或 "all 路径" - let (include_all, filepath) = if args == "all" { - // /save-session all - 保存全部消息 - (true, None) - } else if args.starts_with("all ") { - // /save-session all - 保存全部消息到指定路径 - let path = args[4..].trim(); - (true, Some(path.to_string())) - } else { - // /save-session - 保存活跃消息到指定路径 - (false, Some(args.to_string())) - }; - return Ok(Some(Command::SaveSession { filepath, include_all })); + let parts: Vec<&str> = args.split_whitespace().collect(); + + // 解析参数 + let mut include_all = false; + let mut include_subagents = false; + let mut filepath = None; + + for part in parts { + if part == "all" { + include_all = true; + } else if part == "+sub" { + include_subagents = true; + } else if !part.is_empty() { + // 非特殊参数视为文件路径 + filepath = Some(part.to_string()); + } + } + + return Ok(Some(Command::SaveSession { + filepath, + include_all, + include_subagents, + })); } // 解析 /list 命令 diff --git a/src/command/handlers/mod.rs b/src/command/handlers/mod.rs index b0d4f93..87f173e 100644 --- a/src/command/handlers/mod.rs +++ b/src/command/handlers/mod.rs @@ -11,6 +11,7 @@ pub mod switch_session; pub use save_session::{ escape_yaml_string, format_message_content, format_timestamp, generate_messages_markdown, generate_system_prompt_markdown, + generate_subagent_tasks_markdown, load_subagent_data, SubagentTaskData, }; use crate::bus::ChatMessage; diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index 8236ee4..4ceb4e5 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -5,6 +5,7 @@ use crate::command::handler::{CommandHandler, CommandMetadata, InChatCommandHand use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; use crate::storage::{SessionRecord, SessionStore}; +use crate::tools::task::repository::TaskRepository; use crate::agent::AgentError; use async_trait::async_trait; use chrono::{Local, TimeZone}; @@ -17,8 +18,10 @@ use std::sync::Arc; /// * `session_id` - 会话ID /// * `filepath` - 可选的文件路径 /// * `include_all` - 是否包含 cutoff 之前的所有消息 +/// * `include_subagents` - 是否包含子智能体消息 /// * `store` - 会话存储 -/// * `provider_config` - LLM提供者配置 +/// * `task_repository` - 任务存储(可选,用于查询子智能体) +/// * `system_prompt_provider` - 系统提示词提供者 /// /// # Returns /// 返回保存的文件路径 @@ -26,7 +29,9 @@ pub async fn save_session_to_file( session_id: &str, filepath: Option, include_all: bool, + include_subagents: bool, store: &SessionStore, + task_repository: Option<&dyn TaskRepository>, system_prompt_provider: &dyn SystemPromptProvider, ) -> Result { // 获取会话记录 @@ -46,6 +51,13 @@ pub async fn save_session_to_file( .map_err(|e| format!("Failed to load messages: {}", e))? }; + // 加载子智能体消息(如果启用) + let subagent_data = if include_subagents { + load_subagent_data(session_id, store, task_repository).await + } else { + Vec::new() + }; + // 计算用户消息数(用于系统提示词构建) let user_message_count = messages.iter().filter(|m| m.role == "user").count(); @@ -53,7 +65,7 @@ pub async fn save_session_to_file( let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count); // 生成 Markdown 内容 - let markdown = generate_markdown(&record, &system_prompt, &messages); + let markdown = generate_markdown_with_subagents(&record, &system_prompt, &messages, &subagent_data); // 确定输出路径 let output_path = resolve_filepath(filepath, &record); @@ -78,6 +90,7 @@ pub async fn save_session_to_file( /// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 pub struct SaveSessionCommandHandler { store: Arc, + task_repository: Arc, system_prompt_provider: Arc, } @@ -86,10 +99,16 @@ impl SaveSessionCommandHandler { /// /// # Arguments /// * `store` - 会话存储 - /// * `system_prompt_provider` - 系统提示词提供者(负责构建完整的系统提示词) - pub fn new(store: Arc, system_prompt_provider: Arc) -> Self { + /// * `task_repository` - 任务存储(用于查询子智能体) + /// * `system_prompt_provider` - 系统提示词提供者 + pub fn new( + store: Arc, + task_repository: Arc, + system_prompt_provider: Arc, + ) -> Self { Self { store, + task_repository, system_prompt_provider, } } @@ -121,8 +140,8 @@ impl CommandHandler for SaveSessionCommandHandler { ctx: CommandContext, ) -> Result { match cmd { - Command::SaveSession { filepath, include_all } => { - handle_save_session(self, filepath, include_all, ctx).await + Command::SaveSession { filepath, include_all, include_subagents } => { + handle_save_session(self, filepath, include_all, include_subagents, ctx).await } _ => unreachable!(), } @@ -134,12 +153,14 @@ async fn handle_save_session( handler: &SaveSessionCommandHandler, filepath: Option, include_all: bool, + include_subagents: bool, ctx: CommandContext, ) -> Result { tracing::debug!( ctx_session_id = ?ctx.session_id, ctx_chat_id = ?ctx.chat_id, channel = %ctx.channel_name, + include_subagents = include_subagents, "SaveSession command received" ); @@ -174,7 +195,9 @@ async fn handle_save_session( session_id, filepath, include_all, + include_subagents, &*handler.store, + Some(handler.task_repository.as_ref()), &*handler.system_prompt_provider, ) .await @@ -202,6 +225,174 @@ async fn handle_save_session( .with_metadata("message_count", &message_count.to_string())) } +/// 子智能体任务数据 +#[derive(Debug)] +pub struct SubagentTaskData { + pub task_id: String, + pub session_id: String, + pub description: String, + pub subagent_type: String, + pub state: String, + pub created_at: i64, + pub messages: Vec, +} + +/// 加载子智能体数据 +pub async fn load_subagent_data( + parent_session_id: &str, + store: &SessionStore, + task_repository: Option<&dyn TaskRepository>, +) -> Vec { + let Some(repo) = task_repository else { + return Vec::new(); + }; + + // 获取所有子任务 + let tasks = match repo.list_tasks_for_session(parent_session_id).await { + Ok(tasks) => tasks, + Err(e) => { + tracing::warn!(error = %e, "Failed to list tasks for session"); + return Vec::new(); + } + }; + + let mut result = Vec::new(); + for task in tasks { + // 加载子智能体的消息 + let messages = match store.load_all_messages(&task.session_id) { + Ok(msgs) => msgs, + Err(e) => { + tracing::warn!(error = %e, task_id = %task.id, "Failed to load subagent messages"); + Vec::new() + } + }; + + result.push(SubagentTaskData { + task_id: task.id, + session_id: task.session_id, + description: task.description, + subagent_type: task.subagent_type.as_str().to_string(), + state: format!("{:?}", task.state), + created_at: task.created_at, + messages, + }); + } + + result +} + +/// 生成 Markdown 内容(包含子智能体) +pub fn generate_markdown_with_subagents( + record: &SessionRecord, + system_prompt: &Option, + messages: &[crate::bus::ChatMessage], + subagent_data: &[SubagentTaskData], +) -> String { + let mut output = String::new(); + + // YAML frontmatter + output.push_str("---\n"); + output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title))); + output.push_str(&format!("session_id: {}\n", record.id)); + output.push_str(&format!("channel: {}\n", record.channel_name)); + output.push_str(&format!("chat_id: {}\n", record.chat_id)); + output.push_str(&format!( + "created_at: {}\n", + format_timestamp(record.created_at) + )); + output.push_str(&format!( + "updated_at: {}\n", + format_timestamp(record.updated_at) + )); + output.push_str(&format!( + "last_active_at: {}\n", + format_timestamp(record.last_active_at) + )); + output.push_str(&format!("message_count: {}\n", messages.len())); + if !subagent_data.is_empty() { + output.push_str(&format!("subagent_count: {}\n", subagent_data.len())); + } + output.push_str("---\n\n"); + + // 系统提示词 + output.push_str(&generate_system_prompt_markdown(system_prompt)); + + // 子智能体任务(如果有) + if !subagent_data.is_empty() { + output.push_str(&generate_subagent_tasks_markdown(subagent_data)); + } + + // 主会话消息历史 + output.push_str(&generate_messages_markdown(messages)); + + output +} + +/// 生成子智能体任务 Markdown +pub fn generate_subagent_tasks_markdown(subagent_data: &[SubagentTaskData]) -> String { + let mut output = String::new(); + + output.push_str("# Subagent Tasks\n\n"); + + for task in subagent_data { + output.push_str(&format!("## Task: {} ({})", task.description, task.subagent_type)); + output.push('\n'); + output.push_str(&format!("**Task ID:** `{}`\n\n", task.task_id)); + output.push_str(&format!("**Session ID:** `{}`\n\n", task.session_id)); + output.push_str(&format!("**Status:** {}\n\n", task.state)); + output.push_str(&format!("**Created:** {}\n\n", format_timestamp(task.created_at))); + output.push_str(&format!("**Message Count:** {}\n\n", task.messages.len())); + + // 子智能体消息 + if !task.messages.is_empty() { + output.push_str("### Messages\n\n"); + for (idx, msg) in task.messages.iter().enumerate() { + output.push_str(&format!("#### Message {}\n\n", idx + 1)); + output.push_str(&format!("**Role:** {}\n\n", msg.role)); + output.push_str(&format!("**Time:** {}\n\n", format_timestamp(msg.timestamp))); + + if let Some(ref reasoning) = msg.reasoning_content { + output.push_str("**Reasoning:**\n"); + output.push_str("```\n"); + output.push_str(reasoning); + output.push_str("\n```\n\n"); + } + + output.push_str("**Content:**\n\n"); + if msg.content.is_empty() { + output.push_str("*empty*\n\n"); + } else { + output.push_str(&format!("{}\n\n", format_message_content(&msg.content))); + } + + // 工具调用 + if let Some(ref calls) = msg.tool_calls { + if !calls.is_empty() { + output.push_str("**Tool Calls:**\n\n"); + for call in calls { + output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id)); + output.push_str(" ```json\n"); + let args_json = serde_json::to_string_pretty(&call.arguments) + .unwrap_or_else(|_| call.arguments.to_string()); + for line in args_json.lines() { + output.push_str(&format!(" {}\n", line)); + } + output.push_str(" ```\n"); + } + output.push('\n'); + } + } + + output.push_str("---\n\n"); + } + } + + output.push_str("---\n\n"); + } + + output +} + /// 构建系统提示词 fn build_system_prompt( provider: &dyn SystemPromptProvider, @@ -440,14 +631,20 @@ pub fn resolve_filepath(filepath: Option, record: &SessionRecord) -> Pat /// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令 pub struct SaveSessionInChatHandler { store: Arc, + task_repository: Arc, system_prompt_provider: Arc, } impl SaveSessionInChatHandler { /// 创建新的 InChat 保存会话命令处理器 - pub fn new(store: Arc, system_prompt_provider: Arc) -> Self { + pub fn new( + store: Arc, + task_repository: Arc, + system_prompt_provider: Arc, + ) -> Self { Self { store, + task_repository, system_prompt_provider, } } @@ -465,7 +662,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler { inbound: &InboundMessage, session_manager: &crate::gateway::session::SessionManager, ) -> Result, AgentError> { - let Command::SaveSession { filepath, include_all } = cmd else { + let Command::SaveSession { filepath, include_all, include_subagents } = cmd else { return Ok(None); }; @@ -486,7 +683,9 @@ impl InChatCommandHandler for SaveSessionInChatHandler { &session_id, filepath, include_all, + include_subagents, &*self.store, + Some(self.task_repository.as_ref()), &*self.system_prompt_provider, ) .await; @@ -623,11 +822,12 @@ mod tests { #[test] fn test_can_handle() { let store = Arc::new(SessionStore::in_memory().unwrap()); + let task_repository = Arc::new(crate::tools::task::repository::InMemoryTaskRepository::new()); let provider = Arc::new(TestSystemPromptProvider); - let handler = SaveSessionCommandHandler::new(store, provider); + let handler = SaveSessionCommandHandler::new(store, task_repository, provider); - assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false })); - assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true })); + assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false })); + assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false })); assert!(!handler.can_handle(&Command::CreateSession { title: None })); assert!(!handler.can_handle(&Command::SaveTopic { filepath: None })); } diff --git a/src/command/handlers/save_topic.rs b/src/command/handlers/save_topic.rs index 07dfa64..9981eaf 100644 --- a/src/command/handlers/save_topic.rs +++ b/src/command/handlers/save_topic.rs @@ -4,12 +4,14 @@ use crate::command::context::CommandContext; use crate::command::handler::{CommandHandler, CommandMetadata}; use crate::command::handlers::{ escape_yaml_string, format_timestamp, generate_messages_markdown, - generate_system_prompt_markdown, get_messages_from_session, + generate_subagent_tasks_markdown, generate_system_prompt_markdown, + get_messages_from_session, load_subagent_data, SubagentTaskData, }; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; use crate::gateway::session::SessionManager; use crate::storage::{SessionStore, TopicRecord}; +use crate::tools::task::repository::TaskRepository; use async_trait::async_trait; use chrono::Local; use std::path::PathBuf; @@ -19,9 +21,11 @@ use std::sync::Arc; pub async fn save_topic_to_file( topic_id: &str, filepath: Option, + include_subagents: bool, store: &SessionStore, + task_repository: Option<&dyn TaskRepository>, system_prompt_provider: &dyn SystemPromptProvider, - messages: &[ChatMessage], // ← 从外部传入的消息(已压缩的 active history) + messages: &[ChatMessage], ) -> Result { // 获取话题记录 let topic = store @@ -38,8 +42,15 @@ pub async fn save_topic_to_file( let user_message_count = messages.iter().filter(|m| m.role == "user").count(); let system_prompt = build_system_prompt(system_prompt_provider, &session, user_message_count); + // 加载子智能体消息(如果启用) + let subagent_data = if include_subagents { + load_subagent_data(&topic.session_id, store, task_repository).await + } else { + Vec::new() + }; + // 生成 Markdown 内容 - let markdown = generate_topic_markdown(&topic, &system_prompt, messages); + let markdown = generate_topic_markdown(&topic, &system_prompt, messages, &subagent_data); // 确定输出路径 let output_path = resolve_topic_filepath(filepath, &topic); @@ -79,6 +90,7 @@ fn generate_topic_markdown( topic: &TopicRecord, system_prompt: &Option, messages: &[crate::bus::ChatMessage], + subagent_data: &[SubagentTaskData], ) -> String { let mut output = String::new(); @@ -103,11 +115,19 @@ fn generate_topic_markdown( format_timestamp(topic.last_active_at) )); output.push_str(&format!("message_count: {}\n", messages.len())); + if !subagent_data.is_empty() { + output.push_str(&format!("subagent_count: {}\n", subagent_data.len())); + } output.push_str("---\n\n"); // 系统提示词(复用公共函数) output.push_str(&generate_system_prompt_markdown(system_prompt)); + // 子智能体任务(如果有) + if !subagent_data.is_empty() { + output.push_str(&generate_subagent_tasks_markdown(subagent_data)); + } + // 消息历史(复用公共函数) output.push_str(&generate_messages_markdown(messages)); @@ -153,6 +173,7 @@ fn resolve_topic_filepath(filepath: Option, topic: &TopicRecord) -> Path /// 保存话题命令处理器 pub struct SaveTopicCommandHandler { store: Arc, + task_repository: Arc, system_prompt_provider: Arc, session_manager: Option, } @@ -160,10 +181,12 @@ pub struct SaveTopicCommandHandler { impl SaveTopicCommandHandler { pub fn new( store: Arc, + task_repository: Arc, system_prompt_provider: Arc, ) -> Self { Self { store, + task_repository, system_prompt_provider, session_manager: None, } @@ -195,7 +218,9 @@ impl CommandHandler for SaveTopicCommandHandler { ctx: CommandContext, ) -> Result { match cmd { - Command::SaveTopic { filepath } => handle_save_topic(self, filepath, ctx).await, + Command::SaveTopic { filepath, include_subagents } => { + handle_save_topic(self, filepath, include_subagents, ctx).await + } _ => unreachable!(), } } @@ -204,12 +229,14 @@ impl CommandHandler for SaveTopicCommandHandler { async fn handle_save_topic( handler: &SaveTopicCommandHandler, filepath: Option, + include_subagents: bool, ctx: CommandContext, ) -> Result { tracing::debug!( ctx_topic_id = ?ctx.topic_id, ctx_session_id = ?ctx.session_id, channel = %ctx.channel_name, + include_subagents = include_subagents, "SaveTopic command received" ); @@ -238,7 +265,9 @@ async fn handle_save_topic( let output_path = save_topic_to_file( topic_id, filepath, + include_subagents, &*handler.store, + Some(handler.task_repository.as_ref()), &*handler.system_prompt_provider, &messages, ) diff --git a/src/command/mod.rs b/src/command/mod.rs index a906cc5..88d0d72 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -14,11 +14,15 @@ pub enum Command { /// 创建新话题(在同一个 Session 内) CreateSession { title: Option }, /// 保存当前话题内容到 Markdown 文件 - SaveTopic { filepath: Option }, + SaveTopic { + filepath: Option, + include_subagents: bool, + }, /// 保存会话内容到 Markdown 文件 SaveSession { filepath: Option, include_all: bool, + include_subagents: bool, }, /// 列出当前 Session 的所有话题 ListSessions { include_archived: bool }, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 71fcfc5..4d0b0e2 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -38,6 +38,7 @@ use crate::config::LLMProviderConfig; use crate::logging; use crate::scheduler::Scheduler; use crate::skills::SkillRuntime; +use crate::tools::task::repository::TaskRepository; use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; use outbound_dispatcher::OutboundDispatcher; use processor::InboundProcessor; @@ -50,6 +51,7 @@ pub struct GatewayState { pub session_manager: SessionManager, pub channel_manager: ChannelManager, pub bus: Arc, + pub task_repository: Arc, } impl GatewayState { @@ -72,7 +74,7 @@ impl GatewayState { let channel_manager = ChannelManager::new(); let bus = channel_manager.bus(); - let session_manager = build_session_manager_with_sender( + let (session_manager, task_repository) = build_session_manager_with_sender( agent_prompt_reinject_every, show_tool_results, config.time.timezone.clone(), @@ -91,6 +93,7 @@ impl GatewayState { session_manager, channel_manager, bus, + task_repository, }) } diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 5424268..7e65b56 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -80,12 +80,14 @@ impl InboundProcessor { // 注册 save_session 处理器 command_router.register(Box::new(SaveSessionCommandHandler::new( store.clone(), + session_manager.task_repository(), system_prompt_provider.clone(), ))); // 注册 save_topic 处理器 command_router.register(Box::new(SaveTopicCommandHandler::new( store.clone(), + session_manager.task_repository(), system_prompt_provider, ).with_session_manager(session_manager.clone()))); diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 110319e..9d4c5ee 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -13,6 +13,7 @@ use crate::tools::{ DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender, SessionMessageSender, SubAgentRuntimeConfig, ToolRegistry, }; +use crate::tools::task::repository::TaskRepository; use super::agent_factory::AgentFactory; use super::cli_session::CliSessionService; @@ -35,7 +36,7 @@ pub(crate) fn build_session_manager( task_config: TaskConfig, chat_history_ttl_hours: Option, session_ttl_hours: Option, -) -> Result { +) -> Result<(SessionManager, Arc), AgentError> { build_session_manager_with_sender( agent_prompt_reinject_every, show_tool_results, @@ -63,7 +64,7 @@ pub(crate) fn build_session_manager_with_sender( task_config: TaskConfig, chat_history_ttl_hours: Option, session_ttl_hours: Option, -) -> Result { +) -> Result<(SessionManager, Arc), AgentError> { let store = Arc::new( SessionStore::new() .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, @@ -96,7 +97,7 @@ pub(crate) fn build_session_manager_with_sender( ); // 创建 SubAgentRuntime(如果 task 工具启用) - let factory = if task_config.enabled { + let (factory, task_repository): (_, Arc) = if task_config.enabled { let task_repository = Arc::new(InMemoryTaskRepository::new()); let subagent_tools = Arc::new(factory.build_subagent_tools()); @@ -111,15 +112,16 @@ pub(crate) fn build_session_manager_with_sender( let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new( runtime_config, - task_repository, + task_repository.clone(), conversations.clone(), subagent_tools, provider_config.clone(), )); - factory.with_subagent_runtime(subagent_runtime) + (factory.with_subagent_runtime(subagent_runtime), task_repository) } else { - factory + // 如果 task 工具未启用,创建一个空的内存仓库 + (factory, Arc::new(InMemoryTaskRepository::new())) }; let tools = Arc::new(factory.build()); @@ -151,7 +153,7 @@ pub(crate) fn build_session_manager_with_sender( let memory_maintenance = MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone()); - Ok(SessionManager::from_services(SessionManagerServices { + Ok((SessionManager::from_services(SessionManagerServices { tools: tools as Arc, skills, store, @@ -161,5 +163,6 @@ pub(crate) fn build_session_manager_with_sender( messages, scheduled_tasks, memory_maintenance, - })) + task_repository: task_repository.clone(), + }), task_repository)) } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 8f87532..2c9a71e 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -8,6 +8,7 @@ use crate::scheduler::ScheduledAgentTaskOptions; use crate::skills::SkillRuntime; use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository}; use crate::tools::ToolRegistry; +use crate::tools::task::repository::TaskRepository; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; @@ -458,6 +459,7 @@ pub struct SessionManager { messages: SessionMessageService, scheduled_tasks: ScheduledAgentTaskService, memory_maintenance: MemoryMaintenanceCoordinator, + task_repository: Arc, } pub(crate) struct SessionManagerServices { @@ -470,6 +472,7 @@ pub(crate) struct SessionManagerServices { pub(crate) messages: SessionMessageService, pub(crate) scheduled_tasks: ScheduledAgentTaskService, pub(crate) memory_maintenance: MemoryMaintenanceCoordinator, + pub(crate) task_repository: Arc, } impl SessionManager { @@ -484,6 +487,7 @@ impl SessionManager { messages: services.messages, scheduled_tasks: services.scheduled_tasks, memory_maintenance: services.memory_maintenance, + task_repository: services.task_repository, } } @@ -511,6 +515,7 @@ impl SessionManager { chat_history_ttl_hours, session_ttl_hours, ) + .map(|(session_manager, _)| session_manager) } pub fn tools(&self) -> Arc { @@ -525,6 +530,10 @@ impl SessionManager { self.show_tool_results } + pub fn task_repository(&self) -> Arc { + self.task_repository.clone() + } + pub fn skills(&self) -> Arc { self.skills.clone() } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 50ffc44..65a6af6 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -244,6 +244,7 @@ async fn handle_inbound( router.register(Box::new(LoadSessionCommandHandler::new(store.clone()))); router.register(Box::new(SaveSessionCommandHandler::new( store.clone(), + state.task_repository.clone(), system_prompt_provider.clone(), ))); // 注册 help 处理器