PicoBot/src/tools/task/runtime.rs
oudecheng 6a496ce212 feat(task): 优化孙智能体任务消息的发送与工具注册
- 在发送任务消息时增加可选的任务仓库参数支持子任务重发
- 新增 extract_parent_task_id 函数用于提取孙智能体的父任务 ID
- 补发子任务(孙智能体)的 TaskStarted 事件,解决视图重进导致的 navigateToTaskId 丢失
- 判断并附加子任务的父任务 ID,完善日志记录与事件发送
- 在子智能体运行时根据深度排除 task 工具,防止无限嵌套调用
- ToolRegistry 新增 without 方法,可创建排除指定工具的新实例用于子智能体配置
2026-06-22 11:31:41 +08:00

1035 lines
36 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use crate::agent::{AgentLoop, AgentRuntimeConfig, EmittedMessageHandler, PersistingEmittedMessageHandler, SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::bus::ChatMessage;
use crate::bus::message::{OutboundMessage, OutboundEventKind};
use crate::bus::MessageBus;
use crate::providers::StreamDelta;
use crate::config::{LLMProviderConfig, SubagentsConfig};
use crate::storage::{ConversationRepository, SessionStore};
use crate::tools::{ToolContext, ToolRegistry};
use super::error::TaskError;
use super::prompt::{extract_summary, SubagentPromptBuilder};
use super::repository::TaskRepository;
use super::tool::TaskTool;
use super::types::{SubagentDef, SubagentSource, TaskDefinition, TaskSession, TaskToolResult};
/// 子代理运行时配置
#[derive(Debug, Clone)]
pub struct SubAgentRuntimeConfig {
/// 默认工具白名单(定义未指定时使用)
pub default_allowed_tools: HashSet<String>,
/// 默认最大执行时间(秒)
pub default_max_execution_secs: u64,
/// Explore 类型的最大执行时间(秒)
pub explore_max_execution_secs: u64,
/// 任务 TTL小时
pub ttl_hours: u64,
/// 技能索引(可选,预生成的技能列表字符串)
pub skills_index: Option<String>,
/// 子代理最大嵌套深度0 = 禁止嵌套1 = 允许 1 层孙代理)
pub max_nesting_depth: u32,
}
impl Default for SubAgentRuntimeConfig {
fn default() -> Self {
Self {
default_allowed_tools: HashSet::from([
"read".to_string(),
"edit".to_string(),
"write".to_string(),
"bash".to_string(),
"http_request".to_string(),
"web_fetch".to_string(),
"memory_search".to_string(),
"get_time".to_string(),
"calculator".to_string(),
"skill_activate".to_string(),
"skill_list".to_string(),
"send_session_message".to_string(), // 用于进度通知
]),
default_max_execution_secs: 3600, // 60分钟
explore_max_execution_secs: 3600, // 60分钟
ttl_hours: 24,
skills_index: None,
max_nesting_depth: 1,
}
}
}
/// 子代理运行时抽象接口
#[async_trait]
pub trait SubAgentRuntime: Send + Sync + 'static {
/// 创建并执行子代理任务
async fn spawn(
&self,
parent_context: &ToolContext,
task: TaskDefinition,
) -> Result<TaskToolResult, TaskError>;
/// 恢复现有任务
async fn resume(
&self,
task_id: &str,
parent_context: &ToolContext,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError>;
/// 发送消息给子代理(支持中断或补充指令)
async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>;
/// 清理过期任务
async fn cleanup_expired(&self) -> Result<usize, TaskError>;
/// 获取可用的子代理类型列表
fn available_subagent_names(&self) -> Vec<String>;
}
/// 静态系统提示词提供者(用于子代理)
pub struct StaticSystemPromptProvider {
prompt: String,
}
impl StaticSystemPromptProvider {
pub fn new(prompt: String) -> Self {
Self { prompt }
}
}
/// 子智能体工具调用实时广播器(不依赖 gateway 层)
struct SubAgentEmitter {
bus: Arc<MessageBus>,
channel_name: String,
chat_id: String,
metadata: HashMap<String, String>,
store: Arc<SessionStore>,
/// 子/孙智能体自身的 task_id用于持久化时作为 scope_key
task_id: String,
stream_message_id: std::sync::Mutex<Option<String>>,
}
#[async_trait]
impl EmittedMessageHandler for SubAgentEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None,
None,
&self.metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(
error = %error,
channel = %self.channel_name,
chat_id = %self.chat_id,
"Failed to publish live sub-agent tool call"
);
}
}
}
async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
let mut metadata = self.metadata.clone();
if let Some(ms) = duration_ms {
metadata.insert("tool_duration_ms".to_string(), ms.to_string());
}
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None,
None,
&metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(
error = %error,
channel = %self.channel_name,
chat_id = %self.chat_id,
"Failed to publish live sub-agent tool call"
);
}
}
// 拦截 todo_write 结果:持久化到 SQLite子代理用 task_id 作为 scope_key与 list_todos 保持一致)
if message.tool_name.as_deref() == Some("todo_write") {
self.persist_todo_write_result(&message);
}
}
async fn handle_stream_delta(&self, delta: &StreamDelta) {
let message_id = {
let mut guard = self.stream_message_id.lock().unwrap();
guard.get_or_insert_with(|| uuid::Uuid::new_v4().to_string()).clone()
};
let outbound = if delta.content.is_empty() && delta.reasoning_content.is_none() {
OutboundMessage::stream_end(
&self.channel_name,
&self.chat_id,
None,
&message_id,
self.metadata.clone(),
)
} else {
OutboundMessage::stream_delta(
&self.channel_name,
&self.chat_id,
None,
&message_id,
&delta.content,
delta.reasoning_content.clone(),
self.metadata.clone(),
)
};
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, channel = %self.channel_name, "Failed to publish sub-agent stream delta");
}
}
async fn set_stream_message_id(&self, id: &str) {
*self.stream_message_id.lock().unwrap() = Some(id.to_string());
}
}
impl SubAgentEmitter {
fn persist_todo_write_result(&self, message: &ChatMessage) {
let parsed: serde_json::Value = match serde_json::from_str(&message.content) {
Ok(v) => v,
Err(_) => return,
};
let Some(todos_array) = parsed.get("current_todos").and_then(|v| v.as_array()) else {
return;
};
let scope_key = &self.task_id;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
let records: Vec<crate::storage::TodoRecord> = todos_array
.iter()
.enumerate()
.filter_map(|(idx, item)| {
Some(crate::storage::TodoRecord {
id: item.get("id")?.as_str()?.to_string(),
scope_key: scope_key.clone(),
session_id: scope_key.clone(),
topic_id: None,
content: item.get("content")?.as_str()?.to_string(),
status: item.get("status")?.as_str()?.to_string(),
priority: "medium".to_string(),
created_at: now + idx as i64,
updated_at: now,
})
})
.collect();
if records.is_empty() {
return;
}
tracing::info!(
scope_key = %scope_key,
todo_count = records.len(),
"SubAgentEmitter: persisting todo_write result"
);
if let Err(e) = self.store.replace_todos(scope_key, &records) {
tracing::warn!(error = %e, %scope_key, "Failed to persist sub-agent todo list");
}
}
}
impl SystemPromptProvider for StaticSystemPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
Some(SystemPrompt {
content: self.prompt.clone(),
context: Some("subagent".to_string()),
})
}
}
/// 默认子代理运行时实现
pub struct DefaultSubAgentRuntime {
config: SubAgentRuntimeConfig,
task_repository: Arc<dyn TaskRepository>,
conversation_repository: Arc<dyn ConversationRepository>,
subagent_tools: Arc<ToolRegistry>,
provider_config: LLMProviderConfig,
/// 子代理定义目录(内置 + 自定义)
catalog: Arc<SubagentCatalog>,
bus: Option<Arc<MessageBus>>,
store: Arc<SessionStore>,
}
impl DefaultSubAgentRuntime {
pub fn new(
config: SubAgentRuntimeConfig,
task_repository: Arc<dyn TaskRepository>,
conversation_repository: Arc<dyn ConversationRepository>,
subagent_tools: Arc<ToolRegistry>,
provider_config: LLMProviderConfig,
catalog: Arc<SubagentCatalog>,
bus: Option<Arc<MessageBus>>,
store: Arc<SessionStore>,
) -> Self {
Self {
config,
task_repository,
conversation_repository,
subagent_tools,
provider_config,
catalog,
bus,
store,
}
}
/// 查找子代理定义,找不到时 fallback 到 general
fn find_subagent_def(&self, type_name: &str) -> SubagentDef {
self.catalog
.find(type_name)
.cloned()
.unwrap_or_else(|| self.catalog.find("general").expect("general subagent must exist").clone())
}
/// 获取实际使用的工具白名单(预留,未来可用于动态工具过滤)
#[allow(dead_code)]
fn effective_allowed_tools(&self, def: &SubagentDef) -> HashSet<String> {
def.allowed_tools
.as_ref()
.map(|tools| tools.iter().cloned().collect())
.unwrap_or_else(|| self.config.default_allowed_tools.clone())
}
/// 获取实际执行时间
fn effective_max_execution_secs(&self, def: &SubagentDef) -> u64 {
def.max_execution_secs
.unwrap_or(self.config.default_max_execution_secs)
}
/// 创建子代理实例
fn create_subagent(
&self,
session: &TaskSession,
system_prompt: String,
parent_nesting_depth: u32,
parent_task_id: Option<String>,
) -> Result<AgentLoop, TaskError> {
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
// 孙智能体depth >= 2不注册 task 工具,防止无限嵌套
let child_depth = parent_nesting_depth + 1;
let tools = if child_depth >= 2 {
Arc::new(self.subagent_tools.without(&[TaskTool::TOOL_NAME]))
} else {
self.subagent_tools.clone()
};
AgentLoop::with_tools_and_system_prompt_provider(
AgentRuntimeConfig::from(self.provider_config.clone()),
tools,
prompt_provider,
None, // 子代理不需要 skill provider
)
.map(|agent| {
let agent = agent.with_tool_context(ToolContext {
channel_name: Some(session.parent_channel_name.clone()),
sender_id: None,
chat_id: Some(session.parent_chat_id.clone()),
session_id: Some(session.session_id.clone()),
topic_id: session.parent_topic_id.clone(),
message_id: None,
message_seq: None,
subagent_description: Some(session.description.clone()),
nesting_depth: parent_nesting_depth + 1,
task_id: Some(session.id.clone()),
parent_task_id,
tool_call_id: None,
});
// 如果有 MessageBus附加实时广播 emitter
if let Some(bus) = &self.bus {
let mut metadata = HashMap::new();
metadata.insert("subagent_task_id".to_string(), session.id.clone());
metadata.insert("is_subagent_event".to_string(), "true".to_string());
metadata.insert("topic_id".to_string(), session.parent_topic_id.clone().unwrap_or_default());
let emitter = Arc::new(PersistingEmittedMessageHandler::new(
SubAgentEmitter {
bus: bus.clone(),
channel_name: session.parent_channel_name.clone(),
chat_id: session.parent_chat_id.clone(),
metadata,
store: self.store.clone(),
task_id: session.id.clone(),
stream_message_id: std::sync::Mutex::new(None),
},
self.conversation_repository.clone(),
session.session_id.clone(),
session.parent_topic_id.clone(),
));
return agent.with_emitted_message_handler(emitter);
}
agent
})
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
}
/// 执行任务(带超时控制)
async fn execute_task(
&self,
agent: AgentLoop,
session: &TaskSession,
def: &SubagentDef,
prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 构建初始消息
let history = vec![ChatMessage::user(prompt)];
let system_prompt_context = SystemPromptContext {
session_id: Some(session.session_id.clone()),
chat_id: session.session_id.clone(),
user_message_count: 1,
};
// 设置超时
let max_secs = if session.subagent_type == "explore" {
self.config.explore_max_execution_secs
} else {
self.effective_max_execution_secs(def)
};
let timeout_duration = Duration::from_secs(max_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.process(history, Some(&system_prompt_context)),
)
.await;
match result {
Ok(Ok(process_result)) => {
let final_message = process_result.final_response;
Ok(TaskToolResult {
status: "success".to_string(),
summary: extract_summary(&final_message.content),
output: final_message.content,
task_id: session.id.clone(),
})
}
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
Err(_) => Err(TaskError::Timeout),
}
}
/// 使用历史继续执行
async fn execute_task_with_history(
&self,
agent: AgentLoop,
session: &TaskSession,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 加载历史 + 新消息
let mut history = self
.conversation_repository
.load_messages(&session.session_id)
.map_err(TaskError::RepositoryError)?;
history.push(ChatMessage::user(additional_prompt));
let user_message_count = history.iter().filter(|m| m.role == "user").count();
let system_prompt_context = SystemPromptContext {
session_id: Some(session.session_id.clone()),
chat_id: session.session_id.clone(),
user_message_count,
};
// 使用默认执行时间(恢复任务时原始定义可能已不存在)
let timeout_duration = Duration::from_secs(self.config.default_max_execution_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.process(history, Some(&system_prompt_context)),
)
.await;
match result {
Ok(Ok(process_result)) => {
let final_message = process_result.final_response;
Ok(TaskToolResult {
status: "success".to_string(),
summary: extract_summary(&final_message.content),
output: final_message.content,
task_id: session.id.clone(),
})
}
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
Err(_) => Err(TaskError::Timeout),
}
}
}
#[async_trait]
impl SubAgentRuntime for DefaultSubAgentRuntime {
async fn spawn(
&self,
parent_context: &ToolContext,
task: TaskDefinition,
) -> Result<TaskToolResult, TaskError> {
// 1. 验证上下文
let session_id = parent_context
.session_id
.clone()
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
let chat_id = parent_context
.chat_id
.clone()
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
let channel_name = parent_context
.channel_name
.clone()
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
// 2. 查找子代理定义
let def = self.find_subagent_def(task.subagent_type.as_str());
// 3. 创建任务会话
let topic_id = parent_context.topic_id.clone();
let session = TaskSession::new(
session_id,
topic_id,
chat_id,
channel_name,
task.description.clone(),
task.subagent_type,
);
// 4. 在 sessions 表中创建子智能体会话(确保外键约束满足)
let session_title = format!("Subagent [{}]: {}", session.subagent_type, task.description);
if let Err(e) = self.conversation_repository.ensure_session(
&session.session_id,
&session.parent_channel_name,
&session.parent_chat_id,
&session_title,
) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session");
}
// 5. 保存任务会话
tracing::info!(
task_id = %session.id,
session_id = %session.session_id,
description = %session.description,
subagent_type = %session.subagent_type,
"Spawning sub-agent task"
);
self.task_repository.save_task_session(&session).await?;
// 5.1 立即通知前端 task_id让前端可以显示"查看实时进度"按钮)
if let Some(bus) = &self.bus {
let mut metadata = HashMap::new();
metadata.insert("task_id".to_string(), session.id.clone());
metadata.insert("task_description".to_string(), session.description.clone());
metadata.insert("task_subagent_type".to_string(), session.subagent_type.clone());
metadata.insert("topic_id".to_string(), session.parent_topic_id.clone().unwrap_or_default());
// 如果是子智能体创建的孙智能体,传递父 task_id
if let Some(ref ptid) = parent_context.task_id {
metadata.insert("parent_task_id".to_string(), ptid.clone());
}
// 传递 tool_call_id前端据此精确匹配创建此任务的 tool_call
if let Some(ref tcid) = parent_context.tool_call_id {
metadata.insert("tool_call_id".to_string(), tcid.clone());
}
let event = OutboundMessage {
channel: session.parent_channel_name.clone(),
chat_id: session.parent_chat_id.clone(),
session_id: Some(session.parent_session_id.clone()),
content: String::new(),
reply_to: None,
media: Vec::new(),
metadata,
event_kind: OutboundEventKind::TaskStarted,
role: "system".to_string(),
tool_call_id: None,
tool_name: None,
tool_arguments: None,
reasoning_content: None,
message_id: None,
};
if let Err(e) = bus.publish_outbound(event).await {
tracing::warn!(error = %e, task_id = %session.id, "Failed to publish TaskStarted event");
}
}
// 6. 构建子代理系统提示词
let system_prompt = SubagentPromptBuilder::build(
&def,
&task.description,
&task.prompt,
&self.provider_config,
self.config.skills_index.as_deref(),
);
// 7. 创建子代理
let agent = self.create_subagent(&session, system_prompt, parent_context.nesting_depth, parent_context.task_id.clone())?;
// 8. 执行任务
let result = self
.execute_task(agent, &session, &def, task.prompt.clone())
.await;
// 9. 更新会话状态并保存
match result {
Ok(tool_result) => {
let mut session = session;
session.mark_completed(tool_result.summary.clone());
tracing::info!(
task_id = %session.id,
session_id = %session.session_id,
"Task completed, updating session"
);
self.task_repository.save_task_session(&session).await?;
Ok(tool_result)
}
Err(e) => {
let mut session = session;
let status = e.as_status();
tracing::warn!(
task_id = %session.id,
session_id = %session.session_id,
status = %status,
error = %e,
"Task failed, updating session"
);
if status == "timeout" {
session.mark_timeout();
} else {
session.mark_failed(e.to_string());
}
self.task_repository.save_task_session(&session).await?;
Err(e)
}
}
}
async fn resume(
&self,
task_id: &str,
parent_context: &ToolContext,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 1. 加载现有会话
let session = self
.task_repository
.load_task_session(task_id)
.await?
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
// 2. 验证父会话匹配
let parent_session_id = parent_context
.session_id
.clone()
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
if session.parent_session_id != parent_session_id {
return Err(TaskError::InvalidParentSession);
}
// 3. 确保 sessions 表中存在子智能体会话记录
let session_title = format!("Subagent [{}]: {}", session.subagent_type, session.description);
if let Err(e) = self.conversation_repository.ensure_session(
&session.session_id,
&session.parent_channel_name,
&session.parent_chat_id,
&session_title,
) {
tracing::warn!(error = %e, session_id = %session.session_id, "Failed to ensure subagent session on resume");
}
// 4. 构建恢复提示词
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
&session.description,
&additional_prompt,
);
// 5. 创建子代理
let agent = self.create_subagent(&session, system_prompt, parent_context.nesting_depth, parent_context.task_id.clone())?;
// 6. 使用历史继续执行
let result = self
.execute_task_with_history(agent, &session, additional_prompt)
.await;
// 7. 更新会话状态
match result {
Ok(tool_result) => {
let mut session = session;
session.mark_completed(tool_result.summary.clone());
self.task_repository.save_task_session(&session).await?;
Ok(tool_result)
}
Err(e) => {
let mut session = session;
session.mark_failed(e.to_string());
self.task_repository.save_task_session(&session).await?;
Err(e)
}
}
}
async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> {
// TODO: 实现双向通信
// 需要在 TaskSession 中添加 pending_messages 队列
Err(TaskError::InvalidArguments("send_message not implemented yet".to_string()))
}
async fn cleanup_expired(&self) -> Result<usize, TaskError> {
self.task_repository
.cleanup_expired_tasks(self.config.ttl_hours)
.await
.map_err(TaskError::from)
}
fn available_subagent_names(&self) -> Vec<String> {
self.catalog.names()
}
}
/// 子代理定义目录
///
/// 管理所有可用的子代理定义,包括内置和自定义。
/// 支持用户级(~/.picobot/subagents/)和项目级(./.picobot/subagents/)定义,
/// 项目级定义会覆盖同名的用户级定义。
#[derive(Debug, Default)]
pub struct SubagentCatalog {
definitions: std::collections::HashMap<String, SubagentDef>,
}
impl SubagentCatalog {
/// 创建空的目录,并注册内置子代理
pub fn new() -> Self {
let mut catalog = Self::default();
catalog.register(SubagentDef::builtin_general());
catalog.register(SubagentDef::builtin_explore());
catalog
}
/// 从配置发现子代理(内置 + 文件系统自定义)
///
/// 发现顺序:先内置,后按 sources 配置顺序扫描目录
/// 后发现的同名定义会覆盖先发现的(项目覆盖用户)
pub fn discover(config: &SubagentsConfig) -> Self {
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
Self::discover_with_cwd(config, &cwd)
}
fn discover_with_cwd(config: &SubagentsConfig, cwd: &Path) -> Self {
// 先内置作为基础
let mut merged: std::collections::HashMap<String, SubagentDef> = std::collections::HashMap::new();
merged.insert("general".to_string(), SubagentDef::builtin_general());
merged.insert("explore".to_string(), SubagentDef::builtin_explore());
tracing::debug!(cwd = %cwd.display(), "Discovering subagents from cwd");
// 按配置顺序扫描源目录
if config.enabled {
for source in source_order(&config.sources) {
let root = source_root(&source, cwd);
tracing::debug!(source = ?source, root = ?root.as_ref().map(|p| p.display().to_string()), "Checking subagent source");
if let Some(root) = root {
if root.exists() {
tracing::info!(path = %root.display(), "Scanning subagents directory");
} else {
tracing::debug!(path = %root.display(), "Subagents directory does not exist, skipping");
}
for def in load_subagents_from_root(&root, source.clone()) {
if let Some(existing) = merged.get(&def.name) {
tracing::warn!(
subagent = %def.name,
old_source = ?existing.source,
new_source = ?def.source,
"Duplicate subagent name found; overriding with later source"
);
}
merged.insert(def.name.clone(), def);
}
}
}
} else {
tracing::debug!("Subagents discovery is disabled");
}
// 构建 catalog
let mut catalog = Self::default();
for def in merged.into_values() {
catalog.register(def);
}
tracing::info!(
discovered = catalog.definitions.len(),
"Subagents discovery completed"
);
catalog
}
/// 注册一个子代理定义(同名覆盖)
pub fn register(&mut self, def: SubagentDef) {
self.definitions.insert(def.name.clone(), def);
}
/// 查找子代理定义
pub fn find(&self, name: &str) -> Option<&SubagentDef> {
self.definitions.get(name)
}
/// 获取所有可用的子代理名称
pub fn names(&self) -> Vec<String> {
self.definitions.keys().cloned().collect()
}
/// 获取所有可用的子代理定义(用于生成索引提示)
pub fn all(&self) -> Vec<&SubagentDef> {
self.definitions.values().collect()
}
/// 生成系统索引提示词(用于注入主 agent
pub fn system_index_prompt(&self) -> Option<String> {
let defs = self.all();
if defs.is_empty() {
return None;
}
let mut prompt = String::from(
"# 子代理系统\n\n\
子代理是专用的执行单元,用于处理特定类型的任务。\n\
创建子代理任务时,可以选择以下类型之一:\n\n\
<available_subagents>\n"
);
for def in defs {
prompt.push_str(&format!(
" <subagent>\n <name>{}</name>\n <description>{}</description>\n </subagent>\n",
xml_escape(&def.name),
xml_escape(&def.description),
));
}
prompt.push_str("</available_subagents>");
Some(prompt)
}
}
fn xml_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&apos;")
}
// ========== 自定义子代理发现 ==========
/// 源顺序解析
fn source_order(sources: &[String]) -> Vec<SubagentSource> {
let mut result = Vec::new();
for source in sources {
match source.as_str() {
"user" => {
if !result.contains(&SubagentSource::User) {
result.push(SubagentSource::User);
}
}
"project" => {
if !result.contains(&SubagentSource::Project) {
result.push(SubagentSource::Project);
}
}
unknown => {
let custom = SubagentSource::Custom(unknown.to_string());
if !result.contains(&custom) {
result.push(custom);
}
}
}
}
// 默认顺序:先 user 后 project项目覆盖用户
if result.is_empty() {
vec![SubagentSource::User, SubagentSource::Project]
} else {
result
}
}
/// 获取源目录根路径
fn source_root(source: &SubagentSource, cwd: &Path) -> Option<std::path::PathBuf> {
match source {
SubagentSource::User => dirs::home_dir().map(|p| p.join(".picobot").join("subagents")),
SubagentSource::Project => Some(cwd.join(".picobot").join("subagents")),
SubagentSource::Builtin => None,
SubagentSource::Custom(path) => {
let p = std::path::PathBuf::from(path);
if p.is_absolute() {
Some(p)
} else {
tracing::warn!(path = %path, "Custom subagents source must be an absolute path, skipping");
None
}
}
}
}
/// 子代理 frontmatter 结构
#[derive(Debug, Deserialize)]
struct SubagentFrontmatter {
#[serde(default)]
name: Option<String>,
description: String,
#[serde(default)]
prompt_template: Option<String>,
#[serde(default)]
allowed_tools: Option<Vec<String>>,
#[serde(default)]
max_execution_secs: Option<u64>,
}
/// 从根目录加载所有子代理
fn load_subagents_from_root(root: &Path, source: SubagentSource) -> Vec<SubagentDef> {
let mut out = Vec::new();
if !root.exists() {
tracing::debug!(path = %root.display(), "Subagents root directory does not exist");
return out;
}
tracing::debug!(path = %root.display(), "Reading subagents directory");
let entries = match fs::read_dir(root) {
Ok(entries) => entries,
Err(err) => {
tracing::warn!(path = %root.display(), error = %err, "Failed to read subagents directory");
return out;
}
};
let mut found_dirs = 0;
let mut found_files = 0;
for entry in entries.flatten() {
let path = entry.path();
if !path.is_dir() {
tracing::debug!(path = %path.display(), "Skipping non-directory entry");
continue;
}
found_dirs += 1;
let subagent_md = path.join("SUBAGENT.md");
tracing::debug!(dir = %path.display(), subagent_file = %subagent_md.display(), "Checking subagent directory");
if !subagent_md.exists() {
tracing::debug!(path = %subagent_md.display(), "SUBAGENT.md not found");
continue;
}
found_files += 1;
match parse_subagent_file(&subagent_md, source.clone()) {
Ok(def) => {
tracing::info!(name = %def.name, path = %subagent_md.display(), "Loaded subagent");
out.push(def);
}
Err(err) => {
tracing::warn!(path = %subagent_md.display(), error = %err, "Skipping invalid subagent file");
}
}
}
tracing::debug!(path = %root.display(), dirs = found_dirs, files = found_files, loaded = out.len(), "Subagents scan completed");
out
}
/// 解析子代理文件
fn parse_subagent_file(path: &Path, source: SubagentSource) -> Result<SubagentDef, String> {
let content = fs::read_to_string(path)
.map_err(|e| format!("failed to read file: {}", e))?;
let (frontmatter_raw, body) = split_frontmatter(&content)
.ok_or_else(|| "missing YAML frontmatter block".to_string())?;
let frontmatter: SubagentFrontmatter = serde_yaml::from_str(frontmatter_raw)
.map_err(|e| format!("invalid YAML frontmatter: {}", e))?;
if frontmatter.description.trim().is_empty() {
return Err("description is required and cannot be empty".to_string());
}
// name 可选,默认使用目录名
let dir_name = path
.parent()
.and_then(|p| p.file_name())
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown-subagent".to_string());
let name = frontmatter.name.unwrap_or(dir_name).trim().to_string();
let prompt_template = frontmatter.prompt_template.unwrap_or_default().trim().to_string();
let body_content = body.trim().to_string();
Ok(SubagentDef {
name,
description: frontmatter.description.trim().to_string(),
prompt_template,
body: if body_content.is_empty() { None } else { Some(body_content) },
allowed_tools: frontmatter.allowed_tools,
max_execution_secs: frontmatter.max_execution_secs,
source,
path: Some(path.to_path_buf()),
})
}
/// 分割 frontmatter 和 body
fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
// 跳过开头的 ---
let content = content
.strip_prefix("---")
.or_else(|| content.strip_prefix("---"))?;
// 跳过 --- 后的换行符和可能的空行
let content = content.trim_start_matches('\r').trim_start_matches('\n');
// 找结束标记(容忍不同的换行符格式和前面的空行)
// 尝试多种可能的结束标记格式
let end_markers = ["\n---\n", "\n---", "\r\n---\r\n", "\r\n---"];
let mut idx = None;
let mut marker_len = 0;
for marker in end_markers {
if let Some(pos) = content.find(marker) {
idx = Some(pos);
marker_len = marker.len();
break;
}
}
let idx = idx?;
let frontmatter = &content[..idx];
let body = &content[idx + marker_len..];
let body = body.trim_start_matches('\r').trim_start_matches('\n');
Some((frontmatter, body))
}