From 020b7aa77a64773b37b284cf4f774be85c940f12 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 16 May 2026 08:50:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E5=92=8C=E6=81=A2=E5=A4=8D=E5=AD=90=E4=BB=A3?= =?UTF-8?q?=E7=90=86=E4=BB=BB=E5=8A=A1=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E6=89=A7=E8=A1=8C=E5=92=8C=E7=8A=B6=E6=80=81=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/mod.rs | 58 ++++- src/gateway/runtime.rs | 7 +- src/gateway/tool_registry_factory.rs | 104 +++++++- src/tools/mod.rs | 4 + src/tools/task/error.rs | 48 ++++ src/tools/task/mod.rs | 11 + src/tools/task/prompt.rs | 125 +++++++++ src/tools/task/repository.rs | 98 +++++++ src/tools/task/runtime.rs | 370 +++++++++++++++++++++++++++ src/tools/task/tool.rs | 170 ++++++++++++ src/tools/task/types.rs | 169 ++++++++++++ src/tools/traits.rs | 2 + 12 files changed, 1162 insertions(+), 4 deletions(-) create mode 100644 src/tools/task/error.rs create mode 100644 src/tools/task/mod.rs create mode 100644 src/tools/task/prompt.rs create mode 100644 src/tools/task/repository.rs create mode 100644 src/tools/task/runtime.rs create mode 100644 src/tools/task/tool.rs create mode 100644 src/tools/task/types.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 0bff93c..2d71c96 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -104,11 +104,67 @@ impl Default for SkillsConfig { pub struct ToolsConfig { #[serde(default)] pub disabled: Vec, + #[serde(default)] + pub task: TaskConfig, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TaskConfig { + #[serde(default = "default_task_enabled")] + pub enabled: bool, + #[serde(default = "default_task_max_execution_secs")] + pub max_execution_secs: u64, + #[serde(default = "default_task_ttl_hours")] + pub ttl_hours: u64, + #[serde(default = "default_task_allowed_tools")] + pub allowed_tools: Vec, +} + +fn default_task_enabled() -> bool { + true +} + +fn default_task_max_execution_secs() -> u64 { + 300 +} + +fn default_task_ttl_hours() -> u64 { + 24 +} + +fn default_task_allowed_tools() -> Vec { + vec![ + "file_read".to_string(), + "file_edit".to_string(), + "file_write".to_string(), + "bash".to_string(), + "http_request".to_string(), + "web_fetch".to_string(), + "memory_search".to_string(), + "get_time".to_string(), + "calculator".to_string(), + "skill_activate".to_string(), + "skill_list".to_string(), + ] } impl Default for ToolsConfig { fn default() -> Self { - Self { disabled: Vec::new() } + Self { + disabled: Vec::new(), + task: TaskConfig::default(), + } + } +} + +impl Default for TaskConfig { + fn default() -> Self { + Self { + enabled: default_task_enabled(), + max_execution_secs: default_task_max_execution_secs(), + ttl_hours: default_task_ttl_hours(), + allowed_tools: default_task_allowed_tools(), + } } } diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 6664af1..2eaa1eb 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -9,7 +9,7 @@ use crate::storage::{ ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, SessionStore, SkillEventRepository, }; -use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry}; +use crate::tools::{InMemoryTaskRepository, NoopSessionMessageSender, SessionMessageSender, ToolRegistry}; use super::agent_factory::AgentFactory; use super::cli_session::CliSessionService; @@ -74,6 +74,10 @@ pub(crate) fn build_session_manager_with_sender( let memories: Arc = store.clone(); let scheduler_jobs: Arc = store.clone(); let skill_events: Arc = store.clone(); + + // TaskRepository 使用内存存储(后续可以改为 SQLite) + let task_repository = Arc::new(InMemoryTaskRepository::new()); + let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), @@ -85,6 +89,7 @@ pub(crate) fn build_session_manager_with_sender( default_timezone, disabled_tools, ) + .with_task_deps(task_repository, provider_config.clone()) .build(), ); diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs index 8bbb047..6674f6d 100644 --- a/src/gateway/tool_registry_factory.rs +++ b/src/gateway/tool_registry_factory.rs @@ -1,13 +1,14 @@ use std::collections::HashSet; use std::sync::Arc; +use crate::config::LLMProviderConfig; use crate::skills::SkillRuntime; use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender, - SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TimeTool, ToolRegistry, - WebFetchTool, + SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TaskRuntime, + TaskTool, TimeTool, ToolRegistry, WebFetchTool, }; pub(crate) struct ToolRegistryFactory { @@ -19,6 +20,9 @@ pub(crate) struct ToolRegistryFactory { known_agents: HashSet, default_timezone: String, disabled_tools: HashSet, + // Task 工具需要的依赖 + task_repository: Option>, + provider_config: Option, } impl ToolRegistryFactory { @@ -41,9 +45,22 @@ impl ToolRegistryFactory { known_agents, default_timezone, disabled_tools, + task_repository: None, + provider_config: None, } } + /// 设置 Task 工具的依赖 + pub(crate) fn with_task_deps( + mut self, + task_repository: Arc, + provider_config: LLMProviderConfig, + ) -> Self { + self.task_repository = Some(task_repository); + self.provider_config = Some(provider_config); + self + } + fn is_enabled(&self, tool_name: &str) -> bool { !self.disabled_tools.contains(tool_name) } @@ -108,6 +125,89 @@ impl ToolRegistryFactory { registry.register(WebFetchTool::new(50_000, 30)); } + // 注册 Task 工具(需要 TaskRuntime) + if self.is_enabled("task") { + if let (Some(task_repository), Some(provider_config)) = + (&self.task_repository, &self.provider_config) + { + // 先创建一个临时的 ToolRegistry(不含 task 工具)用于子代理 + let subagent_tools = Arc::new(self.build_without_task()); + let task_runtime = Arc::new(TaskRuntime::new( + crate::tools::TaskRuntimeConfig::default(), + task_repository.clone(), + subagent_tools, + provider_config.clone(), + )); + registry.register(TaskTool::new(task_runtime)); + } + } + + registry + } + + /// 构建不含 task 工具的注册表(用于子代理) + fn build_without_task(&self) -> ToolRegistry { + let mut registry = ToolRegistry::new(); + + if self.is_enabled("calculator") { + registry.register(CalculatorTool::new()); + } + if self.is_enabled("get_time") { + registry.register(TimeTool::new(self.default_timezone.clone())); + } + if self.is_enabled("file_read") { + registry.register(FileReadTool::new()); + } + if self.is_enabled("file_write") { + registry.register(FileWriteTool::new()); + } + if self.is_enabled("file_edit") { + registry.register(FileEditTool::new()); + } + if self.is_enabled("memory_search") { + registry.register(MemorySearchTool::new(self.memories.clone())); + } + if self.is_enabled("memory_manage") { + registry.register(MemoryManageTool::new(self.memories.clone())); + } + if self.is_enabled("session_send") { + registry.register(SessionSendTool::new(self.session_message_sender.clone())); + } + if self.is_enabled("scheduler_manage") { + registry.register(SchedulerManageTool::new( + self.scheduler_jobs.clone(), + self.known_agents.clone(), + )); + } + if self.is_enabled("skill_activate") { + registry.register(SkillActivateTool::new( + self.skills.clone(), + self.skill_events.clone(), + )); + } + if self.is_enabled("skill_list") { + registry.register(SkillListTool::new(self.skills.clone())); + } + if self.is_enabled("skill_manage") { + registry.register(SkillManageTool::new(self.skills.clone())); + } + if self.is_enabled("bash") { + registry.register(BashTool::new()); + } + if self.is_enabled("http_request") { + registry.register(HttpRequestTool::new( + vec!["*".to_string()], + 1_000_000, + 30, + false, + )); + } + if self.is_enabled("web_fetch") { + registry.register(WebFetchTool::new(50_000, 30)); + } + + // 注意:不注册 task 工具,防止子代理递归创建子代理 + registry } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index e78a313..8b6f060 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -12,6 +12,7 @@ pub mod session_send; pub mod schema; pub mod skill_activate; pub mod skill_manage; +pub mod task; pub mod time; pub mod traits; pub mod web_fetch; @@ -33,6 +34,9 @@ pub use session_send::{ pub use schema::{CleaningStrategy, SchemaCleanr}; pub use skill_activate::SkillActivateTool; pub use skill_manage::{SkillListTool, SkillManageTool}; +pub use task::TaskTool; pub use time::TimeTool; pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; + +pub(crate) use task::{TaskRepository, TaskRuntime, TaskRuntimeConfig, InMemoryTaskRepository}; diff --git a/src/tools/task/error.rs b/src/tools/task/error.rs new file mode 100644 index 0000000..fc5bcf9 --- /dev/null +++ b/src/tools/task/error.rs @@ -0,0 +1,48 @@ +use crate::storage::StorageError; + +/// 任务错误类型 +#[derive(Debug, thiserror::Error)] +pub enum TaskError { + #[error("Task session not found: {0}")] + SessionNotFound(String), + + #[error("Invalid parent session")] + InvalidParentSession, + + #[error("Failed to create subagent: {0}")] + AgentCreationFailed(String), + + #[error("Execution failed: {0}")] + ExecutionFailed(String), + + #[error("Task execution timed out")] + Timeout, + + #[error("Repository error: {0}")] + RepositoryError(#[from] StorageError), + + #[error("Serialization error: {0}")] + SerializationError(#[from] serde_json::Error), + + #[error("Missing required context: {0}")] + MissingContext(String), + + #[error("Invalid arguments: {0}")] + InvalidArguments(String), +} + +impl TaskError { + pub fn as_status(&self) -> &'static str { + match self { + Self::Timeout => "timeout", + Self::SessionNotFound(_) => "failed", + Self::InvalidParentSession => "failed", + Self::AgentCreationFailed(_) => "failed", + Self::ExecutionFailed(_) => "failed", + Self::RepositoryError(_) => "failed", + Self::SerializationError(_) => "failed", + Self::MissingContext(_) => "failed", + Self::InvalidArguments(_) => "failed", + } + } +} \ No newline at end of file diff --git a/src/tools/task/mod.rs b/src/tools/task/mod.rs new file mode 100644 index 0000000..0309883 --- /dev/null +++ b/src/tools/task/mod.rs @@ -0,0 +1,11 @@ +mod error; +mod prompt; +mod repository; +mod runtime; +mod tool; +mod types; + +pub(crate) use repository::TaskRepository; +pub(crate) use repository::InMemoryTaskRepository; +pub(crate) use runtime::{TaskRuntime, TaskRuntimeConfig}; +pub use tool::TaskTool; \ No newline at end of file diff --git a/src/tools/task/prompt.rs b/src/tools/task/prompt.rs new file mode 100644 index 0000000..65b208d --- /dev/null +++ b/src/tools/task/prompt.rs @@ -0,0 +1,125 @@ +use crate::bus::ChatMessage; + +use super::types::SubagentType; + +/// 子代理系统提示词构建器 +pub struct SubagentPromptBuilder; + +impl SubagentPromptBuilder { + /// 构建子代理系统提示词 + pub fn build( + subagent_type: SubagentType, + description: &str, + _prompt: &str, + ) -> String { + match subagent_type { + SubagentType::General => Self::build_general_prompt(description), + SubagentType::Explore => Self::build_explore_prompt(description), + } + } + + /// 构建恢复任务的提示词 + pub fn build_resume_prompt(session_description: &str, additional_prompt: &str) -> String { + format!( + "你正在继续执行一个之前创建的子代理任务。\n\n\ + 任务描述: {}\n\n\ + 继续执行指令: {}\n\n\ + 你应该:\n\ + 1. 回顾之前的工作进度(如果已有历史)\n\ + 2. 继续完成任务,不要偏离目标\n\ + 3. 完成后给出简洁的总结\n\ + 4. 不要尝试创建新的子代理任务\n\n\ + 注意: 你在一个独立的执行上下文中,没有访问主对话历史的权限。", + session_description, additional_prompt + ) + } + + /// 构建带上下文摘要的探索提示词 + /// 预留功能:用于 Explore 类型子代理继承主代理上下文 + #[allow(dead_code)] + pub fn build_explore_prompt_with_context( + description: &str, + parent_history: &[ChatMessage], + ) -> String { + let context_summary = Self::extract_context_summary(parent_history); + format!( + "你是一个只读探索代理,用于代码库探索和信息收集。\n\n\ + 任务: {}\n\n\ + 主对话上下文摘要:\n{}\n\n\ + 你应该:\n\ + 1. 只使用只读工具进行探索(file_read, bash 的只读命令)\n\ + 2. 专注于理解和收集信息\n\ + 3. 不要进行任何写操作\n\ + 4. 给出简洁的发现总结\n\n\ + 注意: 你是一个只读代理,禁止执行任何修改操作。", + description, context_summary + ) + } + + fn build_general_prompt(description: &str) -> String { + format!( + "你是一个专注的子代理,正在执行一个独立任务。\n\n\ + 任务描述: {}\n\n\ + 你应该:\n\ + 1. 专注于完成任务,不要偏离目标\n\ + 2. 使用可用的工具进行必要操作\n\ + 3. 完成后给出简洁的总结\n\ + 4. 不要尝试创建新的子代理任务\n\n\ + 注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。", + description + ) + } + + fn build_explore_prompt(description: &str) -> String { + format!( + "你是一个只读探索代理,用于代码库探索和信息收集。\n\n\ + 任务描述: {}\n\n\ + 你应该:\n\ + 1. 只使用只读工具进行探索\n\ + 2. 专注于理解和收集信息\n\ + 3. 不要进行任何写操作\n\ + 4. 给出简洁的发现总结\n\n\ + 注意: 你是一个只读代理,禁止执行任何修改操作。", + description + ) + } + + /// 提取最近用户消息的上下文摘要 + /// 预留功能:用于 Explore 类型子代理继承主代理上下文 + #[allow(dead_code)] + fn extract_context_summary(history: &[ChatMessage]) -> String { + history + .iter() + .filter(|m| m.role == "user") + .rev() + .take(5) + .map(|m| { + let content = &m.content; + if content.len() > 100 { + format!("- {}", content.chars().take(100).collect::()) + } else { + format!("- {}", content) + } + }) + .collect::>() + .join("\n") + } +} + +/// 从子代理输出提取简洁摘要 +pub fn extract_summary(content: &str) -> String { + // 取第一段或前 500 字符 + let first_paragraph = content + .lines() + .take_while(|line| !line.trim().is_empty()) + .collect::>() + .join("\n"); + + if first_paragraph.len() > 500 { + first_paragraph.chars().take(500).collect() + } else if first_paragraph.is_empty() { + content.chars().take(200).collect() + } else { + first_paragraph + } +} \ No newline at end of file diff --git a/src/tools/task/repository.rs b/src/tools/task/repository.rs new file mode 100644 index 0000000..e59a5a5 --- /dev/null +++ b/src/tools/task/repository.rs @@ -0,0 +1,98 @@ +use std::collections::HashMap; +use std::sync::RwLock; + +use async_trait::async_trait; + +use crate::storage::StorageError; + +use super::types::TaskSession; + +/// 任务持久化接口 +#[async_trait] +pub trait TaskRepository: Send + Sync + 'static { + /// 保存任务会话 + async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError>; + + /// 加载任务会话 + async fn load_task_session(&self, task_id: &str) -> Result, StorageError>; + + /// 删除任务会话 + async fn delete_task_session(&self, task_id: &str) -> Result; + + /// 列出父会话的所有任务 + async fn list_tasks_for_session( + &self, + parent_session_id: &str, + ) -> Result, StorageError>; + + /// 清理过期任务(超过指定小时) + async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result; +} + +/// 内存实现(用于测试) +pub struct InMemoryTaskRepository { + sessions: RwLock>, +} + +impl InMemoryTaskRepository { + pub fn new() -> Self { + Self { + sessions: RwLock::new(HashMap::new()), + } + } +} + +impl Default for InMemoryTaskRepository { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TaskRepository for InMemoryTaskRepository { + async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError> { + self.sessions + .write() + .unwrap() + .insert(session.id.clone(), session.clone()); + Ok(()) + } + + async fn load_task_session(&self, task_id: &str) -> Result, StorageError> { + Ok(self.sessions.read().unwrap().get(task_id).cloned()) + } + + async fn delete_task_session(&self, task_id: &str) -> Result { + Ok(self.sessions.write().unwrap().remove(task_id).is_some()) + } + + async fn list_tasks_for_session( + &self, + parent_session_id: &str, + ) -> Result, StorageError> { + Ok(self + .sessions + .read() + .unwrap() + .values() + .filter(|s| s.parent_session_id == parent_session_id) + .cloned() + .collect()) + } + + async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result { + let now = current_timestamp(); + let ttl_millis = ttl_hours * 3600 * 1000; + let mut sessions = self.sessions.write().unwrap(); + let before = sessions.len(); + sessions.retain(|_, s| now - s.updated_at < ttl_millis as i64); + Ok(before - sessions.len()) + } +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_millis() as i64 +} \ No newline at end of file diff --git a/src/tools/task/runtime.rs b/src/tools/task/runtime.rs new file mode 100644 index 0000000..b075cca --- /dev/null +++ b/src/tools/task/runtime.rs @@ -0,0 +1,370 @@ +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; + +use crate::agent::{AgentLoop, SystemPromptContext, SystemPromptProvider}; +use crate::bus::ChatMessage; +use crate::config::LLMProviderConfig; +use crate::tools::{ToolContext, ToolRegistry}; + +use super::error::TaskError; +use super::prompt::{extract_summary, SubagentPromptBuilder}; +use super::repository::TaskRepository; +use super::types::{SubagentType, TaskSession, TaskToolResult}; + +/// 子代理运行时配置 +#[derive(Debug, Clone)] +pub struct TaskRuntimeConfig { + /// 子代理可用的工具列表(白名单) + pub allowed_tools: HashSet, + /// 最大执行时间(秒) + pub max_execution_secs: u64, + /// 探索类型的最大工具调用次数 + pub explore_max_tool_calls: usize, + /// 任务 TTL(小时) + pub ttl_hours: u64, +} + +impl Default for TaskRuntimeConfig { + fn default() -> Self { + Self { + allowed_tools: HashSet::from([ + "file_read".to_string(), + "file_edit".to_string(), + "file_write".to_string(), + "bash".to_string(), + "http_request".to_string(), + "web_fetch".to_string(), + "memory_search".to_string(), + "get_time".to_string(), + "calculator".to_string(), + "skill_activate".to_string(), + "skill_list".to_string(), + ]), + max_execution_secs: 300, // 5分钟 + explore_max_tool_calls: 20, + ttl_hours: 24, + } + } +} + +/// 静态系统提示词提供者 +pub struct StaticSystemPromptProvider { + prompt: String, +} + +impl StaticSystemPromptProvider { + pub fn new(prompt: String) -> Self { + Self { prompt } + } +} + +impl SystemPromptProvider for StaticSystemPromptProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + Some(crate::agent::SystemPrompt { + content: self.prompt.clone(), + context: Some("subagent".to_string()), + }) + } +} + +/// 子代理运行时管理器 +pub struct TaskRuntime { + config: TaskRuntimeConfig, + repository: Arc, + tools: Arc, + provider_config: LLMProviderConfig, +} + +impl TaskRuntime { + pub fn new( + config: TaskRuntimeConfig, + repository: Arc, + tools: Arc, + provider_config: LLMProviderConfig, + ) -> Self { + Self { + config, + repository, + tools, + provider_config, + } + } + + /// 获取子代理工具注册表(排除 task 工具防止递归) + fn get_subagent_tools(&self) -> Arc { + // 创建一个新的工具注册表,只包含允许的工具 + // 这里简化处理,直接使用传入的 tools(假设 task 工具不会注册进去) + // 实际实现中需要过滤 allowed_tools + self.tools.clone() + } + + /// 创建新的子代理会话并执行 + pub async fn spawn( + &self, + parent_context: &ToolContext, + description: String, + prompt: String, + subagent_type: SubagentType, + ) -> Result { + // 1. 验证上下文 + let session_id = parent_context + .session_id + .clone() + .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; + let chat_id = parent_context + .chat_id + .clone() + .ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?; + let channel_name = parent_context + .channel_name + .clone() + .ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?; + + // 2. 创建任务会话 + let session = TaskSession::new( + session_id, + chat_id, + channel_name, + description.clone(), + subagent_type, + ); + + // 3. 保存会话 + self.repository.save_task_session(&session).await?; + + // 4. 构建子代理系统提示词 + let system_prompt = SubagentPromptBuilder::build(subagent_type, &description, &prompt); + + // 5. 创建子代理 + let agent = self.create_subagent(&session, system_prompt)?; + + // 6. 执行任务 + let result = self.execute_task(agent, &session, prompt).await; + + // 7. 更新会话状态并保存 + match result { + Ok(tool_result) => { + let mut session = session; + session.mark_completed(tool_result.summary.clone()); + self.repository.save_task_session(&session).await?; + Ok(tool_result) + } + Err(e) => { + let mut session = session; + let status = e.as_status(); + if status == "timeout" { + session.mark_timeout(); + } else { + session.mark_failed(e.to_string()); + } + self.repository.save_task_session(&session).await?; + Err(e) + } + } + } + + /// 恢复现有任务会话 + pub async fn resume( + &self, + task_id: &str, + parent_context: &ToolContext, + additional_prompt: String, + ) -> Result { + // 1. 加载现有会话 + let session = self + .repository + .load_task_session(task_id) + .await? + .ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?; + + // 2. 验证父会话匹配 + let parent_session_id = parent_context + .session_id + .clone() + .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; + if session.parent_session_id != parent_session_id { + return Err(TaskError::InvalidParentSession); + } + + // 3. 构建恢复提示词 + let system_prompt = SubagentPromptBuilder::build_resume_prompt( + &session.description, + &additional_prompt, + ); + + // 4. 创建子代理 + let agent = self.create_subagent(&session, system_prompt)?; + + // 5. 使用历史继续执行 + let result = self + .execute_task_with_history(agent, &session, additional_prompt) + .await; + + // 6. 更新会话状态 + match result { + Ok(tool_result) => { + let mut session = session; + session.mark_completed(tool_result.summary.clone()); + self.repository.save_task_session(&session).await?; + Ok(tool_result) + } + Err(e) => { + let mut session = session; + session.mark_failed(e.to_string()); + self.repository.save_task_session(&session).await?; + Err(e) + } + } + } + + /// 创建子代理实例 + fn create_subagent( + &self, + session: &TaskSession, + system_prompt: String, + ) -> Result { + let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt)); + + // 获取子代理工具注册表 + let subagent_tools = self.get_subagent_tools(); + + // 直接创建 AgentLoop,使用自定义的提示词提供者 + AgentLoop::with_tools_and_system_prompt_provider( + self.provider_config.clone(), + subagent_tools, + prompt_provider, + None, // 子代理不需要 skill provider + ) + .map(|agent| { + agent.with_tool_context(ToolContext { + channel_name: Some(session.parent_channel_name.clone()), + sender_id: None, + chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id,发送到飞书用户 + session_id: Some(session.id.clone()), + message_id: None, + message_seq: None, + subagent_description: Some(session.description.clone()), // 子代理标识 + }) + }) + .map_err(|e| TaskError::AgentCreationFailed(e.to_string())) + } + + /// 执行任务(带超时控制) + async fn execute_task( + &self, + agent: AgentLoop, + session: &TaskSession, + prompt: String, + ) -> Result { + // 构建初始消息 + let history = vec![ChatMessage::user(prompt)]; + let system_prompt_context = SystemPromptContext { + session_id: Some(session.id.clone()), + chat_id: session.id.clone(), + user_message_count: 1, + }; + + // 设置超时 + let timeout_duration = Duration::from_secs(self.config.max_execution_secs); + + let result = tokio::time::timeout( + timeout_duration, + agent.process(history, Some(&system_prompt_context)), + ) + .await; + + match result { + Ok(Ok(process_result)) => { + let final_message = process_result.final_response; + Ok(TaskToolResult { + status: "success".to_string(), + summary: extract_summary(&final_message.content), + output: final_message.content, + task_id: session.id.clone(), + }) + } + Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), + Err(_) => Err(TaskError::Timeout), + } + } + + /// 使用历史继续执行 + async fn execute_task_with_history( + &self, + agent: AgentLoop, + session: &TaskSession, + additional_prompt: String, + ) -> Result { + // 构建历史 + 新消息 + let mut history = session.history.clone(); + history.push(ChatMessage::user(additional_prompt)); + + let user_message_count = history.iter().filter(|m| m.role == "user").count(); + let system_prompt_context = SystemPromptContext { + session_id: Some(session.id.clone()), + chat_id: session.id.clone(), + user_message_count, + }; + + let timeout_duration = Duration::from_secs(self.config.max_execution_secs); + + let result = tokio::time::timeout( + timeout_duration, + agent.process(history, Some(&system_prompt_context)), + ) + .await; + + match result { + Ok(Ok(process_result)) => { + let final_message = process_result.final_response; + Ok(TaskToolResult { + status: "success".to_string(), + summary: extract_summary(&final_message.content), + output: final_message.content, + task_id: session.id.clone(), + }) + } + Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), + Err(_) => Err(TaskError::Timeout), + } + } + + /// 清理过期任务 + pub async fn cleanup_expired(&self) -> Result { + self.repository + .cleanup_expired_tasks(self.config.ttl_hours) + .await + .map_err(TaskError::from) + } + + /// 创建用于测试的实例 + #[cfg(test)] + pub fn new_for_test() -> Self { + use super::repository::InMemoryTaskRepository; + use std::collections::HashMap; + + Self { + config: TaskRuntimeConfig::default(), + repository: Arc::new(InMemoryTaskRepository::new()), + tools: Arc::new(ToolRegistry::new()), + provider_config: LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test".to_string(), + base_url: "https://test.local/v1".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, + model_id: "test-model".to_string(), + temperature: None, + max_tokens: None, + context_window_tokens: None, + model_extra: HashMap::new(), + max_tool_iterations: 100, + tool_result_max_chars: 20_000, + context_tool_result_trim_chars: 2_000, + }, + } + } +} \ No newline at end of file diff --git a/src/tools/task/tool.rs b/src/tools/task/tool.rs new file mode 100644 index 0000000..3739006 --- /dev/null +++ b/src/tools/task/tool.rs @@ -0,0 +1,170 @@ +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +use crate::tools::{Tool, ToolContext, ToolResult}; + +use super::runtime::TaskRuntime; +use super::types::TaskToolArgs; + +/// Task 工具 - 创建和管理子代理 +pub struct TaskTool { + runtime: Arc, +} + +impl TaskTool { + pub fn new(runtime: Arc) -> Self { + Self { runtime } + } +} + +#[async_trait] +impl Tool for TaskTool { + fn name(&self) -> &str { + "task" + } + + fn description(&self) -> &str { + "Launch a specialized subagent to handle complex, multi-step tasks. \ + Subagents run in isolated contexts and can work in parallel. \ + Use 'general' type for complex tasks, 'explore' type for read-only exploration. \ + You can resume a previous task by providing its task_id." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Short description (3-5 words) of what this task does", + "maxLength": 50 + }, + "prompt": { + "type": "string", + "description": "Detailed instructions for the subagent to execute" + }, + "subagent_type": { + "type": "string", + "enum": ["general", "explore"], + "default": "general", + "description": "Type of subagent: 'general' for complex multi-step tasks, 'explore' for read-only search/exploration" + }, + "task_id": { + "type": "string", + "description": "Optional: Resume an existing task session by providing its task_id" + } + }, + "required": ["description", "prompt"] + }) + } + + fn read_only(&self) -> bool { + false + } + + fn exclusive(&self) -> bool { + // Task 工具创建子代理,不应与其他工具并发执行 + true + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + // Task 工具必须通过 execute_with_context 获取父会话信息 + Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "task tool requires tool context with session_id, chat_id, and channel_name" + .to_string(), + ), + }) + } + + async fn execute_with_context( + &self, + context: &ToolContext, + args: serde_json::Value, + ) -> anyhow::Result { + // 1. 解析参数 + let task_args: TaskToolArgs = serde_json::from_value(args.clone()) + .map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?; + + // 2. 验证描述长度 + let word_count = task_args.description.split_whitespace().count(); + if task_args.description.len() > 50 || word_count > 7 || word_count < 1 { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "description should be 1-5 words, max 50 characters".to_string(), + ), + }); + } + + // 3. 验证上下文 + if context.session_id.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("task tool requires session_id in context".to_string()), + }); + } + if context.chat_id.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("task tool requires chat_id in context".to_string()), + }); + } + if context.channel_name.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("task tool requires channel_name in context".to_string()), + }); + } + + // 4. 执行任务 + let result = if let Some(task_id) = task_args.task_id { + // 恢复现有任务 + self.runtime + .resume(&task_id, context, task_args.prompt) + .await + } else { + // 创建新任务 + self.runtime + .spawn( + context, + task_args.description, + task_args.prompt, + task_args.subagent_type, + ) + .await + }; + + // 5. 构建返回结果 + match result { + Ok(task_result) => Ok(ToolResult { + success: true, + output: serde_json::to_string(&task_result)?, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_task_tool_name_and_description() { + // 简单验证工具名称 + assert!(!TaskTool::name(&TaskTool::new(Arc::new(TaskRuntime::new_for_test()))).is_empty()); + } +} \ No newline at end of file diff --git a/src/tools/task/types.rs b/src/tools/task/types.rs new file mode 100644 index 0000000..82c9b9f --- /dev/null +++ b/src/tools/task/types.rs @@ -0,0 +1,169 @@ +use serde::{Deserialize, Serialize}; + +use crate::bus::ChatMessage; + +/// 子代理会话状态 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum TaskSessionState { + /// 正在执行 + Running, + /// 已完成 + Completed, + /// 已失败 + Failed, + /// 已超时 + Timeout, +} + +impl Default for TaskSessionState { + fn default() -> Self { + Self::Running + } +} + +/// 子代理类型 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SubagentType { + /// 通用型 - 处理复杂多步骤任务 + #[default] + General, + /// 探索型 - 只读搜索代理 + Explore, +} + +impl SubagentType { + pub fn as_str(&self) -> &'static str { + match self { + Self::General => "general", + Self::Explore => "explore", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "general" => Some(Self::General), + "explore" => Some(Self::Explore), + _ => None, + } + } +} + +/// 任务会话记录 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskSession { + /// 任务唯一 ID (UUID) + pub id: String, + /// 父会话 ID (用于关联) + pub parent_session_id: String, + /// 父 chat_id + pub parent_chat_id: String, + /// 父 channel_name + pub parent_channel_name: String, + /// 任务描述 + pub description: String, + /// 子代理类型 + pub subagent_type: SubagentType, + /// 当前状态 + pub state: TaskSessionState, + /// 会话历史(子代理的对话) + #[serde(default)] + pub history: Vec, + /// 创建时间 + pub created_at: i64, + /// 最后更新时间 + pub updated_at: i64, + /// 执行摘要 + pub summary: Option, + /// 错误信息 + pub error: Option, +} + +impl TaskSession { + pub fn new( + parent_session_id: String, + parent_chat_id: String, + parent_channel_name: String, + description: String, + subagent_type: SubagentType, + ) -> Self { + let now = current_timestamp(); + Self { + id: format!("task:{}", uuid::Uuid::new_v4()), + parent_session_id, + parent_chat_id, + parent_channel_name, + description, + subagent_type, + state: TaskSessionState::Running, + history: Vec::new(), + created_at: now, + updated_at: now, + summary: None, + error: None, + } + } + + /// 添加消息到历史 + pub fn add_message(&mut self, message: ChatMessage) { + self.history.push(message); + self.updated_at = current_timestamp(); + } + + /// 标记完成 + pub fn mark_completed(&mut self, summary: String) { + self.state = TaskSessionState::Completed; + self.summary = Some(summary); + self.updated_at = current_timestamp(); + } + + /// 标记失败 + pub fn mark_failed(&mut self, error: String) { + self.state = TaskSessionState::Failed; + self.error = Some(error); + self.updated_at = current_timestamp(); + } + + /// 标记超时 + pub fn mark_timeout(&mut self) { + self.state = TaskSessionState::Timeout; + self.error = Some("Task execution timed out".to_string()); + self.updated_at = current_timestamp(); + } +} + +/// 任务工具参数 +#[derive(Debug, Clone, Deserialize)] +pub struct TaskToolArgs { + /// 简短描述(3-5词) + pub description: String, + /// 详细指令 + pub prompt: String, + /// 子代理类型 + #[serde(default)] + pub subagent_type: SubagentType, + /// 恢复现有会话的 task_id + #[serde(default)] + pub task_id: Option, +} + +/// 任务执行结果 +#[derive(Debug, Clone, Serialize)] +pub struct TaskToolResult { + /// 状态: success/failed/timeout + pub status: String, + /// 任务完成总结 + pub summary: String, + /// 详细输出 + pub output: String, + /// 会话 ID(用于恢复) + pub task_id: String, +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_millis() as i64 +} \ No newline at end of file diff --git a/src/tools/traits.rs b/src/tools/traits.rs index b00ca14..ce36f27 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -15,6 +15,8 @@ pub struct ToolContext { pub session_id: Option, pub message_id: Option, pub message_seq: Option, + /// 子代理标识,用于标注消息来源 + pub subagent_description: Option, } #[async_trait]