diff --git a/src/config/mod.rs b/src/config/mod.rs index 2d71c96..0bff93c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -104,67 +104,11 @@ impl Default for SkillsConfig { pub struct ToolsConfig { #[serde(default)] pub disabled: Vec, - #[serde(default)] - pub task: TaskConfig, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct TaskConfig { - #[serde(default = "default_task_enabled")] - pub enabled: bool, - #[serde(default = "default_task_max_execution_secs")] - pub max_execution_secs: u64, - #[serde(default = "default_task_ttl_hours")] - pub ttl_hours: u64, - #[serde(default = "default_task_allowed_tools")] - pub allowed_tools: Vec, -} - -fn default_task_enabled() -> bool { - true -} - -fn default_task_max_execution_secs() -> u64 { - 300 -} - -fn default_task_ttl_hours() -> u64 { - 24 -} - -fn default_task_allowed_tools() -> Vec { - vec![ - "file_read".to_string(), - "file_edit".to_string(), - "file_write".to_string(), - "bash".to_string(), - "http_request".to_string(), - "web_fetch".to_string(), - "memory_search".to_string(), - "get_time".to_string(), - "calculator".to_string(), - "skill_activate".to_string(), - "skill_list".to_string(), - ] } impl Default for ToolsConfig { fn default() -> Self { - Self { - disabled: Vec::new(), - task: TaskConfig::default(), - } - } -} - -impl Default for TaskConfig { - fn default() -> Self { - Self { - enabled: default_task_enabled(), - max_execution_secs: default_task_max_execution_secs(), - ttl_hours: default_task_ttl_hours(), - allowed_tools: default_task_allowed_tools(), - } + Self { disabled: Vec::new() } } } diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 2eaa1eb..6664af1 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -9,7 +9,7 @@ use crate::storage::{ ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, SessionStore, SkillEventRepository, }; -use crate::tools::{InMemoryTaskRepository, NoopSessionMessageSender, SessionMessageSender, ToolRegistry}; +use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry}; use super::agent_factory::AgentFactory; use super::cli_session::CliSessionService; @@ -74,10 +74,6 @@ pub(crate) fn build_session_manager_with_sender( let memories: Arc = store.clone(); let scheduler_jobs: Arc = store.clone(); let skill_events: Arc = store.clone(); - - // TaskRepository 使用内存存储(后续可以改为 SQLite) - let task_repository = Arc::new(InMemoryTaskRepository::new()); - let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), @@ -89,7 +85,6 @@ pub(crate) fn build_session_manager_with_sender( default_timezone, disabled_tools, ) - .with_task_deps(task_repository, provider_config.clone()) .build(), ); diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs index 6674f6d..8bbb047 100644 --- a/src/gateway/tool_registry_factory.rs +++ b/src/gateway/tool_registry_factory.rs @@ -1,14 +1,13 @@ use std::collections::HashSet; use std::sync::Arc; -use crate::config::LLMProviderConfig; use crate::skills::SkillRuntime; use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender, - SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TaskRuntime, - TaskTool, TimeTool, ToolRegistry, WebFetchTool, + SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TimeTool, ToolRegistry, + WebFetchTool, }; pub(crate) struct ToolRegistryFactory { @@ -20,9 +19,6 @@ pub(crate) struct ToolRegistryFactory { known_agents: HashSet, default_timezone: String, disabled_tools: HashSet, - // Task 工具需要的依赖 - task_repository: Option>, - provider_config: Option, } impl ToolRegistryFactory { @@ -45,22 +41,9 @@ impl ToolRegistryFactory { known_agents, default_timezone, disabled_tools, - task_repository: None, - provider_config: None, } } - /// 设置 Task 工具的依赖 - pub(crate) fn with_task_deps( - mut self, - task_repository: Arc, - provider_config: LLMProviderConfig, - ) -> Self { - self.task_repository = Some(task_repository); - self.provider_config = Some(provider_config); - self - } - fn is_enabled(&self, tool_name: &str) -> bool { !self.disabled_tools.contains(tool_name) } @@ -125,89 +108,6 @@ impl ToolRegistryFactory { registry.register(WebFetchTool::new(50_000, 30)); } - // 注册 Task 工具(需要 TaskRuntime) - if self.is_enabled("task") { - if let (Some(task_repository), Some(provider_config)) = - (&self.task_repository, &self.provider_config) - { - // 先创建一个临时的 ToolRegistry(不含 task 工具)用于子代理 - let subagent_tools = Arc::new(self.build_without_task()); - let task_runtime = Arc::new(TaskRuntime::new( - crate::tools::TaskRuntimeConfig::default(), - task_repository.clone(), - subagent_tools, - provider_config.clone(), - )); - registry.register(TaskTool::new(task_runtime)); - } - } - - registry - } - - /// 构建不含 task 工具的注册表(用于子代理) - fn build_without_task(&self) -> ToolRegistry { - let mut registry = ToolRegistry::new(); - - if self.is_enabled("calculator") { - registry.register(CalculatorTool::new()); - } - if self.is_enabled("get_time") { - registry.register(TimeTool::new(self.default_timezone.clone())); - } - if self.is_enabled("file_read") { - registry.register(FileReadTool::new()); - } - if self.is_enabled("file_write") { - registry.register(FileWriteTool::new()); - } - if self.is_enabled("file_edit") { - registry.register(FileEditTool::new()); - } - if self.is_enabled("memory_search") { - registry.register(MemorySearchTool::new(self.memories.clone())); - } - if self.is_enabled("memory_manage") { - registry.register(MemoryManageTool::new(self.memories.clone())); - } - if self.is_enabled("session_send") { - registry.register(SessionSendTool::new(self.session_message_sender.clone())); - } - if self.is_enabled("scheduler_manage") { - registry.register(SchedulerManageTool::new( - self.scheduler_jobs.clone(), - self.known_agents.clone(), - )); - } - if self.is_enabled("skill_activate") { - registry.register(SkillActivateTool::new( - self.skills.clone(), - self.skill_events.clone(), - )); - } - if self.is_enabled("skill_list") { - registry.register(SkillListTool::new(self.skills.clone())); - } - if self.is_enabled("skill_manage") { - registry.register(SkillManageTool::new(self.skills.clone())); - } - if self.is_enabled("bash") { - registry.register(BashTool::new()); - } - if self.is_enabled("http_request") { - registry.register(HttpRequestTool::new( - vec!["*".to_string()], - 1_000_000, - 30, - false, - )); - } - if self.is_enabled("web_fetch") { - registry.register(WebFetchTool::new(50_000, 30)); - } - - // 注意:不注册 task 工具,防止子代理递归创建子代理 - registry } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 8b6f060..e78a313 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -12,7 +12,6 @@ pub mod session_send; pub mod schema; pub mod skill_activate; pub mod skill_manage; -pub mod task; pub mod time; pub mod traits; pub mod web_fetch; @@ -34,9 +33,6 @@ pub use session_send::{ pub use schema::{CleaningStrategy, SchemaCleanr}; pub use skill_activate::SkillActivateTool; pub use skill_manage::{SkillListTool, SkillManageTool}; -pub use task::TaskTool; pub use time::TimeTool; pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; - -pub(crate) use task::{TaskRepository, TaskRuntime, TaskRuntimeConfig, InMemoryTaskRepository}; diff --git a/src/tools/task/error.rs b/src/tools/task/error.rs deleted file mode 100644 index fc5bcf9..0000000 --- a/src/tools/task/error.rs +++ /dev/null @@ -1,48 +0,0 @@ -use crate::storage::StorageError; - -/// 任务错误类型 -#[derive(Debug, thiserror::Error)] -pub enum TaskError { - #[error("Task session not found: {0}")] - SessionNotFound(String), - - #[error("Invalid parent session")] - InvalidParentSession, - - #[error("Failed to create subagent: {0}")] - AgentCreationFailed(String), - - #[error("Execution failed: {0}")] - ExecutionFailed(String), - - #[error("Task execution timed out")] - Timeout, - - #[error("Repository error: {0}")] - RepositoryError(#[from] StorageError), - - #[error("Serialization error: {0}")] - SerializationError(#[from] serde_json::Error), - - #[error("Missing required context: {0}")] - MissingContext(String), - - #[error("Invalid arguments: {0}")] - InvalidArguments(String), -} - -impl TaskError { - pub fn as_status(&self) -> &'static str { - match self { - Self::Timeout => "timeout", - Self::SessionNotFound(_) => "failed", - Self::InvalidParentSession => "failed", - Self::AgentCreationFailed(_) => "failed", - Self::ExecutionFailed(_) => "failed", - Self::RepositoryError(_) => "failed", - Self::SerializationError(_) => "failed", - Self::MissingContext(_) => "failed", - Self::InvalidArguments(_) => "failed", - } - } -} \ No newline at end of file diff --git a/src/tools/task/mod.rs b/src/tools/task/mod.rs deleted file mode 100644 index 0309883..0000000 --- a/src/tools/task/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod error; -mod prompt; -mod repository; -mod runtime; -mod tool; -mod types; - -pub(crate) use repository::TaskRepository; -pub(crate) use repository::InMemoryTaskRepository; -pub(crate) use runtime::{TaskRuntime, TaskRuntimeConfig}; -pub use tool::TaskTool; \ No newline at end of file diff --git a/src/tools/task/prompt.rs b/src/tools/task/prompt.rs deleted file mode 100644 index 65b208d..0000000 --- a/src/tools/task/prompt.rs +++ /dev/null @@ -1,125 +0,0 @@ -use crate::bus::ChatMessage; - -use super::types::SubagentType; - -/// 子代理系统提示词构建器 -pub struct SubagentPromptBuilder; - -impl SubagentPromptBuilder { - /// 构建子代理系统提示词 - pub fn build( - subagent_type: SubagentType, - description: &str, - _prompt: &str, - ) -> String { - match subagent_type { - SubagentType::General => Self::build_general_prompt(description), - SubagentType::Explore => Self::build_explore_prompt(description), - } - } - - /// 构建恢复任务的提示词 - pub fn build_resume_prompt(session_description: &str, additional_prompt: &str) -> String { - format!( - "你正在继续执行一个之前创建的子代理任务。\n\n\ - 任务描述: {}\n\n\ - 继续执行指令: {}\n\n\ - 你应该:\n\ - 1. 回顾之前的工作进度(如果已有历史)\n\ - 2. 继续完成任务,不要偏离目标\n\ - 3. 完成后给出简洁的总结\n\ - 4. 不要尝试创建新的子代理任务\n\n\ - 注意: 你在一个独立的执行上下文中,没有访问主对话历史的权限。", - session_description, additional_prompt - ) - } - - /// 构建带上下文摘要的探索提示词 - /// 预留功能:用于 Explore 类型子代理继承主代理上下文 - #[allow(dead_code)] - pub fn build_explore_prompt_with_context( - description: &str, - parent_history: &[ChatMessage], - ) -> String { - let context_summary = Self::extract_context_summary(parent_history); - format!( - "你是一个只读探索代理,用于代码库探索和信息收集。\n\n\ - 任务: {}\n\n\ - 主对话上下文摘要:\n{}\n\n\ - 你应该:\n\ - 1. 只使用只读工具进行探索(file_read, bash 的只读命令)\n\ - 2. 专注于理解和收集信息\n\ - 3. 不要进行任何写操作\n\ - 4. 给出简洁的发现总结\n\n\ - 注意: 你是一个只读代理,禁止执行任何修改操作。", - description, context_summary - ) - } - - fn build_general_prompt(description: &str) -> String { - format!( - "你是一个专注的子代理,正在执行一个独立任务。\n\n\ - 任务描述: {}\n\n\ - 你应该:\n\ - 1. 专注于完成任务,不要偏离目标\n\ - 2. 使用可用的工具进行必要操作\n\ - 3. 完成后给出简洁的总结\n\ - 4. 不要尝试创建新的子代理任务\n\n\ - 注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。", - description - ) - } - - fn build_explore_prompt(description: &str) -> String { - format!( - "你是一个只读探索代理,用于代码库探索和信息收集。\n\n\ - 任务描述: {}\n\n\ - 你应该:\n\ - 1. 只使用只读工具进行探索\n\ - 2. 专注于理解和收集信息\n\ - 3. 不要进行任何写操作\n\ - 4. 给出简洁的发现总结\n\n\ - 注意: 你是一个只读代理,禁止执行任何修改操作。", - description - ) - } - - /// 提取最近用户消息的上下文摘要 - /// 预留功能:用于 Explore 类型子代理继承主代理上下文 - #[allow(dead_code)] - fn extract_context_summary(history: &[ChatMessage]) -> String { - history - .iter() - .filter(|m| m.role == "user") - .rev() - .take(5) - .map(|m| { - let content = &m.content; - if content.len() > 100 { - format!("- {}", content.chars().take(100).collect::()) - } else { - format!("- {}", content) - } - }) - .collect::>() - .join("\n") - } -} - -/// 从子代理输出提取简洁摘要 -pub fn extract_summary(content: &str) -> String { - // 取第一段或前 500 字符 - let first_paragraph = content - .lines() - .take_while(|line| !line.trim().is_empty()) - .collect::>() - .join("\n"); - - if first_paragraph.len() > 500 { - first_paragraph.chars().take(500).collect() - } else if first_paragraph.is_empty() { - content.chars().take(200).collect() - } else { - first_paragraph - } -} \ No newline at end of file diff --git a/src/tools/task/repository.rs b/src/tools/task/repository.rs deleted file mode 100644 index e59a5a5..0000000 --- a/src/tools/task/repository.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::collections::HashMap; -use std::sync::RwLock; - -use async_trait::async_trait; - -use crate::storage::StorageError; - -use super::types::TaskSession; - -/// 任务持久化接口 -#[async_trait] -pub trait TaskRepository: Send + Sync + 'static { - /// 保存任务会话 - async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError>; - - /// 加载任务会话 - async fn load_task_session(&self, task_id: &str) -> Result, StorageError>; - - /// 删除任务会话 - async fn delete_task_session(&self, task_id: &str) -> Result; - - /// 列出父会话的所有任务 - async fn list_tasks_for_session( - &self, - parent_session_id: &str, - ) -> Result, StorageError>; - - /// 清理过期任务(超过指定小时) - async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result; -} - -/// 内存实现(用于测试) -pub struct InMemoryTaskRepository { - sessions: RwLock>, -} - -impl InMemoryTaskRepository { - pub fn new() -> Self { - Self { - sessions: RwLock::new(HashMap::new()), - } - } -} - -impl Default for InMemoryTaskRepository { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl TaskRepository for InMemoryTaskRepository { - async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError> { - self.sessions - .write() - .unwrap() - .insert(session.id.clone(), session.clone()); - Ok(()) - } - - async fn load_task_session(&self, task_id: &str) -> Result, StorageError> { - Ok(self.sessions.read().unwrap().get(task_id).cloned()) - } - - async fn delete_task_session(&self, task_id: &str) -> Result { - Ok(self.sessions.write().unwrap().remove(task_id).is_some()) - } - - async fn list_tasks_for_session( - &self, - parent_session_id: &str, - ) -> Result, StorageError> { - Ok(self - .sessions - .read() - .unwrap() - .values() - .filter(|s| s.parent_session_id == parent_session_id) - .cloned() - .collect()) - } - - async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result { - let now = current_timestamp(); - let ttl_millis = ttl_hours * 3600 * 1000; - let mut sessions = self.sessions.write().unwrap(); - let before = sessions.len(); - sessions.retain(|_, s| now - s.updated_at < ttl_millis as i64); - Ok(before - sessions.len()) - } -} - -fn current_timestamp() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("system clock before unix epoch") - .as_millis() as i64 -} \ No newline at end of file diff --git a/src/tools/task/runtime.rs b/src/tools/task/runtime.rs deleted file mode 100644 index b075cca..0000000 --- a/src/tools/task/runtime.rs +++ /dev/null @@ -1,370 +0,0 @@ -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Duration; - -use crate::agent::{AgentLoop, SystemPromptContext, SystemPromptProvider}; -use crate::bus::ChatMessage; -use crate::config::LLMProviderConfig; -use crate::tools::{ToolContext, ToolRegistry}; - -use super::error::TaskError; -use super::prompt::{extract_summary, SubagentPromptBuilder}; -use super::repository::TaskRepository; -use super::types::{SubagentType, TaskSession, TaskToolResult}; - -/// 子代理运行时配置 -#[derive(Debug, Clone)] -pub struct TaskRuntimeConfig { - /// 子代理可用的工具列表(白名单) - pub allowed_tools: HashSet, - /// 最大执行时间(秒) - pub max_execution_secs: u64, - /// 探索类型的最大工具调用次数 - pub explore_max_tool_calls: usize, - /// 任务 TTL(小时) - pub ttl_hours: u64, -} - -impl Default for TaskRuntimeConfig { - fn default() -> Self { - Self { - allowed_tools: HashSet::from([ - "file_read".to_string(), - "file_edit".to_string(), - "file_write".to_string(), - "bash".to_string(), - "http_request".to_string(), - "web_fetch".to_string(), - "memory_search".to_string(), - "get_time".to_string(), - "calculator".to_string(), - "skill_activate".to_string(), - "skill_list".to_string(), - ]), - max_execution_secs: 300, // 5分钟 - explore_max_tool_calls: 20, - ttl_hours: 24, - } - } -} - -/// 静态系统提示词提供者 -pub struct StaticSystemPromptProvider { - prompt: String, -} - -impl StaticSystemPromptProvider { - pub fn new(prompt: String) -> Self { - Self { prompt } - } -} - -impl SystemPromptProvider for StaticSystemPromptProvider { - fn build(&self, _context: &SystemPromptContext) -> Option { - Some(crate::agent::SystemPrompt { - content: self.prompt.clone(), - context: Some("subagent".to_string()), - }) - } -} - -/// 子代理运行时管理器 -pub struct TaskRuntime { - config: TaskRuntimeConfig, - repository: Arc, - tools: Arc, - provider_config: LLMProviderConfig, -} - -impl TaskRuntime { - pub fn new( - config: TaskRuntimeConfig, - repository: Arc, - tools: Arc, - provider_config: LLMProviderConfig, - ) -> Self { - Self { - config, - repository, - tools, - provider_config, - } - } - - /// 获取子代理工具注册表(排除 task 工具防止递归) - fn get_subagent_tools(&self) -> Arc { - // 创建一个新的工具注册表,只包含允许的工具 - // 这里简化处理,直接使用传入的 tools(假设 task 工具不会注册进去) - // 实际实现中需要过滤 allowed_tools - self.tools.clone() - } - - /// 创建新的子代理会话并执行 - pub async fn spawn( - &self, - parent_context: &ToolContext, - description: String, - prompt: String, - subagent_type: SubagentType, - ) -> Result { - // 1. 验证上下文 - let session_id = parent_context - .session_id - .clone() - .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; - let chat_id = parent_context - .chat_id - .clone() - .ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?; - let channel_name = parent_context - .channel_name - .clone() - .ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?; - - // 2. 创建任务会话 - let session = TaskSession::new( - session_id, - chat_id, - channel_name, - description.clone(), - subagent_type, - ); - - // 3. 保存会话 - self.repository.save_task_session(&session).await?; - - // 4. 构建子代理系统提示词 - let system_prompt = SubagentPromptBuilder::build(subagent_type, &description, &prompt); - - // 5. 创建子代理 - let agent = self.create_subagent(&session, system_prompt)?; - - // 6. 执行任务 - let result = self.execute_task(agent, &session, prompt).await; - - // 7. 更新会话状态并保存 - match result { - Ok(tool_result) => { - let mut session = session; - session.mark_completed(tool_result.summary.clone()); - self.repository.save_task_session(&session).await?; - Ok(tool_result) - } - Err(e) => { - let mut session = session; - let status = e.as_status(); - if status == "timeout" { - session.mark_timeout(); - } else { - session.mark_failed(e.to_string()); - } - self.repository.save_task_session(&session).await?; - Err(e) - } - } - } - - /// 恢复现有任务会话 - pub async fn resume( - &self, - task_id: &str, - parent_context: &ToolContext, - additional_prompt: String, - ) -> Result { - // 1. 加载现有会话 - let session = self - .repository - .load_task_session(task_id) - .await? - .ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?; - - // 2. 验证父会话匹配 - let parent_session_id = parent_context - .session_id - .clone() - .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; - if session.parent_session_id != parent_session_id { - return Err(TaskError::InvalidParentSession); - } - - // 3. 构建恢复提示词 - let system_prompt = SubagentPromptBuilder::build_resume_prompt( - &session.description, - &additional_prompt, - ); - - // 4. 创建子代理 - let agent = self.create_subagent(&session, system_prompt)?; - - // 5. 使用历史继续执行 - let result = self - .execute_task_with_history(agent, &session, additional_prompt) - .await; - - // 6. 更新会话状态 - match result { - Ok(tool_result) => { - let mut session = session; - session.mark_completed(tool_result.summary.clone()); - self.repository.save_task_session(&session).await?; - Ok(tool_result) - } - Err(e) => { - let mut session = session; - session.mark_failed(e.to_string()); - self.repository.save_task_session(&session).await?; - Err(e) - } - } - } - - /// 创建子代理实例 - fn create_subagent( - &self, - session: &TaskSession, - system_prompt: String, - ) -> Result { - let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt)); - - // 获取子代理工具注册表 - let subagent_tools = self.get_subagent_tools(); - - // 直接创建 AgentLoop,使用自定义的提示词提供者 - AgentLoop::with_tools_and_system_prompt_provider( - self.provider_config.clone(), - subagent_tools, - prompt_provider, - None, // 子代理不需要 skill provider - ) - .map(|agent| { - agent.with_tool_context(ToolContext { - channel_name: Some(session.parent_channel_name.clone()), - sender_id: None, - chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id,发送到飞书用户 - session_id: Some(session.id.clone()), - message_id: None, - message_seq: None, - subagent_description: Some(session.description.clone()), // 子代理标识 - }) - }) - .map_err(|e| TaskError::AgentCreationFailed(e.to_string())) - } - - /// 执行任务(带超时控制) - async fn execute_task( - &self, - agent: AgentLoop, - session: &TaskSession, - prompt: String, - ) -> Result { - // 构建初始消息 - let history = vec![ChatMessage::user(prompt)]; - let system_prompt_context = SystemPromptContext { - session_id: Some(session.id.clone()), - chat_id: session.id.clone(), - user_message_count: 1, - }; - - // 设置超时 - let timeout_duration = Duration::from_secs(self.config.max_execution_secs); - - let result = tokio::time::timeout( - timeout_duration, - agent.process(history, Some(&system_prompt_context)), - ) - .await; - - match result { - Ok(Ok(process_result)) => { - let final_message = process_result.final_response; - Ok(TaskToolResult { - status: "success".to_string(), - summary: extract_summary(&final_message.content), - output: final_message.content, - task_id: session.id.clone(), - }) - } - Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), - Err(_) => Err(TaskError::Timeout), - } - } - - /// 使用历史继续执行 - async fn execute_task_with_history( - &self, - agent: AgentLoop, - session: &TaskSession, - additional_prompt: String, - ) -> Result { - // 构建历史 + 新消息 - let mut history = session.history.clone(); - history.push(ChatMessage::user(additional_prompt)); - - let user_message_count = history.iter().filter(|m| m.role == "user").count(); - let system_prompt_context = SystemPromptContext { - session_id: Some(session.id.clone()), - chat_id: session.id.clone(), - user_message_count, - }; - - let timeout_duration = Duration::from_secs(self.config.max_execution_secs); - - let result = tokio::time::timeout( - timeout_duration, - agent.process(history, Some(&system_prompt_context)), - ) - .await; - - match result { - Ok(Ok(process_result)) => { - let final_message = process_result.final_response; - Ok(TaskToolResult { - status: "success".to_string(), - summary: extract_summary(&final_message.content), - output: final_message.content, - task_id: session.id.clone(), - }) - } - Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), - Err(_) => Err(TaskError::Timeout), - } - } - - /// 清理过期任务 - pub async fn cleanup_expired(&self) -> Result { - self.repository - .cleanup_expired_tasks(self.config.ttl_hours) - .await - .map_err(TaskError::from) - } - - /// 创建用于测试的实例 - #[cfg(test)] - pub fn new_for_test() -> Self { - use super::repository::InMemoryTaskRepository; - use std::collections::HashMap; - - Self { - config: TaskRuntimeConfig::default(), - repository: Arc::new(InMemoryTaskRepository::new()), - tools: Arc::new(ToolRegistry::new()), - provider_config: LLMProviderConfig { - provider_type: "openai".to_string(), - name: "test".to_string(), - base_url: "https://test.local/v1".to_string(), - api_key: "test-key".to_string(), - extra_headers: HashMap::new(), - llm_timeout_secs: 120, - memory_maintenance_timeout_secs: 600, - model_id: "test-model".to_string(), - temperature: None, - max_tokens: None, - context_window_tokens: None, - model_extra: HashMap::new(), - max_tool_iterations: 100, - tool_result_max_chars: 20_000, - context_tool_result_trim_chars: 2_000, - }, - } - } -} \ No newline at end of file diff --git a/src/tools/task/tool.rs b/src/tools/task/tool.rs deleted file mode 100644 index 3739006..0000000 --- a/src/tools/task/tool.rs +++ /dev/null @@ -1,170 +0,0 @@ -use async_trait::async_trait; -use serde_json::json; -use std::sync::Arc; - -use crate::tools::{Tool, ToolContext, ToolResult}; - -use super::runtime::TaskRuntime; -use super::types::TaskToolArgs; - -/// Task 工具 - 创建和管理子代理 -pub struct TaskTool { - runtime: Arc, -} - -impl TaskTool { - pub fn new(runtime: Arc) -> Self { - Self { runtime } - } -} - -#[async_trait] -impl Tool for TaskTool { - fn name(&self) -> &str { - "task" - } - - fn description(&self) -> &str { - "Launch a specialized subagent to handle complex, multi-step tasks. \ - Subagents run in isolated contexts and can work in parallel. \ - Use 'general' type for complex tasks, 'explore' type for read-only exploration. \ - You can resume a previous task by providing its task_id." - } - - fn parameters_schema(&self) -> serde_json::Value { - json!({ - "type": "object", - "properties": { - "description": { - "type": "string", - "description": "Short description (3-5 words) of what this task does", - "maxLength": 50 - }, - "prompt": { - "type": "string", - "description": "Detailed instructions for the subagent to execute" - }, - "subagent_type": { - "type": "string", - "enum": ["general", "explore"], - "default": "general", - "description": "Type of subagent: 'general' for complex multi-step tasks, 'explore' for read-only search/exploration" - }, - "task_id": { - "type": "string", - "description": "Optional: Resume an existing task session by providing its task_id" - } - }, - "required": ["description", "prompt"] - }) - } - - fn read_only(&self) -> bool { - false - } - - fn exclusive(&self) -> bool { - // Task 工具创建子代理,不应与其他工具并发执行 - true - } - - async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { - // Task 工具必须通过 execute_with_context 获取父会话信息 - Ok(ToolResult { - success: false, - output: String::new(), - error: Some( - "task tool requires tool context with session_id, chat_id, and channel_name" - .to_string(), - ), - }) - } - - async fn execute_with_context( - &self, - context: &ToolContext, - args: serde_json::Value, - ) -> anyhow::Result { - // 1. 解析参数 - let task_args: TaskToolArgs = serde_json::from_value(args.clone()) - .map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?; - - // 2. 验证描述长度 - let word_count = task_args.description.split_whitespace().count(); - if task_args.description.len() > 50 || word_count > 7 || word_count < 1 { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some( - "description should be 1-5 words, max 50 characters".to_string(), - ), - }); - } - - // 3. 验证上下文 - if context.session_id.is_none() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("task tool requires session_id in context".to_string()), - }); - } - if context.chat_id.is_none() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("task tool requires chat_id in context".to_string()), - }); - } - if context.channel_name.is_none() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("task tool requires channel_name in context".to_string()), - }); - } - - // 4. 执行任务 - let result = if let Some(task_id) = task_args.task_id { - // 恢复现有任务 - self.runtime - .resume(&task_id, context, task_args.prompt) - .await - } else { - // 创建新任务 - self.runtime - .spawn( - context, - task_args.description, - task_args.prompt, - task_args.subagent_type, - ) - .await - }; - - // 5. 构建返回结果 - match result { - Ok(task_result) => Ok(ToolResult { - success: true, - output: serde_json::to_string(&task_result)?, - error: None, - }), - Err(e) => Ok(ToolResult { - success: false, - output: String::new(), - error: Some(e.to_string()), - }), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_task_tool_name_and_description() { - // 简单验证工具名称 - assert!(!TaskTool::name(&TaskTool::new(Arc::new(TaskRuntime::new_for_test()))).is_empty()); - } -} \ No newline at end of file diff --git a/src/tools/task/types.rs b/src/tools/task/types.rs deleted file mode 100644 index 82c9b9f..0000000 --- a/src/tools/task/types.rs +++ /dev/null @@ -1,169 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::bus::ChatMessage; - -/// 子代理会话状态 -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum TaskSessionState { - /// 正在执行 - Running, - /// 已完成 - Completed, - /// 已失败 - Failed, - /// 已超时 - Timeout, -} - -impl Default for TaskSessionState { - fn default() -> Self { - Self::Running - } -} - -/// 子代理类型 -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum SubagentType { - /// 通用型 - 处理复杂多步骤任务 - #[default] - General, - /// 探索型 - 只读搜索代理 - Explore, -} - -impl SubagentType { - pub fn as_str(&self) -> &'static str { - match self { - Self::General => "general", - Self::Explore => "explore", - } - } - - pub fn from_str(s: &str) -> Option { - match s { - "general" => Some(Self::General), - "explore" => Some(Self::Explore), - _ => None, - } - } -} - -/// 任务会话记录 -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskSession { - /// 任务唯一 ID (UUID) - pub id: String, - /// 父会话 ID (用于关联) - pub parent_session_id: String, - /// 父 chat_id - pub parent_chat_id: String, - /// 父 channel_name - pub parent_channel_name: String, - /// 任务描述 - pub description: String, - /// 子代理类型 - pub subagent_type: SubagentType, - /// 当前状态 - pub state: TaskSessionState, - /// 会话历史(子代理的对话) - #[serde(default)] - pub history: Vec, - /// 创建时间 - pub created_at: i64, - /// 最后更新时间 - pub updated_at: i64, - /// 执行摘要 - pub summary: Option, - /// 错误信息 - pub error: Option, -} - -impl TaskSession { - pub fn new( - parent_session_id: String, - parent_chat_id: String, - parent_channel_name: String, - description: String, - subagent_type: SubagentType, - ) -> Self { - let now = current_timestamp(); - Self { - id: format!("task:{}", uuid::Uuid::new_v4()), - parent_session_id, - parent_chat_id, - parent_channel_name, - description, - subagent_type, - state: TaskSessionState::Running, - history: Vec::new(), - created_at: now, - updated_at: now, - summary: None, - error: None, - } - } - - /// 添加消息到历史 - pub fn add_message(&mut self, message: ChatMessage) { - self.history.push(message); - self.updated_at = current_timestamp(); - } - - /// 标记完成 - pub fn mark_completed(&mut self, summary: String) { - self.state = TaskSessionState::Completed; - self.summary = Some(summary); - self.updated_at = current_timestamp(); - } - - /// 标记失败 - pub fn mark_failed(&mut self, error: String) { - self.state = TaskSessionState::Failed; - self.error = Some(error); - self.updated_at = current_timestamp(); - } - - /// 标记超时 - pub fn mark_timeout(&mut self) { - self.state = TaskSessionState::Timeout; - self.error = Some("Task execution timed out".to_string()); - self.updated_at = current_timestamp(); - } -} - -/// 任务工具参数 -#[derive(Debug, Clone, Deserialize)] -pub struct TaskToolArgs { - /// 简短描述(3-5词) - pub description: String, - /// 详细指令 - pub prompt: String, - /// 子代理类型 - #[serde(default)] - pub subagent_type: SubagentType, - /// 恢复现有会话的 task_id - #[serde(default)] - pub task_id: Option, -} - -/// 任务执行结果 -#[derive(Debug, Clone, Serialize)] -pub struct TaskToolResult { - /// 状态: success/failed/timeout - pub status: String, - /// 任务完成总结 - pub summary: String, - /// 详细输出 - pub output: String, - /// 会话 ID(用于恢复) - pub task_id: String, -} - -fn current_timestamp() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("system clock before unix epoch") - .as_millis() as i64 -} \ No newline at end of file diff --git a/src/tools/traits.rs b/src/tools/traits.rs index ce36f27..b00ca14 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -15,8 +15,6 @@ pub struct ToolContext { pub session_id: Option, pub message_id: Option, pub message_seq: Option, - /// 子代理标识,用于标注消息来源 - pub subagent_description: Option, } #[async_trait]