diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 84ca3d8..949e8d7 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -492,7 +492,7 @@ async fn handle_inbound( if let Some(task_session_id) = response.metadata.get("task_session_id") { // 提前提取 task_id,用于给历史消息打标记 let task_id = response.metadata.get("task_id").cloned().unwrap_or_default(); - if let Err(e) = send_task_messages(&store, task_session_id, sender, Some(task_id.clone())).await { + if let Err(e) = send_task_messages(&store, task_session_id, sender, Some(task_id.clone()), Some(&state.task_repository)).await { tracing::warn!(error = %e, task_session_id = %task_session_id, "Failed to send task messages"); } @@ -570,7 +570,7 @@ async fn handle_inbound( &load_chat_channel, load_chat_id, ); - if let Err(e) = send_task_messages(&store, &session_id, sender, None).await { + if let Err(e) = send_task_messages(&store, &session_id, sender, None, None).await { tracing::warn!( error = %e, channel = %load_chat_channel, @@ -669,9 +669,13 @@ async fn send_topic_history( for task in running_tasks { if task.state == TaskSessionState::Running { + // 判断是否为孙智能体:parent_session_id 以 "sub:" 开头表示父会话是子智能体 + let parent_task_id = extract_parent_task_id(&task); + tracing::info!( task_id = %task.id, description = %task.description, + parent_task_id = ?parent_task_id, "Re-sending TaskStarted for running task after topic history load" ); let _ = sender @@ -680,7 +684,7 @@ async fn send_topic_history( description: task.description.clone(), subagent_type: task.subagent_type.clone(), topic_id: Some(topic_id.to_string()), - parent_task_id: None, + parent_task_id, tool_call_id: None, }) .await; @@ -696,6 +700,7 @@ async fn send_task_messages( session_id: &str, sender: &mpsc::Sender, subagent_task_id: Option, + task_repository: Option<&Arc>, ) -> Result<(), Box> { let messages = store.load_messages(session_id)?; @@ -713,6 +718,37 @@ async fn send_task_messages( } } + // 补发子任务(孙智能体)的 TaskStarted 事件 + // 解决重新进入子智能体视图后 navigateToTaskId 丢失的问题 + if let (Some(repo), Some(parent_task_id)) = (task_repository, &subagent_task_id) { + match repo.list_tasks_for_session(session_id).await { + Ok(child_tasks) => { + for child in child_tasks { + if child.state == TaskSessionState::Running { + tracing::info!( + child_task_id = %child.id, + parent_task_id = %parent_task_id, + "Re-sending TaskStarted for child task after sub-agent view re-enter" + ); + let _ = sender + .send(WsOutbound::TaskStarted { + task_id: child.id.clone(), + description: child.description.clone(), + subagent_type: child.subagent_type.clone(), + topic_id: child.parent_topic_id.clone(), + parent_task_id: Some(parent_task_id.clone()), + tool_call_id: None, + }) + .await; + } + } + } + Err(e) => { + tracing::warn!(error = %e, session_id = %session_id, "Failed to list child tasks for resend"); + } + } + } + Ok(()) } @@ -737,6 +773,20 @@ fn set_subagent_task_id(outbound: &mut WsOutbound, task_id: &str) { } } +/// 从 TaskSession 中提取父任务 ID(仅孙智能体有值)。 +/// 孙智能体的 parent_session_id 格式为 "sub:{grandparent_session}:task:{parent_task_uuid}", +/// 从中提取 "task:{parent_task_uuid}" 作为 parent_task_id。 +fn extract_parent_task_id(task: &crate::tools::task::types::TaskSession) -> Option { + let parent = &task.parent_session_id; + // 仅当父会话是子智能体会话时才提取(格式: "sub:...:task:{uuid}") + if parent.starts_with("sub:") { + if let Some(pos) = parent.find(":task:") { + return Some(parent[pos + 1..].to_string()); // "task:{uuid}" + } + } + None +} + /// 将 ChatMessage 转换为 WsOutbound 列表 fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Vec { use crate::bus::message::ToolMessageState; diff --git a/src/tools/registry.rs b/src/tools/registry.rs index 63ba82e..e2dc7f3 100644 --- a/src/tools/registry.rs +++ b/src/tools/registry.rs @@ -73,6 +73,20 @@ impl ToolRegistry { .cloned() .collect() } + + /// 创建一个排除指定工具的新 registry 副本 + pub fn without(&self, exclude: &[&str]) -> Self { + let exclude_set: std::collections::HashSet<&str> = exclude.iter().copied().collect(); + let tools = self.tools.read().expect("ToolRegistry lock poisoned"); + let filtered: HashMap> = tools + .iter() + .filter(|(name, _)| !exclude_set.contains(name.as_str())) + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let new_registry = ToolRegistry::new(); + *new_registry.tools.write().expect("ToolRegistry lock poisoned") = filtered; + new_registry + } } impl Default for ToolRegistry { diff --git a/src/tools/task/runtime.rs b/src/tools/task/runtime.rs index 9e836ec..55895b9 100644 --- a/src/tools/task/runtime.rs +++ b/src/tools/task/runtime.rs @@ -19,6 +19,7 @@ use crate::tools::{ToolContext, ToolRegistry}; use super::error::TaskError; use super::prompt::{extract_summary, SubagentPromptBuilder}; use super::repository::TaskRepository; +use super::tool::TaskTool; use super::types::{SubagentDef, SubagentSource, TaskDefinition, TaskSession, TaskToolResult}; /// 子代理运行时配置 @@ -332,9 +333,17 @@ impl DefaultSubAgentRuntime { ) -> Result { let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt)); + // 孙智能体(depth >= 2)不注册 task 工具,防止无限嵌套 + let child_depth = parent_nesting_depth + 1; + let tools = if child_depth >= 2 { + Arc::new(self.subagent_tools.without(&[TaskTool::TOOL_NAME])) + } else { + self.subagent_tools.clone() + }; + AgentLoop::with_tools_and_system_prompt_provider( AgentRuntimeConfig::from(self.provider_config.clone()), - self.subagent_tools.clone(), + tools, prompt_provider, None, // 子代理不需要 skill provider ) diff --git a/src/tools/task/tool.rs b/src/tools/task/tool.rs index 07385c0..f082613 100644 --- a/src/tools/task/tool.rs +++ b/src/tools/task/tool.rs @@ -15,6 +15,8 @@ pub struct TaskTool { } impl TaskTool { + pub const TOOL_NAME: &'static str = "task"; + /// 创建 TaskTool /// - `max_nesting_depth = None`:无深度限制(主 agent) /// - `max_nesting_depth = Some(N)`:允许最多 N 层嵌套(子 agent) @@ -29,7 +31,7 @@ impl TaskTool { #[async_trait] impl Tool for TaskTool { fn name(&self) -> &str { - "task" + Self::TOOL_NAME } fn description(&self) -> &str {