feat(gateway): 添加待办事项读取功能

- 引入 TodoReadTool 工具支持读取当前对话的待办事项列表
- 实现从内存或SQLite数据库读取待办事项的功能
- 添加内存回填机制确保数据一致性
- 在ToolRegistryFactory中注册新的待办事项读取工具
- 更新会话初始化逻辑以传递待办事项存储依赖
- 添加完整的单元测试验证各种读取场景
This commit is contained in:
oudecheng 2026-06-15 15:33:43 +08:00
parent 09dd15f557
commit 7626ba2d2f
6 changed files with 370 additions and 4 deletions

View File

@ -13,7 +13,7 @@ use crate::mcp::McpInitializer;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{ use crate::storage::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SessionStore, SkillEventRepository, SessionStore, SkillEventRepository, TodoRepository,
}; };
use crate::tools::{ use crate::tools::{
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender, DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
@ -106,6 +106,7 @@ pub(crate) fn build_session_manager_with_sender(
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone(); let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone(); let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let conversations: Arc<dyn ConversationRepository> = store.clone(); let conversations: Arc<dyn ConversationRepository> = store.clone();
let todo_repository: Arc<dyn TodoRepository> = store.clone();
// Create ToolRegistryFactory // Create ToolRegistryFactory
let factory = ToolRegistryFactory::new( let factory = ToolRegistryFactory::new(
@ -113,6 +114,7 @@ pub(crate) fn build_session_manager_with_sender(
memories, memories,
scheduler_jobs, scheduler_jobs,
skill_events.clone(), skill_events.clone(),
todo_repository,
session_message_sender.clone(), session_message_sender.clone(),
known_agents, known_agents,
default_timezone, default_timezone,

View File

@ -895,6 +895,7 @@ mod tests {
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
HashSet::new(), HashSet::new(),
"Asia/Shanghai".to_string(), "Asia/Shanghai".to_string(),
@ -942,6 +943,7 @@ mod tests {
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
HashSet::new(), HashSet::new(),
"Asia/Shanghai".to_string(), "Asia/Shanghai".to_string(),
@ -1967,6 +1969,7 @@ mod tests {
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
HashSet::new(), HashSet::new(),
"Asia/Shanghai".to_string(), "Asia/Shanghai".to_string(),
@ -2006,6 +2009,7 @@ mod tests {
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
HashSet::new(), HashSet::new(),
"Asia/Shanghai".to_string(), "Asia/Shanghai".to_string(),
@ -2078,6 +2082,7 @@ mod tests {
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(), store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
HashSet::new(), HashSet::new(),
"Asia/Shanghai".to_string(), "Asia/Shanghai".to_string(),
@ -2128,7 +2133,8 @@ mod tests {
skills, skills,
store.clone(), store.clone(),
store.clone(), store.clone(),
store, store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
HashSet::new(), HashSet::new(),
"Asia/Shanghai".to_string(), "Asia/Shanghai".to_string(),

View File

@ -69,4 +69,16 @@ const TODO_WRITE_INSTRUCTIONS: &str = r#"
{"id": "pQ7nWy2z", "content": "补充测试", "status": "in_progress"} {"id": "pQ7nWy2z", "content": "补充测试", "status": "in_progress"}
]} ]}
``` ```
###
使 `todo_read`
```json
{}
```
`todo_read`
-
-
-
"#; "#;

View File

@ -6,14 +6,14 @@ use tokio::sync::RwLock;
use crate::config::TaskConfig; use crate::config::TaskConfig;
use crate::mcp::McpClientManager; use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository, TodoRepository};
use crate::tools::todo_write::TodoItem; use crate::tools::todo_write::TodoItem;
use crate::tools::{ use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool, HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager, SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
TodoWriteTool, ToolRegistry, WebFetchTool, TodoReadTool, TodoWriteTool, ToolRegistry, WebFetchTool,
}; };
pub(crate) struct ToolRegistryFactory { pub(crate) struct ToolRegistryFactory {
@ -21,6 +21,7 @@ pub(crate) struct ToolRegistryFactory {
memories: Arc<dyn MemoryRepository>, memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>, scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
todo_repository: Arc<dyn TodoRepository>,
session_message_sender: Arc<dyn SessionMessageSender>, session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>, known_agents: HashSet<String>,
default_timezone: String, default_timezone: String,
@ -38,6 +39,7 @@ impl ToolRegistryFactory {
memories: Arc<dyn MemoryRepository>, memories: Arc<dyn MemoryRepository>,
scheduler_jobs: Arc<dyn SchedulerJobRepository>, scheduler_jobs: Arc<dyn SchedulerJobRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
todo_repository: Arc<dyn TodoRepository>,
session_message_sender: Arc<dyn SessionMessageSender>, session_message_sender: Arc<dyn SessionMessageSender>,
known_agents: HashSet<String>, known_agents: HashSet<String>,
default_timezone: String, default_timezone: String,
@ -49,6 +51,7 @@ impl ToolRegistryFactory {
memories, memories,
scheduler_jobs, scheduler_jobs,
skill_events, skill_events,
todo_repository,
session_message_sender, session_message_sender,
known_agents, known_agents,
default_timezone, default_timezone,
@ -121,6 +124,7 @@ impl ToolRegistryFactory {
if self.is_enabled("todo_write") { if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state { if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone())); registry.register(TodoWriteTool::new(state.clone()));
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
} }
} }
if self.is_enabled("session_send") { if self.is_enabled("session_send") {
@ -227,6 +231,7 @@ impl ToolRegistryFactory {
if self.is_enabled("todo_write") { if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state { if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone())); registry.register(TodoWriteTool::new(state.clone()));
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
} }
} }

View File

@ -15,6 +15,7 @@ pub mod skill_activate;
pub mod skill_manage; pub mod skill_manage;
pub mod task; pub mod task;
pub mod time; pub mod time;
pub mod todo_read;
pub mod todo_write; pub mod todo_write;
pub mod traits; pub mod traits;
pub mod web_fetch; pub mod web_fetch;
@ -42,6 +43,7 @@ pub use task::{
SubagentCatalog, TaskError, TaskRepository, TaskTool, SubagentCatalog, TaskError, TaskRepository, TaskTool,
}; };
pub use time::TimeTool; pub use time::TimeTool;
pub use todo_read::TodoReadTool;
pub use todo_write::TodoWriteTool; pub use todo_write::TodoWriteTool;
pub use traits::{Tool, ToolContext, ToolResult}; pub use traits::{Tool, ToolContext, ToolResult};
pub use web_fetch::WebFetchTool; pub use web_fetch::WebFetchTool;

339
src/tools/todo_read.rs Normal file
View File

@ -0,0 +1,339 @@
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::Serialize;
use serde_json::json;
use tokio::sync::RwLock;
use crate::storage::TodoRepository;
use crate::tools::traits::{Tool, ToolContext, ToolResult};
use crate::tools::todo_write::{TodoItem, scope_key_from_context};
// ── 输出结构 ──────────────────────────────────────────────
#[derive(Debug, Clone, Serialize)]
struct TodoReadOutput {
todos: Vec<TodoItem>,
count: usize,
scope_key: String,
source: &'static str,
}
// ── 工具实现 ──────────────────────────────────────────────
pub struct TodoReadTool {
/// 共享内存状态(与 TodoWriteTool 同一实例)
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
/// SQLite 持久化层,用于进程重启后回填内存
repository: Arc<dyn TodoRepository>,
}
impl TodoReadTool {
pub(crate) fn new(
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
repository: Arc<dyn TodoRepository>,
) -> Self {
Self { state, repository }
}
}
#[async_trait]
impl Tool for TodoReadTool {
fn name(&self) -> &str {
"todo_read"
}
fn description(&self) -> &str {
"Read the current todo list for this conversation without modifying it. \
Returns all tracked tasks with their id, content, and status. \
No parameters required."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {},
"required": []
})
}
fn read_only(&self) -> bool {
true
}
fn concurrency_safe(&self) -> bool {
true
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
Ok(error_result("todo_read requires tool context (session_id)"))
}
async fn execute_with_context(
&self,
context: &ToolContext,
_args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
// 1. 计算 scope_key
let scope_key = match scope_key_from_context(context) {
Some(key) => key,
None => {
return Ok(error_result(
"todo_read requires session_id or topic_id in tool context",
))
}
};
// 2. 读锁查内存
{
let guard = self.state.read().await;
if let Some(items) = guard.get(&scope_key) {
if !items.is_empty() {
return Ok(success_result(items, &scope_key, "memory"));
}
}
}
// 3. 内存为空 → 查 SQLite 并回填
let records = match self.repository.list_todos(&scope_key) {
Ok(records) => records,
Err(e) => {
tracing::warn!(
error = %e,
%scope_key,
"TodoReadTool: failed to load todos from SQLite"
);
return Ok(success_result(&[], &scope_key, "memory"));
}
};
if records.is_empty() {
return Ok(success_result(&[], &scope_key, "sqlite"));
}
let items: Vec<TodoItem> = records
.into_iter()
.map(|r| TodoItem {
id: r.id,
content: r.content,
status: r.status,
})
.collect();
// 回填内存
{
let mut guard = self.state.write().await;
guard.insert(scope_key.clone(), items.clone());
}
tracing::info!(
scope_key = %scope_key,
todo_count = items.len(),
"TodoReadTool: backfilled memory from SQLite"
);
Ok(success_result(&items, &scope_key, "sqlite"))
}
}
// ── 辅助函数 ──────────────────────────────────────────────
fn success_result(
items: &[TodoItem],
scope_key: &str,
source: &'static str,
) -> ToolResult {
let output = TodoReadOutput {
todos: items.to_vec(),
count: items.len(),
scope_key: scope_key.to_string(),
source,
};
ToolResult {
success: true,
output: serde_json::to_string_pretty(&output).unwrap_or_default(),
error: None,
}
}
fn error_result(message: &str) -> ToolResult {
ToolResult {
success: false,
output: String::new(),
error: Some(message.to_string()),
}
}
// ── 测试 ──────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::traits::ToolContext;
fn test_context() -> ToolContext {
ToolContext {
channel_name: Some("cli".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("chat-1".to_string()),
session_id: Some("cli:chat-1".to_string()),
topic_id: None,
message_id: Some("msg-1".to_string()),
message_seq: Some(1),
subagent_description: None,
}
}
fn test_state() -> Arc<RwLock<HashMap<String, Vec<TodoItem>>>> {
Arc::new(RwLock::new(HashMap::new()))
}
struct MockTodoRepository {
records: Vec<crate::storage::TodoRecord>,
}
impl TodoRepository for MockTodoRepository {
fn replace_todos(
&self,
_scope_key: &str,
_items: &[crate::storage::TodoRecord],
) -> Result<Vec<crate::storage::TodoRecord>, crate::storage::StorageError> {
Ok(vec![])
}
fn list_todos(
&self,
scope_key: &str,
) -> Result<Vec<crate::storage::TodoRecord>, crate::storage::StorageError> {
Ok(self.records.iter().filter(|r| r.scope_key == scope_key).cloned().collect())
}
}
fn mock_record(scope_key: &str, id: &str, content: &str, status: &str) -> crate::storage::TodoRecord {
crate::storage::TodoRecord {
id: id.to_string(),
scope_key: scope_key.to_string(),
session_id: "cli:chat-1".to_string(),
topic_id: None,
content: content.to_string(),
status: status.to_string(),
priority: "medium".to_string(),
created_at: 1000,
updated_at: 1000,
}
}
#[tokio::test]
async fn test_read_from_memory() {
let state = test_state();
{
let mut guard = state.write().await;
guard.insert(
"cli:chat-1".to_string(),
vec![TodoItem {
id: "a1".to_string(),
content: "任务A".to_string(),
status: "pending".to_string(),
}],
);
}
let repo = Arc::new(MockTodoRepository { records: vec![] });
let tool = TodoReadTool::new(state, repo);
let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["count"], 1);
assert_eq!(output["source"], "memory");
assert_eq!(output["todos"][0]["id"], "a1");
}
#[tokio::test]
async fn test_read_from_sqlite_backfill() {
let state = test_state();
let repo = Arc::new(MockTodoRepository {
records: vec![
mock_record("cli:chat-1", "b1", "任务B", "in_progress"),
mock_record("cli:chat-1", "b2", "任务C", "pending"),
],
});
let tool = TodoReadTool::new(state.clone(), repo);
let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["count"], 2);
assert_eq!(output["source"], "sqlite");
// 验证内存已被回填
let guard = state.read().await;
let items = guard.get("cli:chat-1").unwrap();
assert_eq!(items.len(), 2);
}
#[tokio::test]
async fn test_read_empty_list() {
let state = test_state();
let repo = Arc::new(MockTodoRepository { records: vec![] });
let tool = TodoReadTool::new(state, repo);
let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["count"], 0);
}
#[tokio::test]
async fn test_read_no_context() {
let state = test_state();
let repo = Arc::new(MockTodoRepository { records: vec![] });
let tool = TodoReadTool::new(state, repo);
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("requires tool context"));
}
#[tokio::test]
async fn test_read_topic_isolation() {
let state = test_state();
{
let mut guard = state.write().await;
guard.insert(
"cli:chat-1".to_string(),
vec![TodoItem {
id: "m1".to_string(),
content: "主会话任务".to_string(),
status: "pending".to_string(),
}],
);
}
let repo = Arc::new(MockTodoRepository {
records: vec![
mock_record("topic-xyz", "t1", "话题任务", "completed"),
],
});
let tool = TodoReadTool::new(state, repo);
let topic_ctx = ToolContext {
session_id: Some("cli:chat-1".to_string()),
topic_id: Some("topic-xyz".to_string()),
..ToolContext::default()
};
let result = tool.execute_with_context(&topic_ctx, json!({})).await.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["scope_key"], "topic-xyz");
assert_eq!(output["count"], 1);
assert_eq!(output["todos"][0]["content"], "话题任务");
assert_eq!(output["source"], "sqlite");
}
}