feat(gateway): 添加待办事项读取功能
- 引入 TodoReadTool 工具支持读取当前对话的待办事项列表 - 实现从内存或SQLite数据库读取待办事项的功能 - 添加内存回填机制确保数据一致性 - 在ToolRegistryFactory中注册新的待办事项读取工具 - 更新会话初始化逻辑以传递待办事项存储依赖 - 添加完整的单元测试验证各种读取场景
This commit is contained in:
parent
09dd15f557
commit
7626ba2d2f
@ -13,7 +13,7 @@ use crate::mcp::McpInitializer;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{
|
||||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||||
SessionStore, SkillEventRepository,
|
||||
SessionStore, SkillEventRepository, TodoRepository,
|
||||
};
|
||||
use crate::tools::{
|
||||
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
|
||||
@ -106,6 +106,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
|
||||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||
let todo_repository: Arc<dyn TodoRepository> = store.clone();
|
||||
|
||||
// Create ToolRegistryFactory
|
||||
let factory = ToolRegistryFactory::new(
|
||||
@ -113,6 +114,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
memories,
|
||||
scheduler_jobs,
|
||||
skill_events.clone(),
|
||||
todo_repository,
|
||||
session_message_sender.clone(),
|
||||
known_agents,
|
||||
default_timezone,
|
||||
|
||||
@ -895,6 +895,7 @@ mod tests {
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
@ -942,6 +943,7 @@ mod tests {
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
@ -1967,6 +1969,7 @@ mod tests {
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
@ -2006,6 +2009,7 @@ mod tests {
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
@ -2078,6 +2082,7 @@ mod tests {
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
@ -2128,7 +2133,8 @@ mod tests {
|
||||
skills,
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store,
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
|
||||
@ -69,4 +69,16 @@ const TODO_WRITE_INSTRUCTIONS: &str = r#"
|
||||
{"id": "pQ7nWy2z", "content": "补充测试", "status": "in_progress"}
|
||||
]}
|
||||
```
|
||||
|
||||
### 查询当前列表
|
||||
|
||||
使用 `todo_read` 工具查看当前任务列表,无需任何参数:
|
||||
```json
|
||||
{}
|
||||
```
|
||||
|
||||
在以下场景应主动调用 `todo_read`:
|
||||
- 对话开始时,检查是否有未完成的任务
|
||||
- 不确定当前任务状态时,先查询再操作
|
||||
- 完成一个任务后,查看剩余任务
|
||||
"#;
|
||||
|
||||
@ -6,14 +6,14 @@ use tokio::sync::RwLock;
|
||||
use crate::config::TaskConfig;
|
||||
use crate::mcp::McpClientManager;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository, TodoRepository};
|
||||
use crate::tools::todo_write::TodoItem;
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
||||
SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
|
||||
SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||||
TodoWriteTool, ToolRegistry, WebFetchTool,
|
||||
TodoReadTool, TodoWriteTool, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
|
||||
pub(crate) struct ToolRegistryFactory {
|
||||
@ -21,6 +21,7 @@ pub(crate) struct ToolRegistryFactory {
|
||||
memories: Arc<dyn MemoryRepository>,
|
||||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
todo_repository: Arc<dyn TodoRepository>,
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
@ -38,6 +39,7 @@ impl ToolRegistryFactory {
|
||||
memories: Arc<dyn MemoryRepository>,
|
||||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
todo_repository: Arc<dyn TodoRepository>,
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
@ -49,6 +51,7 @@ impl ToolRegistryFactory {
|
||||
memories,
|
||||
scheduler_jobs,
|
||||
skill_events,
|
||||
todo_repository,
|
||||
session_message_sender,
|
||||
known_agents,
|
||||
default_timezone,
|
||||
@ -121,6 +124,7 @@ impl ToolRegistryFactory {
|
||||
if self.is_enabled("todo_write") {
|
||||
if let Some(ref state) = self.todo_state {
|
||||
registry.register(TodoWriteTool::new(state.clone()));
|
||||
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
|
||||
}
|
||||
}
|
||||
if self.is_enabled("session_send") {
|
||||
@ -227,6 +231,7 @@ impl ToolRegistryFactory {
|
||||
if self.is_enabled("todo_write") {
|
||||
if let Some(ref state) = self.todo_state {
|
||||
registry.register(TodoWriteTool::new(state.clone()));
|
||||
registry.register(TodoReadTool::new(state.clone(), self.todo_repository.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ pub mod skill_activate;
|
||||
pub mod skill_manage;
|
||||
pub mod task;
|
||||
pub mod time;
|
||||
pub mod todo_read;
|
||||
pub mod todo_write;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
@ -42,6 +43,7 @@ pub use task::{
|
||||
SubagentCatalog, TaskError, TaskRepository, TaskTool,
|
||||
};
|
||||
pub use time::TimeTool;
|
||||
pub use todo_read::TodoReadTool;
|
||||
pub use todo_write::TodoWriteTool;
|
||||
pub use traits::{Tool, ToolContext, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
339
src/tools/todo_read.rs
Normal file
339
src/tools/todo_read.rs
Normal file
@ -0,0 +1,339 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::storage::TodoRepository;
|
||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||||
use crate::tools::todo_write::{TodoItem, scope_key_from_context};
|
||||
|
||||
// ── 输出结构 ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct TodoReadOutput {
|
||||
todos: Vec<TodoItem>,
|
||||
count: usize,
|
||||
scope_key: String,
|
||||
source: &'static str,
|
||||
}
|
||||
|
||||
// ── 工具实现 ──────────────────────────────────────────────
|
||||
|
||||
pub struct TodoReadTool {
|
||||
/// 共享内存状态(与 TodoWriteTool 同一实例)
|
||||
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
|
||||
/// SQLite 持久化层,用于进程重启后回填内存
|
||||
repository: Arc<dyn TodoRepository>,
|
||||
}
|
||||
|
||||
impl TodoReadTool {
|
||||
pub(crate) fn new(
|
||||
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
|
||||
repository: Arc<dyn TodoRepository>,
|
||||
) -> Self {
|
||||
Self { state, repository }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TodoReadTool {
|
||||
fn name(&self) -> &str {
|
||||
"todo_read"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the current todo list for this conversation without modifying it. \
|
||||
Returns all tracked tasks with their id, content, and status. \
|
||||
No parameters required."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn concurrency_safe(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
Ok(error_result("todo_read requires tool context (session_id)"))
|
||||
}
|
||||
|
||||
async fn execute_with_context(
|
||||
&self,
|
||||
context: &ToolContext,
|
||||
_args: serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
// 1. 计算 scope_key
|
||||
let scope_key = match scope_key_from_context(context) {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
return Ok(error_result(
|
||||
"todo_read requires session_id or topic_id in tool context",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 2. 读锁查内存
|
||||
{
|
||||
let guard = self.state.read().await;
|
||||
if let Some(items) = guard.get(&scope_key) {
|
||||
if !items.is_empty() {
|
||||
return Ok(success_result(items, &scope_key, "memory"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 内存为空 → 查 SQLite 并回填
|
||||
let records = match self.repository.list_todos(&scope_key) {
|
||||
Ok(records) => records,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
%scope_key,
|
||||
"TodoReadTool: failed to load todos from SQLite"
|
||||
);
|
||||
return Ok(success_result(&[], &scope_key, "memory"));
|
||||
}
|
||||
};
|
||||
|
||||
if records.is_empty() {
|
||||
return Ok(success_result(&[], &scope_key, "sqlite"));
|
||||
}
|
||||
|
||||
let items: Vec<TodoItem> = records
|
||||
.into_iter()
|
||||
.map(|r| TodoItem {
|
||||
id: r.id,
|
||||
content: r.content,
|
||||
status: r.status,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 回填内存
|
||||
{
|
||||
let mut guard = self.state.write().await;
|
||||
guard.insert(scope_key.clone(), items.clone());
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
scope_key = %scope_key,
|
||||
todo_count = items.len(),
|
||||
"TodoReadTool: backfilled memory from SQLite"
|
||||
);
|
||||
|
||||
Ok(success_result(&items, &scope_key, "sqlite"))
|
||||
}
|
||||
}
|
||||
|
||||
// ── 辅助函数 ──────────────────────────────────────────────
|
||||
|
||||
fn success_result(
|
||||
items: &[TodoItem],
|
||||
scope_key: &str,
|
||||
source: &'static str,
|
||||
) -> ToolResult {
|
||||
let output = TodoReadOutput {
|
||||
todos: items.to_vec(),
|
||||
count: items.len(),
|
||||
scope_key: scope_key.to_string(),
|
||||
source,
|
||||
};
|
||||
ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output).unwrap_or_default(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_result(message: &str) -> ToolResult {
|
||||
ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(message.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
// ── 测试 ──────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tools::traits::ToolContext;
|
||||
|
||||
fn test_context() -> ToolContext {
|
||||
ToolContext {
|
||||
channel_name: Some("cli".to_string()),
|
||||
sender_id: Some("user-1".to_string()),
|
||||
chat_id: Some("chat-1".to_string()),
|
||||
session_id: Some("cli:chat-1".to_string()),
|
||||
topic_id: None,
|
||||
message_id: Some("msg-1".to_string()),
|
||||
message_seq: Some(1),
|
||||
subagent_description: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_state() -> Arc<RwLock<HashMap<String, Vec<TodoItem>>>> {
|
||||
Arc::new(RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
struct MockTodoRepository {
|
||||
records: Vec<crate::storage::TodoRecord>,
|
||||
}
|
||||
|
||||
impl TodoRepository for MockTodoRepository {
|
||||
fn replace_todos(
|
||||
&self,
|
||||
_scope_key: &str,
|
||||
_items: &[crate::storage::TodoRecord],
|
||||
) -> Result<Vec<crate::storage::TodoRecord>, crate::storage::StorageError> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn list_todos(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Vec<crate::storage::TodoRecord>, crate::storage::StorageError> {
|
||||
Ok(self.records.iter().filter(|r| r.scope_key == scope_key).cloned().collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_record(scope_key: &str, id: &str, content: &str, status: &str) -> crate::storage::TodoRecord {
|
||||
crate::storage::TodoRecord {
|
||||
id: id.to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
session_id: "cli:chat-1".to_string(),
|
||||
topic_id: None,
|
||||
content: content.to_string(),
|
||||
status: status.to_string(),
|
||||
priority: "medium".to_string(),
|
||||
created_at: 1000,
|
||||
updated_at: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_from_memory() {
|
||||
let state = test_state();
|
||||
{
|
||||
let mut guard = state.write().await;
|
||||
guard.insert(
|
||||
"cli:chat-1".to_string(),
|
||||
vec![TodoItem {
|
||||
id: "a1".to_string(),
|
||||
content: "任务A".to_string(),
|
||||
status: "pending".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
let repo = Arc::new(MockTodoRepository { records: vec![] });
|
||||
let tool = TodoReadTool::new(state, repo);
|
||||
|
||||
let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["count"], 1);
|
||||
assert_eq!(output["source"], "memory");
|
||||
assert_eq!(output["todos"][0]["id"], "a1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_from_sqlite_backfill() {
|
||||
let state = test_state();
|
||||
let repo = Arc::new(MockTodoRepository {
|
||||
records: vec![
|
||||
mock_record("cli:chat-1", "b1", "任务B", "in_progress"),
|
||||
mock_record("cli:chat-1", "b2", "任务C", "pending"),
|
||||
],
|
||||
});
|
||||
let tool = TodoReadTool::new(state.clone(), repo);
|
||||
|
||||
let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["count"], 2);
|
||||
assert_eq!(output["source"], "sqlite");
|
||||
|
||||
// 验证内存已被回填
|
||||
let guard = state.read().await;
|
||||
let items = guard.get("cli:chat-1").unwrap();
|
||||
assert_eq!(items.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_empty_list() {
|
||||
let state = test_state();
|
||||
let repo = Arc::new(MockTodoRepository { records: vec![] });
|
||||
let tool = TodoReadTool::new(state, repo);
|
||||
|
||||
let result = tool.execute_with_context(&test_context(), json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["count"], 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_no_context() {
|
||||
let state = test_state();
|
||||
let repo = Arc::new(MockTodoRepository { records: vec![] });
|
||||
let tool = TodoReadTool::new(state, repo);
|
||||
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("requires tool context"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_topic_isolation() {
|
||||
let state = test_state();
|
||||
{
|
||||
let mut guard = state.write().await;
|
||||
guard.insert(
|
||||
"cli:chat-1".to_string(),
|
||||
vec![TodoItem {
|
||||
id: "m1".to_string(),
|
||||
content: "主会话任务".to_string(),
|
||||
status: "pending".to_string(),
|
||||
}],
|
||||
);
|
||||
}
|
||||
|
||||
let repo = Arc::new(MockTodoRepository {
|
||||
records: vec![
|
||||
mock_record("topic-xyz", "t1", "话题任务", "completed"),
|
||||
],
|
||||
});
|
||||
let tool = TodoReadTool::new(state, repo);
|
||||
|
||||
let topic_ctx = ToolContext {
|
||||
session_id: Some("cli:chat-1".to_string()),
|
||||
topic_id: Some("topic-xyz".to_string()),
|
||||
..ToolContext::default()
|
||||
};
|
||||
|
||||
let result = tool.execute_with_context(&topic_ctx, json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["scope_key"], "topic-xyz");
|
||||
assert_eq!(output["count"], 1);
|
||||
assert_eq!(output["todos"][0]["content"], "话题任务");
|
||||
assert_eq!(output["source"], "sqlite");
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user