增加sub-agent机制。

This commit is contained in:
xiaoski 2026-05-25 23:23:10 +08:00
parent a77c026826
commit 41b4895ff0
12 changed files with 1357 additions and 20 deletions

View File

@ -12,6 +12,8 @@ serde_json = "1.0"
async-trait = "0.1" async-trait = "0.1"
thiserror = "2.0.18" thiserror = "2.0.18"
tokio = { version = "1.52", features = ["full"] } tokio = { version = "1.52", features = ["full"] }
tokio-util = { version = "0.7", features = ["rt"] }
dashmap = "6.1"
uuid = { version = "1.23", features = ["v4"] } uuid = { version = "1.23", features = ["v4"] }
axum = { version = "0.8", features = ["ws"] } axum = { version = "0.8", features = ["ws"] }
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] } tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }

View File

@ -1,8 +1,10 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor; pub mod context_compressor;
pub mod media_handler; pub mod media_handler;
pub mod sub_agent;
pub mod system_prompt; pub mod system_prompt;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use context_compressor::{ContextCompressor, estimate_tokens}; pub use context_compressor::{ContextCompressor, estimate_tokens};
pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus};
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder}; pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};

611
src/agent/sub_agent.rs Normal file
View File

@ -0,0 +1,611 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::agent::AgentLoop;
use crate::agent::AgentError;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider};
use crate::tools::ToolRegistry;
tokio::task_local! {
pub(crate) static DELEGATE_CONTEXT: DelegateContext;
}
/// Read the delegate context from the current task. Returns an error if not set.
pub fn get_delegate_context() -> Result<DelegateContext, String> {
DELEGATE_CONTEXT.try_with(|ctx| ctx.clone())
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
}
const DEFAULT_MAX_ITERATIONS: usize = 99;
const DEFAULT_TIMEOUT_SECS: u64 = 3600;
const MAX_INLINE_RESULT_CHARS: usize = 8000;
const DEFAULT_READONLY_TOOLS: &[&str] = &[
"file_read",
"file_search",
"content_search",
"web_fetch",
"http_request",
"calculator",
];
#[derive(Debug, Clone)]
pub struct SubAgentConfig {
pub prompt: String,
pub mode: ExecutionMode,
pub allowed_tools: Option<Vec<String>>,
pub max_iterations: Option<usize>,
pub timeout_secs: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutionMode {
Inline,
Background,
Parallel,
}
#[derive(Debug, Clone)]
pub struct SubAgentResult {
pub task_id: String,
pub content: String,
pub content_truncated: bool,
pub status: TaskStatus,
pub tool_calls_count: usize,
pub iterations: usize,
pub duration_ms: u64,
}
#[derive(Debug, Clone)]
pub enum TaskStatus {
Completed,
Failed(String),
Cancelled,
TimedOut,
}
#[derive(Debug, Clone)]
pub struct TaskNotification {
pub task_id: String,
pub session_id: String,
pub channel: String,
pub chat_id: String,
pub status: TaskStatus,
pub result_summary: String,
}
#[derive(Debug, Clone)]
pub struct DelegateContext {
pub session_id: String,
pub channel: String,
pub chat_id: String,
}
#[derive(Debug)]
pub enum SubAgentError {
TooManyTasks(usize),
ProviderCreation(String),
Storage(String),
Other(String),
}
impl std::fmt::Display for SubAgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max),
Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e),
Self::Storage(e) => write!(f, "storage error: {}", e),
Self::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for SubAgentError {}
pub struct SubAgentManager {
provider_config: LLMProviderConfig,
full_tools: Arc<ToolRegistry>,
storage: Option<Arc<crate::storage::Storage>>,
active_tasks: Arc<DashMap<String, CancellationToken>>,
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
max_concurrent_background_tasks: usize,
}
impl SubAgentManager {
pub fn new(
provider_config: LLMProviderConfig,
full_tools: Arc<ToolRegistry>,
storage: Option<Arc<crate::storage::Storage>>,
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
max_concurrent_background_tasks: usize,
) -> Self {
Self {
provider_config,
full_tools,
storage,
active_tasks: Arc::new(DashMap::new()),
notify_tx,
max_concurrent_background_tasks,
}
}
pub fn filter_tools(&self, allowed: &Option<Vec<String>>) -> Arc<ToolRegistry> {
let allowed_set: HashSet<&str> = match allowed {
Some(list) => list.iter().map(|s| s.as_str()).collect(),
None => DEFAULT_READONLY_TOOLS.iter().copied().collect(),
};
let filtered = ToolRegistry::new();
for (name, tool) in self.full_tools.iter() {
if allowed_set.contains(name.as_str()) && name != "delegate" {
filtered.register_raw(name, tool);
}
}
Arc::new(filtered)
}
pub fn build_system_prompt(&self, config: &SubAgentConfig, tools: &ToolRegistry) -> String {
let timeout_human = format_duration(config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS));
let tool_descriptions = tools.describe_for_prompt();
let http_only_note = if config.allowed_tools.is_none()
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"))
{
"- When using http_request, only the GET method is permitted. \
Do NOT use POST, PUT, DELETE, or any other method."
} else {
""
};
format!(
"You are a sub-agent working on a delegated task. Complete the task below \
and return a single, self-contained result.\n\
\n\
## Task\n\
{task}\n\
\n\
## Rules\n\
- Focus ONLY on the task above. Do not explore unrelated topics.\n\
- Use tools only when necessary for the task.\n\
- Do NOT use the delegate tool sub-agent recursion is forbidden.\n\
- If the task cannot be completed, explain why clearly.\n\
- Return only the final result. Do not describe your process.\n\
{http_only}\n\
- Timeout: {timeout_human}. If approaching the limit, return partial results.\n\
\n\
## Available Tools\n\
{tool_descriptions}\n\
\n\
## Workspace\n\
{workspace}",
task = config.prompt,
http_only = http_only_note,
timeout_human = timeout_human,
tool_descriptions = tool_descriptions,
workspace = self.provider_config.workspace_dir.display(),
)
}
pub fn build_sub_agent(
&self,
config: &SubAgentConfig,
tools: Arc<ToolRegistry>,
) -> Result<AgentLoop, AgentError> {
let mut provider = create_provider(self.provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
if let Some(ref s) = self.storage {
provider.set_storage(s.clone());
}
let provider: Arc<dyn LLMProvider> = Arc::from(provider);
let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
let workspace_dir = self.provider_config.workspace_dir.clone();
let model_name = self.provider_config.model_id.clone();
let input_types = self.provider_config.input_types.clone();
let agent = AgentLoop::with_provider_and_tools(
provider,
tools,
max_iterations,
model_name,
workspace_dir,
input_types,
)
.with_context_window(self.provider_config.token_limit);
Ok(agent)
}
pub async fn run_inline(
&self,
config: SubAgentConfig,
) -> Result<SubAgentResult, SubAgentError> {
let task_id = generate_task_id();
let tools = self.filter_tools(&config.allowed_tools);
let system_prompt = self.build_system_prompt(&config, &tools);
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let agent = self.build_sub_agent(&config, tools)
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
let history = vec![
ChatMessage::system(system_prompt),
ChatMessage::user(&config.prompt),
];
let start = Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
agent.process(history),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(agent_result)) => {
let (content, truncated) =
truncate_sub_agent_result(&agent_result.final_response.content);
let tool_calls_count = agent_result.emitted_messages.iter()
.filter(|m| m.tool_calls.is_some())
.count();
let iterations = agent_result.emitted_messages.iter()
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
.count();
Ok(SubAgentResult {
task_id,
content,
content_truncated: truncated,
status: TaskStatus::Completed,
tool_calls_count,
iterations,
duration_ms,
})
}
Ok(Err(e)) => Ok(SubAgentResult {
task_id,
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed(e.to_string()),
tool_calls_count: 0,
iterations: 0,
duration_ms,
}),
Err(_elapsed) => Ok(SubAgentResult {
task_id,
content: String::new(),
content_truncated: false,
status: TaskStatus::TimedOut,
tool_calls_count: 0,
iterations: 0,
duration_ms,
}),
}
}
pub async fn run_parallel(
&self,
configs: Vec<SubAgentConfig>,
) -> Result<Vec<SubAgentResult>, SubAgentError> {
let futures: Vec<_> = configs
.into_iter()
.map(|config| {
let mgr = self; // &self borrow, all tasks share the same manager
async move { mgr.run_inline(config).await }
})
.collect();
let results = futures_util::future::join_all(futures).await;
Ok(results.into_iter().collect::<Result<Vec<_>, _>>()?)
}
pub async fn run_background(
&self,
config: SubAgentConfig,
ctx: DelegateContext,
) -> Result<String, SubAgentError> {
if self.active_tasks.len() >= self.max_concurrent_background_tasks {
return Err(SubAgentError::TooManyTasks(
self.max_concurrent_background_tasks,
));
}
let task_id = generate_task_id();
let cancel_token = CancellationToken::new();
self.active_tasks
.insert(task_id.clone(), cancel_token.clone());
// Write DB: pending
if let Some(ref storage) = self.storage {
let allowed_tools_json = config
.allowed_tools
.as_ref()
.and_then(|v| serde_json::to_string(v).ok());
let record = crate::storage::BackgroundTask {
id: task_id.clone(),
session_id: ctx.session_id.clone(),
channel: ctx.channel.clone(),
chat_id: ctx.chat_id.clone(),
prompt: config.prompt.clone(),
allowed_tools: allowed_tools_json,
status: "pending".to_string(),
result: None,
error: None,
tool_calls_count: 0,
iterations: 0,
started_at: None,
finished_at: None,
created_at: chrono::Utc::now().timestamp_millis(),
};
storage
.create_background_task(&record)
.await
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
}
let tools = self.filter_tools(&config.allowed_tools);
let system_prompt = self.build_system_prompt(&config, &tools);
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let provider_config = self.provider_config.clone();
let storage = self.storage.clone();
let notify_tx = self.notify_tx.clone();
let active_tasks = Arc::clone(&self.active_tasks);
let tid = task_id.clone();
let sess_id = ctx.session_id.clone();
let ch = ctx.channel.clone();
let cid = ctx.chat_id.clone();
let prompt = config.prompt.clone();
tokio::spawn(async move {
let started_at = chrono::Utc::now().timestamp_millis();
// Update DB: running
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid, "running", None, None,
Some(started_at), None,
)
.await;
}
let mut provider = create_provider(provider_config.clone()).ok();
if let Some(ref mut p) = provider {
if let Some(ref s) = storage {
p.set_storage(s.clone());
}
}
let provider_result: Option<Arc<dyn LLMProvider>> =
provider.map(|p| Arc::from(p));
let result = match provider_result {
Some(provider) => {
let agent = AgentLoop::with_provider_and_tools(
provider,
tools,
DEFAULT_MAX_ITERATIONS,
provider_config.model_id.clone(),
provider_config.workspace_dir.clone(),
provider_config.input_types.clone(),
)
.with_context_window(provider_config.token_limit);
let history = vec![
ChatMessage::system(system_prompt),
ChatMessage::user(&prompt),
];
tokio::select! {
r = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
agent.process(history),
) => {
match r {
Ok(Ok(agent_result)) => SubAgentResult {
task_id: tid.clone(),
content: agent_result.final_response.content,
content_truncated: false,
status: TaskStatus::Completed,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
Ok(Err(e)) => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed(e.to_string()),
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
Err(_) => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::TimedOut,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
}
}
_ = cancel_token.cancelled() => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Cancelled,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
}
}
None => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed("provider creation failed".into()),
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
};
let finished_at = chrono::Utc::now().timestamp_millis();
let duration_ms = (finished_at - started_at) as u64;
let (status_str, error_val) = match &result.status {
TaskStatus::Completed => ("completed".to_string(), None),
TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())),
TaskStatus::Cancelled => ("cancelled".to_string(), None),
TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())),
};
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid, &status_str,
Some(&result.content), error_val.as_deref(),
Some(started_at), Some(finished_at),
)
.await;
}
let _ = notify_tx.send(TaskNotification {
task_id: tid.clone(),
session_id: sess_id,
channel: ch,
chat_id: cid,
status: result.status,
result_summary: summarize_for_notification(&result.content, duration_ms),
});
active_tasks.remove(&tid);
});
Ok(task_id)
}
pub async fn cancel_task(&self, task_id: &str) -> Result<bool, SubAgentError> {
if let Some((_, token)) = self.active_tasks.remove(task_id) {
token.cancel();
if let Some(ref s) = self.storage {
s.update_background_task_status(
task_id,
"cancelled",
None,
None,
None,
Some(chrono::Utc::now().timestamp_millis()),
)
.await
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
}
Ok(true)
} else if let Some(ref s) = self.storage {
match s.get_background_task(task_id).await {
Ok(task) => {
match task.status.as_str() {
"pending" | "running" => {
tracing::warn!(task_id, "task in DB but not in active_tasks");
Ok(false)
}
_ => Ok(false),
}
}
Err(_) => Ok(false),
}
} else {
Ok(false)
}
}
pub async fn check_task(
&self,
task_id: &str,
) -> Option<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.get_background_task(task_id).await.ok()
} else {
None
}
}
pub async fn list_tasks(
&self,
session_id: &str,
) -> Vec<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.list_background_tasks(session_id).await.unwrap_or_default()
} else {
vec![]
}
}
pub async fn cancel_by_session(&self, session_id: &str) {
// Cancel all running tasks for a session by checking DB
if let Some(ref s) = self.storage {
if let Ok(tasks) = s.list_background_tasks(session_id).await {
for task in &tasks {
if task.status == "pending" || task.status == "running" {
let _ = self.cancel_task(&task.id).await;
}
}
}
}
}
pub fn active_task_count(&self) -> usize {
self.active_tasks.len()
}
}
fn generate_task_id() -> String {
Uuid::new_v4().to_string()[..8].to_string()
}
fn format_duration(seconds: u64) -> String {
if seconds < 60 {
format!("{}s", seconds)
} else if seconds < 3600 {
format!("{}m", seconds / 60)
} else {
format!("{}h", seconds / 3600)
}
}
fn truncate_sub_agent_result(content: &str) -> (String, bool) {
if content.len() <= MAX_INLINE_RESULT_CHARS {
(content.to_string(), false)
} else {
let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS);
(
format!(
"{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]",
&content[..truncate_at],
content.len()
),
true,
)
}
}
fn summarize_for_notification(content: &str, _duration_ms: u64) -> String {
const MAX_SUMMARY_BYTES: usize = 500;
if content.len() <= MAX_SUMMARY_BYTES {
content.to_string()
} else {
let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES);
format!("{}...", &content[..truncate_at])
}
}

