PicoBot/src/tools/task/runtime.rs

422 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use crate::agent::{AgentLoop, AgentRuntimeConfig, SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::storage::ConversationRepository;
use crate::tools::{ToolContext, ToolRegistry};
use super::error::TaskError;
use super::prompt::{extract_summary, SubagentPromptBuilder};
use super::repository::TaskRepository;
use super::types::{SubagentType, TaskDefinition, TaskSession, TaskToolResult};
/// 子代理运行时配置
#[derive(Debug, Clone)]
pub struct SubAgentRuntimeConfig {
/// 子代理可用的工具列表(白名单)
pub allowed_tools: HashSet<String>,
/// 最大执行时间(秒) - General 类型
pub max_execution_secs: u64,
/// Explore 类型的最大执行时间(秒)
pub explore_max_execution_secs: u64,
/// 探索类型的最大工具调用次数
pub explore_max_tool_calls: usize,
/// 任务 TTL小时
pub ttl_hours: u64,
/// 技能索引(可选,预生成的技能列表字符串)
pub skills_index: Option<String>,
}
impl Default for SubAgentRuntimeConfig {
fn default() -> Self {
Self {
allowed_tools: HashSet::from([
"read".to_string(),
"edit".to_string(),
"write".to_string(),
"bash".to_string(),
"http_request".to_string(),
"web_fetch".to_string(),
"memory_search".to_string(),
"get_time".to_string(),
"calculator".to_string(),
"skill_activate".to_string(),
"skill_list".to_string(),
"send_session_message".to_string(), // 用于进度通知
]),
max_execution_secs: 1200, // 20分钟
explore_max_execution_secs: 600, // 10分钟
explore_max_tool_calls: 20,
ttl_hours: 24,
skills_index: None,
}
}
}
/// 子代理运行时抽象接口
#[async_trait]
pub trait SubAgentRuntime: Send + Sync + 'static {
/// 创建并执行子代理任务
async fn spawn(
&self,
parent_context: &ToolContext,
task: TaskDefinition,
) -> Result<TaskToolResult, TaskError>;
/// 恢复现有任务
async fn resume(
&self,
task_id: &str,
parent_context: &ToolContext,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError>;
/// 发送消息给子代理(支持中断或补充指令)
async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>;
/// 清理过期任务
async fn cleanup_expired(&self) -> Result<usize, TaskError>;
}
/// 静态系统提示词提供者(用于子代理)
pub struct StaticSystemPromptProvider {
prompt: String,
}
impl StaticSystemPromptProvider {
pub fn new(prompt: String) -> Self {
Self { prompt }
}
}
impl SystemPromptProvider for StaticSystemPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
Some(SystemPrompt {
content: self.prompt.clone(),
context: Some("subagent".to_string()),
})
}
}
/// 默认子代理运行时实现
pub struct DefaultSubAgentRuntime {
config: SubAgentRuntimeConfig,
task_repository: Arc<dyn TaskRepository>,
conversation_repository: Arc<dyn ConversationRepository>,
subagent_tools: Arc<ToolRegistry>,
provider_config: LLMProviderConfig,
}
impl DefaultSubAgentRuntime {
pub fn new(
config: SubAgentRuntimeConfig,
task_repository: Arc<dyn TaskRepository>,
conversation_repository: Arc<dyn ConversationRepository>,
subagent_tools: Arc<ToolRegistry>,
provider_config: LLMProviderConfig,
) -> Self {
Self {
config,
task_repository,
conversation_repository,
subagent_tools,
provider_config,
}
}
/// 创建子代理实例
fn create_subagent(
&self,
session: &TaskSession,
system_prompt: String,
) -> Result<AgentLoop, TaskError> {
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
AgentLoop::with_tools_and_system_prompt_provider(
AgentRuntimeConfig::from(self.provider_config.clone()),
self.subagent_tools.clone(),
prompt_provider,
None, // 子代理不需要 skill provider
)
.map(|agent| {
agent.with_tool_context(ToolContext {
channel_name: Some(session.parent_channel_name.clone()),
sender_id: None,
chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id
session_id: Some(session.session_id.clone()), // 子代理自己的 session_id
message_id: None,
message_seq: None,
subagent_description: Some(session.description.clone()),
})
})
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
}
/// 执行任务(带超时控制)
async fn execute_task(
&self,
agent: AgentLoop,
session: &TaskSession,
prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 构建初始消息
let history = vec![ChatMessage::user(prompt)];
let system_prompt_context = SystemPromptContext {
session_id: Some(session.session_id.clone()),
chat_id: session.session_id.clone(),
user_message_count: 1,
};
// 设置超时
let max_secs = if session.subagent_type == SubagentType::Explore {
self.config.explore_max_execution_secs
} else {
self.config.max_execution_secs
};
let timeout_duration = Duration::from_secs(max_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.process(history, Some(&system_prompt_context)),
)
.await;
match result {
Ok(Ok(process_result)) => {
// 保存子智能体产生的所有消息到数据库
for message in &process_result.emitted_messages {
if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message");
}
}
let final_message = process_result.final_response;
Ok(TaskToolResult {
status: "success".to_string(),
summary: extract_summary(&final_message.content),
output: final_message.content,
task_id: session.id.clone(),
})
}
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
Err(_) => Err(TaskError::Timeout),
}
}
/// 使用历史继续执行
async fn execute_task_with_history(
&self,
agent: AgentLoop,
session: &TaskSession,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 加载历史 + 新消息
let mut history = self
.conversation_repository
.load_messages(&session.session_id)
.map_err(TaskError::RepositoryError)?;
history.push(ChatMessage::user(additional_prompt));
let user_message_count = history.iter().filter(|m| m.role == "user").count();
let system_prompt_context = SystemPromptContext {
session_id: Some(session.session_id.clone()),
chat_id: session.session_id.clone(),
user_message_count,
};
let timeout_duration = Duration::from_secs(self.config.max_execution_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.process(history, Some(&system_prompt_context)),
)
.await;
match result {
Ok(Ok(process_result)) => {
// 保存子智能体产生的所有消息到数据库
for message in &process_result.emitted_messages {
if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message");
}
}
let final_message = process_result.final_response;
Ok(TaskToolResult {
status: "success".to_string(),
summary: extract_summary(&final_message.content),
output: final_message.content,
task_id: session.id.clone(),
})
}
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
Err(_) => Err(TaskError::Timeout),
}
}
}
#[async_trait]
impl SubAgentRuntime for DefaultSubAgentRuntime {
async fn spawn(
&self,
parent_context: &ToolContext,
task: TaskDefinition,
) -> Result<TaskToolResult, TaskError> {
// 1. 验证上下文
let session_id = parent_context
.session_id
.clone()
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
let chat_id = parent_context
.chat_id
.clone()
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
let channel_name = parent_context
.channel_name
.clone()
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
// 2. 创建任务会话
let session = TaskSession::new(
session_id,
chat_id,
channel_name,
task.description.clone(),
task.subagent_type,
);
// 3. 在 sessions 表中创建子智能体会话(确保外键约束满足)
let session_title = format!("Subagent: {}", task.description);
if let Err(e) = self.conversation_repository.ensure_session(
&session.session_id,
&session.parent_channel_name,
&session.parent_chat_id,
&session_title,
) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session");
}
// 4. 保存任务会话
self.task_repository.save_task_session(&session).await?;
// 4. 构建子代理系统提示词
let system_prompt = SubagentPromptBuilder::build(
task.subagent_type,
&task.description,
&task.prompt,
&self.provider_config,
self.config.skills_index.as_deref(),
);
// 5. 创建子代理
let agent = self.create_subagent(&session, system_prompt)?;
// 6. 执行任务
let result = self
.execute_task(agent, &session, task.prompt.clone())
.await;
// 7. 更新会话状态并保存
match result {
Ok(tool_result) => {
let mut session = session;
session.mark_completed(tool_result.summary.clone());
self.task_repository.save_task_session(&session).await?;
Ok(tool_result)
}
Err(e) => {
let mut session = session;
let status = e.as_status();
if status == "timeout" {
session.mark_timeout();
} else {
session.mark_failed(e.to_string());
}
self.task_repository.save_task_session(&session).await?;
Err(e)
}
}
}
async fn resume(
&self,
task_id: &str,
parent_context: &ToolContext,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 1. 加载现有会话
let session = self
.task_repository
.load_task_session(task_id)
.await?
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
// 2. 验证父会话匹配
let parent_session_id = parent_context
.session_id
.clone()
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
if session.parent_session_id != parent_session_id {
return Err(TaskError::InvalidParentSession);
}
// 3. 确保 sessions 表中存在子智能体会话记录
let session_title = format!("Subagent: {}", session.description);
if let Err(e) = self.conversation_repository.ensure_session(
&session.session_id,
&session.parent_channel_name,
&session.parent_chat_id,
&session_title,
) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session on resume");
}
// 4. 构建恢复提示词
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
&session.description,
&additional_prompt,
);
// 5. 创建子代理
let agent = self.create_subagent(&session, system_prompt)?;
// 6. 使用历史继续执行
let result = self
.execute_task_with_history(agent, &session, additional_prompt)
.await;
// 7. 更新会话状态
match result {
Ok(tool_result) => {
let mut session = session;
session.mark_completed(tool_result.summary.clone());
self.task_repository.save_task_session(&session).await?;
Ok(tool_result)
}
Err(e) => {
let mut session = session;
session.mark_failed(e.to_string());
self.task_repository.save_task_session(&session).await?;
Err(e)
}
}
}
async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> {
// TODO: 实现双向通信
// 需要在 TaskSession 中添加 pending_messages 队列
Err(TaskError::InvalidArguments("send_message not implemented yet".to_string()))
}
async fn cleanup_expired(&self) -> Result<usize, TaskError> {
self.task_repository
.cleanup_expired_tasks(self.config.ttl_hours)
.await
.map_err(TaskError::from)
}
}