1036 lines
36 KiB
Rust
1036 lines
36 KiB
Rust
use std::collections::{HashMap, HashSet};
|
||
use std::fs;
|
||
use std::path::{Path, PathBuf};
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
use async_trait::async_trait;
|
||
use serde::Deserialize;
|
||
|
||
use crate::agent::{AgentLoop, AgentRuntimeConfig, EmittedMessageHandler, PersistingEmittedMessageHandler, SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||
use crate::bus::ChatMessage;
|
||
use crate::bus::message::{OutboundMessage, OutboundEventKind};
|
||
use crate::bus::MessageBus;
|
||
use crate::providers::StreamDelta;
|
||
use crate::config::{LLMProviderConfig, SubagentsConfig};
|
||
use crate::storage::{ConversationRepository, SessionStore};
|
||
use crate::tools::{ToolContext, ToolRegistry};
|
||
|
||
use super::error::TaskError;
|
||
use super::prompt::{extract_summary, SubagentPromptBuilder};
|
||
use super::repository::TaskRepository;
|
||
use super::tool::TaskTool;
|
||
use super::types::{SubagentDef, SubagentSource, TaskDefinition, TaskSession, TaskToolResult};
|
||
|
||
/// 子代理运行时配置
|
||
#[derive(Debug, Clone)]
|
||
pub struct SubAgentRuntimeConfig {
|
||
/// 默认工具白名单(定义未指定时使用)
|
||
pub default_allowed_tools: HashSet<String>,
|
||
/// 默认最大执行时间(秒)
|
||
pub default_max_execution_secs: u64,
|
||
/// Explore 类型的最大执行时间(秒)
|
||
pub explore_max_execution_secs: u64,
|
||
/// 任务 TTL(小时)
|
||
pub ttl_hours: u64,
|
||
/// 技能索引(可选,预生成的技能列表字符串)
|
||
pub skills_index: Option<String>,
|
||
/// 子代理最大嵌套深度(0 = 禁止嵌套,1 = 允许 1 层孙代理)
|
||
pub max_nesting_depth: u32,
|
||
}
|
||
|
||
impl Default for SubAgentRuntimeConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
default_allowed_tools: HashSet::from([
|
||
"read".to_string(),
|
||
"edit".to_string(),
|
||
"write".to_string(),
|
||
"bash".to_string(),
|
||
"http_request".to_string(),
|
||
"web_fetch".to_string(),
|
||
"memory_search".to_string(),
|
||
"get_time".to_string(),
|
||
"calculator".to_string(),
|
||
"skill_activate".to_string(),
|
||
"skill_list".to_string(),
|
||
"send_session_message".to_string(), // 用于进度通知
|
||
]),
|
||
default_max_execution_secs: 3600, // 60分钟
|
||
explore_max_execution_secs: 3600, // 60分钟
|
||
ttl_hours: 24,
|
||
skills_index: None,
|
||
max_nesting_depth: 1,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 子代理运行时抽象接口
|
||
#[async_trait]
|
||
pub trait SubAgentRuntime: Send + Sync + 'static {
|
||
/// 创建并执行子代理任务
|
||
async fn spawn(
|
||
&self,
|
||
parent_context: &ToolContext,
|
||
task: TaskDefinition,
|
||
) -> Result<TaskToolResult, TaskError>;
|
||
|
||
/// 恢复现有任务
|
||
async fn resume(
|
||
&self,
|
||
task_id: &str,
|
||
parent_context: &ToolContext,
|
||
additional_prompt: String,
|
||
) -> Result<TaskToolResult, TaskError>;
|
||
|
||
/// 发送消息给子代理(支持中断或补充指令)
|
||
async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>;
|
||
|
||
/// 清理过期任务
|
||
async fn cleanup_expired(&self) -> Result<usize, TaskError>;
|
||
|
||
/// 获取可用的子代理类型列表
|
||
fn available_subagent_names(&self) -> Vec<String>;
|
||
}
|
||
|
||
/// 静态系统提示词提供者(用于子代理)
|
||
pub struct StaticSystemPromptProvider {
|
||
prompt: String,
|
||
}
|
||
|
||
impl StaticSystemPromptProvider {
|
||
pub fn new(prompt: String) -> Self {
|
||
Self { prompt }
|
||
}
|
||
}
|
||
|
||
/// 子智能体工具调用实时广播器(不依赖 gateway 层)
|
||
struct SubAgentEmitter {
|
||
bus: Arc<MessageBus>,
|
||
channel_name: String,
|
||
chat_id: String,
|
||
metadata: HashMap<String, String>,
|
||
store: Arc<SessionStore>,
|
||
/// 子/孙智能体自身的 task_id,用于持久化时作为 scope_key
|
||
task_id: String,
|
||
stream_message_id: std::sync::Mutex<Option<String>>,
|
||
}
|
||
|
||
#[async_trait]
|
||
impl EmittedMessageHandler for SubAgentEmitter {
|
||
async fn handle(&self, message: ChatMessage) {
|
||
for outbound in OutboundMessage::from_chat_message(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None,
|
||
None,
|
||
&self.metadata,
|
||
&message,
|
||
) {
|
||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||
tracing::error!(
|
||
error = %error,
|
||
channel = %self.channel_name,
|
||
chat_id = %self.chat_id,
|
||
"Failed to publish live sub-agent tool call"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
|
||
let mut metadata = self.metadata.clone();
|
||
if let Some(ms) = duration_ms {
|
||
metadata.insert("tool_duration_ms".to_string(), ms.to_string());
|
||
}
|
||
for outbound in OutboundMessage::from_chat_message(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None,
|
||
None,
|
||
&metadata,
|
||
&message,
|
||
) {
|
||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||
tracing::error!(
|
||
error = %error,
|
||
channel = %self.channel_name,
|
||
chat_id = %self.chat_id,
|
||
"Failed to publish live sub-agent tool call"
|
||
);
|
||
}
|
||
}
|
||
|
||
// 拦截 todo_write 结果:持久化到 SQLite(子代理用 task_id 作为 scope_key,与 list_todos 保持一致)
|
||
if message.tool_name.as_deref() == Some("todo_write") {
|
||
self.persist_todo_write_result(&message);
|
||
}
|
||
}
|
||
|
||
async fn handle_stream_delta(&self, delta: &StreamDelta) {
|
||
let message_id = {
|
||
let mut guard = self.stream_message_id.lock().unwrap();
|
||
guard.get_or_insert_with(|| uuid::Uuid::new_v4().to_string()).clone()
|
||
};
|
||
|
||
let outbound = if delta.content.is_empty() && delta.reasoning_content.is_none() {
|
||
OutboundMessage::stream_end(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None,
|
||
&message_id,
|
||
self.metadata.clone(),
|
||
)
|
||
} else {
|
||
OutboundMessage::stream_delta(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None,
|
||
&message_id,
|
||
&delta.content,
|
||
delta.reasoning_content.clone(),
|
||
self.metadata.clone(),
|
||
)
|
||
};
|
||
|
||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||
tracing::error!(error = %error, channel = %self.channel_name, "Failed to publish sub-agent stream delta");
|
||
}
|
||
}
|
||
|
||
async fn set_stream_message_id(&self, id: &str) {
|
||
*self.stream_message_id.lock().unwrap() = Some(id.to_string());
|
||
}
|
||
}
|
||
|
||
impl SubAgentEmitter {
|
||
fn persist_todo_write_result(&self, message: &ChatMessage) {
|
||
let parsed: serde_json::Value = match serde_json::from_str(&message.content) {
|
||
Ok(v) => v,
|
||
Err(_) => return,
|
||
};
|
||
|
||
let Some(todos_array) = parsed.get("current_todos").and_then(|v| v.as_array()) else {
|
||
return;
|
||
};
|
||
|
||
let scope_key = &self.task_id;
|
||
|
||
let now = std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap_or_default()
|
||
.as_secs() as i64;
|
||
|
||
let records: Vec<crate::storage::TodoRecord> = todos_array
|
||
.iter()
|
||
.enumerate()
|
||
.filter_map(|(idx, item)| {
|
||
Some(crate::storage::TodoRecord {
|
||
id: item.get("id")?.as_str()?.to_string(),
|
||
scope_key: scope_key.clone(),
|
||
session_id: scope_key.clone(),
|
||
topic_id: None,
|
||
content: item.get("content")?.as_str()?.to_string(),
|
||
status: item.get("status")?.as_str()?.to_string(),
|
||
priority: "medium".to_string(),
|
||
created_at: now + idx as i64,
|
||
updated_at: now,
|
||
created_by_message_id: message.tool_call_id.clone(),
|
||
})
|
||
})
|
||
.collect();
|
||
|
||
if records.is_empty() {
|
||
return;
|
||
}
|
||
|
||
tracing::info!(
|
||
scope_key = %scope_key,
|
||
todo_count = records.len(),
|
||
"SubAgentEmitter: persisting todo_write result"
|
||
);
|
||
|
||
if let Err(e) = self.store.replace_todos(scope_key, &records) {
|
||
tracing::warn!(error = %e, %scope_key, "Failed to persist sub-agent todo list");
|
||
}
|
||
}
|
||
}
|
||
|
||
impl SystemPromptProvider for StaticSystemPromptProvider {
|
||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||
Some(SystemPrompt {
|
||
content: self.prompt.clone(),
|
||
context: Some("subagent".to_string()),
|
||
})
|
||
}
|
||
}
|
||
|
||
/// 默认子代理运行时实现
|
||
pub struct DefaultSubAgentRuntime {
|
||
config: SubAgentRuntimeConfig,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
conversation_repository: Arc<dyn ConversationRepository>,
|
||
subagent_tools: Arc<ToolRegistry>,
|
||
provider_config: LLMProviderConfig,
|
||
/// 子代理定义目录(内置 + 自定义)
|
||
catalog: Arc<SubagentCatalog>,
|
||
bus: Option<Arc<MessageBus>>,
|
||
store: Arc<SessionStore>,
|
||
}
|
||
|
||
impl DefaultSubAgentRuntime {
|
||
pub fn new(
|
||
config: SubAgentRuntimeConfig,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
conversation_repository: Arc<dyn ConversationRepository>,
|
||
subagent_tools: Arc<ToolRegistry>,
|
||
provider_config: LLMProviderConfig,
|
||
catalog: Arc<SubagentCatalog>,
|
||
bus: Option<Arc<MessageBus>>,
|
||
store: Arc<SessionStore>,
|
||
) -> Self {
|
||
Self {
|
||
config,
|
||
task_repository,
|
||
conversation_repository,
|
||
subagent_tools,
|
||
provider_config,
|
||
catalog,
|
||
bus,
|
||
store,
|
||
}
|
||
}
|
||
|
||
/// 查找子代理定义,找不到时 fallback 到 general
|
||
fn find_subagent_def(&self, type_name: &str) -> SubagentDef {
|
||
self.catalog
|
||
.find(type_name)
|
||
.cloned()
|
||
.unwrap_or_else(|| self.catalog.find("general").expect("general subagent must exist").clone())
|
||
}
|
||
|
||
/// 获取实际使用的工具白名单(预留,未来可用于动态工具过滤)
|
||
#[allow(dead_code)]
|
||
fn effective_allowed_tools(&self, def: &SubagentDef) -> HashSet<String> {
|
||
def.allowed_tools
|
||
.as_ref()
|
||
.map(|tools| tools.iter().cloned().collect())
|
||
.unwrap_or_else(|| self.config.default_allowed_tools.clone())
|
||
}
|
||
|
||
/// 获取实际执行时间
|
||
fn effective_max_execution_secs(&self, def: &SubagentDef) -> u64 {
|
||
def.max_execution_secs
|
||
.unwrap_or(self.config.default_max_execution_secs)
|
||
}
|
||
|
||
/// 创建子代理实例
|
||
fn create_subagent(
|
||
&self,
|
||
session: &TaskSession,
|
||
system_prompt: String,
|
||
parent_nesting_depth: u32,
|
||
parent_task_id: Option<String>,
|
||
) -> Result<AgentLoop, TaskError> {
|
||
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
|
||
|
||
// 孙智能体(depth >= 2)不注册 task 工具,防止无限嵌套
|
||
let child_depth = parent_nesting_depth + 1;
|
||
let tools = if child_depth >= 2 {
|
||
Arc::new(self.subagent_tools.without(&[TaskTool::TOOL_NAME]))
|
||
} else {
|
||
self.subagent_tools.clone()
|
||
};
|
||
|
||
AgentLoop::with_tools_and_system_prompt_provider(
|
||
AgentRuntimeConfig::from(self.provider_config.clone()),
|
||
tools,
|
||
prompt_provider,
|
||
None, // 子代理不需要 skill provider
|
||
)
|
||
.map(|agent| {
|
||
let agent = agent.with_tool_context(ToolContext {
|
||
channel_name: Some(session.parent_channel_name.clone()),
|
||
sender_id: None,
|
||
chat_id: Some(session.parent_chat_id.clone()),
|
||
session_id: Some(session.session_id.clone()),
|
||
topic_id: session.parent_topic_id.clone(),
|
||
message_id: None,
|
||
message_seq: None,
|
||
subagent_description: Some(session.description.clone()),
|
||
nesting_depth: parent_nesting_depth + 1,
|
||
task_id: Some(session.id.clone()),
|
||
parent_task_id,
|
||
tool_call_id: None,
|
||
});
|
||
|
||
// 如果有 MessageBus,附加实时广播 emitter
|
||
if let Some(bus) = &self.bus {
|
||
let mut metadata = HashMap::new();
|
||
metadata.insert("subagent_task_id".to_string(), session.id.clone());
|
||
metadata.insert("is_subagent_event".to_string(), "true".to_string());
|
||
metadata.insert("topic_id".to_string(), session.parent_topic_id.clone().unwrap_or_default());
|
||
|
||
let emitter = Arc::new(PersistingEmittedMessageHandler::new(
|
||
SubAgentEmitter {
|
||
bus: bus.clone(),
|
||
channel_name: session.parent_channel_name.clone(),
|
||
chat_id: session.parent_chat_id.clone(),
|
||
metadata,
|
||
store: self.store.clone(),
|
||
task_id: session.id.clone(),
|
||
stream_message_id: std::sync::Mutex::new(None),
|
||
},
|
||
self.conversation_repository.clone(),
|
||
session.session_id.clone(),
|
||
session.parent_topic_id.clone(),
|
||
));
|
||
|
||
return agent.with_emitted_message_handler(emitter);
|
||
}
|
||
|
||
agent
|
||
})
|
||
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
|
||
}
|
||
|
||
/// 执行任务(带超时控制)
|
||
async fn execute_task(
|
||
&self,
|
||
agent: AgentLoop,
|
||
session: &TaskSession,
|
||
def: &SubagentDef,
|
||
prompt: String,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 构建初始消息
|
||
let history = vec![ChatMessage::user(prompt)];
|
||
let system_prompt_context = SystemPromptContext {
|
||
session_id: Some(session.session_id.clone()),
|
||
chat_id: session.session_id.clone(),
|
||
user_message_count: 1,
|
||
};
|
||
|
||
// 设置超时
|
||
let max_secs = if session.subagent_type == "explore" {
|
||
self.config.explore_max_execution_secs
|
||
} else {
|
||
self.effective_max_execution_secs(def)
|
||
};
|
||
let timeout_duration = Duration::from_secs(max_secs);
|
||
|
||
let result = tokio::time::timeout(
|
||
timeout_duration,
|
||
agent.process(history, Some(&system_prompt_context)),
|
||
)
|
||
.await;
|
||
|
||
match result {
|
||
Ok(Ok(process_result)) => {
|
||
let final_message = process_result.final_response;
|
||
Ok(TaskToolResult {
|
||
status: "success".to_string(),
|
||
summary: extract_summary(&final_message.content),
|
||
output: final_message.content,
|
||
task_id: session.id.clone(),
|
||
})
|
||
}
|
||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
||
Err(_) => Err(TaskError::Timeout),
|
||
}
|
||
}
|
||
|
||
/// 使用历史继续执行
|
||
async fn execute_task_with_history(
|
||
&self,
|
||
agent: AgentLoop,
|
||
session: &TaskSession,
|
||
additional_prompt: String,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 加载历史 + 新消息
|
||
let mut history = self
|
||
.conversation_repository
|
||
.load_messages(&session.session_id)
|
||
.map_err(TaskError::RepositoryError)?;
|
||
history.push(ChatMessage::user(additional_prompt));
|
||
|
||
let user_message_count = history.iter().filter(|m| m.role == "user").count();
|
||
let system_prompt_context = SystemPromptContext {
|
||
session_id: Some(session.session_id.clone()),
|
||
chat_id: session.session_id.clone(),
|
||
user_message_count,
|
||
};
|
||
|
||
// 使用默认执行时间(恢复任务时原始定义可能已不存在)
|
||
let timeout_duration = Duration::from_secs(self.config.default_max_execution_secs);
|
||
|
||
let result = tokio::time::timeout(
|
||
timeout_duration,
|
||
agent.process(history, Some(&system_prompt_context)),
|
||
)
|
||
.await;
|
||
|
||
match result {
|
||
Ok(Ok(process_result)) => {
|
||
let final_message = process_result.final_response;
|
||
Ok(TaskToolResult {
|
||
status: "success".to_string(),
|
||
summary: extract_summary(&final_message.content),
|
||
output: final_message.content,
|
||
task_id: session.id.clone(),
|
||
})
|
||
}
|
||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
||
Err(_) => Err(TaskError::Timeout),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl SubAgentRuntime for DefaultSubAgentRuntime {
|
||
async fn spawn(
|
||
&self,
|
||
parent_context: &ToolContext,
|
||
task: TaskDefinition,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 1. 验证上下文
|
||
let session_id = parent_context
|
||
.session_id
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
||
let chat_id = parent_context
|
||
.chat_id
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
|
||
let channel_name = parent_context
|
||
.channel_name
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
|
||
|
||
// 2. 查找子代理定义
|
||
let def = self.find_subagent_def(task.subagent_type.as_str());
|
||
|
||
// 3. 创建任务会话
|
||
let topic_id = parent_context.topic_id.clone();
|
||
let session = TaskSession::new(
|
||
session_id,
|
||
topic_id,
|
||
chat_id,
|
||
channel_name,
|
||
task.description.clone(),
|
||
task.subagent_type,
|
||
);
|
||
|
||
// 4. 在 sessions 表中创建子智能体会话(确保外键约束满足)
|
||
let session_title = format!("Subagent [{}]: {}", session.subagent_type, task.description);
|
||
if let Err(e) = self.conversation_repository.ensure_session(
|
||
&session.session_id,
|
||
&session.parent_channel_name,
|
||
&session.parent_chat_id,
|
||
&session_title,
|
||
) {
|
||
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session");
|
||
}
|
||
|
||
// 5. 保存任务会话
|
||
tracing::info!(
|
||
task_id = %session.id,
|
||
session_id = %session.session_id,
|
||
description = %session.description,
|
||
subagent_type = %session.subagent_type,
|
||
"Spawning sub-agent task"
|
||
);
|
||
self.task_repository.save_task_session(&session).await?;
|
||
|
||
// 5.1 立即通知前端 task_id(让前端可以显示"查看实时进度"按钮)
|
||
if let Some(bus) = &self.bus {
|
||
let mut metadata = HashMap::new();
|
||
metadata.insert("task_id".to_string(), session.id.clone());
|
||
metadata.insert("task_description".to_string(), session.description.clone());
|
||
metadata.insert("task_subagent_type".to_string(), session.subagent_type.clone());
|
||
metadata.insert("topic_id".to_string(), session.parent_topic_id.clone().unwrap_or_default());
|
||
|
||
// 如果是子智能体创建的孙智能体,传递父 task_id
|
||
if let Some(ref ptid) = parent_context.task_id {
|
||
metadata.insert("parent_task_id".to_string(), ptid.clone());
|
||
}
|
||
|
||
// 传递 tool_call_id,前端据此精确匹配创建此任务的 tool_call
|
||
if let Some(ref tcid) = parent_context.tool_call_id {
|
||
metadata.insert("tool_call_id".to_string(), tcid.clone());
|
||
}
|
||
|
||
let event = OutboundMessage {
|
||
channel: session.parent_channel_name.clone(),
|
||
chat_id: session.parent_chat_id.clone(),
|
||
session_id: Some(session.parent_session_id.clone()),
|
||
content: String::new(),
|
||
reply_to: None,
|
||
media: Vec::new(),
|
||
metadata,
|
||
event_kind: OutboundEventKind::TaskStarted,
|
||
role: "system".to_string(),
|
||
tool_call_id: None,
|
||
tool_name: None,
|
||
tool_arguments: None,
|
||
reasoning_content: None,
|
||
message_id: None,
|
||
};
|
||
|
||
if let Err(e) = bus.publish_outbound(event).await {
|
||
tracing::warn!(error = %e, task_id = %session.id, "Failed to publish TaskStarted event");
|
||
}
|
||
}
|
||
|
||
// 6. 构建子代理系统提示词
|
||
let system_prompt = SubagentPromptBuilder::build(
|
||
&def,
|
||
&task.description,
|
||
&task.prompt,
|
||
&self.provider_config,
|
||
self.config.skills_index.as_deref(),
|
||
);
|
||
|
||
// 7. 创建子代理
|
||
let agent = self.create_subagent(&session, system_prompt, parent_context.nesting_depth, parent_context.task_id.clone())?;
|
||
|
||
// 8. 执行任务
|
||
let result = self
|
||
.execute_task(agent, &session, &def, task.prompt.clone())
|
||
.await;
|
||
|
||
// 9. 更新会话状态并保存
|
||
match result {
|
||
Ok(tool_result) => {
|
||
let mut session = session;
|
||
session.mark_completed(tool_result.summary.clone());
|
||
tracing::info!(
|
||
task_id = %session.id,
|
||
session_id = %session.session_id,
|
||
"Task completed, updating session"
|
||
);
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Ok(tool_result)
|
||
}
|
||
Err(e) => {
|
||
let mut session = session;
|
||
let status = e.as_status();
|
||
tracing::warn!(
|
||
task_id = %session.id,
|
||
session_id = %session.session_id,
|
||
status = %status,
|
||
error = %e,
|
||
"Task failed, updating session"
|
||
);
|
||
if status == "timeout" {
|
||
session.mark_timeout();
|
||
} else {
|
||
session.mark_failed(e.to_string());
|
||
}
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Err(e)
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn resume(
|
||
&self,
|
||
task_id: &str,
|
||
parent_context: &ToolContext,
|
||
additional_prompt: String,
|
||
) -> Result<TaskToolResult, TaskError> {
|
||
// 1. 加载现有会话
|
||
let session = self
|
||
.task_repository
|
||
.load_task_session(task_id)
|
||
.await?
|
||
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
|
||
|
||
// 2. 验证父会话匹配
|
||
let parent_session_id = parent_context
|
||
.session_id
|
||
.clone()
|
||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
||
if session.parent_session_id != parent_session_id {
|
||
return Err(TaskError::InvalidParentSession);
|
||
}
|
||
|
||
// 3. 确保 sessions 表中存在子智能体会话记录
|
||
let session_title = format!("Subagent [{}]: {}", session.subagent_type, session.description);
|
||
if let Err(e) = self.conversation_repository.ensure_session(
|
||
&session.session_id,
|
||
&session.parent_channel_name,
|
||
&session.parent_chat_id,
|
||
&session_title,
|
||
) {
|
||
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session on resume");
|
||
}
|
||
|
||
// 4. 构建恢复提示词
|
||
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
|
||
&session.description,
|
||
&additional_prompt,
|
||
);
|
||
|
||
// 5. 创建子代理
|
||
let agent = self.create_subagent(&session, system_prompt, parent_context.nesting_depth, parent_context.task_id.clone())?;
|
||
|
||
// 6. 使用历史继续执行
|
||
let result = self
|
||
.execute_task_with_history(agent, &session, additional_prompt)
|
||
.await;
|
||
|
||
// 7. 更新会话状态
|
||
match result {
|
||
Ok(tool_result) => {
|
||
let mut session = session;
|
||
session.mark_completed(tool_result.summary.clone());
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Ok(tool_result)
|
||
}
|
||
Err(e) => {
|
||
let mut session = session;
|
||
session.mark_failed(e.to_string());
|
||
self.task_repository.save_task_session(&session).await?;
|
||
Err(e)
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> {
|
||
// TODO: 实现双向通信
|
||
// 需要在 TaskSession 中添加 pending_messages 队列
|
||
Err(TaskError::InvalidArguments("send_message not implemented yet".to_string()))
|
||
}
|
||
|
||
async fn cleanup_expired(&self) -> Result<usize, TaskError> {
|
||
self.task_repository
|
||
.cleanup_expired_tasks(self.config.ttl_hours)
|
||
.await
|
||
.map_err(TaskError::from)
|
||
}
|
||
|
||
fn available_subagent_names(&self) -> Vec<String> {
|
||
self.catalog.names()
|
||
}
|
||
}
|
||
|
||
/// 子代理定义目录
|
||
///
|
||
/// 管理所有可用的子代理定义,包括内置和自定义。
|
||
/// 支持用户级(~/.picobot/subagents/)和项目级(./.picobot/subagents/)定义,
|
||
/// 项目级定义会覆盖同名的用户级定义。
|
||
#[derive(Debug, Default)]
|
||
pub struct SubagentCatalog {
|
||
definitions: std::collections::HashMap<String, SubagentDef>,
|
||
}
|
||
|
||
impl SubagentCatalog {
|
||
/// 创建空的目录,并注册内置子代理
|
||
pub fn new() -> Self {
|
||
let mut catalog = Self::default();
|
||
catalog.register(SubagentDef::builtin_general());
|
||
catalog.register(SubagentDef::builtin_explore());
|
||
catalog
|
||
}
|
||
|
||
/// 从配置发现子代理(内置 + 文件系统自定义)
|
||
///
|
||
/// 发现顺序:先内置,后按 sources 配置顺序扫描目录
|
||
/// 后发现的同名定义会覆盖先发现的(项目覆盖用户)
|
||
pub fn discover(config: &SubagentsConfig) -> Self {
|
||
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
|
||
Self::discover_with_cwd(config, &cwd)
|
||
}
|
||
|
||
fn discover_with_cwd(config: &SubagentsConfig, cwd: &Path) -> Self {
|
||
// 先内置作为基础
|
||
let mut merged: std::collections::HashMap<String, SubagentDef> = std::collections::HashMap::new();
|
||
merged.insert("general".to_string(), SubagentDef::builtin_general());
|
||
merged.insert("explore".to_string(), SubagentDef::builtin_explore());
|
||
|
||
tracing::debug!(cwd = %cwd.display(), "Discovering subagents from cwd");
|
||
|
||
// 按配置顺序扫描源目录
|
||
if config.enabled {
|
||
for source in source_order(&config.sources) {
|
||
let root = source_root(&source, cwd);
|
||
tracing::debug!(source = ?source, root = ?root.as_ref().map(|p| p.display().to_string()), "Checking subagent source");
|
||
if let Some(root) = root {
|
||
if root.exists() {
|
||
tracing::info!(path = %root.display(), "Scanning subagents directory");
|
||
} else {
|
||
tracing::debug!(path = %root.display(), "Subagents directory does not exist, skipping");
|
||
}
|
||
for def in load_subagents_from_root(&root, source.clone()) {
|
||
if let Some(existing) = merged.get(&def.name) {
|
||
tracing::warn!(
|
||
subagent = %def.name,
|
||
old_source = ?existing.source,
|
||
new_source = ?def.source,
|
||
"Duplicate subagent name found; overriding with later source"
|
||
);
|
||
}
|
||
merged.insert(def.name.clone(), def);
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
tracing::debug!("Subagents discovery is disabled");
|
||
}
|
||
|
||
// 构建 catalog
|
||
let mut catalog = Self::default();
|
||
for def in merged.into_values() {
|
||
catalog.register(def);
|
||
}
|
||
|
||
tracing::info!(
|
||
discovered = catalog.definitions.len(),
|
||
"Subagents discovery completed"
|
||
);
|
||
|
||
catalog
|
||
}
|
||
|
||
/// 注册一个子代理定义(同名覆盖)
|
||
pub fn register(&mut self, def: SubagentDef) {
|
||
self.definitions.insert(def.name.clone(), def);
|
||
}
|
||
|
||
/// 查找子代理定义
|
||
pub fn find(&self, name: &str) -> Option<&SubagentDef> {
|
||
self.definitions.get(name)
|
||
}
|
||
|
||
/// 获取所有可用的子代理名称
|
||
pub fn names(&self) -> Vec<String> {
|
||
self.definitions.keys().cloned().collect()
|
||
}
|
||
|
||
/// 获取所有可用的子代理定义(用于生成索引提示)
|
||
pub fn all(&self) -> Vec<&SubagentDef> {
|
||
self.definitions.values().collect()
|
||
}
|
||
|
||
/// 生成系统索引提示词(用于注入主 agent)
|
||
pub fn system_index_prompt(&self) -> Option<String> {
|
||
let defs = self.all();
|
||
if defs.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
let mut prompt = String::from(
|
||
"# 子代理系统\n\n\
|
||
子代理是专用的执行单元,用于处理特定类型的任务。\n\
|
||
创建子代理任务时,可以选择以下类型之一:\n\n\
|
||
<available_subagents>\n"
|
||
);
|
||
|
||
for def in defs {
|
||
prompt.push_str(&format!(
|
||
" <subagent>\n <name>{}</name>\n <description>{}</description>\n </subagent>\n",
|
||
xml_escape(&def.name),
|
||
xml_escape(&def.description),
|
||
));
|
||
}
|
||
|
||
prompt.push_str("</available_subagents>");
|
||
Some(prompt)
|
||
}
|
||
}
|
||
|
||
fn xml_escape(s: &str) -> String {
|
||
s.replace('&', "&")
|
||
.replace('<', "<")
|
||
.replace('>', ">")
|
||
.replace('"', """)
|
||
.replace('\'', "'")
|
||
}
|
||
|
||
// ========== 自定义子代理发现 ==========
|
||
|
||
/// 源顺序解析
|
||
fn source_order(sources: &[String]) -> Vec<SubagentSource> {
|
||
let mut result = Vec::new();
|
||
for source in sources {
|
||
match source.as_str() {
|
||
"user" => {
|
||
if !result.contains(&SubagentSource::User) {
|
||
result.push(SubagentSource::User);
|
||
}
|
||
}
|
||
"project" => {
|
||
if !result.contains(&SubagentSource::Project) {
|
||
result.push(SubagentSource::Project);
|
||
}
|
||
}
|
||
unknown => {
|
||
let custom = SubagentSource::Custom(unknown.to_string());
|
||
if !result.contains(&custom) {
|
||
result.push(custom);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 默认顺序:先 user 后 project(项目覆盖用户)
|
||
if result.is_empty() {
|
||
vec![SubagentSource::User, SubagentSource::Project]
|
||
} else {
|
||
result
|
||
}
|
||
}
|
||
|
||
/// 获取源目录根路径
|
||
fn source_root(source: &SubagentSource, cwd: &Path) -> Option<std::path::PathBuf> {
|
||
match source {
|
||
SubagentSource::User => dirs::home_dir().map(|p| p.join(".picobot").join("subagents")),
|
||
SubagentSource::Project => Some(cwd.join(".picobot").join("subagents")),
|
||
SubagentSource::Builtin => None,
|
||
SubagentSource::Custom(path) => {
|
||
let p = std::path::PathBuf::from(path);
|
||
if p.is_absolute() {
|
||
Some(p)
|
||
} else {
|
||
tracing::warn!(path = %path, "Custom subagents source must be an absolute path, skipping");
|
||
None
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 子代理 frontmatter 结构
|
||
#[derive(Debug, Deserialize)]
|
||
struct SubagentFrontmatter {
|
||
#[serde(default)]
|
||
name: Option<String>,
|
||
description: String,
|
||
#[serde(default)]
|
||
prompt_template: Option<String>,
|
||
#[serde(default)]
|
||
allowed_tools: Option<Vec<String>>,
|
||
#[serde(default)]
|
||
max_execution_secs: Option<u64>,
|
||
}
|
||
|
||
/// 从根目录加载所有子代理
|
||
fn load_subagents_from_root(root: &Path, source: SubagentSource) -> Vec<SubagentDef> {
|
||
let mut out = Vec::new();
|
||
if !root.exists() {
|
||
tracing::debug!(path = %root.display(), "Subagents root directory does not exist");
|
||
return out;
|
||
}
|
||
|
||
tracing::debug!(path = %root.display(), "Reading subagents directory");
|
||
|
||
let entries = match fs::read_dir(root) {
|
||
Ok(entries) => entries,
|
||
Err(err) => {
|
||
tracing::warn!(path = %root.display(), error = %err, "Failed to read subagents directory");
|
||
return out;
|
||
}
|
||
};
|
||
|
||
let mut found_dirs = 0;
|
||
let mut found_files = 0;
|
||
|
||
for entry in entries.flatten() {
|
||
let path = entry.path();
|
||
if !path.is_dir() {
|
||
tracing::debug!(path = %path.display(), "Skipping non-directory entry");
|
||
continue;
|
||
}
|
||
found_dirs += 1;
|
||
let subagent_md = path.join("SUBAGENT.md");
|
||
tracing::debug!(dir = %path.display(), subagent_file = %subagent_md.display(), "Checking subagent directory");
|
||
if !subagent_md.exists() {
|
||
tracing::debug!(path = %subagent_md.display(), "SUBAGENT.md not found");
|
||
continue;
|
||
}
|
||
found_files += 1;
|
||
|
||
match parse_subagent_file(&subagent_md, source.clone()) {
|
||
Ok(def) => {
|
||
tracing::info!(name = %def.name, path = %subagent_md.display(), "Loaded subagent");
|
||
out.push(def);
|
||
}
|
||
Err(err) => {
|
||
tracing::warn!(path = %subagent_md.display(), error = %err, "Skipping invalid subagent file");
|
||
}
|
||
}
|
||
}
|
||
|
||
tracing::debug!(path = %root.display(), dirs = found_dirs, files = found_files, loaded = out.len(), "Subagents scan completed");
|
||
|
||
out
|
||
}
|
||
|
||
/// 解析子代理文件
|
||
fn parse_subagent_file(path: &Path, source: SubagentSource) -> Result<SubagentDef, String> {
|
||
let content = fs::read_to_string(path)
|
||
.map_err(|e| format!("failed to read file: {}", e))?;
|
||
|
||
let (frontmatter_raw, body) = split_frontmatter(&content)
|
||
.ok_or_else(|| "missing YAML frontmatter block".to_string())?;
|
||
|
||
let frontmatter: SubagentFrontmatter = serde_yaml::from_str(frontmatter_raw)
|
||
.map_err(|e| format!("invalid YAML frontmatter: {}", e))?;
|
||
|
||
if frontmatter.description.trim().is_empty() {
|
||
return Err("description is required and cannot be empty".to_string());
|
||
}
|
||
|
||
// name 可选,默认使用目录名
|
||
let dir_name = path
|
||
.parent()
|
||
.and_then(|p| p.file_name())
|
||
.map(|s| s.to_string_lossy().to_string())
|
||
.unwrap_or_else(|| "unknown-subagent".to_string());
|
||
|
||
let name = frontmatter.name.unwrap_or(dir_name).trim().to_string();
|
||
let prompt_template = frontmatter.prompt_template.unwrap_or_default().trim().to_string();
|
||
let body_content = body.trim().to_string();
|
||
|
||
Ok(SubagentDef {
|
||
name,
|
||
description: frontmatter.description.trim().to_string(),
|
||
prompt_template,
|
||
body: if body_content.is_empty() { None } else { Some(body_content) },
|
||
allowed_tools: frontmatter.allowed_tools,
|
||
max_execution_secs: frontmatter.max_execution_secs,
|
||
source,
|
||
path: Some(path.to_path_buf()),
|
||
})
|
||
}
|
||
|
||
/// 分割 frontmatter 和 body
|
||
fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
|
||
// 跳过开头的 ---
|
||
let content = content
|
||
.strip_prefix("---")
|
||
.or_else(|| content.strip_prefix("---"))?;
|
||
|
||
// 跳过 --- 后的换行符和可能的空行
|
||
let content = content.trim_start_matches('\r').trim_start_matches('\n');
|
||
|
||
// 找结束标记(容忍不同的换行符格式和前面的空行)
|
||
// 尝试多种可能的结束标记格式
|
||
let end_markers = ["\n---\n", "\n---", "\r\n---\r\n", "\r\n---"];
|
||
let mut idx = None;
|
||
let mut marker_len = 0;
|
||
for marker in end_markers {
|
||
if let Some(pos) = content.find(marker) {
|
||
idx = Some(pos);
|
||
marker_len = marker.len();
|
||
break;
|
||
}
|
||
}
|
||
let idx = idx?;
|
||
|
||
let frontmatter = &content[..idx];
|
||
let body = &content[idx + marker_len..];
|
||
let body = body.trim_start_matches('\r').trim_start_matches('\n');
|
||
|
||
Some((frontmatter, body))
|
||
}
|