422 lines
14 KiB
Rust
422 lines
14 KiB
Rust
use std::collections::HashSet;
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
use async_trait::async_trait;
|
||
|
||
use crate::agent::{AgentLoop, AgentRuntimeConfig, SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||
use crate::bus::ChatMessage;
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::storage::ConversationRepository;
|
||
use crate::tools::{ToolContext, ToolRegistry};
|
||
|
||
use super::error::TaskError;
|
||
use super::prompt::{extract_summary, SubagentPromptBuilder};
|
||
use super::repository::TaskRepository;
|
||
use super::types::{SubagentType, TaskDefinition, TaskSession, TaskToolResult};
|
||
|
||
/// 子代理运行时配置
|
||
#[derive(Debug, Clone)]
|
||
pub struct SubAgentRuntimeConfig {
|
||
/// 子代理可用的工具列表(白名单)
|
||
pub allowed_tools: HashSet<String>,
|
||
/// 最大执行时间(秒) - General 类型
|
||
pub max_execution_secs: u64,
|
||
/// Explore 类型的最大执行时间(秒)
|
||
pub explore_max_execution_secs: u64,
|
||
/// 探索类型的最大工具调用次数
|
||
pub explore_max_tool_calls: usize,
|
||
/// 任务 TTL(小时)
|
||
pub ttl_hours: u64,
|
||
/// 技能索引(可选,预生成的技能列表字符串)
|
||
pub skills_index: Option<String>,
|
||
}
|
||
|
||
impl Default for SubAgentRuntimeConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
allowed_tools: HashSet::from([
|
||
"read".to_string(),
|
||
"edit".to_string(),
|
||
"write".to_string(),
|
||
"bash".to_string(),
|
||
"http_request".to_string(),
|
||
"web_fetch".to_string(),
|
||
"memory_search".to_string(),
|
||
"get_time".to_string(),
|
||
"calculator".to_string(),
|
||
"skill_activate".to_string(),
|
||
"skill_list".to_string(),
|
||
"send_session_message".to_string(), // 用于进度通知
|
||
]),
|
||
max_execution_secs: 1200, // 20分钟
|
||
explore_max_execution_secs: 600, // 10分钟
|
||
explore_max_tool_calls: 20,
|
||
ttl_hours: 24,
|
||
skills_index: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 子代理运行时抽象接口
|
||
#[async_trait]
|
||
pub trait SubAgentRuntime: Send + Sync + 'static {
|
||
/// 创建并执行子代理任务
|
||
async fn spawn(
|
||
&self,
|
||
parent_context: &ToolContext,
|
||
task: TaskDefinition,
|
||
) -> Result<TaskToolResult, TaskError>;
|
||
|
||
/// 恢复现有任务
|
||
async fn resume(
|
||
&self,
|
||
task_id: &str,
|
||
parent_context: &ToolContext,
|
||
additional_prompt: String,
|
||
) -> Result<TaskToolResult, TaskError>;
|
||
|
||
/// 发送消息给子代理(支持中断或补充指令)
|
||
async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>;
|
||
|
||
/// 清理过期任务
|
||
async fn cleanup_expired(&self) -> Result<usize, TaskError>;
|
||
}
|
||
|
||
/// 静态系统提示词提供者(用于子代理)
|
||
pub struct StaticSystemPromptProvider {
|
||
prompt: String,
|
||
}
|
||
|
||
impl StaticSystemPromptProvider {
|
||
pub fn new(prompt: String) -> Self {
|
||
Self { prompt }
|
||
}
|
||
}
|
||
|
||
impl SystemPromptProvider for StaticSystemPromptProvider {
|
||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||
Some(SystemPrompt {
|
||
content: self.prompt.clone(),
|
||
context: Some("subagent".to_string()),
|
||
})
|
||
}
|
||
}
|
||
|
||
/// 默认子代理运行时实现
|
||
pub struct DefaultSubAgentRuntime {
|
||
config: SubAgentRuntimeConfig,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
conversation_repository: Arc<dyn ConversationRepository>,
|
||
subagent_tools: Arc<ToolRegistry>,
|
||
provider_config: LLMProviderConfig,
|
||
}
|
||
|
||
impl DefaultSubAgentRuntime {
|
||
pub fn new(
|
||
config: SubAgentRuntimeConfig,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
conversation_repository: Arc<dyn ConversationRepository>,
|
||
subagent_tools: Arc<ToolRegistry>,
|
||
provider_config: LLMProviderConfig,
|
||
) -> Self {
|
||
Self {
|
||
config,
|
||
task_repository,
|
||
conversation_repository,
|
||
subagent_tools,
|
||
provider_config,
|
||
}
|
||
}
|
||
|
||
/// 创建子代理实例
|
||
fn create_subagent(
|
||
&self,
|
||
session: &TaskSession,
|
||
system_prompt: String,
|
||
) -> Result<AgentLoop, TaskError> {
|
||
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
|
||
|
||
AgentLoop::with_tools_and_system_prompt_provider(
|
||
AgentRuntimeConfig::from(self.provider_config.clone()),
|
||
self.subagent_tools.clone(),
|
||
prompt_provider,
|
||
None, // 子代理不需要 skill provider
|
||
)
|
||
.map(|agent| {
|
||
agent.with_tool_context(ToolContext {
|
||
channel_name: Some(session.parent_channel_name.clone()),
|
||
sender_id: None,
|
||
chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id
|
||
session_id: Some(session.session_id.clone()), // 子代理自己的 session_id
|
||
message_id: None,
|
||
message_seq: None,
|
||
subagent_description: Some(session.description.clone()),
|
||
})
|
||
})
|
||
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
|
||
}
|
||
|
||
/// 执行任务(带超时控制)
|
||
async fn execute_task(
|
||
&self,
|
||
agent: AgentLoop,
|
||
session: &TaskSession,
|
||
prompt: String,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 构建初始消息
|
||
let history = vec![ChatMessage::user(prompt)];
|
||
let system_prompt_context = SystemPromptContext {
|
||
session_id: Some(session.session_id.clone()),
|
||
chat_id: session.session_id.clone(),
|
||
user_message_count: 1,
|
||
};
|
||
|
||
// 设置超时
|
||
let max_secs = if session.subagent_type == SubagentType::Explore {
|
||
self.config.explore_max_execution_secs
|
||
} else {
|
||
self.config.max_execution_secs
|
||
};
|
||
let timeout_duration = Duration::from_secs(max_secs);
|
||
|
||
let result = tokio::time::timeout(
|
||
timeout_duration,
|
||
agent.process(history, Some(&system_prompt_context)),
|
||
)
|
||
.await;
|
||
|
||
match result {
|
||
Ok(Ok(process_result)) => {
|
||
// 保存子智能体产生的所有消息到数据库
|
||
for message in &process_result.emitted_messages {
|
||
if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) {
|
||
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message");
|
||
}
|
||
}
|
||
|
||
let final_message = process_result.final_response;
|
||
Ok(TaskToolResult {
|
||
status: "success".to_string(),
|
||
summary: extract_summary(&final_message.content),
|
||
output: final_message.content,
|
||
task_id: session.id.clone(),
|
||
})
|
||
}
|
||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
||
Err(_) => Err(TaskError::Timeout),
|
||
}
|
||
}
|
||
|
||
/// 使用历史继续执行
|
||
async fn execute_task_with_history(
|
||
&self,
|
||
agent: AgentLoop,
|
||
session: &TaskSession,
|
||
additional_prompt: String,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 加载历史 + 新消息
|
||
let mut history = self
|
||
.conversation_repository
|
||
.load_messages(&session.session_id)
|
||
.map_err(TaskError::RepositoryError)?;
|
||
history.push(ChatMessage::user(additional_prompt));
|
||
|
||
let user_message_count = history.iter().filter(|m| m.role == "user").count();
|
||
let system_prompt_context = SystemPromptContext {
|
||
session_id: Some(session.session_id.clone()),
|
||
chat_id: session.session_id.clone(),
|
||
user_message_count,
|
||
};
|
||
|
||
let timeout_duration = Duration::from_secs(self.config.max_execution_secs);
|
||
|
||
let result = tokio::time::timeout(
|
||
timeout_duration,
|
||
agent.process(history, Some(&system_prompt_context)),
|
||
)
|
||
.await;
|
||
|
||
match result {
|
||
Ok(Ok(process_result)) => {
|
||
// 保存子智能体产生的所有消息到数据库
|
||
for message in &process_result.emitted_messages {
|
||
if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) {
|
||
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message");
|
||
}
|
||
}
|
||
|
||
let final_message = process_result.final_response;
|
||
Ok(TaskToolResult {
|
||
status: "success".to_string(),
|
||
summary: extract_summary(&final_message.content),
|
||
output: final_message.content,
|
||
task_id: session.id.clone(),
|
||
})
|
||
}
|
||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
||
Err(_) => Err(TaskError::Timeout),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl SubAgentRuntime for DefaultSubAgentRuntime {
|
||
async fn spawn(
|
||
&self,
|
||
parent_context: &ToolContext,
|
||
task: TaskDefinition,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 1. 验证上下文
|
||
let session_id = parent_context
|
||
.session_id
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
||
let chat_id = parent_context
|
||
.chat_id
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
|
||
let channel_name = parent_context
|
||
.channel_name
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
|
||
|
||
// 2. 创建任务会话
|
||
let session = TaskSession::new(
|
||
session_id,
|
||
chat_id,
|
||
channel_name,
|
||
task.description.clone(),
|
||
task.subagent_type,
|
||
);
|
||
|
||
// 3. 在 sessions 表中创建子智能体会话(确保外键约束满足)
|
||
let session_title = format!("Subagent: {}", task.description);
|
||
if let Err(e) = self.conversation_repository.ensure_session(
|
||
&session.session_id,
|
||
&session.parent_channel_name,
|
||
&session.parent_chat_id,
|
||
&session_title,
|
||
) {
|
||
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session");
|
||
}
|
||
|
||
// 4. 保存任务会话
|
||
self.task_repository.save_task_session(&session).await?;
|
||
|
||
// 4. 构建子代理系统提示词
|
||
let system_prompt = SubagentPromptBuilder::build(
|
||
task.subagent_type,
|
||
&task.description,
|
||
&task.prompt,
|
||
&self.provider_config,
|
||
self.config.skills_index.as_deref(),
|
||
);
|
||
|
||
// 5. 创建子代理
|
||
let agent = self.create_subagent(&session, system_prompt)?;
|
||
|
||
// 6. 执行任务
|
||
let result = self
|
||
.execute_task(agent, &session, task.prompt.clone())
|
||
.await;
|
||
|
||
// 7. 更新会话状态并保存
|
||
match result {
|
||
Ok(tool_result) => {
|
||
let mut session = session;
|
||
session.mark_completed(tool_result.summary.clone());
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Ok(tool_result)
|
||
}
|
||
Err(e) => {
|
||
let mut session = session;
|
||
let status = e.as_status();
|
||
if status == "timeout" {
|
||
session.mark_timeout();
|
||
} else {
|
||
session.mark_failed(e.to_string());
|
||
}
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Err(e)
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn resume(
|
||
&self,
|
||
task_id: &str,
|
||
parent_context: &ToolContext,
|
||
additional_prompt: String,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 1. 加载现有会话
|
||
let session = self
|
||
.task_repository
|
||
.load_task_session(task_id)
|
||
.await?
|
||
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
|
||
|
||
// 2. 验证父会话匹配
|
||
let parent_session_id = parent_context
|
||
.session_id
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
||
if session.parent_session_id != parent_session_id {
|
||
return Err(TaskError::InvalidParentSession);
|
||
}
|
||
|
||
// 3. 确保 sessions 表中存在子智能体会话记录
|
||
let session_title = format!("Subagent: {}", session.description);
|
||
if let Err(e) = self.conversation_repository.ensure_session(
|
||
&session.session_id,
|
||
&session.parent_channel_name,
|
||
&session.parent_chat_id,
|
||
&session_title,
|
||
) {
|
||
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session on resume");
|
||
}
|
||
|
||
// 4. 构建恢复提示词
|
||
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
|
||
&session.description,
|
||
&additional_prompt,
|
||
);
|
||
|
||
// 5. 创建子代理
|
||
let agent = self.create_subagent(&session, system_prompt)?;
|
||
|
||
// 6. 使用历史继续执行
|
||
let result = self
|
||
.execute_task_with_history(agent, &session, additional_prompt)
|
||
.await;
|
||
|
||
// 7. 更新会话状态
|
||
match result {
|
||
Ok(tool_result) => {
|
||
let mut session = session;
|
||
session.mark_completed(tool_result.summary.clone());
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Ok(tool_result)
|
||
}
|
||
Err(e) => {
|
||
let mut session = session;
|
||
session.mark_failed(e.to_string());
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Err(e)
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> {
|
||
// TODO: 实现双向通信
|
||
// 需要在 TaskSession 中添加 pending_messages 队列
|
||
Err(TaskError::InvalidArguments("send_message not implemented yet".to_string()))
|
||
}
|
||
|
||
async fn cleanup_expired(&self) -> Result<usize, TaskError> {
|
||
self.task_repository
|
||
.cleanup_expired_tasks(self.config.ttl_hours)
|
||
.await
|
||
.map_err(TaskError::from)
|
||
}
|
||
} |