PicoBot/src/gateway/tool_registry_factory.rs
oudecheng 7626ba2d2f feat(gateway): 添加待办事项读取功能
- 引入 TodoReadTool 工具支持读取当前对话的待办事项列表
- 实现从内存或SQLite数据库读取待办事项的功能
- 添加内存回填机制确保数据一致性
- 在ToolRegistryFactory中注册新的待办事项读取工具
- 更新会话初始化逻辑以传递待办事项存储依赖
- 添加完整的单元测试验证各种读取场景
2026-06-15 15:33:43 +08:00

250 lines
8.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::config::TaskConfig;
use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository, TodoRepository};
use crate::tools::todo_write::TodoItem;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
TodoReadTool, TodoWriteTool, ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
todo_repository: Arc<dyn TodoRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>,
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
shell_session_manager: Arc<ShellSessionManager>,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
todo_repository: Arc<dyn TodoRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
) -> Self {
Self {
skills,
memories,
scheduler_jobs,
skill_events,
todo_repository,
session_message_sender,
known_agents,
default_timezone,
disabled_tools,
task_config,
subagent_runtime: None,
mcp_manager: None,
todo_state: None,
shell_session_manager: Arc::new(ShellSessionManager::new()),
}
}
pub(crate) fn with_todo_state(
mut self,
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
) -> Self {
self.todo_state = Some(state);
self
}
pub(crate) fn with_subagent_runtime(
mut self,
runtime: Arc<dyn SubAgentRuntime>,
) -> Self {
self.subagent_runtime = Some(runtime);
self
}
pub(crate) fn with_mcp_manager(
mut self,
manager: Arc<McpClientManager>,
) -> Self {
self.mcp_manager = Some(manager);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}
/// Get a reference to the shell session manager for lifecycle control.
pub(crate) fn shell_session_manager(&self) -> Arc<ShellSessionManager> {
self.shell_session_manager.clone()
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone()));
}
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
}
}
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
if self.is_enabled("scheduler_manage") {
registry.register(SchedulerManageTool::new(
self.scheduler_jobs.clone(),
self.known_agents.clone(),
));
}
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_manage") {
registry.register(SkillManageTool::new(self.skills.clone()));
}
if self.is_enabled("bash") {
registry.register(BashTool::new(self.shell_session_manager.clone()));
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注册 Task 工具(如果启用且有 subagent_runtime
if self.is_enabled("task") && self.task_config.enabled {
if let Some(runtime) = &self.subagent_runtime {
registry.register(TaskTool::new(runtime.clone()));
}
}
registry
}
/// 构建子代理专用工具集(不包含 task 工具防止递归)
/// 可选地包含 MCP 工具(通过 mcp_tools 参数传递)
pub(crate) fn build_subagent_tools(
&self,
mcp_tools: Option<Vec<crate::mcp::tool_adapter::McpToolWrapper>>,
) -> ToolRegistry {
let mut registry = ToolRegistry::new();
// 基础工具
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("bash") {
registry.register(BashTool::new(self.shell_session_manager.clone()));
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 记忆工具(只读)
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
// Skill 工具
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
// 进度通知工具
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
// Todo 追踪工具
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
}
}
// 注册 MCP 工具(如果提供)
if let Some(mcp_tools) = mcp_tools {
for tool in mcp_tools {
registry.register(tool);
}
}
// 注意:不注册 task 工具,防止递归创建子代理
registry
}
}