View File

@ -3,11 +3,7 @@
//! This module provides a modular framework for building system prompts //! This module provides a modular framework for building system prompts
//! using the SystemPromptBuilder pattern. //! using the SystemPromptBuilder pattern.
//! //!
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic //! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation
//!
//! Configuration files loaded from ~/.picobot/:
//! - AGENTS.md — agent identity and behavior
//! - USER.md — user preferences and profile
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use std::path::Path; use std::path::Path;
@ -55,6 +51,7 @@ impl SystemPromptBuilder {
Box::new(CrossChannelSection), Box::new(CrossChannelSection),
Box::new(MemorySection), Box::new(MemorySection),
Box::new(HistorySection), Box::new(HistorySection),
Box::new(DelegationSection),
], ],
} }
} }
@ -353,6 +350,45 @@ impl PromptSection for HistorySection {
} }
} }
/// Sub-agent delegation principles.
pub struct DelegationSection;
impl PromptSection for DelegationSection {
fn name(&self) -> &str {
"delegation"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 子 Agent 委托原则\n\n\
使 delegate Agent\n\
\n\
### \n\
- 使 mode=\"parallel\"\n\
- 使 mode=\"background\"\n\
- \n\
\n\
### \n\
- **** Agent \n\
- **** file_readfile_searchweb_fetch bashfile_writefile_edit\n\
- **** delegate Agent\n\
- **** Agent \n\
\n\
### \n\
- prompt \n\
- prompt \"跳过 .tmp 文件\"\n\
- \n\
\n\
### \n\
- 使 mode=\"parallel\",任务定义在 tasks 数组中\n\
- \n\
- 5 \n\
\n\
### \n\
- 30s 使 mode=\"background\"\n\
- ".to_string()
}
}
// === Helper Functions === // === Helper Functions ===
/// Get user config directory (~/.picobot/). /// Get user config directory (~/.picobot/).

