From 41b4895ff0e292c7e107d4cb40dcd056f0bfc84f Mon Sep 17 00:00:00 2001 From: xiaoski Date: Mon, 25 May 2026 23:23:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0sub-agent=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 2 + src/agent/mod.rs | 2 + src/agent/sub_agent.rs | 611 +++++++++++++++++++++++++++++++++ src/agent/system_prompt.rs | 46 ++- src/config/mod.rs | 29 +- src/gateway/mod.rs | 1 + src/session/session.rs | 92 ++++- src/storage/background_task.rs | 19 + src/storage/mod.rs | 186 ++++++++++ src/tools/delegate.rs | 365 ++++++++++++++++++++ src/tools/mod.rs | 8 + src/tools/registry.rs | 16 + 12 files changed, 1357 insertions(+), 20 deletions(-) create mode 100644 src/agent/sub_agent.rs create mode 100644 src/storage/background_task.rs create mode 100644 src/tools/delegate.rs diff --git a/Cargo.toml b/Cargo.toml index bfe3209..4edf0e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ serde_json = "1.0" async-trait = "0.1" thiserror = "2.0.18" tokio = { version = "1.52", features = ["full"] } +tokio-util = { version = "0.7", features = ["rt"] } +dashmap = "6.1" uuid = { version = "1.23", features = ["v4"] } axum = { version = "0.8", features = ["ws"] } tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] } diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 5319eeb..f4af26b 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,8 +1,10 @@ pub mod agent_loop; pub mod context_compressor; pub mod media_handler; +pub mod sub_agent; pub mod system_prompt; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use context_compressor::{ContextCompressor, estimate_tokens}; +pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus}; pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder}; diff --git a/src/agent/sub_agent.rs b/src/agent/sub_agent.rs new file mode 100644 index 0000000..4cd6590 --- /dev/null +++ b/src/agent/sub_agent.rs @@ -0,0 +1,611 @@ +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + +use dashmap::DashMap; +use tokio_util::sync::CancellationToken; +use uuid::Uuid; + +use crate::agent::AgentLoop; +use crate::agent::AgentError; +use crate::bus::ChatMessage; +use crate::config::LLMProviderConfig; +use crate::providers::{create_provider, LLMProvider}; +use crate::tools::ToolRegistry; + +tokio::task_local! { + pub(crate) static DELEGATE_CONTEXT: DelegateContext; +} + +/// Read the delegate context from the current task. Returns an error if not set. +pub fn get_delegate_context() -> Result { + DELEGATE_CONTEXT.try_with(|ctx| ctx.clone()) + .map_err(|_| "DELEGATE_CONTEXT not set".to_string()) +} + +const DEFAULT_MAX_ITERATIONS: usize = 99; +const DEFAULT_TIMEOUT_SECS: u64 = 3600; +const MAX_INLINE_RESULT_CHARS: usize = 8000; + +const DEFAULT_READONLY_TOOLS: &[&str] = &[ + "file_read", + "file_search", + "content_search", + "web_fetch", + "http_request", + "calculator", +]; + +#[derive(Debug, Clone)] +pub struct SubAgentConfig { + pub prompt: String, + pub mode: ExecutionMode, + pub allowed_tools: Option>, + pub max_iterations: Option, + pub timeout_secs: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExecutionMode { + Inline, + Background, + Parallel, +} + +#[derive(Debug, Clone)] +pub struct SubAgentResult { + pub task_id: String, + pub content: String, + pub content_truncated: bool, + pub status: TaskStatus, + pub tool_calls_count: usize, + pub iterations: usize, + pub duration_ms: u64, +} + +#[derive(Debug, Clone)] +pub enum TaskStatus { + Completed, + Failed(String), + Cancelled, + TimedOut, +} + +#[derive(Debug, Clone)] +pub struct TaskNotification { + pub task_id: String, + pub session_id: String, + pub channel: String, + pub chat_id: String, + pub status: TaskStatus, + pub result_summary: String, +} + +#[derive(Debug, Clone)] +pub struct DelegateContext { + pub session_id: String, + pub channel: String, + pub chat_id: String, +} + +#[derive(Debug)] +pub enum SubAgentError { + TooManyTasks(usize), + ProviderCreation(String), + Storage(String), + Other(String), +} + +impl std::fmt::Display for SubAgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max), + Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e), + Self::Storage(e) => write!(f, "storage error: {}", e), + Self::Other(e) => write!(f, "{}", e), + } + } +} + +impl std::error::Error for SubAgentError {} + +pub struct SubAgentManager { + provider_config: LLMProviderConfig, + full_tools: Arc, + storage: Option>, + active_tasks: Arc>, + notify_tx: tokio::sync::mpsc::UnboundedSender, + max_concurrent_background_tasks: usize, +} + +impl SubAgentManager { + pub fn new( + provider_config: LLMProviderConfig, + full_tools: Arc, + storage: Option>, + notify_tx: tokio::sync::mpsc::UnboundedSender, + max_concurrent_background_tasks: usize, + ) -> Self { + Self { + provider_config, + full_tools, + storage, + active_tasks: Arc::new(DashMap::new()), + notify_tx, + max_concurrent_background_tasks, + } + } + + pub fn filter_tools(&self, allowed: &Option>) -> Arc { + let allowed_set: HashSet<&str> = match allowed { + Some(list) => list.iter().map(|s| s.as_str()).collect(), + None => DEFAULT_READONLY_TOOLS.iter().copied().collect(), + }; + let filtered = ToolRegistry::new(); + for (name, tool) in self.full_tools.iter() { + if allowed_set.contains(name.as_str()) && name != "delegate" { + filtered.register_raw(name, tool); + } + } + Arc::new(filtered) + } + + pub fn build_system_prompt(&self, config: &SubAgentConfig, tools: &ToolRegistry) -> String { + let timeout_human = format_duration(config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS)); + let tool_descriptions = tools.describe_for_prompt(); + + let http_only_note = if config.allowed_tools.is_none() + || config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request")) + { + "- When using http_request, only the GET method is permitted. \ + Do NOT use POST, PUT, DELETE, or any other method." + } else { + "" + }; + + format!( + "You are a sub-agent working on a delegated task. Complete the task below \ + and return a single, self-contained result.\n\ + \n\ + ## Task\n\ + {task}\n\ + \n\ + ## Rules\n\ + - Focus ONLY on the task above. Do not explore unrelated topics.\n\ + - Use tools only when necessary for the task.\n\ + - Do NOT use the delegate tool — sub-agent recursion is forbidden.\n\ + - If the task cannot be completed, explain why clearly.\n\ + - Return only the final result. Do not describe your process.\n\ + {http_only}\n\ + - Timeout: {timeout_human}. If approaching the limit, return partial results.\n\ + \n\ + ## Available Tools\n\ + {tool_descriptions}\n\ + \n\ + ## Workspace\n\ + {workspace}", + task = config.prompt, + http_only = http_only_note, + timeout_human = timeout_human, + tool_descriptions = tool_descriptions, + workspace = self.provider_config.workspace_dir.display(), + ) + } + + pub fn build_sub_agent( + &self, + config: &SubAgentConfig, + tools: Arc, + ) -> Result { + let mut provider = create_provider(self.provider_config.clone()) + .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + if let Some(ref s) = self.storage { + provider.set_storage(s.clone()); + } + let provider: Arc = Arc::from(provider); + + let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS); + let workspace_dir = self.provider_config.workspace_dir.clone(); + let model_name = self.provider_config.model_id.clone(); + let input_types = self.provider_config.input_types.clone(); + + let agent = AgentLoop::with_provider_and_tools( + provider, + tools, + max_iterations, + model_name, + workspace_dir, + input_types, + ) + .with_context_window(self.provider_config.token_limit); + + Ok(agent) + } + + pub async fn run_inline( + &self, + config: SubAgentConfig, + ) -> Result { + let task_id = generate_task_id(); + let tools = self.filter_tools(&config.allowed_tools); + let system_prompt = self.build_system_prompt(&config, &tools); + let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS); + + let agent = self.build_sub_agent(&config, tools) + .map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?; + + let history = vec![ + ChatMessage::system(system_prompt), + ChatMessage::user(&config.prompt), + ]; + + let start = Instant::now(); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(timeout_secs), + agent.process(history), + ) + .await; + + let duration_ms = start.elapsed().as_millis() as u64; + + match result { + Ok(Ok(agent_result)) => { + let (content, truncated) = + truncate_sub_agent_result(&agent_result.final_response.content); + let tool_calls_count = agent_result.emitted_messages.iter() + .filter(|m| m.tool_calls.is_some()) + .count(); + let iterations = agent_result.emitted_messages.iter() + .filter(|m| m.role == "assistant" && m.tool_calls.is_some()) + .count(); + Ok(SubAgentResult { + task_id, + content, + content_truncated: truncated, + status: TaskStatus::Completed, + tool_calls_count, + iterations, + duration_ms, + }) + } + Ok(Err(e)) => Ok(SubAgentResult { + task_id, + content: String::new(), + content_truncated: false, + status: TaskStatus::Failed(e.to_string()), + tool_calls_count: 0, + iterations: 0, + duration_ms, + }), + Err(_elapsed) => Ok(SubAgentResult { + task_id, + content: String::new(), + content_truncated: false, + status: TaskStatus::TimedOut, + tool_calls_count: 0, + iterations: 0, + duration_ms, + }), + } + } + + pub async fn run_parallel( + &self, + configs: Vec, + ) -> Result, SubAgentError> { + let futures: Vec<_> = configs + .into_iter() + .map(|config| { + let mgr = self; // &self borrow, all tasks share the same manager + async move { mgr.run_inline(config).await } + }) + .collect(); + + let results = futures_util::future::join_all(futures).await; + Ok(results.into_iter().collect::, _>>()?) + } + + pub async fn run_background( + &self, + config: SubAgentConfig, + ctx: DelegateContext, + ) -> Result { + if self.active_tasks.len() >= self.max_concurrent_background_tasks { + return Err(SubAgentError::TooManyTasks( + self.max_concurrent_background_tasks, + )); + } + + let task_id = generate_task_id(); + let cancel_token = CancellationToken::new(); + + self.active_tasks + .insert(task_id.clone(), cancel_token.clone()); + + // Write DB: pending + if let Some(ref storage) = self.storage { + let allowed_tools_json = config + .allowed_tools + .as_ref() + .and_then(|v| serde_json::to_string(v).ok()); + let record = crate::storage::BackgroundTask { + id: task_id.clone(), + session_id: ctx.session_id.clone(), + channel: ctx.channel.clone(), + chat_id: ctx.chat_id.clone(), + prompt: config.prompt.clone(), + allowed_tools: allowed_tools_json, + status: "pending".to_string(), + result: None, + error: None, + tool_calls_count: 0, + iterations: 0, + started_at: None, + finished_at: None, + created_at: chrono::Utc::now().timestamp_millis(), + }; + storage + .create_background_task(&record) + .await + .map_err(|e| SubAgentError::Storage(e.to_string()))?; + } + + let tools = self.filter_tools(&config.allowed_tools); + let system_prompt = self.build_system_prompt(&config, &tools); + let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS); + let provider_config = self.provider_config.clone(); + let storage = self.storage.clone(); + let notify_tx = self.notify_tx.clone(); + let active_tasks = Arc::clone(&self.active_tasks); + + let tid = task_id.clone(); + let sess_id = ctx.session_id.clone(); + let ch = ctx.channel.clone(); + let cid = ctx.chat_id.clone(); + let prompt = config.prompt.clone(); + + tokio::spawn(async move { + let started_at = chrono::Utc::now().timestamp_millis(); + + // Update DB: running + if let Some(ref s) = storage { + let _ = s + .update_background_task_status( + &tid, "running", None, None, + Some(started_at), None, + ) + .await; + } + + let mut provider = create_provider(provider_config.clone()).ok(); + if let Some(ref mut p) = provider { + if let Some(ref s) = storage { + p.set_storage(s.clone()); + } + } + let provider_result: Option> = + provider.map(|p| Arc::from(p)); + + let result = match provider_result { + Some(provider) => { + let agent = AgentLoop::with_provider_and_tools( + provider, + tools, + DEFAULT_MAX_ITERATIONS, + provider_config.model_id.clone(), + provider_config.workspace_dir.clone(), + provider_config.input_types.clone(), + ) + .with_context_window(provider_config.token_limit); + + let history = vec![ + ChatMessage::system(system_prompt), + ChatMessage::user(&prompt), + ]; + + tokio::select! { + r = tokio::time::timeout( + std::time::Duration::from_secs(timeout_secs), + agent.process(history), + ) => { + match r { + Ok(Ok(agent_result)) => SubAgentResult { + task_id: tid.clone(), + content: agent_result.final_response.content, + content_truncated: false, + status: TaskStatus::Completed, + tool_calls_count: 0, + iterations: 0, + duration_ms: 0, + }, + Ok(Err(e)) => SubAgentResult { + task_id: tid.clone(), + content: String::new(), + content_truncated: false, + status: TaskStatus::Failed(e.to_string()), + tool_calls_count: 0, + iterations: 0, + duration_ms: 0, + }, + Err(_) => SubAgentResult { + task_id: tid.clone(), + content: String::new(), + content_truncated: false, + status: TaskStatus::TimedOut, + tool_calls_count: 0, + iterations: 0, + duration_ms: 0, + }, + } + } + _ = cancel_token.cancelled() => SubAgentResult { + task_id: tid.clone(), + content: String::new(), + content_truncated: false, + status: TaskStatus::Cancelled, + tool_calls_count: 0, + iterations: 0, + duration_ms: 0, + }, + } + } + None => SubAgentResult { + task_id: tid.clone(), + content: String::new(), + content_truncated: false, + status: TaskStatus::Failed("provider creation failed".into()), + tool_calls_count: 0, + iterations: 0, + duration_ms: 0, + }, + }; + + let finished_at = chrono::Utc::now().timestamp_millis(); + let duration_ms = (finished_at - started_at) as u64; + + let (status_str, error_val) = match &result.status { + TaskStatus::Completed => ("completed".to_string(), None), + TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())), + TaskStatus::Cancelled => ("cancelled".to_string(), None), + TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())), + }; + + if let Some(ref s) = storage { + let _ = s + .update_background_task_status( + &tid, &status_str, + Some(&result.content), error_val.as_deref(), + Some(started_at), Some(finished_at), + ) + .await; + } + + let _ = notify_tx.send(TaskNotification { + task_id: tid.clone(), + session_id: sess_id, + channel: ch, + chat_id: cid, + status: result.status, + result_summary: summarize_for_notification(&result.content, duration_ms), + }); + + active_tasks.remove(&tid); + }); + + Ok(task_id) + } + + pub async fn cancel_task(&self, task_id: &str) -> Result { + if let Some((_, token)) = self.active_tasks.remove(task_id) { + token.cancel(); + if let Some(ref s) = self.storage { + s.update_background_task_status( + task_id, + "cancelled", + None, + None, + None, + Some(chrono::Utc::now().timestamp_millis()), + ) + .await + .map_err(|e| SubAgentError::Storage(e.to_string()))?; + } + Ok(true) + } else if let Some(ref s) = self.storage { + match s.get_background_task(task_id).await { + Ok(task) => { + match task.status.as_str() { + "pending" | "running" => { + tracing::warn!(task_id, "task in DB but not in active_tasks"); + Ok(false) + } + _ => Ok(false), + } + } + Err(_) => Ok(false), + } + } else { + Ok(false) + } + } + + pub async fn check_task( + &self, + task_id: &str, + ) -> Option { + if let Some(ref s) = self.storage { + s.get_background_task(task_id).await.ok() + } else { + None + } + } + + pub async fn list_tasks( + &self, + session_id: &str, + ) -> Vec { + if let Some(ref s) = self.storage { + s.list_background_tasks(session_id).await.unwrap_or_default() + } else { + vec![] + } + } + + pub async fn cancel_by_session(&self, session_id: &str) { + // Cancel all running tasks for a session by checking DB + if let Some(ref s) = self.storage { + if let Ok(tasks) = s.list_background_tasks(session_id).await { + for task in &tasks { + if task.status == "pending" || task.status == "running" { + let _ = self.cancel_task(&task.id).await; + } + } + } + } + } + + pub fn active_task_count(&self) -> usize { + self.active_tasks.len() + } +} + +fn generate_task_id() -> String { + Uuid::new_v4().to_string()[..8].to_string() +} + +fn format_duration(seconds: u64) -> String { + if seconds < 60 { + format!("{}s", seconds) + } else if seconds < 3600 { + format!("{}m", seconds / 60) + } else { + format!("{}h", seconds / 3600) + } +} + +fn truncate_sub_agent_result(content: &str) -> (String, bool) { + if content.len() <= MAX_INLINE_RESULT_CHARS { + (content.to_string(), false) + } else { + let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS); + ( + format!( + "{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]", + &content[..truncate_at], + content.len() + ), + true, + ) + } +} + +fn summarize_for_notification(content: &str, _duration_ms: u64) -> String { + const MAX_SUMMARY_BYTES: usize = 500; + if content.len() <= MAX_SUMMARY_BYTES { + content.to_string() + } else { + let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES); + format!("{}...", &content[..truncate_at]) + } +} diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index dcdb13b..bdcf90f 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -3,11 +3,7 @@ //! This module provides a modular framework for building system prompts //! using the SystemPromptBuilder pattern. //! -//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic -//! -//! Configuration files loaded from ~/.picobot/: -//! - AGENTS.md — agent identity and behavior -//! - USER.md — user preferences and profile +//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation use crate::tools::ToolRegistry; use std::path::Path; @@ -55,6 +51,7 @@ impl SystemPromptBuilder { Box::new(CrossChannelSection), Box::new(MemorySection), Box::new(HistorySection), + Box::new(DelegationSection), ], } } @@ -353,6 +350,45 @@ impl PromptSection for HistorySection { } } +/// Sub-agent delegation principles. +pub struct DelegationSection; + +impl PromptSection for DelegationSection { + fn name(&self) -> &str { + "delegation" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + "## 子 Agent 委托原则\n\n\ +当任务复杂需要拆解时,使用 delegate 工具创建子 Agent:\n\ +\n\ +### 何时委托\n\ +- 多个独立子任务可以并行处理时(使用 mode=\"parallel\")\n\ +- 长时间运行的任务需要后台执行时(使用 mode=\"background\")\n\ +- 需要以不同权限(受限工具集)执行时\n\ +\n\ +### 工具分配原则\n\ +- **最小权限**:只给子 Agent 完成其任务所需的最少工具\n\ +- **只读优先**:如果可以只用 file_read、file_search、web_fetch 完成,不要给写权限(bash、file_write、file_edit)\n\ +- **禁止递归**:永远不要把 delegate 工具分配给子 Agent\n\ +- **明确边界**:每个子 Agent 只负责一个清晰、独立的子任务\n\ +\n\ +### 任务描述\n\ +- 任务 prompt 要清晰、具体、有明确输出要求\n\ +- 如需额外约束,直接写在 prompt 中(例如:\"跳过 .tmp 文件\")\n\ +- 明确说明期望的输出格式\n\ +\n\ +### 并行模式\n\ +- 多个无依赖的子任务使用 mode=\"parallel\",任务定义在 tasks 数组中\n\ +- 并行任务之间不应有数据依赖\n\ +- 并行任务数建议不超过 5 个\n\ +\n\ +### 后台模式\n\ +- 预计执行时间超过 30s 的任务使用 mode=\"background\"\n\ +- 后台任务有全局并发上限,如果失败提示用户稍后重试".to_string() + } +} + // === Helper Functions === /// Get user config directory (~/.picobot/). diff --git a/src/config/mod.rs b/src/config/mod.rs index af5fa14..5c2f8e9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -152,10 +152,26 @@ pub struct GatewayConfig { pub cleanup_interval_minutes: Option, #[serde(default, rename = "session_db_path")] pub session_db_path: Option, + #[serde(default, rename = "max_concurrent_background_tasks")] + pub max_concurrent_background_tasks: usize, #[serde(default)] pub scheduler: Option, } +impl Default for GatewayConfig { + fn default() -> Self { + Self { + host: default_gateway_host(), + port: default_gateway_port(), + session_ttl_hours: None, + cleanup_interval_minutes: None, + session_db_path: None, + max_concurrent_background_tasks: 10, + scheduler: None, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SchedulerConfig { /// Whether the scheduler is enabled @@ -209,19 +225,6 @@ fn default_gateway_url() -> String { "ws://127.0.0.1:19876/ws".to_string() } -impl Default for GatewayConfig { - fn default() -> Self { - Self { - host: default_gateway_host(), - port: default_gateway_port(), - session_ttl_hours: None, - cleanup_interval_minutes: None, - session_db_path: None, - scheduler: None, - } - } -} - impl Default for ClientConfig { fn default() -> Self { Self { diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 541860e..a641827 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -91,6 +91,7 @@ impl GatewayState { bus.clone(), memory_manager, browser_config, + config.gateway.max_concurrent_background_tasks, )?; let session_manager = Arc::new(session_manager); diff --git a/src/session/session.rs b/src/session/session.rs index 7f2c606..679fb1c 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -751,6 +751,7 @@ pub struct SessionManager { storage: Arc, bus: Arc, memory_manager: Arc, + sub_agent_manager: Arc, } struct SessionManagerInner { @@ -847,6 +848,7 @@ impl SessionManager { bus: Arc, memory_manager: Arc, browser_config: Option, + max_concurrent_background_tasks: usize, ) -> Result { let mut skills_loader = SkillsLoader::new(); skills_loader.load_skills(); @@ -856,9 +858,61 @@ impl SessionManager { let tools = Arc::new(create_default_tools( skills_loader.clone(), memory_manager.clone(), + None, // SubAgentManager created below browser_config.as_ref(), )); + // Create SubAgentManager and register DelegateTool + let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel(); + let sub_agent_manager = Arc::new(crate::agent::SubAgentManager::new( + provider_config.clone(), + tools.clone(), + Some(storage.clone()), + notify_tx, + max_concurrent_background_tasks, + )); + tools.register(crate::tools::DelegateTool::new(sub_agent_manager.clone())); + + // Start background task notification consumer + let sm_bus = bus.clone(); + tokio::spawn(async move { + while let Some(notif) = notify_rx.recv().await { + let content = format_task_notification( + ¬if.task_id, + ¬if.status, + ¬if.result_summary, + ); + let outbound = OutboundMessage { + channel: notif.channel, + chat_id: notif.chat_id, + content, + reply_to: None, + media: vec![], + metadata: std::collections::HashMap::new(), + }; + let _ = sm_bus.publish_outbound(outbound).await; + } + }); + + // Start periodic background task cleanup (every hour, TTL 24h) + let cleanup_storage = storage.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600)); + interval.tick().await; // skip immediate first tick + loop { + interval.tick().await; + match cleanup_storage.cleanup_old_tasks(86_400_000).await { + Ok(count) if count > 0 => { + tracing::info!(count, "Cleaned up old background tasks"); + } + Err(e) => { + tracing::warn!(error = %e, "Failed to clean up old background tasks"); + } + _ => {} + } + } + }); + Ok(Self { inner: Arc::new(Mutex::new(SessionManagerInner { sessions: HashMap::new(), @@ -870,6 +924,7 @@ impl SessionManager { storage, bus, memory_manager, + sub_agent_manager, }) } @@ -1073,6 +1128,8 @@ impl SessionManager { msgs.push("消息队列已清空。".to_string()); } guard.worker_generation = guard.worker_generation.wrapping_add(1); + // Cancel all running background sub-agent tasks for this session + self.sub_agent_manager.cancel_by_session(&sid.to_string()).await; let resp = if msgs.is_empty() { "没有正在执行的任务或队列。".to_string() } else { @@ -1469,7 +1526,8 @@ fn spawn_agent_worker( unified_str: String, ) { tokio::spawn(async move { - let _scope = CURRENT_SOURCE_SESSION.scope(Some(unified_str), async { + let unified_for_source = unified_str.clone(); + let _scope = CURRENT_SOURCE_SESSION.scope(Some(unified_for_source), async { while let Some(task) = task_rx.recv().await { let task_chan = task.channel.clone(); let task_cid = task.chat_id.clone(); @@ -1613,8 +1671,17 @@ fn spawn_agent_worker( let bus2 = bus.clone(); let chan2 = task_chan.clone(); let cid2 = task_cid.clone(); + let unified_str2 = unified_str.clone(); let process_future = async move { - let result = match agent.process(history_out.clone()).await { + let process_result = crate::agent::sub_agent::DELEGATE_CONTEXT.scope( + crate::agent::DelegateContext { + session_id: unified_str2, + channel: chan2.clone(), + chat_id: cid2.clone(), + }, + agent.process(history_out.clone()), + ).await; + let result = match process_result { Ok(r) => r, Err(AgentError::LlmError(ref msg)) if is_context_overflow_error(msg) => @@ -1911,3 +1978,24 @@ mod tests { } } } + +fn format_task_notification(task_id: &str, status: &crate::agent::TaskStatus, summary: &str) -> String { + match status { + crate::agent::TaskStatus::Completed => format!( + "📋 后台任务完成\n\n任务 ID: {}\n\n结果:\n{}", + task_id, summary + ), + crate::agent::TaskStatus::Failed(err) => format!( + "📋 后台任务失败\n\n任务 ID: {}\n错误: {}", + task_id, err + ), + crate::agent::TaskStatus::Cancelled => format!( + "📋 后台任务已取消\n\n任务 ID: {}", + task_id + ), + crate::agent::TaskStatus::TimedOut => format!( + "📋 后台任务超时\n\n任务 ID: {}", + task_id + ), + } +} diff --git a/src/storage/background_task.rs b/src/storage/background_task.rs new file mode 100644 index 0000000..a01d1ed --- /dev/null +++ b/src/storage/background_task.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BackgroundTask { + pub id: String, + pub session_id: String, + pub channel: String, + pub chat_id: String, + pub prompt: String, + pub allowed_tools: Option, + pub status: String, + pub result: Option, + pub error: Option, + pub tool_calls_count: i64, + pub iterations: i64, + pub started_at: Option, + pub finished_at: Option, + pub created_at: i64, +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index b23bd35..1a08a71 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,10 +1,12 @@ pub mod error; pub mod memory; pub mod message; +pub mod background_task; pub mod scheduler; pub mod session; pub use error::StorageError; +pub use background_task::BackgroundTask; pub use scheduler::{JobRun, ScheduledJob}; use sqlx::{Pool, Row, Sqlite, SqlitePool}; @@ -105,6 +107,48 @@ impl Storage { .await .ok(); + // Background tasks table — for async sub-agent tasks. + // Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL). + // Session and task association is maintained at the application level. + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS background_tasks ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + channel TEXT NOT NULL, + chat_id TEXT NOT NULL, + prompt TEXT NOT NULL, + allowed_tools TEXT, + status TEXT NOT NULL DEFAULT 'pending', + result TEXT, + error TEXT, + tool_calls_count INTEGER DEFAULT 0, + iterations INTEGER DEFAULT 0, + started_at INTEGER, + finished_at INTEGER, + created_at INTEGER NOT NULL + ) + "#, + ) + .execute(&self.pool) + .await?; + + sqlx::query( + r#" + CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id) + "#, + ) + .execute(&self.pool) + .await?; + + sqlx::query( + r#" + CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status) + "#, + ) + .execute(&self.pool) + .await?; + sqlx::query( r#" CREATE TABLE IF NOT EXISTS memories ( @@ -816,6 +860,148 @@ impl Storage { } unreachable!() } + + // ── Background Task CRUD ── + + pub async fn create_background_task( + &self, + task: &crate::storage::background_task::BackgroundTask, + ) -> Result<(), StorageError> { + sqlx::query( + r#" + INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + "#, + ) + .bind(&task.id) + .bind(&task.session_id) + .bind(&task.channel) + .bind(&task.chat_id) + .bind(&task.prompt) + .bind(&task.allowed_tools) + .bind(&task.status) + .bind(&task.result) + .bind(&task.error) + .bind(task.tool_calls_count) + .bind(task.iterations) + .bind(task.started_at) + .bind(task.finished_at) + .bind(task.created_at) + .execute(self.pool()) + .await?; + Ok(()) + } + + pub async fn update_background_task_status( + &self, + id: &str, + status: &str, + result: Option<&str>, + error: Option<&str>, + started_at: Option, + finished_at: Option, + ) -> Result<(), StorageError> { + sqlx::query( + r#" + UPDATE background_tasks + SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error), + started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at) + WHERE id = ? + "#, + ) + .bind(status) + .bind(result) + .bind(error) + .bind(started_at) + .bind(finished_at) + .bind(id) + .execute(self.pool()) + .await?; + Ok(()) + } + + pub async fn get_background_task( + &self, + id: &str, + ) -> Result { + let row = sqlx::query( + r#" + SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, + tool_calls_count, iterations, started_at, finished_at, created_at + FROM background_tasks WHERE id = ? + "#, + ) + .bind(id) + .fetch_optional(self.pool()) + .await? + .ok_or_else(|| StorageError::NotFound(id.to_string()))?; + + Ok(crate::storage::background_task::BackgroundTask { + id: row.get("id"), + session_id: row.get("session_id"), + channel: row.get("channel"), + chat_id: row.get("chat_id"), + prompt: row.get("prompt"), + allowed_tools: row.get("allowed_tools"), + status: row.get("status"), + result: row.get("result"), + error: row.get("error"), + tool_calls_count: row.get("tool_calls_count"), + iterations: row.get("iterations"), + started_at: row.get("started_at"), + finished_at: row.get("finished_at"), + created_at: row.get("created_at"), + }) + } + + pub async fn list_background_tasks( + &self, + session_id: &str, + ) -> Result, StorageError> { + let rows = sqlx::query( + r#" + SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, + tool_calls_count, iterations, started_at, finished_at, created_at + FROM background_tasks + WHERE session_id = ? + ORDER BY created_at DESC + "#, + ) + .bind(session_id) + .fetch_all(self.pool()) + .await?; + + Ok(rows + .into_iter() + .map(|row| crate::storage::background_task::BackgroundTask { + id: row.get("id"), + session_id: row.get("session_id"), + channel: row.get("channel"), + chat_id: row.get("chat_id"), + prompt: row.get("prompt"), + allowed_tools: row.get("allowed_tools"), + status: row.get("status"), + result: row.get("result"), + error: row.get("error"), + tool_calls_count: row.get("tool_calls_count"), + iterations: row.get("iterations"), + started_at: row.get("started_at"), + finished_at: row.get("finished_at"), + created_at: row.get("created_at"), + }) + .collect()) + } + + pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result { + let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms; + let result = sqlx::query( + "DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?", + ) + .bind(cutoff) + .execute(self.pool()) + .await?; + Ok(result.rows_affected() as usize) + } } #[cfg(test)] diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs new file mode 100644 index 0000000..2300cce --- /dev/null +++ b/src/tools/delegate.rs @@ -0,0 +1,365 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use serde_json::json; + +use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus}; +use crate::tools::traits::{Tool, ToolResult}; + +pub struct DelegateTool { + sub_agent_manager: Arc, +} + +impl DelegateTool { + pub fn new(sub_agent_manager: Arc) -> Self { + Self { sub_agent_manager } + } +} + +#[async_trait] +impl Tool for DelegateTool { + fn name(&self) -> &str { + "delegate" + } + + fn description(&self) -> &str { + "子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\ + inline (阻塞返回结果)、background (异步执行,完成后通知)、\ + parallel (多个子 Agent 并发执行,聚合结果)。\ + 也可用于查询、取消和列出后台任务。" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["run", "check_task", "cancel_task", "list_tasks"], + "description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务" + }, + "prompt": { + "type": "string", + "description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件)。action=run 时必填" + }, + "mode": { + "type": "string", + "enum": ["inline", "background", "parallel"], + "description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline" + }, + "allowed_tools": { + "type": "array", + "items": { "type": "string" }, + "description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator" + }, + "max_iterations": { + "type": "integer", + "description": "最大迭代次数,默认 99" + }, + "timeout_secs": { + "type": "integer", + "description": "超时秒数,默认 3600(1小时)" + }, + "tasks": { + "type": "array", + "description": "并行模式下的多个子任务(仅 mode=parallel 时使用)", + "items": { + "type": "object", + "properties": { + "prompt": { "type": "string", "description": "子任务描述" }, + "allowed_tools": { + "type": "array", + "items": { "type": "string" }, + "description": "该子任务的工具列表" + } + }, + "required": ["prompt"] + } + }, + "task_id": { + "type": "string", + "description": "后台任务ID(action=check_task/cancel_task 时必填)" + } + }, + "required": ["action"] + }) + } + + fn read_only(&self) -> bool { + false + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = args["action"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?; + + match action { + "run" => self.handle_run(&args).await, + "check_task" => self.handle_check_task(&args).await, + "cancel_task" => self.handle_cancel_task(&args).await, + "list_tasks" => self.handle_list_tasks(&args).await, + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)), + }), + } + } +} + +impl DelegateTool { + fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result { + let prompt = args["prompt"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))? + .to_string(); + + let allowed_tools: Option> = args["allowed_tools"] + .as_array() + .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); + + let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize); + let timeout_secs = args["timeout_secs"].as_u64(); + + Ok(SubAgentConfig { + prompt, + mode: ExecutionMode::Inline, + allowed_tools, + max_iterations, + timeout_secs, + }) + } + + async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result { + let mode_str = args["mode"].as_str().unwrap_or("inline"); + let mode = match mode_str { + "inline" => ExecutionMode::Inline, + "background" => ExecutionMode::Background, + "parallel" => ExecutionMode::Parallel, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)), + }) + } + }; + + match mode { + ExecutionMode::Inline => { + let config = self.parse_config_from_args(args)?; + let result = self.sub_agent_manager.run_inline(config).await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + match result.status { + TaskStatus::Completed => Ok(ToolResult { + success: true, + output: result.content, + error: None, + }), + TaskStatus::Failed(err) => Ok(ToolResult { + success: false, + output: result.content, + error: Some(err), + }), + TaskStatus::TimedOut => Ok(ToolResult { + success: false, + output: result.content, + error: Some("sub-agent timed out".into()), + }), + TaskStatus::Cancelled => Ok(ToolResult { + success: false, + output: result.content, + error: Some("sub-agent cancelled".into()), + }), + } + } + ExecutionMode::Background => { + let config = self.parse_config_from_args(args)?; + let ctx = crate::agent::sub_agent::get_delegate_context() + .map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?; + + let task_id = self.sub_agent_manager.run_background(config, ctx).await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + Ok(ToolResult { + success: true, + output: format!("后台任务已启动。\ntask_id: {}", task_id), + error: None, + }) + } + ExecutionMode::Parallel => { + let tasks = args["tasks"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?; + + let mut configs = Vec::new(); + for task in tasks { + let prompt = task["prompt"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))? + .to_string(); + let allowed_tools: Option> = task["allowed_tools"] + .as_array() + .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); + + configs.push(SubAgentConfig { + prompt, + mode: ExecutionMode::Inline, + allowed_tools, + max_iterations: args["max_iterations"].as_u64().map(|v| v as usize), + timeout_secs: args["timeout_secs"].as_u64(), + }); + } + + let has_args_allowed = args["allowed_tools"].as_array().is_some(); + for c in &mut configs { + if c.allowed_tools.is_none() && has_args_allowed { + c.allowed_tools = args["allowed_tools"] + .as_array() + .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); + } + } + + let results = self.sub_agent_manager.run_parallel(configs).await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut output = String::new(); + for (i, r) in results.iter().enumerate() { + let status_icon = match r.status { + TaskStatus::Completed => "✅", + TaskStatus::Failed(_) => "❌", + TaskStatus::TimedOut => "⏱️ 超时", + TaskStatus::Cancelled => "🚫 已取消", + }; + output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon)); + if !r.content.is_empty() { + output.push_str(&r.content); + output.push_str("\n\n"); + } + if let TaskStatus::Failed(ref err) = r.status { + output.push_str(&format!("错误: {}\n\n", err)); + } + } + + let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed)); + Ok(ToolResult { + success: all_success, + output: output.trim().to_string(), + error: None, + }) + } + } + } + + async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result { + let task_id = args["task_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?; + + match self.sub_agent_manager.check_task(task_id).await { + Some(task) => { + let status_icon = match task.status.as_str() { + "completed" => "✅ 已完成", + "failed" => "❌ 失败", + "cancelled" => "🚫 已取消", + "running" => "🔄 运行中", + "pending" => "⏳ 等待中", + _ => task.status.as_str(), + }; + let mut output = format!( + "任务 ID: {}\n状态: {}\n任务: {}", + task.id, status_icon, task.prompt + ); + if let Some(ref result) = task.result { + output.push_str(&format!("\n\n结果:\n{}", result)); + } + if let Some(ref error) = task.error { + output.push_str(&format!("\n错误: {}", error)); + } + if let Some(started) = task.started_at { + if let Some(finished) = task.finished_at { + let duration = (finished - started) as f64 / 1000.0; + output.push_str(&format!("\n耗时: {:.1}s", duration)); + } + } + Ok(ToolResult { + success: true, + output, + error: None, + }) + } + None => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("task not found: {}", task_id)), + }), + } + } + + async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result { + let task_id = args["task_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?; + + match self.sub_agent_manager.cancel_task(task_id).await { + Ok(true) => Ok(ToolResult { + success: true, + output: format!("后台任务 {} 已取消", task_id), + error: None, + }), + Ok(false) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)), + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("取消失败: {}", e)), + }), + } + } + + async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result { + let ctx = crate::agent::sub_agent::get_delegate_context() + .map_err(|_| anyhow::anyhow!("delegate context not available"))?; + let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await; + + if tasks.is_empty() { + return Ok(ToolResult { + success: true, + output: "没有后台任务".to_string(), + error: None, + }); + } + + let mut output = String::from("后台任务列表:\n\n"); + for task in &tasks { + let status_icon = match task.status.as_str() { + "completed" => "✅", + "failed" => "❌", + "cancelled" => "🚫", + "running" => "🔄", + "pending" => "⏳", + _ => "❓", + }; + output.push_str(&format!( + "{} {} - {} - {} (created: {})\n", + status_icon, + &task.id[..std::cmp::min(8, task.id.len())], + task.prompt.chars().take(60).collect::(), + task.status, + task.created_at, + )); + } + output.push_str(&format!("\n共 {} 个任务", tasks.len())); + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index b1608e9..b94fec7 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -4,6 +4,7 @@ pub mod calculator; pub mod chat_manager; pub mod content_search; pub mod cron; +pub mod delegate; pub mod file_edit; pub mod file_read; pub mod file_search; @@ -23,6 +24,7 @@ pub use browser::BrowserTool; pub use calculator::CalculatorTool; pub use chat_manager::ChatManagerTool; pub use content_search::ContentSearchTool; +pub use delegate::DelegateTool; pub use file_edit::FileEditTool; pub use file_read::FileReadTool; pub use file_search::FileSearchTool; @@ -36,6 +38,7 @@ pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use web_fetch::WebFetchTool; use std::sync::Arc; +use crate::agent::SubAgentManager; use crate::config::BrowserConfig; use crate::memory::MemoryManager; use crate::skills::SkillsLoader; @@ -46,6 +49,7 @@ use crate::skills::SkillsLoader; pub fn create_default_tools( skills_loader: Arc, memory: Arc, + sub_agent_manager: Option>, browser_config: Option<&BrowserConfig>, ) -> ToolRegistry { let registry = ToolRegistry::new(); @@ -76,5 +80,9 @@ pub fn create_default_tools( } } + if let Some(mgr) = sub_agent_manager { + registry.register(DelegateTool::new(mgr)); + } + registry } diff --git a/src/tools/registry.rs b/src/tools/registry.rs index 3b2cc53..84df8b7 100644 --- a/src/tools/registry.rs +++ b/src/tools/registry.rs @@ -20,6 +20,11 @@ impl ToolRegistry { self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool)); } + /// Register an existing Arc-wrapped tool by name + pub fn register_raw(&self, name: String, tool: Arc) { + self.tools.lock().unwrap().insert(name, tool); + } + pub fn get(&self, name: &str) -> Option> { self.tools.lock().unwrap().get(name).cloned() } @@ -62,6 +67,17 @@ impl ToolRegistry { .map(|(k, v)| (k.clone(), v.clone())) .collect() } + + /// 生成工具列表描述,用于子 Agent 系统提示词 + pub fn describe_for_prompt(&self) -> String { + let mut entries: Vec = self + .iter() + .into_iter() + .map(|(name, tool)| format!("- {}: {}", name, tool.description())) + .collect(); + entries.sort(); + entries.join("\n") + } } impl Default for ToolRegistry {