use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use crate::agent::{AgentLoop, AgentRuntimeConfig, SystemPrompt, SystemPromptContext, SystemPromptProvider}; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::storage::ConversationRepository; use crate::tools::{ToolContext, ToolRegistry}; use super::error::TaskError; use super::prompt::{extract_summary, SubagentPromptBuilder}; use super::repository::TaskRepository; use super::types::{SubagentType, TaskDefinition, TaskSession, TaskToolResult}; /// 子代理运行时配置 #[derive(Debug, Clone)] pub struct SubAgentRuntimeConfig { /// 子代理可用的工具列表(白名单) pub allowed_tools: HashSet, /// 最大执行时间(秒) - General 类型 pub max_execution_secs: u64, /// Explore 类型的最大执行时间(秒) pub explore_max_execution_secs: u64, /// 探索类型的最大工具调用次数 pub explore_max_tool_calls: usize, /// 任务 TTL(小时) pub ttl_hours: u64, /// 技能索引(可选,预生成的技能列表字符串) pub skills_index: Option, } impl Default for SubAgentRuntimeConfig { fn default() -> Self { Self { allowed_tools: HashSet::from([ "read".to_string(), "edit".to_string(), "write".to_string(), "bash".to_string(), "http_request".to_string(), "web_fetch".to_string(), "memory_search".to_string(), "get_time".to_string(), "calculator".to_string(), "skill_activate".to_string(), "skill_list".to_string(), "send_session_message".to_string(), // 用于进度通知 ]), max_execution_secs: 1200, // 20分钟 explore_max_execution_secs: 600, // 10分钟 explore_max_tool_calls: 20, ttl_hours: 24, skills_index: None, } } } /// 子代理运行时抽象接口 #[async_trait] pub trait SubAgentRuntime: Send + Sync + 'static { /// 创建并执行子代理任务 async fn spawn( &self, parent_context: &ToolContext, task: TaskDefinition, ) -> Result; /// 恢复现有任务 async fn resume( &self, task_id: &str, parent_context: &ToolContext, additional_prompt: String, ) -> Result; /// 发送消息给子代理(支持中断或补充指令) async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>; /// 清理过期任务 async fn cleanup_expired(&self) -> Result; } /// 静态系统提示词提供者(用于子代理) pub struct StaticSystemPromptProvider { prompt: String, } impl StaticSystemPromptProvider { pub fn new(prompt: String) -> Self { Self { prompt } } } impl SystemPromptProvider for StaticSystemPromptProvider { fn build(&self, _context: &SystemPromptContext) -> Option { Some(SystemPrompt { content: self.prompt.clone(), context: Some("subagent".to_string()), }) } } /// 默认子代理运行时实现 pub struct DefaultSubAgentRuntime { config: SubAgentRuntimeConfig, task_repository: Arc, conversation_repository: Arc, subagent_tools: Arc, provider_config: LLMProviderConfig, } impl DefaultSubAgentRuntime { pub fn new( config: SubAgentRuntimeConfig, task_repository: Arc, conversation_repository: Arc, subagent_tools: Arc, provider_config: LLMProviderConfig, ) -> Self { Self { config, task_repository, conversation_repository, subagent_tools, provider_config, } } /// 创建子代理实例 fn create_subagent( &self, session: &TaskSession, system_prompt: String, ) -> Result { let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt)); AgentLoop::with_tools_and_system_prompt_provider( AgentRuntimeConfig::from(self.provider_config.clone()), self.subagent_tools.clone(), prompt_provider, None, // 子代理不需要 skill provider ) .map(|agent| { agent.with_tool_context(ToolContext { channel_name: Some(session.parent_channel_name.clone()), sender_id: None, chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id session_id: Some(session.session_id.clone()), // 子代理自己的 session_id message_id: None, message_seq: None, subagent_description: Some(session.description.clone()), }) }) .map_err(|e| TaskError::AgentCreationFailed(e.to_string())) } /// 执行任务(带超时控制) async fn execute_task( &self, agent: AgentLoop, session: &TaskSession, prompt: String, ) -> Result { // 构建初始消息 let history = vec![ChatMessage::user(prompt)]; let system_prompt_context = SystemPromptContext { session_id: Some(session.session_id.clone()), chat_id: session.session_id.clone(), user_message_count: 1, }; // 设置超时 let max_secs = if session.subagent_type == SubagentType::Explore { self.config.explore_max_execution_secs } else { self.config.max_execution_secs }; let timeout_duration = Duration::from_secs(max_secs); let result = tokio::time::timeout( timeout_duration, agent.process(history, Some(&system_prompt_context)), ) .await; match result { Ok(Ok(process_result)) => { // 保存子智能体产生的所有消息到数据库 for message in &process_result.emitted_messages { if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) { tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message"); } } let final_message = process_result.final_response; Ok(TaskToolResult { status: "success".to_string(), summary: extract_summary(&final_message.content), output: final_message.content, task_id: session.id.clone(), }) } Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), Err(_) => Err(TaskError::Timeout), } } /// 使用历史继续执行 async fn execute_task_with_history( &self, agent: AgentLoop, session: &TaskSession, additional_prompt: String, ) -> Result { // 加载历史 + 新消息 let mut history = self .conversation_repository .load_messages(&session.session_id) .map_err(TaskError::RepositoryError)?; history.push(ChatMessage::user(additional_prompt)); let user_message_count = history.iter().filter(|m| m.role == "user").count(); let system_prompt_context = SystemPromptContext { session_id: Some(session.session_id.clone()), chat_id: session.session_id.clone(), user_message_count, }; let timeout_duration = Duration::from_secs(self.config.max_execution_secs); let result = tokio::time::timeout( timeout_duration, agent.process(history, Some(&system_prompt_context)), ) .await; match result { Ok(Ok(process_result)) => { // 保存子智能体产生的所有消息到数据库 for message in &process_result.emitted_messages { if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) { tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message"); } } let final_message = process_result.final_response; Ok(TaskToolResult { status: "success".to_string(), summary: extract_summary(&final_message.content), output: final_message.content, task_id: session.id.clone(), }) } Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())), Err(_) => Err(TaskError::Timeout), } } } #[async_trait] impl SubAgentRuntime for DefaultSubAgentRuntime { async fn spawn( &self, parent_context: &ToolContext, task: TaskDefinition, ) -> Result { // 1. 验证上下文 let session_id = parent_context .session_id .clone() .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; let chat_id = parent_context .chat_id .clone() .ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?; let channel_name = parent_context .channel_name .clone() .ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?; // 2. 创建任务会话 let session = TaskSession::new( session_id, chat_id, channel_name, task.description.clone(), task.subagent_type, ); // 3. 在 sessions 表中创建子智能体会话(确保外键约束满足) let session_title = format!("Subagent: {}", task.description); if let Err(e) = self.conversation_repository.ensure_session( &session.session_id, &session.parent_channel_name, &session.parent_chat_id, &session_title, ) { tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session"); } // 4. 保存任务会话 self.task_repository.save_task_session(&session).await?; // 4. 构建子代理系统提示词 let system_prompt = SubagentPromptBuilder::build( task.subagent_type, &task.description, &task.prompt, &self.provider_config, self.config.skills_index.as_deref(), ); // 5. 创建子代理 let agent = self.create_subagent(&session, system_prompt)?; // 6. 执行任务 let result = self .execute_task(agent, &session, task.prompt.clone()) .await; // 7. 更新会话状态并保存 match result { Ok(tool_result) => { let mut session = session; session.mark_completed(tool_result.summary.clone()); self.task_repository.save_task_session(&session).await?; Ok(tool_result) } Err(e) => { let mut session = session; let status = e.as_status(); if status == "timeout" { session.mark_timeout(); } else { session.mark_failed(e.to_string()); } self.task_repository.save_task_session(&session).await?; Err(e) } } } async fn resume( &self, task_id: &str, parent_context: &ToolContext, additional_prompt: String, ) -> Result { // 1. 加载现有会话 let session = self .task_repository .load_task_session(task_id) .await? .ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?; // 2. 验证父会话匹配 let parent_session_id = parent_context .session_id .clone() .ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?; if session.parent_session_id != parent_session_id { return Err(TaskError::InvalidParentSession); } // 3. 确保 sessions 表中存在子智能体会话记录 let session_title = format!("Subagent: {}", session.description); if let Err(e) = self.conversation_repository.ensure_session( &session.session_id, &session.parent_channel_name, &session.parent_chat_id, &session_title, ) { tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session on resume"); } // 4. 构建恢复提示词 let system_prompt = SubagentPromptBuilder::build_resume_prompt( &session.description, &additional_prompt, ); // 5. 创建子代理 let agent = self.create_subagent(&session, system_prompt)?; // 6. 使用历史继续执行 let result = self .execute_task_with_history(agent, &session, additional_prompt) .await; // 7. 更新会话状态 match result { Ok(tool_result) => { let mut session = session; session.mark_completed(tool_result.summary.clone()); self.task_repository.save_task_session(&session).await?; Ok(tool_result) } Err(e) => { let mut session = session; session.mark_failed(e.to_string()); self.task_repository.save_task_session(&session).await?; Err(e) } } } async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> { // TODO: 实现双向通信 // 需要在 TaskSession 中添加 pending_messages 队列 Err(TaskError::InvalidArguments("send_message not implemented yet".to_string())) } async fn cleanup_expired(&self) -> Result { self.task_repository .cleanup_expired_tasks(self.config.ttl_hours) .await .map_err(TaskError::from) } }