View File

@ -152,10 +152,26 @@ pub struct GatewayConfig {
pub cleanup_interval_minutes: Option<u64>, pub cleanup_interval_minutes: Option<u64>,
#[serde(default, rename = "session_db_path")] #[serde(default, rename = "session_db_path")]
pub session_db_path: Option<String>, pub session_db_path: Option<String>,
#[serde(default, rename = "max_concurrent_background_tasks")]
pub max_concurrent_background_tasks: usize,
#[serde(default)] #[serde(default)]
pub scheduler: Option<SchedulerConfig>, pub scheduler: Option<SchedulerConfig>,
} }
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
max_concurrent_background_tasks: 10,
scheduler: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig { pub struct SchedulerConfig {
/// Whether the scheduler is enabled /// Whether the scheduler is enabled
@ -209,19 +225,6 @@ fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string() "ws://127.0.0.1:19876/ws".to_string()
} }
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
scheduler: None,
}
}
}
impl Default for ClientConfig { impl Default for ClientConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {

View File

@ -91,6 +91,7 @@ impl GatewayState {
bus.clone(), bus.clone(),
memory_manager, memory_manager,
browser_config, browser_config,
config.gateway.max_concurrent_background_tasks,
)?; )?;
let session_manager = Arc::new(session_manager); let session_manager = Arc::new(session_manager);

View File

@ -751,6 +751,7 @@ pub struct SessionManager {
storage: Arc<Storage>, storage: Arc<Storage>,
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
memory_manager: Arc<crate::memory::MemoryManager>, memory_manager: Arc<crate::memory::MemoryManager>,
sub_agent_manager: Arc<crate::agent::SubAgentManager>,
} }
struct SessionManagerInner { struct SessionManagerInner {
@ -847,6 +848,7 @@ impl SessionManager {
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
memory_manager: Arc<crate::memory::MemoryManager>, memory_manager: Arc<crate::memory::MemoryManager>,
browser_config: Option<BrowserConfig>, browser_config: Option<BrowserConfig>,
max_concurrent_background_tasks: usize,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let mut skills_loader = SkillsLoader::new(); let mut skills_loader = SkillsLoader::new();
skills_loader.load_skills(); skills_loader.load_skills();
@ -856,9 +858,61 @@ impl SessionManager {
let tools = Arc::new(create_default_tools( let tools = Arc::new(create_default_tools(
skills_loader.clone(), skills_loader.clone(),
memory_manager.clone(), memory_manager.clone(),
None, // SubAgentManager created below
browser_config.as_ref(), browser_config.as_ref(),
)); ));
// Create SubAgentManager and register DelegateTool
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
let sub_agent_manager = Arc::new(crate::agent::SubAgentManager::new(
provider_config.clone(),
tools.clone(),
Some(storage.clone()),
notify_tx,
max_concurrent_background_tasks,
));
tools.register(crate::tools::DelegateTool::new(sub_agent_manager.clone()));
// Start background task notification consumer
let sm_bus = bus.clone();
tokio::spawn(async move {
while let Some(notif) = notify_rx.recv().await {
let content = format_task_notification(
&notif.task_id,
&notif.status,
&notif.result_summary,
);
let outbound = OutboundMessage {
channel: notif.channel,
chat_id: notif.chat_id,
content,
reply_to: None,
media: vec![],
metadata: std::collections::HashMap::new(),
};
let _ = sm_bus.publish_outbound(outbound).await;
}
});
// Start periodic background task cleanup (every hour, TTL 24h)
let cleanup_storage = storage.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600));
interval.tick().await; // skip immediate first tick
loop {
interval.tick().await;
match cleanup_storage.cleanup_old_tasks(86_400_000).await {
Ok(count) if count > 0 => {
tracing::info!(count, "Cleaned up old background tasks");
}
Err(e) => {
tracing::warn!(error = %e, "Failed to clean up old background tasks");
}
_ => {}
}
}
});
Ok(Self { Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner { inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(), sessions: HashMap::new(),
@ -870,6 +924,7 @@ impl SessionManager {
storage, storage,
bus, bus,
memory_manager, memory_manager,
sub_agent_manager,
}) })
} }
@ -1073,6 +1128,8 @@ impl SessionManager {
msgs.push("消息队列已清空。".to_string()); msgs.push("消息队列已清空。".to_string());
} }
guard.worker_generation = guard.worker_generation.wrapping_add(1); guard.worker_generation = guard.worker_generation.wrapping_add(1);
// Cancel all running background sub-agent tasks for this session
self.sub_agent_manager.cancel_by_session(&sid.to_string()).await;
let resp = if msgs.is_empty() { let resp = if msgs.is_empty() {
"没有正在执行的任务或队列。".to_string() "没有正在执行的任务或队列。".to_string()
} else { } else {
@ -1469,7 +1526,8 @@ fn spawn_agent_worker(
unified_str: String, unified_str: String,
) { ) {
tokio::spawn(async move { tokio::spawn(async move {
let _scope = CURRENT_SOURCE_SESSION.scope(Some(unified_str), async { let unified_for_source = unified_str.clone();
let _scope = CURRENT_SOURCE_SESSION.scope(Some(unified_for_source), async {
while let Some(task) = task_rx.recv().await { while let Some(task) = task_rx.recv().await {
let task_chan = task.channel.clone(); let task_chan = task.channel.clone();
let task_cid = task.chat_id.clone(); let task_cid = task.chat_id.clone();
@ -1613,8 +1671,17 @@ fn spawn_agent_worker(
let bus2 = bus.clone(); let bus2 = bus.clone();
let chan2 = task_chan.clone(); let chan2 = task_chan.clone();
let cid2 = task_cid.clone(); let cid2 = task_cid.clone();
let unified_str2 = unified_str.clone();
let process_future = async move { let process_future = async move {
let result = match agent.process(history_out.clone()).await { let process_result = crate::agent::sub_agent::DELEGATE_CONTEXT.scope(
crate::agent::DelegateContext {
session_id: unified_str2,
channel: chan2.clone(),
chat_id: cid2.clone(),
},
agent.process(history_out.clone()),
).await;
let result = match process_result {
Ok(r) => r, Ok(r) => r,
Err(AgentError::LlmError(ref msg)) Err(AgentError::LlmError(ref msg))
if is_context_overflow_error(msg) => if is_context_overflow_error(msg) =>
@ -1911,3 +1978,24 @@ mod tests {
} }
} }
} }
fn format_task_notification(task_id: &str, status: &crate::agent::TaskStatus, summary: &str) -> String {
match status {
crate::agent::TaskStatus::Completed => format!(
"📋 后台任务完成\n\n任务 ID: {}\n\n结果:\n{}",
task_id, summary
),
crate::agent::TaskStatus::Failed(err) => format!(
"📋 后台任务失败\n\n任务 ID: {}\n错误: {}",
task_id, err
),
crate::agent::TaskStatus::Cancelled => format!(
"📋 后台任务已取消\n\n任务 ID: {}",
task_id
),
crate::agent::TaskStatus::TimedOut => format!(
"📋 后台任务超时\n\n任务 ID: {}",
task_id
),
}
}

View File

@ -0,0 +1,19 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackgroundTask {
pub id: String,
pub session_id: String,
pub channel: String,
pub chat_id: String,
pub prompt: String,
pub allowed_tools: Option<String>,
pub status: String,
pub result: Option<String>,
pub error: Option<String>,
pub tool_calls_count: i64,
pub iterations: i64,
pub started_at: Option<i64>,
pub finished_at: Option<i64>,
pub created_at: i64,
}

View File

@ -1,10 +1,12 @@
pub mod error; pub mod error;
pub mod memory; pub mod memory;
pub mod message; pub mod message;
pub mod background_task;
pub mod scheduler; pub mod scheduler;
pub mod session; pub mod session;
pub use error::StorageError; pub use error::StorageError;
pub use background_task::BackgroundTask;
pub use scheduler::{JobRun, ScheduledJob}; pub use scheduler::{JobRun, ScheduledJob};
use sqlx::{Pool, Row, Sqlite, SqlitePool}; use sqlx::{Pool, Row, Sqlite, SqlitePool};
@ -105,6 +107,48 @@ impl Storage {
.await .await
.ok(); .ok();
// Background tasks table — for async sub-agent tasks.
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
// Session and task association is maintained at the application level.
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS background_tasks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
prompt TEXT NOT NULL,
allowed_tools TEXT,
status TEXT NOT NULL DEFAULT 'pending',
result TEXT,
error TEXT,
tool_calls_count INTEGER DEFAULT 0,
iterations INTEGER DEFAULT 0,
started_at INTEGER,
finished_at INTEGER,
created_at INTEGER NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query( sqlx::query(
r#" r#"
CREATE TABLE IF NOT EXISTS memories ( CREATE TABLE IF NOT EXISTS memories (
@ -816,6 +860,148 @@ impl Storage {
} }
unreachable!() unreachable!()
} }
// ── Background Task CRUD ──
pub async fn create_background_task(
&self,
task: &crate::storage::background_task::BackgroundTask,
) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&task.id)
.bind(&task.session_id)
.bind(&task.channel)
.bind(&task.chat_id)
.bind(&task.prompt)
.bind(&task.allowed_tools)
.bind(&task.status)
.bind(&task.result)
.bind(&task.error)
.bind(task.tool_calls_count)
.bind(task.iterations)
.bind(task.started_at)
.bind(task.finished_at)
.bind(task.created_at)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn update_background_task_status(
&self,
id: &str,
status: &str,
result: Option<&str>,
error: Option<&str>,
started_at: Option<i64>,
finished_at: Option<i64>,
) -> Result<(), StorageError> {
sqlx::query(
r#"
UPDATE background_tasks
SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error),
started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at)
WHERE id = ?
"#,
)
.bind(status)
.bind(result)
.bind(error)
.bind(started_at)
.bind(finished_at)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn get_background_task(
&self,
id: &str,
) -> Result<crate::storage::background_task::BackgroundTask, StorageError> {
let row = sqlx::query(
r#"
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
tool_calls_count, iterations, started_at, finished_at, created_at
FROM background_tasks WHERE id = ?
"#,
)
.bind(id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
Ok(crate::storage::background_task::BackgroundTask {
id: row.get("id"),
session_id: row.get("session_id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
prompt: row.get("prompt"),
allowed_tools: row.get("allowed_tools"),
status: row.get("status"),
result: row.get("result"),
error: row.get("error"),
tool_calls_count: row.get("tool_calls_count"),
iterations: row.get("iterations"),
started_at: row.get("started_at"),
finished_at: row.get("finished_at"),
created_at: row.get("created_at"),
})
}
pub async fn list_background_tasks(
&self,
session_id: &str,
) -> Result<Vec<crate::storage::background_task::BackgroundTask>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
tool_calls_count, iterations, started_at, finished_at, created_at
FROM background_tasks
WHERE session_id = ?
ORDER BY created_at DESC
"#,
)
.bind(session_id)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| crate::storage::background_task::BackgroundTask {
id: row.get("id"),
session_id: row.get("session_id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
prompt: row.get("prompt"),
allowed_tools: row.get("allowed_tools"),
status: row.get("status"),
result: row.get("result"),
error: row.get("error"),
tool_calls_count: row.get("tool_calls_count"),
iterations: row.get("iterations"),
started_at: row.get("started_at"),
finished_at: row.get("finished_at"),
created_at: row.get("created_at"),
})
.collect())
}
pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result<usize, StorageError> {
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms;
let result = sqlx::query(
"DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?",
)
.bind(cutoff)
.execute(self.pool())
.await?;
Ok(result.rows_affected() as usize)
}
} }
#[cfg(test)] #[cfg(test)]

365
src/tools/delegate.rs Normal file
View File

@ -0,0 +1,365 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus};
use crate::tools::traits::{Tool, ToolResult};
pub struct DelegateTool {
sub_agent_manager: Arc<SubAgentManager>,
}
impl DelegateTool {
pub fn new(sub_agent_manager: Arc<SubAgentManager>) -> Self {
Self { sub_agent_manager }
}
}
#[async_trait]
impl Tool for DelegateTool {
fn name(&self) -> &str {
"delegate"
}
fn description(&self) -> &str {
"子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\
inline ()background ()\
parallel ( Agent )\
"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["run", "check_task", "cancel_task", "list_tasks"],
"description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务"
},
"prompt": {
"type": "string",
"description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件。action=run 时必填"
},
"mode": {
"type": "string",
"enum": ["inline", "background", "parallel"],
"description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline"
},
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator"
},
"max_iterations": {
"type": "integer",
"description": "最大迭代次数,默认 99"
},
"timeout_secs": {
"type": "integer",
"description": "超时秒数,默认 36001小时"
},
"tasks": {
"type": "array",
"description": "并行模式下的多个子任务(仅 mode=parallel 时使用)",
"items": {
"type": "object",
"properties": {
"prompt": { "type": "string", "description": "子任务描述" },
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "该子任务的工具列表"
}
},
"required": ["prompt"]
}
},
"task_id": {
"type": "string",
"description": "后台任务IDaction=check_task/cancel_task 时必填)"
}
},
"required": ["action"]
})
}
fn read_only(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args["action"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
match action {
"run" => self.handle_run(&args).await,
"check_task" => self.handle_check_task(&args).await,
"cancel_task" => self.handle_cancel_task(&args).await,
"list_tasks" => self.handle_list_tasks(&args).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)),
}),
}
}
}
impl DelegateTool {
fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result<SubAgentConfig> {
let prompt = args["prompt"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
.to_string();
let allowed_tools: Option<Vec<String>> = args["allowed_tools"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
let timeout_secs = args["timeout_secs"].as_u64();
Ok(SubAgentConfig {
prompt,
mode: ExecutionMode::Inline,
allowed_tools,
max_iterations,
timeout_secs,
})
}
async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let mode_str = args["mode"].as_str().unwrap_or("inline");
let mode = match mode_str {
"inline" => ExecutionMode::Inline,
"background" => ExecutionMode::Background,
"parallel" => ExecutionMode::Parallel,
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)),
})
}
};
match mode {
ExecutionMode::Inline => {
let config = self.parse_config_from_args(args)?;
let result = self.sub_agent_manager.run_inline(config).await
.map_err(|e| anyhow::anyhow!("{}", e))?;
match result.status {
TaskStatus::Completed => Ok(ToolResult {
success: true,
output: result.content,
error: None,
}),
TaskStatus::Failed(err) => Ok(ToolResult {
success: false,
output: result.content,
error: Some(err),
}),
TaskStatus::TimedOut => Ok(ToolResult {
success: false,
output: result.content,
error: Some("sub-agent timed out".into()),
}),
TaskStatus::Cancelled => Ok(ToolResult {
success: false,
output: result.content,
error: Some("sub-agent cancelled".into()),
}),
}
}
ExecutionMode::Background => {
let config = self.parse_config_from_args(args)?;
let ctx = crate::agent::sub_agent::get_delegate_context()
.map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?;
let task_id = self.sub_agent_manager.run_background(config, ctx).await
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ToolResult {
success: true,
output: format!("后台任务已启动。\ntask_id: {}", task_id),
error: None,
})
}
ExecutionMode::Parallel => {
let tasks = args["tasks"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?;
let mut configs = Vec::new();
for task in tasks {
let prompt = task["prompt"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
.to_string();
let allowed_tools: Option<Vec<String>> = task["allowed_tools"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
configs.push(SubAgentConfig {
prompt,
mode: ExecutionMode::Inline,
allowed_tools,
max_iterations: args["max_iterations"].as_u64().map(|v| v as usize),
timeout_secs: args["timeout_secs"].as_u64(),
});
}
let has_args_allowed = args["allowed_tools"].as_array().is_some();
for c in &mut configs {
if c.allowed_tools.is_none() && has_args_allowed {
c.allowed_tools = args["allowed_tools"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
}
}
let results = self.sub_agent_manager.run_parallel(configs).await
.map_err(|e| anyhow::anyhow!("{}", e))?;
let mut output = String::new();
for (i, r) in results.iter().enumerate() {
let status_icon = match r.status {
TaskStatus::Completed => "",
TaskStatus::Failed(_) => "",
TaskStatus::TimedOut => "⏱️ 超时",
TaskStatus::Cancelled => "🚫 已取消",
};
output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon));
if !r.content.is_empty() {
output.push_str(&r.content);
output.push_str("\n\n");
}
if let TaskStatus::Failed(ref err) = r.status {
output.push_str(&format!("错误: {}\n\n", err));
}
}
let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed));
Ok(ToolResult {
success: all_success,
output: output.trim().to_string(),
error: None,
})
}
}
}
async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let task_id = args["task_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
match self.sub_agent_manager.check_task(task_id).await {
Some(task) => {
let status_icon = match task.status.as_str() {
"completed" => "✅ 已完成",
"failed" => "❌ 失败",
"cancelled" => "🚫 已取消",
"running" => "🔄 运行中",
"pending" => "⏳ 等待中",
_ => task.status.as_str(),
};
let mut output = format!(
"任务 ID: {}\n状态: {}\n任务: {}",
task.id, status_icon, task.prompt
);
if let Some(ref result) = task.result {
output.push_str(&format!("\n\n结果:\n{}", result));
}
if let Some(ref error) = task.error {
output.push_str(&format!("\n错误: {}", error));
}
if let Some(started) = task.started_at {
if let Some(finished) = task.finished_at {
let duration = (finished - started) as f64 / 1000.0;
output.push_str(&format!("\n耗时: {:.1}s", duration));
}
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
None => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("task not found: {}", task_id)),
}),
}
}
async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let task_id = args["task_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
match self.sub_agent_manager.cancel_task(task_id).await {
Ok(true) => Ok(ToolResult {
success: true,
output: format!("后台任务 {} 已取消", task_id),
error: None,
}),
Ok(false) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)),
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("取消失败: {}", e)),
}),
}
}
async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let ctx = crate::agent::sub_agent::get_delegate_context()
.map_err(|_| anyhow::anyhow!("delegate context not available"))?;
let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await;
if tasks.is_empty() {
return Ok(ToolResult {
success: true,
output: "没有后台任务".to_string(),
error: None,
});
}
let mut output = String::from("后台任务列表:\n\n");
for task in &tasks {
let status_icon = match task.status.as_str() {
"completed" => "",
"failed" => "",
"cancelled" => "🚫",
"running" => "🔄",
"pending" => "",
_ => "",
};
output.push_str(&format!(
"{} {} - {} - {} (created: {})\n",
status_icon,
&task.id[..std::cmp::min(8, task.id.len())],
task.prompt.chars().take(60).collect::<String>(),
task.status,
task.created_at,
));
}
output.push_str(&format!("\n{} 个任务", tasks.len()));
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}

