PicoBot/src/gateway/tool_registry_factory.rs
oudecheng 881fcace47 feat: 添加 todo_write 工具,支持全量替换和增量合并两种模式
- Tool: 纯内存实现 (Arc<RwLock<HashMap>>),零 DB 依赖,解耦持久化
- 状态机: pending → in_progress → completed/cancelled,单 in_progress 约束
- merge=false: 全量替换模式(默认)
- merge=true: 增量更新模式,只传变更的项,其余保留
- 隔离: scope_key = topic_id.unwrap_or(session_id),topic 和子代理隔离
- 持久化: TodoRepository trait + SessionStore SQLite 实现,在 Session 拦截器层完成
- 前端推送: WsOutbound::TodoList 事件
- Prompt: TodoPromptProvider 中文指令,子代理模板也包含
- 测试: 16 个单元测试,全部通过

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 14:19:07 +08:00

238 lines
7.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::config::TaskConfig;
use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
use crate::tools::todo_write::TodoItem;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool,
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
TodoWriteTool, ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>,
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>,
session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
) -> Self {
Self {
skills,
memories,
scheduler_jobs,
skill_events,
session_message_sender,
known_agents,
default_timezone,
disabled_tools,
task_config,
subagent_runtime: None,
mcp_manager: None,
todo_state: None,
}
}
pub(crate) fn with_todo_state(
mut self,
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
) -> Self {
self.todo_state = Some(state);
self
}
pub(crate) fn with_subagent_runtime(
mut self,
runtime: Arc<dyn SubAgentRuntime>,
) -> Self {
self.subagent_runtime = Some(runtime);
self
}
pub(crate) fn with_mcp_manager(
mut self,
manager: Arc<McpClientManager>,
) -> Self {
self.mcp_manager = Some(manager);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone()));
}
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
}
}
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
if self.is_enabled("scheduler_manage") {
registry.register(SchedulerManageTool::new(
self.scheduler_jobs.clone(),
self.known_agents.clone(),
));
}
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_manage") {
registry.register(SkillManageTool::new(self.skills.clone()));
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注册 Task 工具(如果启用且有 subagent_runtime
if self.is_enabled("task") && self.task_config.enabled {
if let Some(runtime) = &self.subagent_runtime {
registry.register(TaskTool::new(runtime.clone()));
}
}
registry
}
/// 构建子代理专用工具集(不包含 task 工具防止递归)
/// 可选地包含 MCP 工具(通过 mcp_tools 参数传递)
pub(crate) fn build_subagent_tools(
&self,
mcp_tools: Option<Vec<crate::mcp::tool_adapter::McpToolWrapper>>,
) -> ToolRegistry {
let mut registry = ToolRegistry::new();
// 基础工具
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 记忆工具(只读)
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
// Skill 工具
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
// 进度通知工具
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
// Todo 追踪工具
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
}
}
// 注册 MCP 工具(如果提供)
if let Some(mcp_tools) = mcp_tools {
for tool in mcp_tools {
registry.register(tool);
}
}
// 注意:不注册 task 工具,防止递归创建子代理
registry
}
}