From 7626ba2d2fa65ba7a8370f7076560fd264898268 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Mon, 15 Jun 2026 15:33:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(gateway):=20=E6=B7=BB=E5=8A=A0=E5=BE=85?= =?UTF-8?q?=E5=8A=9E=E4=BA=8B=E9=A1=B9=E8=AF=BB=E5=8F=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入 TodoReadTool 工具支持读取当前对话的待办事项列表 - 实现从内存或SQLite数据库读取待办事项的功能 - 添加内存回填机制确保数据一致性 - 在ToolRegistryFactory中注册新的待办事项读取工具 - 更新会话初始化逻辑以传递待办事项存储依赖 - 添加完整的单元测试验证各种读取场景 --- src/gateway/runtime.rs | 4 +- src/gateway/session.rs | 8 +- src/gateway/todo_prompt_provider.rs | 12 + src/gateway/tool_registry_factory.rs | 9 +- src/tools/mod.rs | 2 + src/tools/todo_read.rs | 339 +++++++++++++++++++++++++++ 6 files changed, 370 insertions(+), 4 deletions(-) create mode 100644 src/tools/todo_read.rs diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index ec5d112..38498cc 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -13,7 +13,7 @@ use crate::mcp::McpInitializer; use crate::skills::SkillRuntime; use crate::storage::{ ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, - SessionStore, SkillEventRepository, + SessionStore, SkillEventRepository, TodoRepository, }; use crate::tools::{ DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender, @@ -106,6 +106,7 @@ pub(crate) fn build_session_manager_with_sender( let scheduler_jobs: Arc = store.clone(); let skill_events: Arc = store.clone(); let conversations: Arc = store.clone(); + let todo_repository: Arc = store.clone(); // Create ToolRegistryFactory let factory = ToolRegistryFactory::new( @@ -113,6 +114,7 @@ pub(crate) fn build_session_manager_with_sender( memories, scheduler_jobs, skill_events.clone(), + todo_repository, session_message_sender.clone(), known_agents, default_timezone, diff --git a/src/gateway/session.rs b/src/gateway/session.rs index daf7a97..f089479 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -895,6 +895,7 @@ mod tests { store.clone(), store.clone(), store.clone(), + store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), @@ -942,6 +943,7 @@ mod tests { store.clone(), store.clone(), store.clone(), + store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), @@ -1967,6 +1969,7 @@ mod tests { store.clone(), store.clone(), store.clone(), + store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), @@ -2006,6 +2009,7 @@ mod tests { store.clone(), store.clone(), store.clone(), + store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), @@ -2078,6 +2082,7 @@ mod tests { store.clone(), store.clone(), store.clone(), + store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), @@ -2128,7 +2133,8 @@ mod tests { skills, store.clone(), store.clone(), - store, + store.clone(), + store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), diff --git a/src/gateway/todo_prompt_provider.rs b/src/gateway/todo_prompt_provider.rs index fe8857e..2d9ade2 100644 --- a/src/gateway/todo_prompt_provider.rs +++ b/src/gateway/todo_prompt_provider.rs @@ -69,4 +69,16 @@ const TODO_WRITE_INSTRUCTIONS: &str = r#" {"id": "pQ7nWy2z", "content": "补充测试", "status": "in_progress"} ]} ``` + +### 查询当前列表 + +使用 `todo_read` 工具查看当前任务列表,无需任何参数: +```json +{} +``` + +在以下场景应主动调用 `todo_read`: +- 对话开始时,检查是否有未完成的任务 +- 不确定当前任务状态时,先查询再操作 +- 完成一个任务后,查看剩余任务 "#; diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs index 6c1e026..759881d 100644 --- a/src/gateway/tool_registry_factory.rs +++ b/src/gateway/tool_registry_factory.rs @@ -6,14 +6,14 @@ use tokio::sync::RwLock; use crate::config::TaskConfig; use crate::mcp::McpClientManager; use crate::skills::SkillRuntime; -use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; +use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository, TodoRepository}; use crate::tools::todo_write::TodoItem; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager, SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, - TodoWriteTool, ToolRegistry, WebFetchTool, + TodoReadTool, TodoWriteTool, ToolRegistry, WebFetchTool, }; pub(crate) struct ToolRegistryFactory { @@ -21,6 +21,7 @@ pub(crate) struct ToolRegistryFactory { memories: Arc, scheduler_jobs: Arc, skill_events: Arc, + todo_repository: Arc, session_message_sender: Arc, known_agents: HashSet, default_timezone: String, @@ -38,6 +39,7 @@ impl ToolRegistryFactory { memories: Arc, scheduler_jobs: Arc, skill_events: Arc, + todo_repository: Arc, session_message_sender: Arc, known_agents: HashSet, default_timezone: String, @@ -49,6 +51,7 @@ impl ToolRegistryFactory { memories, scheduler_jobs, skill_events, + todo_repository, session_message_sender, known_agents, default_timezone, @@ -121,6 +124,7 @@ impl ToolRegistryFactory { if self.is_enabled("todo_write") { if let Some(ref state) = self.todo_state { registry.register(TodoWriteTool::new(state.clone())); + registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone())); } } if self.is_enabled("session_send") { @@ -227,6 +231,7 @@ impl ToolRegistryFactory { if self.is_enabled("todo_write") { if let Some(ref state) = self.todo_state { registry.register(TodoWriteTool::new(state.clone())); + registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone())); } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 5a421e6..c3d960a 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -15,6 +15,7 @@ pub mod skill_activate; pub mod skill_manage; pub mod task; pub mod time; +pub mod todo_read; pub mod todo_write; pub mod traits; pub mod web_fetch; @@ -42,6 +43,7 @@ pub use task::{ SubagentCatalog, TaskError, TaskRepository, TaskTool, }; pub use time::TimeTool; +pub use todo_read::TodoReadTool; pub use todo_write::TodoWriteTool; pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; diff --git a/src/tools/todo_read.rs b/src/tools/todo_read.rs new file mode 100644 index 0000000..ecbf4ba --- /dev/null +++ b/src/tools/todo_read.rs @@ -0,0 +1,339 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use serde::Serialize; +use serde_json::json; +use tokio::sync::RwLock; + +use crate::storage::TodoRepository; +use crate::tools::traits::{Tool, ToolContext, ToolResult}; +use crate::tools::todo_write::{TodoItem, scope_key_from_context}; + +// ── 输出结构 ────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize)] +struct TodoReadOutput { + todos: Vec, + count: usize, + scope_key: String, + source: &'static str, +} + +// ── 工具实现 ────────────────────────────────────────────── + +pub struct TodoReadTool { + /// 共享内存状态(与 TodoWriteTool 同一实例) + state: Arc>>>, + /// SQLite 持久化层,用于进程重启后回填内存 + repository: Arc, +} + +impl TodoReadTool { + pub(crate) fn new( + state: Arc>>>, + repository: Arc, + ) -> Self { + Self { state, repository } + } +} + +#[async_trait] +impl Tool for TodoReadTool { + fn name(&self) -> &str { + "todo_read" + } + + fn description(&self) -> &str { + "Read the current todo list for this conversation without modifying it. \ + Returns all tracked tasks with their id, content, and status. \ + No parameters required." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": {}, + "required": [] + }) + } + + fn read_only(&self) -> bool { + true + } + + fn concurrency_safe(&self) -> bool { + true + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(error_result("todo_read requires tool context (session_id)")) + } + + async fn execute_with_context( + &self, + context: &ToolContext, + _args: serde_json::Value, + ) -> anyhow::Result { + // 1. 计算 scope_key + let scope_key = match scope_key_from_context(context) { + Some(key) => key, + None => { + return Ok(error_result( + "todo_read requires session_id or topic_id in tool context", + )) + } + }; + + // 2. 读锁查内存 + { + let guard = self.state.read().await; + if let Some(items) = guard.get(&scope_key) { + if !items.is_empty() { + return Ok(success_result(items, &scope_key, "memory")); + } + } + } + + // 3. 内存为空 → 查 SQLite 并回填 + let records = match self.repository.list_todos(&scope_key) { + Ok(records) => records, + Err(e) => { + tracing::warn!( + error = %e, + %scope_key, + "TodoReadTool: failed to load todos from SQLite" + ); + return Ok(success_result(&[], &scope_key, "memory")); + } + }; + + if records.is_empty() { + return Ok(success_result(&[], &scope_key, "sqlite")); + } + + let items: Vec = records + .into_iter() + .map(|r| TodoItem { + id: r.id, + content: r.content, + status: r.status, + }) + .collect(); + + // 回填内存 + { + let mut guard = self.state.write().await; + guard.insert(scope_key.clone(), items.clone()); + } + + tracing::info!( + scope_key = %scope_key, + todo_count = items.len(), + "TodoReadTool: backfilled memory from SQLite" + ); + + Ok(success_result(&items, &scope_key, "sqlite")) + } +} + +// ── 辅助函数 ────────────────────────────────────────────── + +fn success_result( + items: &[TodoItem], + scope_key: &str, + source: &'static str, +) -> ToolResult { + let output = TodoReadOutput { + todos: items.to_vec(), + count: items.len(), + scope_key: scope_key.to_string(), + source, + }; + ToolResult { + success: true, + output: serde_json::to_string_pretty(&output).unwrap_or_default(), + error: None, + } +} + +fn error_result(message: &str) -> ToolResult { + ToolResult { + success: false, + output: String::new(), + error: Some(message.to_string()), + } +} + +// ── 测试 ────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::tools::traits::ToolContext; + + fn test_context() -> ToolContext { + ToolContext { + channel_name: Some("cli".to_string()), + sender_id: Some("user-1".to_string()), + chat_id: Some("chat-1".to_string()), + session_id: Some("cli:chat-1".to_string()), + topic_id: None, + message_id: Some("msg-1".to_string()), + message_seq: Some(1), + subagent_description: None, + } + } + + fn test_state() -> Arc>>> { + Arc::new(RwLock::new(HashMap::new())) + } + + struct MockTodoRepository { + records: Vec, + } + + impl TodoRepository for MockTodoRepository { + fn replace_todos( + &self, + _scope_key: &str, + _items: &[crate::storage::TodoRecord], + ) -> Result, crate::storage::StorageError> { + Ok(vec![]) + } + + fn list_todos( + &self, + scope_key: &str, + ) -> Result, crate::storage::StorageError> { + Ok(self.records.iter().filter(|r| r.scope_key == scope_key).cloned().collect()) + } + } + + fn mock_record(scope_key: &str, id: &str, content: &str, status: &str) -> crate::storage::TodoRecord { + crate::storage::TodoRecord { + id: id.to_string(), + scope_key: scope_key.to_string(), + session_id: "cli:chat-1".to_string(), + topic_id: None, + content: content.to_string(), + status: status.to_string(), + priority: "medium".to_string(), + created_at: 1000, + updated_at: 1000, + } + } + + #[tokio::test] + async fn test_read_from_memory() { + let state = test_state(); + { + let mut guard = state.write().await; + guard.insert( + "cli:chat-1".to_string(), + vec![TodoItem { + id: "a1".to_string(), + content: "任务A".to_string(), + status: "pending".to_string(), + }], + ); + } + + let repo = Arc::new(MockTodoRepository { records: vec![] }); + let tool = TodoReadTool::new(state, repo); + + let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap(); + assert!(result.success); + + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["count"], 1); + assert_eq!(output["source"], "memory"); + assert_eq!(output["todos"][0]["id"], "a1"); + } + + #[tokio::test] + async fn test_read_from_sqlite_backfill() { + let state = test_state(); + let repo = Arc::new(MockTodoRepository { + records: vec![ + mock_record("cli:chat-1", "b1", "任务B", "in_progress"), + mock_record("cli:chat-1", "b2", "任务C", "pending"), + ], + }); + let tool = TodoReadTool::new(state.clone(), repo); + + let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap(); + assert!(result.success); + + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["count"], 2); + assert_eq!(output["source"], "sqlite"); + + // 验证内存已被回填 + let guard = state.read().await; + let items = guard.get("cli:chat-1").unwrap(); + assert_eq!(items.len(), 2); + } + + #[tokio::test] + async fn test_read_empty_list() { + let state = test_state(); + let repo = Arc::new(MockTodoRepository { records: vec![] }); + let tool = TodoReadTool::new(state, repo); + + let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap(); + assert!(result.success); + + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["count"], 0); + } + + #[tokio::test] + async fn test_read_no_context() { + let state = test_state(); + let repo = Arc::new(MockTodoRepository { records: vec![] }); + let tool = TodoReadTool::new(state, repo); + + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("requires tool context")); + } + + #[tokio::test] + async fn test_read_topic_isolation() { + let state = test_state(); + { + let mut guard = state.write().await; + guard.insert( + "cli:chat-1".to_string(), + vec![TodoItem { + id: "m1".to_string(), + content: "主会话任务".to_string(), + status: "pending".to_string(), + }], + ); + } + + let repo = Arc::new(MockTodoRepository { + records: vec![ + mock_record("topic-xyz", "t1", "话题任务", "completed"), + ], + }); + let tool = TodoReadTool::new(state, repo); + + let topic_ctx = ToolContext { + session_id: Some("cli:chat-1".to_string()), + topic_id: Some("topic-xyz".to_string()), + ..ToolContext::default() + }; + + let result = tool.execute_with_context(&topic_ctx, json!({})).await.unwrap(); + assert!(result.success); + + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["scope_key"], "topic-xyz"); + assert_eq!(output["count"], 1); + assert_eq!(output["todos"][0]["content"], "话题任务"); + assert_eq!(output["source"], "sqlite"); + } +}