View File

@ -4,6 +4,7 @@ pub mod calculator;
pub mod chat_manager; pub mod chat_manager;
pub mod content_search; pub mod content_search;
pub mod cron; pub mod cron;
pub mod delegate;
pub mod file_edit; pub mod file_edit;
pub mod file_read; pub mod file_read;
pub mod file_search; pub mod file_search;
@ -23,6 +24,7 @@ pub use browser::BrowserTool;
pub use calculator::CalculatorTool; pub use calculator::CalculatorTool;
pub use chat_manager::ChatManagerTool; pub use chat_manager::ChatManagerTool;
pub use content_search::ContentSearchTool; pub use content_search::ContentSearchTool;
pub use delegate::DelegateTool;
pub use file_edit::FileEditTool; pub use file_edit::FileEditTool;
pub use file_read::FileReadTool; pub use file_read::FileReadTool;
pub use file_search::FileSearchTool; pub use file_search::FileSearchTool;
@ -36,6 +38,7 @@ pub use traits::{OutboundMessenger, Tool, ToolResult};
pub use web_fetch::WebFetchTool; pub use web_fetch::WebFetchTool;
use std::sync::Arc; use std::sync::Arc;
use crate::agent::SubAgentManager;
use crate::config::BrowserConfig; use crate::config::BrowserConfig;
use crate::memory::MemoryManager; use crate::memory::MemoryManager;
use crate::skills::SkillsLoader; use crate::skills::SkillsLoader;
@ -46,6 +49,7 @@ use crate::skills::SkillsLoader;
pub fn create_default_tools( pub fn create_default_tools(
skills_loader: Arc<SkillsLoader>, skills_loader: Arc<SkillsLoader>,
memory: Arc<MemoryManager>, memory: Arc<MemoryManager>,
sub_agent_manager: Option<Arc<SubAgentManager>>,
browser_config: Option<&BrowserConfig>, browser_config: Option<&BrowserConfig>,
) -> ToolRegistry { ) -> ToolRegistry {
let registry = ToolRegistry::new(); let registry = ToolRegistry::new();
@ -76,5 +80,9 @@ pub fn create_default_tools(
} }
} }
if let Some(mgr) = sub_agent_manager {
registry.register(DelegateTool::new(mgr));
}
registry registry
} }

View File

@ -20,6 +20,11 @@ impl ToolRegistry {
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool)); self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
} }
/// Register an existing Arc-wrapped tool by name
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
self.tools.lock().unwrap().insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> { pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
self.tools.lock().unwrap().get(name).cloned() self.tools.lock().unwrap().get(name).cloned()
} }
@ -62,6 +67,17 @@ impl ToolRegistry {
.map(|(k, v)| (k.clone(), v.clone())) .map(|(k, v)| (k.clone(), v.clone()))
.collect() .collect()
} }
/// 生成工具列表描述,用于子 Agent 系统提示词
pub fn describe_for_prompt(&self) -> String {
let mut entries: Vec<String> = self
.iter()
.into_iter()
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
.collect();
entries.sort();
entries.join("\n")
}
} }
impl Default for ToolRegistry { impl Default for ToolRegistry {