197 lines
6.4 KiB
Rust
197 lines
6.4 KiB
Rust
use std::collections::HashSet;
|
||
use std::sync::Arc;
|
||
|
||
use crate::config::TaskConfig;
|
||
use crate::skills::SkillRuntime;
|
||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||
use crate::tools::{
|
||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
||
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SkillListTool,
|
||
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||
ToolRegistry, WebFetchTool,
|
||
};
|
||
|
||
pub(crate) struct ToolRegistryFactory {
|
||
skills: Arc<SkillRuntime>,
|
||
memories: Arc<dyn MemoryRepository>,
|
||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||
skill_events: Arc<dyn SkillEventRepository>,
|
||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||
known_agents: HashSet<String>,
|
||
default_timezone: String,
|
||
disabled_tools: HashSet<String>,
|
||
task_config: TaskConfig,
|
||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||
}
|
||
|
||
impl ToolRegistryFactory {
|
||
pub(crate) fn new(
|
||
skills: Arc<SkillRuntime>,
|
||
memories: Arc<dyn MemoryRepository>,
|
||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||
skill_events: Arc<dyn SkillEventRepository>,
|
||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||
known_agents: HashSet<String>,
|
||
default_timezone: String,
|
||
disabled_tools: HashSet<String>,
|
||
task_config: TaskConfig,
|
||
) -> Self {
|
||
Self {
|
||
skills,
|
||
memories,
|
||
scheduler_jobs,
|
||
skill_events,
|
||
session_message_sender,
|
||
known_agents,
|
||
default_timezone,
|
||
disabled_tools,
|
||
task_config,
|
||
subagent_runtime: None,
|
||
}
|
||
}
|
||
|
||
pub(crate) fn with_subagent_runtime(
|
||
mut self,
|
||
runtime: Arc<dyn SubAgentRuntime>,
|
||
) -> Self {
|
||
self.subagent_runtime = Some(runtime);
|
||
self
|
||
}
|
||
|
||
fn is_enabled(&self, tool_name: &str) -> bool {
|
||
!self.disabled_tools.contains(tool_name)
|
||
}
|
||
|
||
pub(crate) fn build(&self) -> ToolRegistry {
|
||
let mut registry = ToolRegistry::new();
|
||
|
||
if self.is_enabled("calculator") {
|
||
registry.register(CalculatorTool::new());
|
||
}
|
||
if self.is_enabled("get_time") {
|
||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||
}
|
||
if self.is_enabled("read") {
|
||
registry.register(FileReadTool::new());
|
||
}
|
||
if self.is_enabled("write") {
|
||
registry.register(FileWriteTool::new());
|
||
}
|
||
if self.is_enabled("edit") {
|
||
registry.register(FileEditTool::new());
|
||
}
|
||
if self.is_enabled("memory_search") {
|
||
registry.register(MemorySearchTool::new(self.memories.clone()));
|
||
}
|
||
if self.is_enabled("memory_manage") {
|
||
registry.register(MemoryManageTool::new(self.memories.clone()));
|
||
}
|
||
if self.is_enabled("session_send") {
|
||
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
|
||
}
|
||
if self.is_enabled("scheduler_manage") {
|
||
registry.register(SchedulerManageTool::new(
|
||
self.scheduler_jobs.clone(),
|
||
self.known_agents.clone(),
|
||
));
|
||
}
|
||
if self.is_enabled("skill_activate") {
|
||
registry.register(SkillActivateTool::new(
|
||
self.skills.clone(),
|
||
self.skill_events.clone(),
|
||
));
|
||
}
|
||
if self.is_enabled("skill_list") {
|
||
registry.register(SkillListTool::new(self.skills.clone()));
|
||
}
|
||
if self.is_enabled("skill_manage") {
|
||
registry.register(SkillManageTool::new(self.skills.clone()));
|
||
}
|
||
if self.is_enabled("bash") {
|
||
registry.register(BashTool::new());
|
||
}
|
||
if self.is_enabled("http_request") {
|
||
registry.register(HttpRequestTool::new(
|
||
vec!["*".to_string()],
|
||
1_000_000,
|
||
30,
|
||
false,
|
||
));
|
||
}
|
||
if self.is_enabled("web_fetch") {
|
||
registry.register(WebFetchTool::new(50_000, 30));
|
||
}
|
||
|
||
// 注册 Task 工具(如果启用且有 subagent_runtime)
|
||
if self.is_enabled("task") && self.task_config.enabled {
|
||
if let Some(runtime) = &self.subagent_runtime {
|
||
registry.register(TaskTool::new(runtime.clone()));
|
||
}
|
||
}
|
||
|
||
registry
|
||
}
|
||
|
||
/// 构建子代理专用工具集(不包含 task 工具防止递归)
|
||
pub(crate) fn build_subagent_tools(&self) -> ToolRegistry {
|
||
let mut registry = ToolRegistry::new();
|
||
|
||
// 基础工具
|
||
if self.is_enabled("calculator") {
|
||
registry.register(CalculatorTool::new());
|
||
}
|
||
if self.is_enabled("get_time") {
|
||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||
}
|
||
if self.is_enabled("read") {
|
||
registry.register(FileReadTool::new());
|
||
}
|
||
if self.is_enabled("write") {
|
||
registry.register(FileWriteTool::new());
|
||
}
|
||
if self.is_enabled("edit") {
|
||
registry.register(FileEditTool::new());
|
||
}
|
||
if self.is_enabled("bash") {
|
||
registry.register(BashTool::new());
|
||
}
|
||
if self.is_enabled("http_request") {
|
||
registry.register(HttpRequestTool::new(
|
||
vec!["*".to_string()],
|
||
1_000_000,
|
||
30,
|
||
false,
|
||
));
|
||
}
|
||
if self.is_enabled("web_fetch") {
|
||
registry.register(WebFetchTool::new(50_000, 30));
|
||
}
|
||
|
||
// 记忆工具(只读)
|
||
if self.is_enabled("memory_search") {
|
||
registry.register(MemorySearchTool::new(self.memories.clone()));
|
||
}
|
||
|
||
// Skill 工具
|
||
if self.is_enabled("skill_activate") {
|
||
registry.register(SkillActivateTool::new(
|
||
self.skills.clone(),
|
||
self.skill_events.clone(),
|
||
));
|
||
}
|
||
if self.is_enabled("skill_list") {
|
||
registry.register(SkillListTool::new(self.skills.clone()));
|
||
}
|
||
|
||
// 进度通知工具
|
||
if self.is_enabled("session_send") {
|
||
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
|
||
}
|
||
|
||
// 注意:不注册 task 工具,防止递归创建子代理
|
||
|
||
registry
|
||
}
|
||
}
|