From bee1a39a063fee846fead0b8578dca167fc4e950 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 16 May 2026 16:12:28 +0800 Subject: [PATCH] feat: add task management tool with subagent support - Introduced `TaskConfig` struct to manage task-related configurations. - Implemented `TaskTool` for creating and managing subagents for complex tasks. - Added `TaskSession` and `TaskRepository` for handling task sessions and persistence. - Created `DefaultSubAgentRuntime` to execute tasks with timeout and history support. - Enhanced `ToolContext` to include `subagent_description` for better context tracking. - Implemented error handling for task execution and session management. - Updated `ToolRegistryFactory` to register task tools conditionally based on configuration. - Added prompt builders for subagent tasks to improve interaction clarity. --- src/config/mod.rs | 54 +++- src/gateway/agent_factory.rs | 1 + src/gateway/mod.rs | 1 + src/gateway/runtime.rs | 64 +++-- src/gateway/session.rs | 2 + src/gateway/tool_registry_factory.rs | 96 ++++++- src/tools/mod.rs | 5 + src/tools/task/error.rs | 48 ++++ src/tools/task/mod.rs | 13 + src/tools/task/prompt.rs | 80 ++++++ src/tools/task/repository.rs | 98 +++++++ src/tools/task/runtime.rs | 378 +++++++++++++++++++++++++++ src/tools/task/tool.rs | 153 +++++++++++ src/tools/task/types.rs | 190 ++++++++++++++ src/tools/traits.rs | 2 + 15 files changed, 1162 insertions(+), 23 deletions(-) create mode 100644 src/tools/task/error.rs create mode 100644 src/tools/task/mod.rs create mode 100644 src/tools/task/prompt.rs create mode 100644 src/tools/task/repository.rs create mode 100644 src/tools/task/runtime.rs create mode 100644 src/tools/task/tool.rs create mode 100644 src/tools/task/types.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 0bff93c..8f3e78b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -100,15 +100,63 @@ impl Default for SkillsConfig { } } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ToolsConfig { #[serde(default)] pub disabled: Vec, + #[serde(default)] + pub task: TaskConfig, } -impl Default for ToolsConfig { +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TaskConfig { + #[serde(default = "default_task_enabled")] + pub enabled: bool, + #[serde(default = "default_task_max_execution_secs")] + pub max_execution_secs: u64, + #[serde(default = "default_task_ttl_hours")] + pub ttl_hours: u64, + #[serde(default = "default_task_allowed_tools")] + pub allowed_tools: Vec, +} + +fn default_task_enabled() -> bool { + true +} + +fn default_task_max_execution_secs() -> u64 { + 300 +} + +fn default_task_ttl_hours() -> u64 { + 24 +} + +fn default_task_allowed_tools() -> Vec { + vec![ + "file_read".to_string(), + "file_edit".to_string(), + "file_write".to_string(), + "bash".to_string(), + "http_request".to_string(), + "web_fetch".to_string(), + "memory_search".to_string(), + "get_time".to_string(), + "calculator".to_string(), + "skill_activate".to_string(), + "skill_list".to_string(), + "send_session_message".to_string(), + ] +} + +impl Default for TaskConfig { fn default() -> Self { - Self { disabled: Vec::new() } + Self { + enabled: default_task_enabled(), + max_execution_secs: default_task_max_execution_secs(), + ttl_hours: default_task_ttl_hours(), + allowed_tools: default_task_allowed_tools(), + } } } diff --git a/src/gateway/agent_factory.rs b/src/gateway/agent_factory.rs index 167ec76..4f2e6aa 100644 --- a/src/gateway/agent_factory.rs +++ b/src/gateway/agent_factory.rs @@ -71,6 +71,7 @@ impl AgentFactory { session_id: Some(session_id), message_id: request.message_id.map(str::to_string), message_seq: None, + subagent_description: None, }) }) } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 4dcb9c1..71fcfc5 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -81,6 +81,7 @@ impl GatewayState { skills, Arc::new(BusSessionMessageSender::new(bus.clone())), std::collections::HashSet::new(), + config.tools.task.clone(), chat_history_ttl_hours, session_ttl_hours, )?; diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 6664af1..6e7e151 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -2,14 +2,17 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::agent::AgentError; -use crate::config::LLMProviderConfig; +use crate::config::{LLMProviderConfig, TaskConfig}; use crate::gateway::tool_registry_factory::ToolRegistryFactory; use crate::skills::SkillRuntime; use crate::storage::{ ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, SessionStore, SkillEventRepository, }; -use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry}; +use crate::tools::{ + DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender, + SessionMessageSender, SubAgentRuntime, SubAgentRuntimeConfig, ToolRegistry, +}; use super::agent_factory::AgentFactory; use super::cli_session::CliSessionService; @@ -29,6 +32,7 @@ pub(crate) fn build_session_manager( provider_configs: HashMap, skills: Arc, disabled_tools: HashSet, + task_config: TaskConfig, chat_history_ttl_hours: Option, session_ttl_hours: Option, ) -> Result { @@ -41,6 +45,7 @@ pub(crate) fn build_session_manager( skills, Arc::new(NoopSessionMessageSender), disabled_tools, + task_config, chat_history_ttl_hours, session_ttl_hours, ) @@ -55,6 +60,7 @@ pub(crate) fn build_session_manager_with_sender( skills: Arc, session_message_sender: Arc, disabled_tools: HashSet, + task_config: TaskConfig, chat_history_ttl_hours: Option, session_ttl_hours: Option, ) -> Result { @@ -74,20 +80,49 @@ pub(crate) fn build_session_manager_with_sender( let memories: Arc = store.clone(); let scheduler_jobs: Arc = store.clone(); let skill_events: Arc = store.clone(); - let tools = Arc::new( - ToolRegistryFactory::new( - skills.clone(), - memories, - scheduler_jobs, - skill_events.clone(), - session_message_sender, - known_agents, - default_timezone, - disabled_tools, - ) - .build(), + let conversations: Arc = store.clone(); + + // 创建 ToolRegistryFactory + let factory = ToolRegistryFactory::new( + skills.clone(), + memories, + scheduler_jobs, + skill_events.clone(), + session_message_sender.clone(), + conversations.clone(), + known_agents, + default_timezone, + disabled_tools, + task_config.clone(), ); + // 创建 SubAgentRuntime(如果 task 工具启用) + let factory = if task_config.enabled { + let task_repository = Arc::new(InMemoryTaskRepository::new()); + let subagent_tools = Arc::new(factory.build_subagent_tools()); + + let runtime_config = SubAgentRuntimeConfig { + allowed_tools: task_config.allowed_tools.iter().cloned().collect(), + max_execution_secs: task_config.max_execution_secs, + explore_max_tool_calls: 20, + ttl_hours: task_config.ttl_hours, + }; + + let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new( + runtime_config, + task_repository, + conversations.clone(), + subagent_tools, + provider_config.clone(), + )); + + factory.with_subagent_runtime(subagent_runtime) + } else { + factory + }; + + let tools = Arc::new(factory.build()); + let prompt_repository: Arc = store.clone(); let agent_factory = AgentFactory::new( tools.clone(), @@ -95,7 +130,6 @@ pub(crate) fn build_session_manager_with_sender( agent_prompt_reinject_every as usize, prompt_repository.clone(), ); - let conversations: Arc = store.clone(); let session_factory = SessionFactory::new( provider_config.clone(), skills.clone(), diff --git a/src/gateway/session.rs b/src/gateway/session.rs index e454608..d45bdf0 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -485,6 +485,7 @@ impl SessionManager { provider_configs: HashMap, skills: Arc, disabled_tools: std::collections::HashSet, + task_config: crate::config::TaskConfig, chat_history_ttl_hours: Option, session_ttl_hours: Option, ) -> Result { @@ -496,6 +497,7 @@ impl SessionManager { provider_configs, skills, disabled_tools, + task_config, chat_history_ttl_hours, session_ttl_hours, ) diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs index 8bbb047..0f984ce 100644 --- a/src/gateway/tool_registry_factory.rs +++ b/src/gateway/tool_registry_factory.rs @@ -1,13 +1,15 @@ use std::collections::HashSet; use std::sync::Arc; +use crate::config::TaskConfig; use crate::skills::SkillRuntime; -use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; +use crate::storage::{ConversationRepository, MemoryRepository, SchedulerJobRepository, SkillEventRepository}; use crate::tools::{ - BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, - MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender, - SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TimeTool, ToolRegistry, - WebFetchTool, + BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, + HttpRequestTool, MemoryManageTool, MemorySearchTool, + SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SkillListTool, + SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, + ToolRegistry, WebFetchTool, }; pub(crate) struct ToolRegistryFactory { @@ -16,9 +18,12 @@ pub(crate) struct ToolRegistryFactory { scheduler_jobs: Arc, skill_events: Arc, session_message_sender: Arc, + conversations: Arc, known_agents: HashSet, default_timezone: String, disabled_tools: HashSet, + task_config: TaskConfig, + subagent_runtime: Option>, } impl ToolRegistryFactory { @@ -28,9 +33,11 @@ impl ToolRegistryFactory { scheduler_jobs: Arc, skill_events: Arc, session_message_sender: Arc, + conversations: Arc, known_agents: HashSet, default_timezone: String, disabled_tools: HashSet, + task_config: TaskConfig, ) -> Self { Self { skills, @@ -38,12 +45,23 @@ impl ToolRegistryFactory { scheduler_jobs, skill_events, session_message_sender, + conversations, known_agents, default_timezone, disabled_tools, + task_config, + subagent_runtime: None, } } + pub(crate) fn with_subagent_runtime( + mut self, + runtime: Arc, + ) -> Self { + self.subagent_runtime = Some(runtime); + self + } + fn is_enabled(&self, tool_name: &str) -> bool { !self.disabled_tools.contains(tool_name) } @@ -108,6 +126,74 @@ impl ToolRegistryFactory { registry.register(WebFetchTool::new(50_000, 30)); } + // 注册 Task 工具(如果启用且有 subagent_runtime) + if self.is_enabled("task") && self.task_config.enabled { + if let Some(runtime) = &self.subagent_runtime { + registry.register(TaskTool::new(runtime.clone())); + } + } + + registry + } + + /// 构建子代理专用工具集(不包含 task 工具防止递归) + pub(crate) fn build_subagent_tools(&self) -> ToolRegistry { + let mut registry = ToolRegistry::new(); + + // 基础工具 + if self.is_enabled("calculator") { + registry.register(CalculatorTool::new()); + } + if self.is_enabled("get_time") { + registry.register(TimeTool::new(self.default_timezone.clone())); + } + if self.is_enabled("file_read") { + registry.register(FileReadTool::new()); + } + if self.is_enabled("file_write") { + registry.register(FileWriteTool::new()); + } + if self.is_enabled("file_edit") { + registry.register(FileEditTool::new()); + } + if self.is_enabled("bash") { + registry.register(BashTool::new()); + } + if self.is_enabled("http_request") { + registry.register(HttpRequestTool::new( + vec!["*".to_string()], + 1_000_000, + 30, + false, + )); + } + if self.is_enabled("web_fetch") { + registry.register(WebFetchTool::new(50_000, 30)); + } + + // 记忆工具(只读) + if self.is_enabled("memory_search") { + registry.register(MemorySearchTool::new(self.memories.clone())); + } + + // Skill 工具 + if self.is_enabled("skill_activate") { + registry.register(SkillActivateTool::new( + self.skills.clone(), + self.skill_events.clone(), + )); + } + if self.is_enabled("skill_list") { + registry.register(SkillListTool::new(self.skills.clone())); + } + + // 进度通知工具 + if self.is_enabled("session_send") { + registry.register(SessionSendTool::new(self.session_message_sender.clone())); + } + + // 注意:不注册 task 工具,防止递归创建子代理 + registry } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index e78a313..b0fdee4 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -12,6 +12,7 @@ pub mod session_send; pub mod schema; pub mod skill_activate; pub mod skill_manage; +pub mod task; pub mod time; pub mod traits; pub mod web_fetch; @@ -33,6 +34,10 @@ pub use session_send::{ pub use schema::{CleaningStrategy, SchemaCleanr}; pub use skill_activate::SkillActivateTool; pub use skill_manage::{SkillListTool, SkillManageTool}; +pub use task::{ + DefaultSubAgentRuntime, InMemoryTaskRepository, SubAgentRuntime, SubAgentRuntimeConfig, + TaskError, TaskRepository, TaskTool, +}; pub use time::TimeTool; pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; diff --git a/src/tools/task/error.rs b/src/tools/task/error.rs new file mode 100644 index 0000000..fc5bcf9 --- /dev/null +++ b/src/tools/task/error.rs @@ -0,0 +1,48 @@ +use crate::storage::StorageError; + +/// 任务错误类型 +#[derive(Debug, thiserror::Error)] +pub enum TaskError { + #[error("Task session not found: {0}")] + SessionNotFound(String), + + #[error("Invalid parent session")] + InvalidParentSession, + + #[error("Failed to create subagent: {0}")] + AgentCreationFailed(String), + + #[error("Execution failed: {0}")] + ExecutionFailed(String), + + #[error("Task execution timed out")] + Timeout, + + #[error("Repository error: {0}")] + RepositoryError(#[from] StorageError), + + #[error("Serialization error: {0}")] + SerializationError(#[from] serde_json::Error), + + #[error("Missing required context: {0}")] + MissingContext(String), + + #[error("Invalid arguments: {0}")] + InvalidArguments(String), +} + +impl TaskError { + pub fn as_status(&self) -> &'static str { + match self { + Self::Timeout => "timeout", + Self::SessionNotFound(_) => "failed", + Self::InvalidParentSession => "failed", + Self::AgentCreationFailed(_) => "failed", + Self::ExecutionFailed(_) => "failed", + Self::RepositoryError(_) => "failed", + Self::SerializationError(_) => "failed", + Self::MissingContext(_) => "failed", + Self::InvalidArguments(_) => "failed", + } + } +} \ No newline at end of file diff --git a/src/tools/task/mod.rs b/src/tools/task/mod.rs new file mode 100644 index 0000000..37619fc --- /dev/null +++ b/src/tools/task/mod.rs @@ -0,0 +1,13 @@ +pub mod error; +pub mod prompt; +pub mod repository; +pub mod runtime; +pub mod tool; +pub mod types; + +pub use error::TaskError; +pub use prompt::SubagentPromptBuilder; +pub use repository::{InMemoryTaskRepository, TaskRepository}; +pub use runtime::{DefaultSubAgentRuntime, SubAgentRuntime, SubAgentRuntimeConfig, StaticSystemPromptProvider}; +pub use tool::TaskTool; +pub use types::{SubagentType, TaskDefinition, TaskHandle, TaskSession, TaskSessionState, TaskToolArgs, TaskToolResult}; \ No newline at end of file diff --git a/src/tools/task/prompt.rs b/src/tools/task/prompt.rs new file mode 100644 index 0000000..8e73353 --- /dev/null +++ b/src/tools/task/prompt.rs @@ -0,0 +1,80 @@ +use super::types::SubagentType; + +/// 子代理系统提示词构建器 +pub struct SubagentPromptBuilder; + +impl SubagentPromptBuilder { + /// 构建子代理系统提示词 + pub fn build( + subagent_type: SubagentType, + description: &str, + _prompt: &str, + ) -> String { + match subagent_type { + SubagentType::General => Self::build_general_prompt(description), + SubagentType::Explore => Self::build_explore_prompt(description), + } + } + + /// 构建恢复任务的提示词 + pub fn build_resume_prompt(session_description: &str, additional_prompt: &str) -> String { + format!( + "你正在继续执行一个之前创建的子代理任务。\n\n\ + 任务描述: {}\n\n\ + 继续执行指令: {}\n\n\ + 你应该:\n\ + 1. 回顾之前的工作进度(如果已有历史)\n\ + 2. 继续完成任务,不要偏离目标\n\ + 3. 完成后给出简洁的总结\n\ + 4. 不要尝试创建新的子代理任务\n\n\ + 注意: 你在一个独立的执行上下文中,没有访问主对话历史的权限。", + session_description, additional_prompt + ) + } + + fn build_general_prompt(description: &str) -> String { + format!( + "你是一个专注的子代理,正在执行一个独立任务。\n\n\ + 任务描述: {}\n\n\ + 你应该:\n\ + 1. 专注于完成任务,不要偏离目标\n\ + 2. 使用可用的工具进行必要操作\n\ + 3. 完成后给出简洁的总结\n\ + 4. 不要尝试创建新的子代理任务\n\n\ + 注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。", + description + ) + } + + fn build_explore_prompt(description: &str) -> String { + format!( + "你是一个只读探索代理,用于代码库探索和信息收集。\n\n\ + 任务描述: {}\n\n\ + 你应该:\n\ + 1. 只使用只读工具进行探索\n\ + 2. 专注于理解和收集信息\n\ + 3. 不要进行任何写操作\n\ + 4. 给出简洁的发现总结\n\n\ + 注意: 你是一个只读代理,禁止执行任何修改操作。", + description + ) + } +} + +/// 从子代理输出提取简洁摘要 +pub fn extract_summary(content: &str) -> String { + // 取第一段或前 500 字符 + let first_paragraph = content + .lines() + .take_while(|line| !line.trim().is_empty()) + .collect::>() + .join("\n"); + + if first_paragraph.len() > 500 { + first_paragraph.chars().take(500).collect() + } else if first_paragraph.is_empty() { + content.chars().take(200).collect() + } else { + first_paragraph + } +} \ No newline at end of file diff --git a/src/tools/task/repository.rs b/src/tools/task/repository.rs new file mode 100644 index 0000000..e59a5a5 --- /dev/null +++ b/src/tools/task/repository.rs @@ -0,0 +1,98 @@ +use std::collections::HashMap; +use std::sync::RwLock; + +use async_trait::async_trait; + +use crate::storage::StorageError; + +use super::types::TaskSession; + +/// 任务持久化接口 +#[async_trait] +pub trait TaskRepository: Send + Sync + 'static { + /// 保存任务会话 + async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError>; + + /// 加载任务会话 + async fn load_task_session(&self, task_id: &str) -> Result, StorageError>; + + /// 删除任务会话 + async fn delete_task_session(&self, task_id: &str) -> Result; + + /// 列出父会话的所有任务 + async fn list_tasks_for_session( + &self, + parent_session_id: &str, + ) -> Result, StorageError>; + + /// 清理过期任务(超过指定小时) + async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result; +} + +/// 内存实现(用于测试) +pub struct InMemoryTaskRepository { + sessions: RwLock>, +} + +impl InMemoryTaskRepository { + pub fn new() -> Self { + Self { + sessions: RwLock::new(HashMap::new()), + } + } +} + +impl Default for InMemoryTaskRepository { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TaskRepository for InMemoryTaskRepository { + async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError> { + self.sessions + .write() + .unwrap() + .insert(session.id.clone(), session.clone()); + Ok(()) + } + + async fn load_task_session(&self, task_id: &str) -> Result, StorageError> { + Ok(self.sessions.read().unwrap().get(task_id).cloned()) + } + + async fn delete_task_session(&self, task_id: &str) -> Result { + Ok(self.sessions.write().unwrap().remove(task_id).is_some()) + } + + async fn list_tasks_for_session( + &self, + parent_session_id: &str, + ) -> Result, StorageError> { + Ok(self + .sessions + .read() + .unwrap() + .values() + .filter(|s| s.parent_session_id == parent_session_id) + .cloned() + .collect()) + } + + async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result { + let now = current_timestamp(); + let ttl_millis = ttl_hours * 3600 * 1000; + let mut sessions = self.sessions.write().unwrap(); + let before = sessions.len(); + sessions.retain(|_, s| now - s.updated_at < ttl_millis as i64); + Ok(before - sessions.len()) + } +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_millis() as i64 +} \ No newline at end of file diff --git a/src/tools/task/runtime.rs b/src/tools/task/runtime.rs new file mode 100644 index 0000000..1cd850b --- /dev/null +++ b/src/tools/task/runtime.rs @@ -0,0 +1,378 @@ +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; + +use crate::agent::{AgentLoop, AgentRuntimeConfig, SystemPrompt, SystemPromptContext, SystemPromptProvider}; +use crate::bus::ChatMessage; +use crate::config::LLMProviderConfig; +use crate::storage::ConversationRepository; +use crate::tools::{ToolContext, ToolRegistry}; + +use super::error::TaskError; +use super::prompt::{extract_summary, SubagentPromptBuilder}; +use super::repository::TaskRepository; +use super::types::{SubagentType, TaskDefinition, TaskHandle, TaskSession, TaskSessionState, TaskToolResult}; + +/// 子代理运行时配置 +#[derive(Debug, Clone)] +pub struct SubAgentRuntimeConfig { + /// 子代理可用的工具列表(白名单) + pub allowed_tools: HashSet, + /// 最大执行时间(秒) + pub max_execution_secs: u64, + /// 探索类型的最大工具调用次数 + pub explore_max_tool_calls: usize, + /// 任务 TTL(小时) + pub ttl_hours: u64, +} + +impl Default for SubAgentRuntimeConfig { + fn default() -> Self { + Self { + allowed_tools: HashSet::from([ + "file_read".to_string(), + "file_edit".to_string(), + "file_write".to_string(), + "bash".to_string(), + "http_request".to_string(), + "web_fetch".to_string(), + "memory_search".to_string(), + "get_time".to_string(), + "calculator".to_string(), + "skill_activate".to_string(), + "skill_list".to_string(), + "send_session_message".to_string(), // 用于进度通知 + ]), + max_execution_secs: 300, // 5分钟 + explore_max_tool_calls: 20, + ttl_hours: 24, + } + } +} + +/// 子代理运行时抽象接口 +#[async_trait] +pub trait SubAgentRuntime: Send + Sync + 'static { + /// 创建并执行子代理任务 + async fn spawn( + &self, + parent_context: &ToolContext, + task: TaskDefinition, + ) -> Result; + + /// 恢复现有任务 + async fn resume( + &self, + task_id: &str, + parent_context: &ToolContext, + additional_prompt: String, + ) -> Result; + + /// 发送消息给子代理(支持中断或补充指令) + async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>; + + /// 清理过期任务 + async fn cleanup_expired(&self) -> Result; +} + +/// 静态系统提示词提供者(用于子代理) +pub struct StaticSystemPromptProvider { + prompt: String, +} + +impl StaticSystemPromptProvider { + pub fn new(prompt: String) -> Self { + Self { prompt } + } +} + +impl SystemPromptProvider for StaticSystemPromptProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + Some(SystemPrompt { + content: self.prompt.clone(), + context: Some("subagent".to_string()), + }) + } +} + +/// 默认子代理运行时实现 +pub struct DefaultSubAgentRuntime { + config: SubAgentRuntimeConfig, + task_repository: Arc, + conversation_repository: Arc, + subagent_tools: Arc, + provider_config: LLMProviderConfig, +} + +impl DefaultSubAgentRuntime { + pub fn new( + config: SubAgentRuntimeConfig, + task_repository: Arc, + conversation_repository: Arc, + subagent_tools: Arc, + provider_config: LLMProviderConfig, + ) -> Self { + Self { + config, + task_repository, + conversation_repository, + subagent_tools, + provider_config, + } + } + + /// 创建子代理实例 + fn create_subagent( + &self, + session: &TaskSession, + system_prompt: String, + ) -> Result { + let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt)); + + AgentLoop::with_tools_and_system_prompt_provider( + AgentRuntimeConfig::from(self.provider_config.clone()), + self.subagent_tools.clone(), + prompt_provider, + None, // 子代理不需要 skill provider + ) + .map(|agent| { + agent.with_tool_context(ToolContext { + channel_name: Some(session.parent_channel_name.clone()), + sender_id: None, + chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id + session_id: Some(session.session_id.clone()), // 子代理自己的 session_id + message_id: None, + message_seq: None, + subagent_description: Some(session.description.clone()), + }) + }) + .map_err(|e| TaskError::AgentCreationFailed(e.to_string())) + } + + /// 执行任务(带超时控制) + async fn execute_task( + &self, + agent: AgentLoop, + session: &TaskSession, + prompt: String, + ) -> Result { + // 构建初始消息 + let history = vec![ChatMessage::user(prompt)]; + let system_prompt_context = SystemPromptContext { + session_id: Some(session.session_id.clone()), + chat_id: session.session_id.clone(), + user_message_count: 1, + }; + + // 设置超时 + let max_secs = if session.subagent_type == SubagentType::Explore { + self.config.max_execution_secs / 2 // Explore 类型时间更短 + } else { + self.config.max_execution_secs + }; + let timeout_duration = Duration::from_secs(max_secs); + + let result = tokio::time::timeout( + timeout_duration, + agent.process(history, Some(&system_prompt_context)), + ) + .await; + + match result { + Ok(Ok(process_result)) => { + let final_message = process_result.final_response; + Ok(TaskToolResult { + status: "success".to_string(), + summary: extract_summary(&final_message.content), + output: final_message.content, + task_id: session.id.clone(), + }) + } + Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), + Err(_) => Err(TaskError::Timeout), + } + } + + /// 使用历史继续执行 + async fn execute_task_with_history( + &self, + agent: AgentLoop, + session: &TaskSession, + additional_prompt: String, + ) -> Result { + // 加载历史 + 新消息 + let mut history = self + .conversation_repository + .load_messages(&session.session_id) + .map_err(TaskError::RepositoryError)?; + history.push(ChatMessage::user(additional_prompt)); + + let user_message_count = history.iter().filter(|m| m.role == "user").count(); + let system_prompt_context = SystemPromptContext { + session_id: Some(session.session_id.clone()), + chat_id: session.session_id.clone(), + user_message_count, + }; + + let timeout_duration = Duration::from_secs(self.config.max_execution_secs); + + let result = tokio::time::timeout( + timeout_duration, + agent.process(history, Some(&system_prompt_context)), + ) + .await; + + match result { + Ok(Ok(process_result)) => { + let final_message = process_result.final_response; + Ok(TaskToolResult { + status: "success".to_string(), + summary: extract_summary(&final_message.content), + output: final_message.content, + task_id: session.id.clone(), + }) + } + Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), + Err(_) => Err(TaskError::Timeout), + } + } +} + +#[async_trait] +impl SubAgentRuntime for DefaultSubAgentRuntime { + async fn spawn( + &self, + parent_context: &ToolContext, + task: TaskDefinition, + ) -> Result { + // 1. 验证上下文 + let session_id = parent_context + .session_id + .clone() + .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; + let chat_id = parent_context + .chat_id + .clone() + .ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?; + let channel_name = parent_context + .channel_name + .clone() + .ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?; + + // 2. 创建任务会话 + let session = TaskSession::new( + session_id, + chat_id, + channel_name, + task.description.clone(), + task.subagent_type, + ); + + // 3. 保存会话 + self.task_repository.save_task_session(&session).await?; + + // 4. 构建子代理系统提示词 + let system_prompt = SubagentPromptBuilder::build( + task.subagent_type, + &task.description, + &task.prompt, + ); + + // 5. 创建子代理 + let agent = self.create_subagent(&session, system_prompt)?; + + // 6. 执行任务 + let result = self + .execute_task(agent, &session, task.prompt.clone()) + .await; + + // 7. 更新会话状态并保存 + match result { + Ok(tool_result) => { + let mut session = session; + session.mark_completed(tool_result.summary.clone()); + self.task_repository.save_task_session(&session).await?; + Ok(tool_result) + } + Err(e) => { + let mut session = session; + let status = e.as_status(); + if status == "timeout" { + session.mark_timeout(); + } else { + session.mark_failed(e.to_string()); + } + self.task_repository.save_task_session(&session).await?; + Err(e) + } + } + } + + async fn resume( + &self, + task_id: &str, + parent_context: &ToolContext, + additional_prompt: String, + ) -> Result { + // 1. 加载现有会话 + let session = self + .task_repository + .load_task_session(task_id) + .await? + .ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?; + + // 2. 验证父会话匹配 + let parent_session_id = parent_context + .session_id + .clone() + .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; + if session.parent_session_id != parent_session_id { + return Err(TaskError::InvalidParentSession); + } + + // 3. 构建恢复提示词 + let system_prompt = SubagentPromptBuilder::build_resume_prompt( + &session.description, + &additional_prompt, + ); + + // 4. 创建子代理 + let agent = self.create_subagent(&session, system_prompt)?; + + // 5. 使用历史继续执行 + let result = self + .execute_task_with_history(agent, &session, additional_prompt) + .await; + + // 6. 更新会话状态 + match result { + Ok(tool_result) => { + let mut session = session; + session.mark_completed(tool_result.summary.clone()); + self.task_repository.save_task_session(&session).await?; + Ok(tool_result) + } + Err(e) => { + let mut session = session; + session.mark_failed(e.to_string()); + self.task_repository.save_task_session(&session).await?; + Err(e) + } + } + } + + async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> { + // TODO: 实现双向通信 + // 需要在 TaskSession 中添加 pending_messages 队列 + Err(TaskError::InvalidArguments("send_message not implemented yet".to_string())) + } + + async fn cleanup_expired(&self) -> Result { + self.task_repository + .cleanup_expired_tasks(self.config.ttl_hours) + .await + .map_err(TaskError::from) + } +} \ No newline at end of file diff --git a/src/tools/task/tool.rs b/src/tools/task/tool.rs new file mode 100644 index 0000000..6a0b66d --- /dev/null +++ b/src/tools/task/tool.rs @@ -0,0 +1,153 @@ +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +use crate::tools::{Tool, ToolContext, ToolResult}; + +use super::runtime::SubAgentRuntime; +use super::types::{TaskDefinition, TaskToolArgs}; + +/// Task 工具 - 创建和管理子代理 +pub struct TaskTool { + runtime: Arc, +} + +impl TaskTool { + pub fn new(runtime: Arc) -> Self { + Self { runtime } + } +} + +#[async_trait] +impl Tool for TaskTool { + fn name(&self) -> &str { + "task" + } + + fn description(&self) -> &str { + "Launch a specialized subagent to handle complex, multi-step tasks. \ + Subagents run in isolated contexts and can work in parallel. \ + Use 'general' type for complex tasks, 'explore' type for read-only exploration. \ + You can resume a previous task by providing its task_id." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Short description (3-5 words) of what this task does", + "maxLength": 50 + }, + "prompt": { + "type": "string", + "description": "Detailed instructions for the subagent to execute" + }, + "subagent_type": { + "type": "string", + "enum": ["general", "explore"], + "default": "general", + "description": "Type of subagent: 'general' for complex multi-step tasks, 'explore' for read-only search/exploration" + }, + "task_id": { + "type": "string", + "description": "Optional: Resume an existing task session by providing its task_id" + } + }, + "required": ["description", "prompt"] + }) + } + + fn read_only(&self) -> bool { + false + } + + fn exclusive(&self) -> bool { + // Task 工具创建子代理,不应与其他工具并发执行 + true + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + // Task 工具必须通过 execute_with_context 获取父会话信息 + Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "task tool requires tool context with session_id, chat_id, and channel_name" + .to_string(), + ), + }) + } + + async fn execute_with_context( + &self, + context: &ToolContext, + args: serde_json::Value, + ) -> anyhow::Result { + // 1. 解析参数 + let task_args: TaskToolArgs = serde_json::from_value(args.clone()) + .map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?; + + // 2. 验证描述长度 + let word_count = task_args.description.split_whitespace().count(); + if task_args.description.len() > 50 || word_count > 7 || word_count < 1 { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "description should be 1-5 words, max 50 characters".to_string(), + ), + }); + } + + // 3. 验证上下文 + if context.session_id.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("task tool requires session_id in context".to_string()), + }); + } + if context.chat_id.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("task tool requires chat_id in context".to_string()), + }); + } + if context.channel_name.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("task tool requires channel_name in context".to_string()), + }); + } + + // 4. 执行任务 + let result = if let Some(task_id) = task_args.task_id { + // 恢复现有任务 + self.runtime + .resume(&task_id, context, task_args.prompt) + .await + } else { + // 创建新任务 + let task_def = TaskDefinition::from(task_args); + self.runtime.spawn(context, task_def).await + }; + + // 5. 构建返回结果 + match result { + Ok(task_result) => Ok(ToolResult { + success: true, + output: serde_json::to_string(&task_result)?, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} \ No newline at end of file diff --git a/src/tools/task/types.rs b/src/tools/task/types.rs new file mode 100644 index 0000000..349ff79 --- /dev/null +++ b/src/tools/task/types.rs @@ -0,0 +1,190 @@ +use serde::{Deserialize, Serialize}; + +/// 子代理会话状态 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum TaskSessionState { + /// 正在执行 + Running, + /// 已完成 + Completed, + /// 已失败 + Failed, + /// 已超时 + Timeout, +} + +impl Default for TaskSessionState { + fn default() -> Self { + Self::Running + } +} + +/// 子代理类型 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SubagentType { + /// 通用型 - 处理复杂多步骤任务 + #[default] + General, + /// 探索型 - 只读搜索代理 + Explore, +} + +impl SubagentType { + pub fn as_str(&self) -> &'static str { + match self { + Self::General => "general", + Self::Explore => "explore", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "general" => Some(Self::General), + "explore" => Some(Self::Explore), + _ => None, + } + } +} + +/// 任务会话记录 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskSession { + /// 任务唯一 ID (UUID) + pub id: String, + /// 子代理独立的 session_id(存储在 message 表) + pub session_id: String, + /// 父会话 ID (用于关联) + pub parent_session_id: String, + /// 父 chat_id + pub parent_chat_id: String, + /// 父 channel_name + pub parent_channel_name: String, + /// 任务描述 + pub description: String, + /// 子代理类型 + pub subagent_type: SubagentType, + /// 当前状态 + pub state: TaskSessionState, + /// 创建时间 + pub created_at: i64, + /// 最后更新时间 + pub updated_at: i64, + /// 执行摘要 + pub summary: Option, + /// 错误信息 + pub error: Option, +} + +impl TaskSession { + pub fn new( + parent_session_id: String, + parent_chat_id: String, + parent_channel_name: String, + description: String, + subagent_type: SubagentType, + ) -> Self { + let id = format!("task:{}", uuid::Uuid::new_v4()); + let session_id = format!("sub:{}:{}", parent_session_id, id); + let now = current_timestamp(); + Self { + id, + session_id, + parent_session_id, + parent_chat_id, + parent_channel_name, + description, + subagent_type, + state: TaskSessionState::Running, + created_at: now, + updated_at: now, + summary: None, + error: None, + } + } + + /// 标记完成 + pub fn mark_completed(&mut self, summary: String) { + self.state = TaskSessionState::Completed; + self.summary = Some(summary); + self.updated_at = current_timestamp(); + } + + /// 标记失败 + pub fn mark_failed(&mut self, error: String) { + self.state = TaskSessionState::Failed; + self.error = Some(error); + self.updated_at = current_timestamp(); + } + + /// 标记超时 + pub fn mark_timeout(&mut self) { + self.state = TaskSessionState::Timeout; + self.error = Some("Task execution timed out".to_string()); + self.updated_at = current_timestamp(); + } +} + +/// 任务工具参数 +#[derive(Debug, Clone, Deserialize)] +pub struct TaskToolArgs { + /// 简短描述(3-5词) + pub description: String, + /// 详细指令 + pub prompt: String, + /// 子代理类型 + #[serde(default)] + pub subagent_type: SubagentType, + /// 恢复现有会话的 task_id + #[serde(default)] + pub task_id: Option, +} + +/// 任务定义(用于 SubAgentRuntime::spawn) +#[derive(Debug, Clone)] +pub struct TaskDefinition { + pub description: String, + pub prompt: String, + pub subagent_type: SubagentType, + pub max_execution_secs: Option, +} + +impl From for TaskDefinition { + fn from(args: TaskToolArgs) -> Self { + Self { + description: args.description, + prompt: args.prompt, + subagent_type: args.subagent_type, + max_execution_secs: None, + } + } +} + +/// 任务句柄(运行中任务) +#[derive(Debug, Clone)] +pub struct TaskHandle { + pub task_id: String, + pub session_id: String, + pub status: TaskSessionState, +} + +/// 任务执行结果 +#[derive(Debug, Clone, Serialize)] +pub struct TaskToolResult { + /// 状态: success/failed/timeout + pub status: String, + /// 任务完成总结 + pub summary: String, + /// 详细输出 + pub output: String, + /// 会话 ID(用于恢复) + pub task_id: String, +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_millis() as i64 +} \ No newline at end of file diff --git a/src/tools/traits.rs b/src/tools/traits.rs index b00ca14..ce36f27 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -15,6 +15,8 @@ pub struct ToolContext { pub session_id: Option, pub message_id: Option, pub message_seq: Option, + /// 子代理标识,用于标注消息来源 + pub subagent_description: Option, } #[async_trait]