PicoBot/src/gateway/tool_registry_factory.rs

197 lines
6.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashSet;
use std::sync::Arc;
use crate::config::TaskConfig;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SkillListTool,
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
) -> Self {
Self {
skills,
memories,
scheduler_jobs,
skill_events,
session_message_sender,
known_agents,
default_timezone,
disabled_tools,
task_config,
subagent_runtime: None,
}
}
pub(crate) fn with_subagent_runtime(
mut self,
runtime: Arc<dyn SubAgentRuntime>,
) -> Self {
self.subagent_runtime = Some(runtime);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone()));
}
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
if self.is_enabled("scheduler_manage") {
registry.register(SchedulerManageTool::new(
self.scheduler_jobs.clone(),
self.known_agents.clone(),
));
}
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_list") {
registry.register(SkillListTool::new(self.skills.clone()));
}
if self.is_enabled("skill_manage") {
registry.register(SkillManageTool::new(self.skills.clone()));
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注册 Task 工具(如果启用且有 subagent_runtime
if self.is_enabled("task") && self.task_config.enabled {
if let Some(runtime) = &self.subagent_runtime {
registry.register(TaskTool::new(runtime.clone()));
}
}
registry
}
/// 构建子代理专用工具集(不包含 task 工具防止递归)
pub(crate) fn build_subagent_tools(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
// 基础工具
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 记忆工具(只读)
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
// Skill 工具
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_list") {
registry.register(SkillListTool::new(self.skills.clone()));
}
// 进度通知工具
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
// 注意:不注册 task 工具,防止递归创建子代理
registry
}
}