PicoBot/src/gateway/tool_registry_factory.rs
oudecheng 644f5f9132 feat: 子代理继承主代理的 MCP 工具
- 为 McpToolWrapper 添加 Clone trait,支持工具实例复用
- 修改 build_subagent_tools 方法,支持传入 MCP 工具列表
- 调整 runtime 构建顺序:先等待 MCP 连接,再将 MCP 工具传递给子代理

子代理现在可以自动使用主代理配置的 MCP 工具(如 filesystem、fetch 等)。
2026-05-26 11:53:40 +08:00

213 lines
6.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashSet;
use std::sync::Arc;
use crate::config::TaskConfig;
use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool,
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
) -> Self {
Self {
skills,
memories,
scheduler_jobs,
skill_events,
session_message_sender,
known_agents,
default_timezone,
disabled_tools,
task_config,
subagent_runtime: None,
mcp_manager: None,
}
}
pub(crate) fn with_subagent_runtime(
mut self,
runtime: Arc<dyn SubAgentRuntime>,
) -> Self {
self.subagent_runtime = Some(runtime);
self
}
pub(crate) fn with_mcp_manager(
mut self,
manager: Arc<McpClientManager>,
) -> Self {
self.mcp_manager = Some(manager);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone()));
}
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
if self.is_enabled("scheduler_manage") {
registry.register(SchedulerManageTool::new(
self.scheduler_jobs.clone(),
self.known_agents.clone(),
));
}
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_manage") {
registry.register(SkillManageTool::new(self.skills.clone()));
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注册 Task 工具(如果启用且有 subagent_runtime
if self.is_enabled("task") && self.task_config.enabled {
if let Some(runtime) = &self.subagent_runtime {
registry.register(TaskTool::new(runtime.clone()));
}
}
registry
}
/// 构建子代理专用工具集(不包含 task 工具防止递归)
/// 可选地包含 MCP 工具(通过 mcp_tools 参数传递)
pub(crate) fn build_subagent_tools(
&self,
mcp_tools: Option<Vec<crate::mcp::tool_adapter::McpToolWrapper>>,
) -> ToolRegistry {
let mut registry = ToolRegistry::new();
// 基础工具
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 记忆工具(只读)
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
// Skill 工具
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
// 进度通知工具
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
// 注册 MCP 工具(如果提供)
if let Some(mcp_tools) = mcp_tools {
for tool in mcp_tools {
registry.register(tool);
}
}
// 注意:不注册 task 工具,防止递归创建子代理
registry
}
}