- Tool: 纯内存实现 (Arc<RwLock<HashMap>>),零 DB 依赖,解耦持久化 - 状态机: pending → in_progress → completed/cancelled,单 in_progress 约束 - merge=false: 全量替换模式(默认) - merge=true: 增量更新模式,只传变更的项,其余保留 - 隔离: scope_key = topic_id.unwrap_or(session_id),topic 和子代理隔离 - 持久化: TodoRepository trait + SessionStore SQLite 实现,在 Session 拦截器层完成 - 前端推送: WsOutbound::TodoList 事件 - Prompt: TodoPromptProvider 中文指令,子代理模板也包含 - 测试: 16 个单元测试,全部通过 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
238 lines
7.5 KiB
Rust
238 lines
7.5 KiB
Rust
use std::collections::{HashMap, HashSet};
|
||
use std::sync::Arc;
|
||
|
||
use tokio::sync::RwLock;
|
||
|
||
use crate::config::TaskConfig;
|
||
use crate::mcp::McpClientManager;
|
||
use crate::skills::SkillRuntime;
|
||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||
use crate::tools::todo_write::TodoItem;
|
||
use crate::tools::{
|
||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
||
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool,
|
||
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||
TodoWriteTool, ToolRegistry, WebFetchTool,
|
||
};
|
||
|
||
pub(crate) struct ToolRegistryFactory {
|
||
skills: Arc<SkillRuntime>,
|
||
memories: Arc<dyn MemoryRepository>,
|
||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||
skill_events: Arc<dyn SkillEventRepository>,
|
||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||
known_agents: HashSet<String>,
|
||
default_timezone: String,
|
||
disabled_tools: HashSet<String>,
|
||
task_config: TaskConfig,
|
||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||
mcp_manager: Option<Arc<McpClientManager>>,
|
||
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
|
||
}
|
||
|
||
impl ToolRegistryFactory {
|
||
pub(crate) fn new(
|
||
skills: Arc<SkillRuntime>,
|
||
memories: Arc<dyn MemoryRepository>,
|
||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||
skill_events: Arc<dyn SkillEventRepository>,
|
||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||
known_agents: HashSet<String>,
|
||
default_timezone: String,
|
||
disabled_tools: HashSet<String>,
|
||
task_config: TaskConfig,
|
||
) -> Self {
|
||
Self {
|
||
skills,
|
||
memories,
|
||
scheduler_jobs,
|
||
skill_events,
|
||
session_message_sender,
|
||
known_agents,
|
||
default_timezone,
|
||
disabled_tools,
|
||
task_config,
|
||
subagent_runtime: None,
|
||
mcp_manager: None,
|
||
todo_state: None,
|
||
}
|
||
}
|
||
|
||
pub(crate) fn with_todo_state(
|
||
mut self,
|
||
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
|
||
) -> Self {
|
||
self.todo_state = Some(state);
|
||
self
|
||
}
|
||
|
||
pub(crate) fn with_subagent_runtime(
|
||
mut self,
|
||
runtime: Arc<dyn SubAgentRuntime>,
|
||
) -> Self {
|
||
self.subagent_runtime = Some(runtime);
|
||
self
|
||
}
|
||
|
||
pub(crate) fn with_mcp_manager(
|
||
mut self,
|
||
manager: Arc<McpClientManager>,
|
||
) -> Self {
|
||
self.mcp_manager = Some(manager);
|
||
self
|
||
}
|
||
|
||
fn is_enabled(&self, tool_name: &str) -> bool {
|
||
!self.disabled_tools.contains(tool_name)
|
||
}
|
||
|
||
pub(crate) fn build(&self) -> ToolRegistry {
|
||
let mut registry = ToolRegistry::new();
|
||
|
||
if self.is_enabled("calculator") {
|
||
registry.register(CalculatorTool::new());
|
||
}
|
||
if self.is_enabled("get_time") {
|
||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||
}
|
||
if self.is_enabled("read") {
|
||
registry.register(FileReadTool::new());
|
||
}
|
||
if self.is_enabled("write") {
|
||
registry.register(FileWriteTool::new());
|
||
}
|
||
if self.is_enabled("edit") {
|
||
registry.register(FileEditTool::new());
|
||
}
|
||
if self.is_enabled("memory_search") {
|
||
registry.register(MemorySearchTool::new(self.memories.clone()));
|
||
}
|
||
if self.is_enabled("memory_manage") {
|
||
registry.register(MemoryManageTool::new(self.memories.clone()));
|
||
}
|
||
if self.is_enabled("todo_write") {
|
||
if let Some(ref state) = self.todo_state {
|
||
registry.register(TodoWriteTool::new(state.clone()));
|
||
}
|
||
}
|
||
if self.is_enabled("session_send") {
|
||
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
|
||
}
|
||
if self.is_enabled("scheduler_manage") {
|
||
registry.register(SchedulerManageTool::new(
|
||
self.scheduler_jobs.clone(),
|
||
self.known_agents.clone(),
|
||
));
|
||
}
|
||
if self.is_enabled("skill_activate") {
|
||
registry.register(SkillActivateTool::new(
|
||
self.skills.clone(),
|
||
self.skill_events.clone(),
|
||
));
|
||
}
|
||
if self.is_enabled("skill_manage") {
|
||
registry.register(SkillManageTool::new(self.skills.clone()));
|
||
}
|
||
if self.is_enabled("bash") {
|
||
registry.register(BashTool::new());
|
||
}
|
||
if self.is_enabled("http_request") {
|
||
registry.register(HttpRequestTool::new(
|
||
vec!["*".to_string()],
|
||
1_000_000,
|
||
30,
|
||
false,
|
||
));
|
||
}
|
||
if self.is_enabled("web_fetch") {
|
||
registry.register(WebFetchTool::new(50_000, 30));
|
||
}
|
||
|
||
// 注册 Task 工具(如果启用且有 subagent_runtime)
|
||
if self.is_enabled("task") && self.task_config.enabled {
|
||
if let Some(runtime) = &self.subagent_runtime {
|
||
registry.register(TaskTool::new(runtime.clone()));
|
||
}
|
||
}
|
||
|
||
registry
|
||
}
|
||
|
||
/// 构建子代理专用工具集(不包含 task 工具防止递归)
|
||
/// 可选地包含 MCP 工具(通过 mcp_tools 参数传递)
|
||
pub(crate) fn build_subagent_tools(
|
||
&self,
|
||
mcp_tools: Option<Vec<crate::mcp::tool_adapter::McpToolWrapper>>,
|
||
) -> ToolRegistry {
|
||
let mut registry = ToolRegistry::new();
|
||
|
||
// 基础工具
|
||
if self.is_enabled("calculator") {
|
||
registry.register(CalculatorTool::new());
|
||
}
|
||
if self.is_enabled("get_time") {
|
||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||
}
|
||
if self.is_enabled("read") {
|
||
registry.register(FileReadTool::new());
|
||
}
|
||
if self.is_enabled("write") {
|
||
registry.register(FileWriteTool::new());
|
||
}
|
||
if self.is_enabled("edit") {
|
||
registry.register(FileEditTool::new());
|
||
}
|
||
if self.is_enabled("bash") {
|
||
registry.register(BashTool::new());
|
||
}
|
||
if self.is_enabled("http_request") {
|
||
registry.register(HttpRequestTool::new(
|
||
vec!["*".to_string()],
|
||
1_000_000,
|
||
30,
|
||
false,
|
||
));
|
||
}
|
||
if self.is_enabled("web_fetch") {
|
||
registry.register(WebFetchTool::new(50_000, 30));
|
||
}
|
||
|
||
// 记忆工具(只读)
|
||
if self.is_enabled("memory_search") {
|
||
registry.register(MemorySearchTool::new(self.memories.clone()));
|
||
}
|
||
|
||
// Skill 工具
|
||
if self.is_enabled("skill_activate") {
|
||
registry.register(SkillActivateTool::new(
|
||
self.skills.clone(),
|
||
self.skill_events.clone(),
|
||
));
|
||
}
|
||
|
||
// 进度通知工具
|
||
if self.is_enabled("session_send") {
|
||
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
|
||
}
|
||
|
||
// Todo 追踪工具
|
||
if self.is_enabled("todo_write") {
|
||
if let Some(ref state) = self.todo_state {
|
||
registry.register(TodoWriteTool::new(state.clone()));
|
||
}
|
||
}
|
||
|
||
// 注册 MCP 工具(如果提供)
|
||
if let Some(mcp_tools) = mcp_tools {
|
||
for tool in mcp_tools {
|
||
registry.register(tool);
|
||
}
|
||
}
|
||
|
||
// 注意:不注册 task 工具,防止递归创建子代理
|
||
|
||
registry
|
||
}
|
||
}
|