增加sub-agent机制。
This commit is contained in:
parent
a77c026826
commit
41b4895ff0
@ -12,6 +12,8 @@ serde_json = "1.0"
|
||||
async-trait = "0.1"
|
||||
thiserror = "2.0.18"
|
||||
tokio = { version = "1.52", features = ["full"] }
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
dashmap = "6.1"
|
||||
uuid = { version = "1.23", features = ["v4"] }
|
||||
axum = { version = "0.8", features = ["ws"] }
|
||||
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
pub mod agent_loop;
|
||||
pub mod context_compressor;
|
||||
pub mod media_handler;
|
||||
pub mod sub_agent;
|
||||
pub mod system_prompt;
|
||||
|
||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
||||
pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus};
|
||||
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};
|
||||
|
||||
611
src/agent/sub_agent.rs
Normal file
611
src/agent/sub_agent.rs
Normal file
@ -0,0 +1,611 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, LLMProvider};
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
tokio::task_local! {
|
||||
pub(crate) static DELEGATE_CONTEXT: DelegateContext;
|
||||
}
|
||||
|
||||
/// Read the delegate context from the current task. Returns an error if not set.
|
||||
pub fn get_delegate_context() -> Result<DelegateContext, String> {
|
||||
DELEGATE_CONTEXT.try_with(|ctx| ctx.clone())
|
||||
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
|
||||
}
|
||||
|
||||
const DEFAULT_MAX_ITERATIONS: usize = 99;
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 3600;
|
||||
const MAX_INLINE_RESULT_CHARS: usize = 8000;
|
||||
|
||||
const DEFAULT_READONLY_TOOLS: &[&str] = &[
|
||||
"file_read",
|
||||
"file_search",
|
||||
"content_search",
|
||||
"web_fetch",
|
||||
"http_request",
|
||||
"calculator",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubAgentConfig {
|
||||
pub prompt: String,
|
||||
pub mode: ExecutionMode,
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
pub max_iterations: Option<usize>,
|
||||
pub timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecutionMode {
|
||||
Inline,
|
||||
Background,
|
||||
Parallel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubAgentResult {
|
||||
pub task_id: String,
|
||||
pub content: String,
|
||||
pub content_truncated: bool,
|
||||
pub status: TaskStatus,
|
||||
pub tool_calls_count: usize,
|
||||
pub iterations: usize,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TaskStatus {
|
||||
Completed,
|
||||
Failed(String),
|
||||
Cancelled,
|
||||
TimedOut,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskNotification {
|
||||
pub task_id: String,
|
||||
pub session_id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub status: TaskStatus,
|
||||
pub result_summary: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DelegateContext {
|
||||
pub session_id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SubAgentError {
|
||||
TooManyTasks(usize),
|
||||
ProviderCreation(String),
|
||||
Storage(String),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SubAgentError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max),
|
||||
Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e),
|
||||
Self::Storage(e) => write!(f, "storage error: {}", e),
|
||||
Self::Other(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SubAgentError {}
|
||||
|
||||
pub struct SubAgentManager {
|
||||
provider_config: LLMProviderConfig,
|
||||
full_tools: Arc<ToolRegistry>,
|
||||
storage: Option<Arc<crate::storage::Storage>>,
|
||||
active_tasks: Arc<DashMap<String, CancellationToken>>,
|
||||
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
|
||||
max_concurrent_background_tasks: usize,
|
||||
}
|
||||
|
||||
impl SubAgentManager {
|
||||
pub fn new(
|
||||
provider_config: LLMProviderConfig,
|
||||
full_tools: Arc<ToolRegistry>,
|
||||
storage: Option<Arc<crate::storage::Storage>>,
|
||||
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
|
||||
max_concurrent_background_tasks: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_config,
|
||||
full_tools,
|
||||
storage,
|
||||
active_tasks: Arc::new(DashMap::new()),
|
||||
notify_tx,
|
||||
max_concurrent_background_tasks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn filter_tools(&self, allowed: &Option<Vec<String>>) -> Arc<ToolRegistry> {
|
||||
let allowed_set: HashSet<&str> = match allowed {
|
||||
Some(list) => list.iter().map(|s| s.as_str()).collect(),
|
||||
None => DEFAULT_READONLY_TOOLS.iter().copied().collect(),
|
||||
};
|
||||
let filtered = ToolRegistry::new();
|
||||
for (name, tool) in self.full_tools.iter() {
|
||||
if allowed_set.contains(name.as_str()) && name != "delegate" {
|
||||
filtered.register_raw(name, tool);
|
||||
}
|
||||
}
|
||||
Arc::new(filtered)
|
||||
}
|
||||
|
||||
pub fn build_system_prompt(&self, config: &SubAgentConfig, tools: &ToolRegistry) -> String {
|
||||
let timeout_human = format_duration(config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS));
|
||||
let tool_descriptions = tools.describe_for_prompt();
|
||||
|
||||
let http_only_note = if config.allowed_tools.is_none()
|
||||
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"))
|
||||
{
|
||||
"- When using http_request, only the GET method is permitted. \
|
||||
Do NOT use POST, PUT, DELETE, or any other method."
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
format!(
|
||||
"You are a sub-agent working on a delegated task. Complete the task below \
|
||||
and return a single, self-contained result.\n\
|
||||
\n\
|
||||
## Task\n\
|
||||
{task}\n\
|
||||
\n\
|
||||
## Rules\n\
|
||||
- Focus ONLY on the task above. Do not explore unrelated topics.\n\
|
||||
- Use tools only when necessary for the task.\n\
|
||||
- Do NOT use the delegate tool — sub-agent recursion is forbidden.\n\
|
||||
- If the task cannot be completed, explain why clearly.\n\
|
||||
- Return only the final result. Do not describe your process.\n\
|
||||
{http_only}\n\
|
||||
- Timeout: {timeout_human}. If approaching the limit, return partial results.\n\
|
||||
\n\
|
||||
## Available Tools\n\
|
||||
{tool_descriptions}\n\
|
||||
\n\
|
||||
## Workspace\n\
|
||||
{workspace}",
|
||||
task = config.prompt,
|
||||
http_only = http_only_note,
|
||||
timeout_human = timeout_human,
|
||||
tool_descriptions = tool_descriptions,
|
||||
workspace = self.provider_config.workspace_dir.display(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_sub_agent(
|
||||
&self,
|
||||
config: &SubAgentConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
let mut provider = create_provider(self.provider_config.clone())
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
if let Some(ref s) = self.storage {
|
||||
provider.set_storage(s.clone());
|
||||
}
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider);
|
||||
|
||||
let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
|
||||
let workspace_dir = self.provider_config.workspace_dir.clone();
|
||||
let model_name = self.provider_config.model_id.clone();
|
||||
let input_types = self.provider_config.input_types.clone();
|
||||
|
||||
let agent = AgentLoop::with_provider_and_tools(
|
||||
provider,
|
||||
tools,
|
||||
max_iterations,
|
||||
model_name,
|
||||
workspace_dir,
|
||||
input_types,
|
||||
)
|
||||
.with_context_window(self.provider_config.token_limit);
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
|
||||
pub async fn run_inline(
|
||||
&self,
|
||||
config: SubAgentConfig,
|
||||
) -> Result<SubAgentResult, SubAgentError> {
|
||||
let task_id = generate_task_id();
|
||||
let tools = self.filter_tools(&config.allowed_tools);
|
||||
let system_prompt = self.build_system_prompt(&config, &tools);
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
|
||||
let agent = self.build_sub_agent(&config, tools)
|
||||
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
let history = vec![
|
||||
ChatMessage::system(system_prompt),
|
||||
ChatMessage::user(&config.prompt),
|
||||
];
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
agent.process(history),
|
||||
)
|
||||
.await;
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
match result {
|
||||
Ok(Ok(agent_result)) => {
|
||||
let (content, truncated) =
|
||||
truncate_sub_agent_result(&agent_result.final_response.content);
|
||||
let tool_calls_count = agent_result.emitted_messages.iter()
|
||||
.filter(|m| m.tool_calls.is_some())
|
||||
.count();
|
||||
let iterations = agent_result.emitted_messages.iter()
|
||||
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
|
||||
.count();
|
||||
Ok(SubAgentResult {
|
||||
task_id,
|
||||
content,
|
||||
content_truncated: truncated,
|
||||
status: TaskStatus::Completed,
|
||||
tool_calls_count,
|
||||
iterations,
|
||||
duration_ms,
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Ok(SubAgentResult {
|
||||
task_id,
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Failed(e.to_string()),
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms,
|
||||
}),
|
||||
Err(_elapsed) => Ok(SubAgentResult {
|
||||
task_id,
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::TimedOut,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_parallel(
|
||||
&self,
|
||||
configs: Vec<SubAgentConfig>,
|
||||
) -> Result<Vec<SubAgentResult>, SubAgentError> {
|
||||
let futures: Vec<_> = configs
|
||||
.into_iter()
|
||||
.map(|config| {
|
||||
let mgr = self; // &self borrow, all tasks share the same manager
|
||||
async move { mgr.run_inline(config).await }
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = futures_util::future::join_all(futures).await;
|
||||
Ok(results.into_iter().collect::<Result<Vec<_>, _>>()?)
|
||||
}
|
||||
|
||||
pub async fn run_background(
|
||||
&self,
|
||||
config: SubAgentConfig,
|
||||
ctx: DelegateContext,
|
||||
) -> Result<String, SubAgentError> {
|
||||
if self.active_tasks.len() >= self.max_concurrent_background_tasks {
|
||||
return Err(SubAgentError::TooManyTasks(
|
||||
self.max_concurrent_background_tasks,
|
||||
));
|
||||
}
|
||||
|
||||
let task_id = generate_task_id();
|
||||
let cancel_token = CancellationToken::new();
|
||||
|
||||
self.active_tasks
|
||||
.insert(task_id.clone(), cancel_token.clone());
|
||||
|
||||
// Write DB: pending
|
||||
if let Some(ref storage) = self.storage {
|
||||
let allowed_tools_json = config
|
||||
.allowed_tools
|
||||
.as_ref()
|
||||
.and_then(|v| serde_json::to_string(v).ok());
|
||||
let record = crate::storage::BackgroundTask {
|
||||
id: task_id.clone(),
|
||||
session_id: ctx.session_id.clone(),
|
||||
channel: ctx.channel.clone(),
|
||||
chat_id: ctx.chat_id.clone(),
|
||||
prompt: config.prompt.clone(),
|
||||
allowed_tools: allowed_tools_json,
|
||||
status: "pending".to_string(),
|
||||
result: None,
|
||||
error: None,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
started_at: None,
|
||||
finished_at: None,
|
||||
created_at: chrono::Utc::now().timestamp_millis(),
|
||||
};
|
||||
storage
|
||||
.create_background_task(&record)
|
||||
.await
|
||||
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
|
||||
}
|
||||
|
||||
let tools = self.filter_tools(&config.allowed_tools);
|
||||
let system_prompt = self.build_system_prompt(&config, &tools);
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
let provider_config = self.provider_config.clone();
|
||||
let storage = self.storage.clone();
|
||||
let notify_tx = self.notify_tx.clone();
|
||||
let active_tasks = Arc::clone(&self.active_tasks);
|
||||
|
||||
let tid = task_id.clone();
|
||||
let sess_id = ctx.session_id.clone();
|
||||
let ch = ctx.channel.clone();
|
||||
let cid = ctx.chat_id.clone();
|
||||
let prompt = config.prompt.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let started_at = chrono::Utc::now().timestamp_millis();
|
||||
|
||||
// Update DB: running
|
||||
if let Some(ref s) = storage {
|
||||
let _ = s
|
||||
.update_background_task_status(
|
||||
&tid, "running", None, None,
|
||||
Some(started_at), None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut provider = create_provider(provider_config.clone()).ok();
|
||||
if let Some(ref mut p) = provider {
|
||||
if let Some(ref s) = storage {
|
||||
p.set_storage(s.clone());
|
||||
}
|
||||
}
|
||||
let provider_result: Option<Arc<dyn LLMProvider>> =
|
||||
provider.map(|p| Arc::from(p));
|
||||
|
||||
let result = match provider_result {
|
||||
Some(provider) => {
|
||||
let agent = AgentLoop::with_provider_and_tools(
|
||||
provider,
|
||||
tools,
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
provider_config.model_id.clone(),
|
||||
provider_config.workspace_dir.clone(),
|
||||
provider_config.input_types.clone(),
|
||||
)
|
||||
.with_context_window(provider_config.token_limit);
|
||||
|
||||
let history = vec![
|
||||
ChatMessage::system(system_prompt),
|
||||
ChatMessage::user(&prompt),
|
||||
];
|
||||
|
||||
tokio::select! {
|
||||
r = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
agent.process(history),
|
||||
) => {
|
||||
match r {
|
||||
Ok(Ok(agent_result)) => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: agent_result.final_response.content,
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Completed,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
Ok(Err(e)) => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Failed(e.to_string()),
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
Err(_) => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::TimedOut,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = cancel_token.cancelled() => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Cancelled,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
None => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Failed("provider creation failed".into()),
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
};
|
||||
|
||||
let finished_at = chrono::Utc::now().timestamp_millis();
|
||||
let duration_ms = (finished_at - started_at) as u64;
|
||||
|
||||
let (status_str, error_val) = match &result.status {
|
||||
TaskStatus::Completed => ("completed".to_string(), None),
|
||||
TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())),
|
||||
TaskStatus::Cancelled => ("cancelled".to_string(), None),
|
||||
TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())),
|
||||
};
|
||||
|
||||
if let Some(ref s) = storage {
|
||||
let _ = s
|
||||
.update_background_task_status(
|
||||
&tid, &status_str,
|
||||
Some(&result.content), error_val.as_deref(),
|
||||
Some(started_at), Some(finished_at),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let _ = notify_tx.send(TaskNotification {
|
||||
task_id: tid.clone(),
|
||||
session_id: sess_id,
|
||||
channel: ch,
|
||||
chat_id: cid,
|
||||
status: result.status,
|
||||
result_summary: summarize_for_notification(&result.content, duration_ms),
|
||||
});
|
||||
|
||||
active_tasks.remove(&tid);
|
||||
});
|
||||
|
||||
Ok(task_id)
|
||||
}
|
||||
|
||||
pub async fn cancel_task(&self, task_id: &str) -> Result<bool, SubAgentError> {
|
||||
if let Some((_, token)) = self.active_tasks.remove(task_id) {
|
||||
token.cancel();
|
||||
if let Some(ref s) = self.storage {
|
||||
s.update_background_task_status(
|
||||
task_id,
|
||||
"cancelled",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(chrono::Utc::now().timestamp_millis()),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
|
||||
}
|
||||
Ok(true)
|
||||
} else if let Some(ref s) = self.storage {
|
||||
match s.get_background_task(task_id).await {
|
||||
Ok(task) => {
|
||||
match task.status.as_str() {
|
||||
"pending" | "running" => {
|
||||
tracing::warn!(task_id, "task in DB but not in active_tasks");
|
||||
Ok(false)
|
||||
}
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
) -> Option<crate::storage::BackgroundTask> {
|
||||
if let Some(ref s) = self.storage {
|
||||
s.get_background_task(task_id).await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_tasks(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Vec<crate::storage::BackgroundTask> {
|
||||
if let Some(ref s) = self.storage {
|
||||
s.list_background_tasks(session_id).await.unwrap_or_default()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cancel_by_session(&self, session_id: &str) {
|
||||
// Cancel all running tasks for a session by checking DB
|
||||
if let Some(ref s) = self.storage {
|
||||
if let Ok(tasks) = s.list_background_tasks(session_id).await {
|
||||
for task in &tasks {
|
||||
if task.status == "pending" || task.status == "running" {
|
||||
let _ = self.cancel_task(&task.id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_task_count(&self) -> usize {
|
||||
self.active_tasks.len()
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_task_id() -> String {
|
||||
Uuid::new_v4().to_string()[..8].to_string()
|
||||
}
|
||||
|
||||
fn format_duration(seconds: u64) -> String {
|
||||
if seconds < 60 {
|
||||
format!("{}s", seconds)
|
||||
} else if seconds < 3600 {
|
||||
format!("{}m", seconds / 60)
|
||||
} else {
|
||||
format!("{}h", seconds / 3600)
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_sub_agent_result(content: &str) -> (String, bool) {
|
||||
if content.len() <= MAX_INLINE_RESULT_CHARS {
|
||||
(content.to_string(), false)
|
||||
} else {
|
||||
let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS);
|
||||
(
|
||||
format!(
|
||||
"{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]",
|
||||
&content[..truncate_at],
|
||||
content.len()
|
||||
),
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_for_notification(content: &str, _duration_ms: u64) -> String {
|
||||
const MAX_SUMMARY_BYTES: usize = 500;
|
||||
if content.len() <= MAX_SUMMARY_BYTES {
|
||||
content.to_string()
|
||||
} else {
|
||||
let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES);
|
||||
format!("{}...", &content[..truncate_at])
|
||||
}
|
||||
}
|
||||
@ -3,11 +3,7 @@
|
||||
//! This module provides a modular framework for building system prompts
|
||||
//! using the SystemPromptBuilder pattern.
|
||||
//!
|
||||
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic
|
||||
//!
|
||||
//! Configuration files loaded from ~/.picobot/:
|
||||
//! - AGENTS.md — agent identity and behavior
|
||||
//! - USER.md — user preferences and profile
|
||||
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation
|
||||
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::path::Path;
|
||||
@ -55,6 +51,7 @@ impl SystemPromptBuilder {
|
||||
Box::new(CrossChannelSection),
|
||||
Box::new(MemorySection),
|
||||
Box::new(HistorySection),
|
||||
Box::new(DelegationSection),
|
||||
],
|
||||
}
|
||||
}
|
||||
@ -353,6 +350,45 @@ impl PromptSection for HistorySection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sub-agent delegation principles.
|
||||
pub struct DelegationSection;
|
||||
|
||||
impl PromptSection for DelegationSection {
|
||||
fn name(&self) -> &str {
|
||||
"delegation"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
"## 子 Agent 委托原则\n\n\
|
||||
当任务复杂需要拆解时,使用 delegate 工具创建子 Agent:\n\
|
||||
\n\
|
||||
### 何时委托\n\
|
||||
- 多个独立子任务可以并行处理时(使用 mode=\"parallel\")\n\
|
||||
- 长时间运行的任务需要后台执行时(使用 mode=\"background\")\n\
|
||||
- 需要以不同权限(受限工具集)执行时\n\
|
||||
\n\
|
||||
### 工具分配原则\n\
|
||||
- **最小权限**:只给子 Agent 完成其任务所需的最少工具\n\
|
||||
- **只读优先**:如果可以只用 file_read、file_search、web_fetch 完成,不要给写权限(bash、file_write、file_edit)\n\
|
||||
- **禁止递归**:永远不要把 delegate 工具分配给子 Agent\n\
|
||||
- **明确边界**:每个子 Agent 只负责一个清晰、独立的子任务\n\
|
||||
\n\
|
||||
### 任务描述\n\
|
||||
- 任务 prompt 要清晰、具体、有明确输出要求\n\
|
||||
- 如需额外约束,直接写在 prompt 中(例如:\"跳过 .tmp 文件\")\n\
|
||||
- 明确说明期望的输出格式\n\
|
||||
\n\
|
||||
### 并行模式\n\
|
||||
- 多个无依赖的子任务使用 mode=\"parallel\",任务定义在 tasks 数组中\n\
|
||||
- 并行任务之间不应有数据依赖\n\
|
||||
- 并行任务数建议不超过 5 个\n\
|
||||
\n\
|
||||
### 后台模式\n\
|
||||
- 预计执行时间超过 30s 的任务使用 mode=\"background\"\n\
|
||||
- 后台任务有全局并发上限,如果失败提示用户稍后重试".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
// === Helper Functions ===
|
||||
|
||||
/// Get user config directory (~/.picobot/).
|
||||
|
||||
@ -152,10 +152,26 @@ pub struct GatewayConfig {
|
||||
pub cleanup_interval_minutes: Option<u64>,
|
||||
#[serde(default, rename = "session_db_path")]
|
||||
pub session_db_path: Option<String>,
|
||||
#[serde(default, rename = "max_concurrent_background_tasks")]
|
||||
pub max_concurrent_background_tasks: usize,
|
||||
#[serde(default)]
|
||||
pub scheduler: Option<SchedulerConfig>,
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
session_ttl_hours: None,
|
||||
cleanup_interval_minutes: None,
|
||||
session_db_path: None,
|
||||
max_concurrent_background_tasks: 10,
|
||||
scheduler: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SchedulerConfig {
|
||||
/// Whether the scheduler is enabled
|
||||
@ -209,19 +225,6 @@ fn default_gateway_url() -> String {
|
||||
"ws://127.0.0.1:19876/ws".to_string()
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
session_ttl_hours: None,
|
||||
cleanup_interval_minutes: None,
|
||||
session_db_path: None,
|
||||
scheduler: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
||||
@ -91,6 +91,7 @@ impl GatewayState {
|
||||
bus.clone(),
|
||||
memory_manager,
|
||||
browser_config,
|
||||
config.gateway.max_concurrent_background_tasks,
|
||||
)?;
|
||||
let session_manager = Arc::new(session_manager);
|
||||
|
||||
|
||||
@ -751,6 +751,7 @@ pub struct SessionManager {
|
||||
storage: Arc<Storage>,
|
||||
bus: Arc<MessageBus>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
sub_agent_manager: Arc<crate::agent::SubAgentManager>,
|
||||
}
|
||||
|
||||
struct SessionManagerInner {
|
||||
@ -847,6 +848,7 @@ impl SessionManager {
|
||||
bus: Arc<MessageBus>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
browser_config: Option<BrowserConfig>,
|
||||
max_concurrent_background_tasks: usize,
|
||||
) -> Result<Self, AgentError> {
|
||||
let mut skills_loader = SkillsLoader::new();
|
||||
skills_loader.load_skills();
|
||||
@ -856,9 +858,61 @@ impl SessionManager {
|
||||
let tools = Arc::new(create_default_tools(
|
||||
skills_loader.clone(),
|
||||
memory_manager.clone(),
|
||||
None, // SubAgentManager created below
|
||||
browser_config.as_ref(),
|
||||
));
|
||||
|
||||
// Create SubAgentManager and register DelegateTool
|
||||
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let sub_agent_manager = Arc::new(crate::agent::SubAgentManager::new(
|
||||
provider_config.clone(),
|
||||
tools.clone(),
|
||||
Some(storage.clone()),
|
||||
notify_tx,
|
||||
max_concurrent_background_tasks,
|
||||
));
|
||||
tools.register(crate::tools::DelegateTool::new(sub_agent_manager.clone()));
|
||||
|
||||
// Start background task notification consumer
|
||||
let sm_bus = bus.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(notif) = notify_rx.recv().await {
|
||||
let content = format_task_notification(
|
||||
¬if.task_id,
|
||||
¬if.status,
|
||||
¬if.result_summary,
|
||||
);
|
||||
let outbound = OutboundMessage {
|
||||
channel: notif.channel,
|
||||
chat_id: notif.chat_id,
|
||||
content,
|
||||
reply_to: None,
|
||||
media: vec![],
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
let _ = sm_bus.publish_outbound(outbound).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Start periodic background task cleanup (every hour, TTL 24h)
|
||||
let cleanup_storage = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600));
|
||||
interval.tick().await; // skip immediate first tick
|
||||
loop {
|
||||
interval.tick().await;
|
||||
match cleanup_storage.cleanup_old_tasks(86_400_000).await {
|
||||
Ok(count) if count > 0 => {
|
||||
tracing::info!(count, "Cleaned up old background tasks");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to clean up old background tasks");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||||
sessions: HashMap::new(),
|
||||
@ -870,6 +924,7 @@ impl SessionManager {
|
||||
storage,
|
||||
bus,
|
||||
memory_manager,
|
||||
sub_agent_manager,
|
||||
})
|
||||
}
|
||||
|
||||
@ -1073,6 +1128,8 @@ impl SessionManager {
|
||||
msgs.push("消息队列已清空。".to_string());
|
||||
}
|
||||
guard.worker_generation = guard.worker_generation.wrapping_add(1);
|
||||
// Cancel all running background sub-agent tasks for this session
|
||||
self.sub_agent_manager.cancel_by_session(&sid.to_string()).await;
|
||||
let resp = if msgs.is_empty() {
|
||||
"没有正在执行的任务或队列。".to_string()
|
||||
} else {
|
||||
@ -1469,7 +1526,8 @@ fn spawn_agent_worker(
|
||||
unified_str: String,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let _scope = CURRENT_SOURCE_SESSION.scope(Some(unified_str), async {
|
||||
let unified_for_source = unified_str.clone();
|
||||
let _scope = CURRENT_SOURCE_SESSION.scope(Some(unified_for_source), async {
|
||||
while let Some(task) = task_rx.recv().await {
|
||||
let task_chan = task.channel.clone();
|
||||
let task_cid = task.chat_id.clone();
|
||||
@ -1613,8 +1671,17 @@ fn spawn_agent_worker(
|
||||
let bus2 = bus.clone();
|
||||
let chan2 = task_chan.clone();
|
||||
let cid2 = task_cid.clone();
|
||||
let unified_str2 = unified_str.clone();
|
||||
let process_future = async move {
|
||||
let result = match agent.process(history_out.clone()).await {
|
||||
let process_result = crate::agent::sub_agent::DELEGATE_CONTEXT.scope(
|
||||
crate::agent::DelegateContext {
|
||||
session_id: unified_str2,
|
||||
channel: chan2.clone(),
|
||||
chat_id: cid2.clone(),
|
||||
},
|
||||
agent.process(history_out.clone()),
|
||||
).await;
|
||||
let result = match process_result {
|
||||
Ok(r) => r,
|
||||
Err(AgentError::LlmError(ref msg))
|
||||
if is_context_overflow_error(msg) =>
|
||||
@ -1911,3 +1978,24 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_task_notification(task_id: &str, status: &crate::agent::TaskStatus, summary: &str) -> String {
|
||||
match status {
|
||||
crate::agent::TaskStatus::Completed => format!(
|
||||
"📋 后台任务完成\n\n任务 ID: {}\n\n结果:\n{}",
|
||||
task_id, summary
|
||||
),
|
||||
crate::agent::TaskStatus::Failed(err) => format!(
|
||||
"📋 后台任务失败\n\n任务 ID: {}\n错误: {}",
|
||||
task_id, err
|
||||
),
|
||||
crate::agent::TaskStatus::Cancelled => format!(
|
||||
"📋 后台任务已取消\n\n任务 ID: {}",
|
||||
task_id
|
||||
),
|
||||
crate::agent::TaskStatus::TimedOut => format!(
|
||||
"📋 后台任务超时\n\n任务 ID: {}",
|
||||
task_id
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
19
src/storage/background_task.rs
Normal file
19
src/storage/background_task.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BackgroundTask {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub prompt: String,
|
||||
pub allowed_tools: Option<String>,
|
||||
pub status: String,
|
||||
pub result: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub tool_calls_count: i64,
|
||||
pub iterations: i64,
|
||||
pub started_at: Option<i64>,
|
||||
pub finished_at: Option<i64>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
@ -1,10 +1,12 @@
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
pub mod message;
|
||||
pub mod background_task;
|
||||
pub mod scheduler;
|
||||
pub mod session;
|
||||
|
||||
pub use error::StorageError;
|
||||
pub use background_task::BackgroundTask;
|
||||
pub use scheduler::{JobRun, ScheduledJob};
|
||||
|
||||
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
||||
@ -105,6 +107,48 @@ impl Storage {
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Background tasks table — for async sub-agent tasks.
|
||||
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
|
||||
// Session and task association is maintained at the application level.
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS background_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
channel TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
allowed_tools TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
result TEXT,
|
||||
error TEXT,
|
||||
tool_calls_count INTEGER DEFAULT 0,
|
||||
iterations INTEGER DEFAULT 0,
|
||||
started_at INTEGER,
|
||||
finished_at INTEGER,
|
||||
created_at INTEGER NOT NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
@ -816,6 +860,148 @@ impl Storage {
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
// ── Background Task CRUD ──
|
||||
|
||||
pub async fn create_background_task(
|
||||
&self,
|
||||
task: &crate::storage::background_task::BackgroundTask,
|
||||
) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&task.id)
|
||||
.bind(&task.session_id)
|
||||
.bind(&task.channel)
|
||||
.bind(&task.chat_id)
|
||||
.bind(&task.prompt)
|
||||
.bind(&task.allowed_tools)
|
||||
.bind(&task.status)
|
||||
.bind(&task.result)
|
||||
.bind(&task.error)
|
||||
.bind(task.tool_calls_count)
|
||||
.bind(task.iterations)
|
||||
.bind(task.started_at)
|
||||
.bind(task.finished_at)
|
||||
.bind(task.created_at)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_background_task_status(
|
||||
&self,
|
||||
id: &str,
|
||||
status: &str,
|
||||
result: Option<&str>,
|
||||
error: Option<&str>,
|
||||
started_at: Option<i64>,
|
||||
finished_at: Option<i64>,
|
||||
) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE background_tasks
|
||||
SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error),
|
||||
started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at)
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(status)
|
||||
.bind(result)
|
||||
.bind(error)
|
||||
.bind(started_at)
|
||||
.bind(finished_at)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_background_task(
|
||||
&self,
|
||||
id: &str,
|
||||
) -> Result<crate::storage::background_task::BackgroundTask, StorageError> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
|
||||
tool_calls_count, iterations, started_at, finished_at, created_at
|
||||
FROM background_tasks WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(self.pool())
|
||||
.await?
|
||||
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
|
||||
|
||||
Ok(crate::storage::background_task::BackgroundTask {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
prompt: row.get("prompt"),
|
||||
allowed_tools: row.get("allowed_tools"),
|
||||
status: row.get("status"),
|
||||
result: row.get("result"),
|
||||
error: row.get("error"),
|
||||
tool_calls_count: row.get("tool_calls_count"),
|
||||
iterations: row.get("iterations"),
|
||||
started_at: row.get("started_at"),
|
||||
finished_at: row.get("finished_at"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_background_tasks(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<Vec<crate::storage::background_task::BackgroundTask>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
|
||||
tool_calls_count, iterations, started_at, finished_at, created_at
|
||||
FROM background_tasks
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| crate::storage::background_task::BackgroundTask {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
prompt: row.get("prompt"),
|
||||
allowed_tools: row.get("allowed_tools"),
|
||||
status: row.get("status"),
|
||||
result: row.get("result"),
|
||||
error: row.get("error"),
|
||||
tool_calls_count: row.get("tool_calls_count"),
|
||||
iterations: row.get("iterations"),
|
||||
started_at: row.get("started_at"),
|
||||
finished_at: row.get("finished_at"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result<usize, StorageError> {
|
||||
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms;
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?",
|
||||
)
|
||||
.bind(cutoff)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(result.rows_affected() as usize)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
365
src/tools/delegate.rs
Normal file
365
src/tools/delegate.rs
Normal file
@ -0,0 +1,365 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct DelegateTool {
|
||||
sub_agent_manager: Arc<SubAgentManager>,
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
pub fn new(sub_agent_manager: Arc<SubAgentManager>) -> Self {
|
||||
Self { sub_agent_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DelegateTool {
|
||||
fn name(&self) -> &str {
|
||||
"delegate"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\
|
||||
inline (阻塞返回结果)、background (异步执行,完成后通知)、\
|
||||
parallel (多个子 Agent 并发执行,聚合结果)。\
|
||||
也可用于查询、取消和列出后台任务。"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["run", "check_task", "cancel_task", "list_tasks"],
|
||||
"description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件)。action=run 时必填"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["inline", "background", "parallel"],
|
||||
"description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline"
|
||||
},
|
||||
"allowed_tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator"
|
||||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": "最大迭代次数,默认 99"
|
||||
},
|
||||
"timeout_secs": {
|
||||
"type": "integer",
|
||||
"description": "超时秒数,默认 3600(1小时)"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"description": "并行模式下的多个子任务(仅 mode=parallel 时使用)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": { "type": "string", "description": "子任务描述" },
|
||||
"allowed_tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "该子任务的工具列表"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "后台任务ID(action=check_task/cancel_task 时必填)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args["action"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
|
||||
|
||||
match action {
|
||||
"run" => self.handle_run(&args).await,
|
||||
"check_task" => self.handle_check_task(&args).await,
|
||||
"cancel_task" => self.handle_cancel_task(&args).await,
|
||||
"list_tasks" => self.handle_list_tasks(&args).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result<SubAgentConfig> {
|
||||
let prompt = args["prompt"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
|
||||
.to_string();
|
||||
|
||||
let allowed_tools: Option<Vec<String>> = args["allowed_tools"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
||||
|
||||
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
|
||||
let timeout_secs = args["timeout_secs"].as_u64();
|
||||
|
||||
Ok(SubAgentConfig {
|
||||
prompt,
|
||||
mode: ExecutionMode::Inline,
|
||||
allowed_tools,
|
||||
max_iterations,
|
||||
timeout_secs,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let mode_str = args["mode"].as_str().unwrap_or("inline");
|
||||
let mode = match mode_str {
|
||||
"inline" => ExecutionMode::Inline,
|
||||
"background" => ExecutionMode::Background,
|
||||
"parallel" => ExecutionMode::Parallel,
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Inline => {
|
||||
let config = self.parse_config_from_args(args)?;
|
||||
let result = self.sub_agent_manager.run_inline(config).await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
match result.status {
|
||||
TaskStatus::Completed => Ok(ToolResult {
|
||||
success: true,
|
||||
output: result.content,
|
||||
error: None,
|
||||
}),
|
||||
TaskStatus::Failed(err) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: result.content,
|
||||
error: Some(err),
|
||||
}),
|
||||
TaskStatus::TimedOut => Ok(ToolResult {
|
||||
success: false,
|
||||
output: result.content,
|
||||
error: Some("sub-agent timed out".into()),
|
||||
}),
|
||||
TaskStatus::Cancelled => Ok(ToolResult {
|
||||
success: false,
|
||||
output: result.content,
|
||||
error: Some("sub-agent cancelled".into()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
ExecutionMode::Background => {
|
||||
let config = self.parse_config_from_args(args)?;
|
||||
let ctx = crate::agent::sub_agent::get_delegate_context()
|
||||
.map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?;
|
||||
|
||||
let task_id = self.sub_agent_manager.run_background(config, ctx).await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("后台任务已启动。\ntask_id: {}", task_id),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
ExecutionMode::Parallel => {
|
||||
let tasks = args["tasks"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?;
|
||||
|
||||
let mut configs = Vec::new();
|
||||
for task in tasks {
|
||||
let prompt = task["prompt"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
|
||||
.to_string();
|
||||
let allowed_tools: Option<Vec<String>> = task["allowed_tools"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
||||
|
||||
configs.push(SubAgentConfig {
|
||||
prompt,
|
||||
mode: ExecutionMode::Inline,
|
||||
allowed_tools,
|
||||
max_iterations: args["max_iterations"].as_u64().map(|v| v as usize),
|
||||
timeout_secs: args["timeout_secs"].as_u64(),
|
||||
});
|
||||
}
|
||||
|
||||
let has_args_allowed = args["allowed_tools"].as_array().is_some();
|
||||
for c in &mut configs {
|
||||
if c.allowed_tools.is_none() && has_args_allowed {
|
||||
c.allowed_tools = args["allowed_tools"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
||||
}
|
||||
}
|
||||
|
||||
let results = self.sub_agent_manager.run_parallel(configs).await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
let mut output = String::new();
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let status_icon = match r.status {
|
||||
TaskStatus::Completed => "✅",
|
||||
TaskStatus::Failed(_) => "❌",
|
||||
TaskStatus::TimedOut => "⏱️ 超时",
|
||||
TaskStatus::Cancelled => "🚫 已取消",
|
||||
};
|
||||
output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon));
|
||||
if !r.content.is_empty() {
|
||||
output.push_str(&r.content);
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
if let TaskStatus::Failed(ref err) = r.status {
|
||||
output.push_str(&format!("错误: {}\n\n", err));
|
||||
}
|
||||
}
|
||||
|
||||
let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed));
|
||||
Ok(ToolResult {
|
||||
success: all_success,
|
||||
output: output.trim().to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let task_id = args["task_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
|
||||
|
||||
match self.sub_agent_manager.check_task(task_id).await {
|
||||
Some(task) => {
|
||||
let status_icon = match task.status.as_str() {
|
||||
"completed" => "✅ 已完成",
|
||||
"failed" => "❌ 失败",
|
||||
"cancelled" => "🚫 已取消",
|
||||
"running" => "🔄 运行中",
|
||||
"pending" => "⏳ 等待中",
|
||||
_ => task.status.as_str(),
|
||||
};
|
||||
let mut output = format!(
|
||||
"任务 ID: {}\n状态: {}\n任务: {}",
|
||||
task.id, status_icon, task.prompt
|
||||
);
|
||||
if let Some(ref result) = task.result {
|
||||
output.push_str(&format!("\n\n结果:\n{}", result));
|
||||
}
|
||||
if let Some(ref error) = task.error {
|
||||
output.push_str(&format!("\n错误: {}", error));
|
||||
}
|
||||
if let Some(started) = task.started_at {
|
||||
if let Some(finished) = task.finished_at {
|
||||
let duration = (finished - started) as f64 / 1000.0;
|
||||
output.push_str(&format!("\n耗时: {:.1}s", duration));
|
||||
}
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("task not found: {}", task_id)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let task_id = args["task_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
|
||||
|
||||
match self.sub_agent_manager.cancel_task(task_id).await {
|
||||
Ok(true) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("后台任务 {} 已取消", task_id),
|
||||
error: None,
|
||||
}),
|
||||
Ok(false) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)),
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("取消失败: {}", e)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let ctx = crate::agent::sub_agent::get_delegate_context()
|
||||
.map_err(|_| anyhow::anyhow!("delegate context not available"))?;
|
||||
let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await;
|
||||
|
||||
if tasks.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "没有后台任务".to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let mut output = String::from("后台任务列表:\n\n");
|
||||
for task in &tasks {
|
||||
let status_icon = match task.status.as_str() {
|
||||
"completed" => "✅",
|
||||
"failed" => "❌",
|
||||
"cancelled" => "🚫",
|
||||
"running" => "🔄",
|
||||
"pending" => "⏳",
|
||||
_ => "❓",
|
||||
};
|
||||
output.push_str(&format!(
|
||||
"{} {} - {} - {} (created: {})\n",
|
||||
status_icon,
|
||||
&task.id[..std::cmp::min(8, task.id.len())],
|
||||
task.prompt.chars().take(60).collect::<String>(),
|
||||
task.status,
|
||||
task.created_at,
|
||||
));
|
||||
}
|
||||
output.push_str(&format!("\n共 {} 个任务", tasks.len()));
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ pub mod calculator;
|
||||
pub mod chat_manager;
|
||||
pub mod content_search;
|
||||
pub mod cron;
|
||||
pub mod delegate;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_search;
|
||||
@ -23,6 +24,7 @@ pub use browser::BrowserTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use chat_manager::ChatManagerTool;
|
||||
pub use content_search::ContentSearchTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_search::FileSearchTool;
|
||||
@ -36,6 +38,7 @@ pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
use std::sync::Arc;
|
||||
use crate::agent::SubAgentManager;
|
||||
use crate::config::BrowserConfig;
|
||||
use crate::memory::MemoryManager;
|
||||
use crate::skills::SkillsLoader;
|
||||
@ -46,6 +49,7 @@ use crate::skills::SkillsLoader;
|
||||
pub fn create_default_tools(
|
||||
skills_loader: Arc<SkillsLoader>,
|
||||
memory: Arc<MemoryManager>,
|
||||
sub_agent_manager: Option<Arc<SubAgentManager>>,
|
||||
browser_config: Option<&BrowserConfig>,
|
||||
) -> ToolRegistry {
|
||||
let registry = ToolRegistry::new();
|
||||
@ -76,5 +80,9 @@ pub fn create_default_tools(
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mgr) = sub_agent_manager {
|
||||
registry.register(DelegateTool::new(mgr));
|
||||
}
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
@ -20,6 +20,11 @@ impl ToolRegistry {
|
||||
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
|
||||
}
|
||||
|
||||
/// Register an existing Arc-wrapped tool by name
|
||||
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
|
||||
self.tools.lock().unwrap().insert(name, tool);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
|
||||
self.tools.lock().unwrap().get(name).cloned()
|
||||
}
|
||||
@ -62,6 +67,17 @@ impl ToolRegistry {
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 生成工具列表描述,用于子 Agent 系统提示词
|
||||
pub fn describe_for_prompt(&self) -> String {
|
||||
let mut entries: Vec<String> = self
|
||||
.iter()
|
||||
.into_iter()
|
||||
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
|
||||
.collect();
|
||||
entries.sort();
|
||||
entries.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolRegistry {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user