180 lines
5.9 KiB
Rust
180 lines
5.9 KiB
Rust
use async_trait::async_trait;
|
||
use serde_json::json;
|
||
use std::sync::Arc;
|
||
|
||
use crate::tools::{Tool, ToolContext, ToolResult};
|
||
|
||
use super::runtime::SubAgentRuntime;
|
||
use super::types::{TaskDefinition, TaskToolArgs};
|
||
|
||
/// Task 工具 - 创建和管理子代理
|
||
pub struct TaskTool {
|
||
runtime: Arc<dyn SubAgentRuntime>,
|
||
/// 最大嵌套深度(0 = 无限制用于主 agent,>0 = 限制子代理最大嵌套层级)
|
||
max_nesting_depth: u32,
|
||
}
|
||
|
||
impl TaskTool {
|
||
pub fn new(runtime: Arc<dyn SubAgentRuntime>) -> Self {
|
||
Self {
|
||
runtime,
|
||
max_nesting_depth: 0, // 主 agent 无深度限制
|
||
}
|
||
}
|
||
|
||
/// 创建带嵌套深度限制的 TaskTool(用于子代理)
|
||
pub fn new_with_depth(runtime: Arc<dyn SubAgentRuntime>, max_nesting_depth: u32) -> Self {
|
||
Self {
|
||
runtime,
|
||
max_nesting_depth,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl Tool for TaskTool {
|
||
fn name(&self) -> &str {
|
||
"task"
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
"Launch a specialized subagent to handle complex, multi-step tasks. \
|
||
Subagents run in isolated contexts and can work in parallel. \
|
||
You can resume a previous task by providing its task_id."
|
||
}
|
||
|
||
fn parameters_schema(&self) -> serde_json::Value {
|
||
let types = self.runtime.available_subagent_names();
|
||
let types_array: Vec<serde_json::Value> = types.into_iter().map(|t| json!(t)).collect();
|
||
|
||
json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"description": {
|
||
"type": "string",
|
||
"description": "Short description (3-5 words) of what this task does",
|
||
"maxLength": 50
|
||
},
|
||
"prompt": {
|
||
"type": "string",
|
||
"description": "Detailed instructions for the subagent to execute"
|
||
},
|
||
"subagent_type": {
|
||
"type": "string",
|
||
"enum": types_array,
|
||
"default": "general",
|
||
"description": "Type of subagent to use for the task"
|
||
},
|
||
"task_id": {
|
||
"type": "string",
|
||
"description": "Optional: Resume an existing task session by providing its task_id"
|
||
}
|
||
},
|
||
"required": ["description", "prompt"]
|
||
})
|
||
}
|
||
|
||
fn read_only(&self) -> bool {
|
||
false
|
||
}
|
||
|
||
fn exclusive(&self) -> bool {
|
||
// Task 工具创建子代理,不应与其他工具并发执行
|
||
true
|
||
}
|
||
|
||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||
// Task 工具必须通过 execute_with_context 获取父会话信息
|
||
Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(
|
||
"task tool requires tool context with session_id, chat_id, and channel_name"
|
||
.to_string(),
|
||
),
|
||
})
|
||
}
|
||
|
||
async fn execute_with_context(
|
||
&self,
|
||
context: &ToolContext,
|
||
args: serde_json::Value,
|
||
) -> anyhow::Result<ToolResult> {
|
||
// 1. 解析参数
|
||
let task_args: TaskToolArgs = serde_json::from_value(args.clone())
|
||
.map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?;
|
||
|
||
// 2. 验证描述长度
|
||
let word_count = task_args.description.split_whitespace().count();
|
||
if task_args.description.len() > 50 || word_count > 7 || word_count < 1 {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(
|
||
"description should be 1-5 words, max 50 characters".to_string(),
|
||
),
|
||
});
|
||
}
|
||
|
||
// 3. 验证上下文
|
||
if context.session_id.is_none() {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some("task tool requires session_id in context".to_string()),
|
||
});
|
||
}
|
||
if context.chat_id.is_none() {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some("task tool requires chat_id in context".to_string()),
|
||
});
|
||
}
|
||
if context.channel_name.is_none() {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some("task tool requires channel_name in context".to_string()),
|
||
});
|
||
}
|
||
|
||
// 4. 深度校验(仅对嵌套场景生效,主 agent 的 max_nesting_depth = 0 不限制)
|
||
if self.max_nesting_depth > 0 && context.nesting_depth >= self.max_nesting_depth {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(format!(
|
||
"Cannot create nested subagent: max nesting depth ({}) reached",
|
||
self.max_nesting_depth
|
||
)),
|
||
});
|
||
}
|
||
|
||
// 5. 执行任务
|
||
let result = if let Some(task_id) = task_args.task_id {
|
||
// 恢复现有任务
|
||
self.runtime
|
||
.resume(&task_id, context, task_args.prompt)
|
||
.await
|
||
} else {
|
||
// 创建新任务
|
||
let task_def = TaskDefinition::from(task_args);
|
||
self.runtime.spawn(context, task_def).await
|
||
};
|
||
|
||
// 5. 构建返回结果
|
||
match result {
|
||
Ok(task_result) => Ok(ToolResult {
|
||
success: true,
|
||
output: serde_json::to_string(&task_result)?,
|
||
error: None,
|
||
}),
|
||
Err(e) => Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(e.to_string()),
|
||
}),
|
||
}
|
||
}
|
||
} |