PicoBot/src/gateway/tool_registry_factory.rs

251 lines
8.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::config::TaskConfig;
use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository, TodoRepository};
use crate::tools::todo_write::TodoItem;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
TodoReadTool, TodoWriteTool, ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
todo_repository: Arc<dyn TodoRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>,
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
shell_session_manager: Arc<ShellSessionManager>,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
todo_repository: Arc<dyn TodoRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
) -> Self {
Self {
skills,
memories,
scheduler_jobs,
skill_events,
todo_repository,
session_message_sender,
known_agents,
default_timezone,
disabled_tools,
task_config,
subagent_runtime: None,
mcp_manager: None,
todo_state: None,
shell_session_manager: Arc::new(ShellSessionManager::new()),
}
}
pub(crate) fn with_todo_state(
mut self,
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
) -> Self {
self.todo_state = Some(state);
self
}
pub(crate) fn with_subagent_runtime(
mut self,
runtime: Arc<dyn SubAgentRuntime>,
) -> Self {
self.subagent_runtime = Some(runtime);
self
}
pub(crate) fn with_mcp_manager(
mut self,
manager: Arc<McpClientManager>,
) -> Self {
self.mcp_manager = Some(manager);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}
/// Get a reference to the shell session manager for lifecycle control.
#[allow(dead_code)]
pub(crate) fn shell_session_manager(&self) -> Arc<ShellSessionManager> {
self.shell_session_manager.clone()
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone()));
}
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
}
}
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
if self.is_enabled("scheduler_manage") {
registry.register(SchedulerManageTool::new(
self.scheduler_jobs.clone(),
self.known_agents.clone(),
));
}
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_manage") {
registry.register(SkillManageTool::new(self.skills.clone()));
}
if self.is_enabled("bash") {
registry.register(BashTool::new(self.shell_session_manager.clone()));
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注册 Task 工具(如果启用且有 subagent_runtime
if self.is_enabled("task") && self.task_config.enabled {
if let Some(runtime) = &self.subagent_runtime {
registry.register(TaskTool::new(runtime.clone()));
}
}
registry
}
/// 构建子代理专用工具集(不包含 task 工具防止递归)
/// 可选地包含 MCP 工具(通过 mcp_tools 参数传递)
pub(crate) fn build_subagent_tools(
&self,
mcp_tools: Option<Vec<crate::mcp::tool_adapter::McpToolWrapper>>,
) -> ToolRegistry {
let mut registry = ToolRegistry::new();
// 基础工具
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("bash") {
registry.register(BashTool::new(self.shell_session_manager.clone()));
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 记忆工具(只读)
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
// Skill 工具
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
// 进度通知工具
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
// Todo 追踪工具
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
}
}
// 注册 MCP 工具(如果提供)
if let Some(mcp_tools) = mcp_tools {
for tool in mcp_tools {
registry.register(tool);
}
}
// 注意:不注册 task 工具,防止递归创建子代理
registry
}
}