PicoBot/src/tools/task/tool.rs

180 lines
5.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::tools::{Tool, ToolContext, ToolResult};
use super::runtime::SubAgentRuntime;
use super::types::{TaskDefinition, TaskToolArgs};
/// Task 工具 - 创建和管理子代理
pub struct TaskTool {
runtime: Arc<dyn SubAgentRuntime>,
/// 最大嵌套深度0 = 无限制用于主 agent>0 = 限制子代理最大嵌套层级)
max_nesting_depth: u32,
}
impl TaskTool {
pub fn new(runtime: Arc<dyn SubAgentRuntime>) -> Self {
Self {
runtime,
max_nesting_depth: 0, // 主 agent 无深度限制
}
}
/// 创建带嵌套深度限制的 TaskTool用于子代理
pub fn new_with_depth(runtime: Arc<dyn SubAgentRuntime>, max_nesting_depth: u32) -> Self {
Self {
runtime,
max_nesting_depth,
}
}
}
#[async_trait]
impl Tool for TaskTool {
fn name(&self) -> &str {
"task"
}
fn description(&self) -> &str {
"Launch a specialized subagent to handle complex, multi-step tasks. \
Subagents run in isolated contexts and can work in parallel. \
You can resume a previous task by providing its task_id."
}
fn parameters_schema(&self) -> serde_json::Value {
let types = self.runtime.available_subagent_names();
let types_array: Vec<serde_json::Value> = types.into_iter().map(|t| json!(t)).collect();
json!({
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "Short description (3-5 words) of what this task does",
"maxLength": 50
},
"prompt": {
"type": "string",
"description": "Detailed instructions for the subagent to execute"
},
"subagent_type": {
"type": "string",
"enum": types_array,
"default": "general",
"description": "Type of subagent to use for the task"
},
"task_id": {
"type": "string",
"description": "Optional: Resume an existing task session by providing its task_id"
}
},
"required": ["description", "prompt"]
})
}
fn read_only(&self) -> bool {
false
}
fn exclusive(&self) -> bool {
// Task 工具创建子代理,不应与其他工具并发执行
true
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
// Task 工具必须通过 execute_with_context 获取父会话信息
Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"task tool requires tool context with session_id, chat_id, and channel_name"
.to_string(),
),
})
}
async fn execute_with_context(
&self,
context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
// 1. 解析参数
let task_args: TaskToolArgs = serde_json::from_value(args.clone())
.map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?;
// 2. 验证描述长度
let word_count = task_args.description.split_whitespace().count();
if task_args.description.len() > 50 || word_count > 7 || word_count < 1 {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"description should be 1-5 words, max 50 characters".to_string(),
),
});
}
// 3. 验证上下文
if context.session_id.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("task tool requires session_id in context".to_string()),
});
}
if context.chat_id.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("task tool requires chat_id in context".to_string()),
});
}
if context.channel_name.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("task tool requires channel_name in context".to_string()),
});
}
// 4. 深度校验(仅对嵌套场景生效,主 agent 的 max_nesting_depth = 0 不限制)
if self.max_nesting_depth > 0 && context.nesting_depth >= self.max_nesting_depth {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Cannot create nested subagent: max nesting depth ({}) reached",
self.max_nesting_depth
)),
});
}
// 5. 执行任务
let result = if let Some(task_id) = task_args.task_id {
// 恢复现有任务
self.runtime
.resume(&task_id, context, task_args.prompt)
.await
} else {
// 创建新任务
let task_def = TaskDefinition::from(task_args);
self.runtime.spawn(context, task_def).await
};
// 5. 构建返回结果
match result {
Ok(task_result) => Ok(ToolResult {
success: true,
output: serde_json::to_string(&task_result)?,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
}