From 881fcace47724a7ab3a73365ba8ecd3b8cecd9ea Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Fri, 12 Jun 2026 14:19:07 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20todo=5Fwrite=20?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=EF=BC=8C=E6=94=AF=E6=8C=81=E5=85=A8=E9=87=8F?= =?UTF-8?q?=E6=9B=BF=E6=8D=A2=E5=92=8C=E5=A2=9E=E9=87=8F=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E4=B8=A4=E7=A7=8D=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Tool: 纯内存实现 (Arc>),零 DB 依赖,解耦持久化 - 状态机: pending → in_progress → completed/cancelled,单 in_progress 约束 - merge=false: 全量替换模式(默认) - merge=true: 增量更新模式,只传变更的项,其余保留 - 隔离: scope_key = topic_id.unwrap_or(session_id),topic 和子代理隔离 - 持久化: TodoRepository trait + SessionStore SQLite 实现,在 Session 拦截器层完成 - 前端推送: WsOutbound::TodoList 事件 - Prompt: TodoPromptProvider 中文指令,子代理模板也包含 - 测试: 16 个单元测试,全部通过 Co-Authored-By: Claude Opus 4.8 --- src/gateway/agent_factory.rs | 2 + src/gateway/execution.rs | 8 + src/gateway/mod.rs | 1 + src/gateway/runtime.rs | 8 + src/gateway/session.rs | 78 ++ src/gateway/todo_prompt_provider.rs | 62 ++ src/gateway/tool_registry_factory.rs | 29 +- src/protocol/mod.rs | 16 + src/storage/mod.rs | 111 ++- src/storage/ports.rs | 28 +- src/storage/records.rs | 13 + src/tools/mod.rs | 2 + src/tools/task/prompt.rs | 2 + src/tools/task/types.rs | 2 +- src/tools/todo_write.rs | 1153 ++++++++++++++++++++++++++ 15 files changed, 1508 insertions(+), 7 deletions(-) create mode 100644 src/gateway/todo_prompt_provider.rs create mode 100644 src/tools/todo_write.rs diff --git a/src/gateway/agent_factory.rs b/src/gateway/agent_factory.rs index 93bc6be..cc3be2c 100644 --- a/src/gateway/agent_factory.rs +++ b/src/gateway/agent_factory.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider}; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::AgentPromptProvider; +use crate::gateway::todo_prompt_provider::TodoPromptProvider; use crate::skills::{SkillPromptProvider, SkillRuntime}; use crate::storage::persistent_session_id; use crate::storage::PromptInjectionRepository; @@ -53,6 +54,7 @@ impl AgentFactory { self.prompt_repository.clone(), )), Box::new(SkillPromptProvider::new(self.skills.clone())), + Box::new(TodoPromptProvider::new()), ])); AgentLoop::with_tools_and_system_prompt_provider( diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index 4ab20be..c4e2bee 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -190,6 +190,14 @@ impl AgentExecutionService { // 只有当是最新回合时才触发历史压缩 let should_schedule_compaction = is_current_turn; + // 拦截 todo_write 结果:持久化 + 前端推送 + if is_current_turn { + session.intercept_todo_write_results( + &request.result.emitted_messages, + request.chat_id, + ); + } + Ok(FinalizedAgentResult { outbound_messages, should_schedule_compaction, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2453933..4d9d724 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -25,6 +25,7 @@ pub mod session_message_service; pub mod session_pool; pub mod static_files; pub mod tool_registry_factory; +pub mod todo_prompt_provider; pub mod ws; use axum::{Router, routing}; diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index bc40403..43a77e7 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -3,6 +3,8 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use tokio::sync::RwLock; + use crate::agent::AgentError; use crate::bus::MessageBus; use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, SubagentsConfig, TaskConfig}; @@ -18,6 +20,7 @@ use crate::tools::{ SessionMessageSender, SubAgentRuntimeConfig, SubagentCatalog, ToolRegistry, }; use crate::tools::task::repository::TaskRepository; +use crate::tools::todo_write::TodoItem; use super::agent_factory::AgentFactory; use super::cli_session::CliSessionService; @@ -117,6 +120,11 @@ pub(crate) fn build_session_manager_with_sender( task_config.clone(), ); + // Create shared todo state for TodoWriteTool + let todo_state: Arc>>> = + Arc::new(RwLock::new(HashMap::new())); + let factory = factory.with_todo_state(todo_state); + // Create MCP Initializer (async, non-blocking) // MCP servers connect in background task let mut mcp_initializer = McpInitializer::with_config(mcp_config); diff --git a/src/gateway/session.rs b/src/gateway/session.rs index c93fded..6968d4f 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -385,6 +385,84 @@ impl Session { let _ = self.user_tx.send(msg).await; } + /// 扫描 agent 结果中的 todo_write 工具消息, + /// 提取 todos 并做持久化 + 前端推送(同步版本)。 + pub(crate) fn intercept_todo_write_results( + &self, + emitted_messages: &[ChatMessage], + chat_id: &str, + ) { + for msg in emitted_messages { + if msg.role != "tool" { + continue; + } + if msg.tool_name.as_deref() != Some("todo_write") { + continue; + } + + // 解析工具返回的 JSON + let parsed: serde_json::Value = match serde_json::from_str(&msg.content) { + Ok(v) => v, + Err(_) => continue, + }; + + let Some(todos_array) = parsed + .get("current_todos") + .and_then(|v| v.as_array()) + else { + continue; + }; + + // 计算持久化所需的 key + let session_id = crate::storage::persistent_session_id(&self.channel_name, chat_id); + let topic_id = self.current_topic(chat_id); + let scope_key = topic_id.map(|t| t.to_string()).unwrap_or_else(|| session_id.clone()); + + // 转换为 TodoRecord 并持久化 + let records: Vec = todos_array + .iter() + .filter_map(|item| { + Some(crate::storage::TodoRecord { + id: item.get("id")?.as_str()?.to_string(), + scope_key: scope_key.clone(), + session_id: session_id.clone(), + topic_id: topic_id.map(|t| t.to_string()), + content: item.get("content")?.as_str()?.to_string(), + status: item.get("status")?.as_str()?.to_string(), + priority: item.get("priority")?.as_str()?.to_string(), + created_at: item.get("created_at")?.as_i64()?, + updated_at: item.get("updated_at")?.as_i64()?, + }) + }) + .collect(); + + // 持久化到 SQLite + if let Err(e) = self.store.replace_todos(&scope_key, &records) { + tracing::warn!(error = %e, scope_key = %scope_key, "Failed to persist todo list"); + } + + // 推送到前端(使用 try_send 避免异步) + let summaries: Vec = records + .iter() + .map(|r| crate::protocol::TodoItemSummary { + id: r.id.clone(), + content: r.content.clone(), + status: r.status.clone(), + priority: r.priority.clone(), + created_at: r.created_at, + updated_at: r.updated_at, + }) + .collect(); + + let _ = self.user_tx.try_send(crate::protocol::WsOutbound::TodoList { + todos: summaries, + scope_key: scope_key.clone(), + }); + + break; // 只处理第一个成功的 todo_write + } + } + /// 获取 provider_config 引用 pub fn provider_config(&self) -> &LLMProviderConfig { &self.provider_config diff --git a/src/gateway/todo_prompt_provider.rs b/src/gateway/todo_prompt_provider.rs new file mode 100644 index 0000000..b5de9c3 --- /dev/null +++ b/src/gateway/todo_prompt_provider.rs @@ -0,0 +1,62 @@ +use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; + +pub struct TodoPromptProvider; + +impl TodoPromptProvider { + pub fn new() -> Self { + Self + } +} + +impl SystemPromptProvider for TodoPromptProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + Some(SystemPrompt { + content: TODO_WRITE_INSTRUCTIONS.to_string(), + context: Some("todo_write".to_string()), + }) + } +} + +const TODO_WRITE_INSTRUCTIONS: &str = r#" +## TodoWrite 工具 + +你可以使用 `todo_write` 工具在对话中维护结构化的任务列表。 + +### 何时使用 +- 当任务有 3 个或以上明确步骤时,应该使用 todo_write 追踪进度 +- 不需要为简单的单步操作(如回答一个问题、读取一个文件)创建 todo + +### merge 参数 +- `merge: false`(默认):全量替换 — 只传入需要追踪的 todo,不在列表中的项将被移除 +- `merge: true`(推荐):增量更新 — 只传入需要添加或更新的项,未提及的项保持不变。**绝大多数情况应该使用 merge=true,这样你不需要记住所有 id** + +### 状态语义 +- `pending` — 尚未开始 +- `in_progress` — 当前正在执行(同一时间只能有一个) +- `completed` — 已完成 +- `cancelled` — 不再需要 + +### 核心规则 +1. 同一时间只能有一个任务处于 `in_progress` 状态 +2. 必须先完成当前 `in_progress` 的任务,再开始下一个 +3. `completed` 和 `cancelled` 是终端状态,已完成的项不能被重新激活 +4. 不要先标记 completed 再去实际执行 — 先完成工作,再标记 +5. `content` 字段保持简洁、可执行 + +### 使用范例 + +开始任务时(merge 模式): +```json +{"merge": true, "todos": [{"content": "修复登录 bug", "status": "in_progress"}]} +``` + +发现新任务时: +```json +{"merge": true, "todos": [{"content": "补充测试", "status": "pending"}]} +``` + +完成任务时(传入 id + 新状态): +```json +{"merge": true, "todos": [{"id": "xxx", "content": "修复登录 bug", "status": "completed"}]} +``` +"#; diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs index 69f7c3e..520928a 100644 --- a/src/gateway/tool_registry_factory.rs +++ b/src/gateway/tool_registry_factory.rs @@ -1,16 +1,19 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use tokio::sync::RwLock; + use crate::config::TaskConfig; use crate::mcp::McpClientManager; use crate::skills::SkillRuntime; use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; +use crate::tools::todo_write::TodoItem; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, - ToolRegistry, WebFetchTool, + TodoWriteTool, ToolRegistry, WebFetchTool, }; pub(crate) struct ToolRegistryFactory { @@ -25,6 +28,7 @@ pub(crate) struct ToolRegistryFactory { task_config: TaskConfig, subagent_runtime: Option>, mcp_manager: Option>, + todo_state: Option>>>>, } impl ToolRegistryFactory { @@ -51,9 +55,18 @@ impl ToolRegistryFactory { task_config, subagent_runtime: None, mcp_manager: None, + todo_state: None, } } + pub(crate) fn with_todo_state( + mut self, + state: Arc>>>, + ) -> Self { + self.todo_state = Some(state); + self + } + pub(crate) fn with_subagent_runtime( mut self, runtime: Arc, @@ -98,6 +111,11 @@ impl ToolRegistryFactory { if self.is_enabled("memory_manage") { registry.register(MemoryManageTool::new(self.memories.clone())); } + if self.is_enabled("todo_write") { + if let Some(ref state) = self.todo_state { + registry.register(TodoWriteTool::new(state.clone())); + } + } if self.is_enabled("session_send") { registry.register(SessionSendTool::new(self.session_message_sender.clone())); } @@ -198,6 +216,13 @@ impl ToolRegistryFactory { registry.register(SessionSendTool::new(self.session_message_sender.clone())); } + // Todo 追踪工具 + if self.is_enabled("todo_write") { + if let Some(ref state) = self.todo_state { + registry.register(TodoWriteTool::new(state.clone())); + } + } + // 注册 MCP 工具(如果提供) if let Some(mcp_tools) = mcp_tools { for tool in mcp_tools { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 87cd453..79825b4 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -82,6 +82,17 @@ pub struct SkillSummary { pub source: String, } +/// Todo item 摘要(发送给前端) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TodoItemSummary { + pub id: String, + pub content: String, + pub status: String, + pub priority: String, + pub created_at: i64, + pub updated_at: i64, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SchedulerJobSummary { pub id: String, @@ -257,6 +268,11 @@ pub enum WsOutbound { }, #[serde(rename = "execution_cancelled")] ExecutionCancelled { message: String }, + #[serde(rename = "todo_list")] + TodoList { + todos: Vec, + scope_key: String, + }, #[serde(rename = "pong")] Pong, } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 0185a1f..221c53d 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -14,13 +14,13 @@ pub mod records; pub use error::StorageError; pub use ports::{ ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, - SkillEventRepository, + SkillEventRepository, TodoRepository, }; pub use records::{ allowed_namespace_names, get_namespace_description, is_valid_namespace, ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, - TopicRecord, + TodoRecord, TopicRecord, }; #[derive(Clone)] @@ -217,6 +217,7 @@ impl SessionStore { ensure_messages_schema(&conn)?; ensure_scheduler_schema(&conn)?; ensure_memory_scope_key_migration(&conn)?; + ensure_todos_schema(&conn)?; drop(conn); @@ -1491,6 +1492,74 @@ impl SessionStore { ) .map_err(StorageError::from) } + + pub fn replace_todos( + &self, + scope_key: &str, + items: &[TodoRecord], + ) -> Result, StorageError> { + let conn = self.pool.get()?; + let now = current_timestamp(); + + // Delete existing todos for this scope_key + conn.execute( + "DELETE FROM todos WHERE scope_key = ?1", + params![scope_key], + )?; + + // Insert new todos + for item in items { + conn.execute( + "INSERT INTO todos (id, scope_key, session_id, topic_id, content, status, priority, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + params![ + item.id, + scope_key, + item.session_id, + item.topic_id, + item.content, + item.status, + item.priority, + item.created_at, + now, + ], + )?; + } + + drop(conn); + + self.list_todos(scope_key) + } + + pub fn list_todos(&self, scope_key: &str) -> Result, StorageError> { + let conn = self.pool.get()?; + let mut stmt = conn.prepare( + "SELECT id, scope_key, session_id, topic_id, content, status, priority, created_at, updated_at + FROM todos + WHERE scope_key = ?1 + ORDER BY created_at ASC", + )?; + + let rows = stmt.query_map(params![scope_key], |row| { + Ok(TodoRecord { + id: row.get(0)?, + scope_key: row.get(1)?, + session_id: row.get(2)?, + topic_id: row.get(3)?, + content: row.get(4)?, + status: row.get(5)?, + priority: row.get(6)?, + created_at: row.get(7)?, + updated_at: row.get(8)?, + }) + })?; + + let mut todos = Vec::new(); + for row in rows { + todos.push(row?); + } + Ok(todos) + } } pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String { @@ -1800,6 +1869,42 @@ fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageErr Ok(()) } +fn ensure_todos_schema(conn: &Connection) -> Result<(), StorageError> { + let table_exists: bool = conn + .query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='todos'", + [], + |row| row.get::<_, i64>(0), + ) + .map(|count| count > 0)?; + + if !table_exists { + conn.execute_batch( + " + CREATE TABLE IF NOT EXISTS todos ( + id TEXT PRIMARY KEY, + scope_key TEXT NOT NULL, + session_id TEXT NOT NULL, + topic_id TEXT, + content TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + priority TEXT NOT NULL DEFAULT 'medium', + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_todos_scope + ON todos(scope_key, created_at ASC); + + CREATE INDEX IF NOT EXISTS idx_todos_session + ON todos(session_id); + ", + )?; + } + + Ok(()) +} + fn has_column( conn: &Connection, table_name: &str, @@ -2009,7 +2114,7 @@ fn load_messages_after( messages.push(row?); } Ok(messages) -} + } fn current_timestamp() -> i64 { std::time::SystemTime::now() diff --git a/src/storage/ports.rs b/src/storage/ports.rs index 35fbb87..a9c143f 100644 --- a/src/storage/ports.rs +++ b/src/storage/ports.rs @@ -1,6 +1,6 @@ use super::{ MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, - SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError, + SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError, TodoRecord, }; use crate::bus::ChatMessage; @@ -145,6 +145,18 @@ pub trait SkillEventRepository: Send + Sync + 'static { ) -> Result, StorageError>; } +pub trait TodoRepository: Send + Sync + 'static { + /// Replace all todos for a scope (full replacement pattern). + fn replace_todos( + &self, + scope_key: &str, + todo_records: &[TodoRecord], + ) -> Result, StorageError>; + + /// Load all todos for a scope, ordered by created_at. + fn list_todos(&self, scope_key: &str) -> Result, StorageError>; +} + impl ConversationRepository for super::SessionStore { fn ensure_channel_session( &self, @@ -356,3 +368,17 @@ impl SkillEventRepository for super::SessionStore { super::SessionStore::list_skill_events(self, session_id) } } + +impl TodoRepository for super::SessionStore { + fn replace_todos( + &self, + scope_key: &str, + todo_records: &[TodoRecord], + ) -> Result, StorageError> { + super::SessionStore::replace_todos(self, scope_key, todo_records) + } + + fn list_todos(&self, scope_key: &str) -> Result, StorageError> { + super::SessionStore::list_todos(self, scope_key) + } +} diff --git a/src/storage/records.rs b/src/storage/records.rs index 7bc1c0e..5c77c53 100644 --- a/src/storage/records.rs +++ b/src/storage/records.rs @@ -35,6 +35,19 @@ pub fn allowed_namespace_names() -> Vec<&'static str> { ALLOWED_MEMORY_NAMESPACES.iter().map(|(name, _)| *name).collect() } +#[derive(Debug, Clone)] +pub struct TodoRecord { + pub id: String, + pub scope_key: String, + pub session_id: String, + pub topic_id: Option, + pub content: String, + pub status: String, + pub priority: String, + pub created_at: i64, + pub updated_at: i64, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SkillEventRecord { pub id: String, diff --git a/src/tools/mod.rs b/src/tools/mod.rs index f7f79c8..eeca684 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -14,6 +14,7 @@ pub mod skill_activate; pub mod skill_manage; pub mod task; pub mod time; +pub mod todo_write; pub mod traits; pub mod web_fetch; @@ -39,6 +40,7 @@ pub use task::{ SubagentCatalog, TaskError, TaskRepository, TaskTool, }; pub use time::TimeTool; +pub use todo_write::TodoWriteTool; pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; diff --git a/src/tools/task/prompt.rs b/src/tools/task/prompt.rs index e5e830d..301d87f 100644 --- a/src/tools/task/prompt.rs +++ b/src/tools/task/prompt.rs @@ -52,6 +52,8 @@ impl SubagentPromptBuilder { 2. 使用可用的工具进行必要操作\n\ 3. 完成后给出简洁的总结\n\ 4. 不要尝试创建新的子代理任务\n\n\ + 任务追踪:\n\ + 你可以使用 `todo_write` 工具追踪子任务进度。规则:同一时间只有一个 in_progress,完成后再标记下一个,3步以上才使用。\n\n\ 注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。" } else { &def.prompt_template diff --git a/src/tools/task/types.rs b/src/tools/task/types.rs index 72d1419..281a7b4 100644 --- a/src/tools/task/types.rs +++ b/src/tools/task/types.rs @@ -61,7 +61,7 @@ impl SubagentDef { Self { name: "general".to_string(), description: "通用型子代理 - 处理复杂多步骤任务".to_string(), - prompt_template: "你是一个专注的子代理,正在执行一个独立任务。\n\n任务描述: {{description}}\n\n你应该:\n1. 专注于完成任务,不要偏离目标\n2. 使用可用的工具进行必要操作\n3. 完成后给出简洁的总结\n4. 不要尝试创建新的子代理任务\n\n注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。".to_string(), + prompt_template: "你是一个专注的子代理,正在执行一个独立任务。\n\n任务描述: {{description}}\n\n你应该:\n1. 专注于完成任务,不要偏离目标\n2. 使用可用的工具进行必要操作\n3. 完成后给出简洁的总结\n4. 不要尝试创建新的子代理任务\n\n任务追踪:\n你可以使用 `todo_write` 工具追踪子任务进度。规则:同一时间只有一个 in_progress,完成后再标记下一个,3步以上才使用。\n\n注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。".to_string(), body: None, allowed_tools: None, max_execution_secs: None, diff --git a/src/tools/todo_write.rs b/src/tools/todo_write.rs new file mode 100644 index 0000000..3fc9b71 --- /dev/null +++ b/src/tools/todo_write.rs @@ -0,0 +1,1153 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use serde::Serialize; +use serde_json::json; +use tokio::sync::RwLock; + +use crate::tools::traits::{Tool, ToolContext, ToolResult}; + +// ── 数据模型 ────────────────────────────────────────────── + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TodoStatus { + Pending, + InProgress, + Completed, + Cancelled, +} + +impl TodoStatus { + fn as_str(&self) -> &'static str { + match self { + TodoStatus::Pending => "pending", + TodoStatus::InProgress => "in_progress", + TodoStatus::Completed => "completed", + TodoStatus::Cancelled => "cancelled", + } + } + + fn from_str(value: &str) -> Option { + match value { + "pending" => Some(Self::Pending), + "in_progress" => Some(Self::InProgress), + "completed" => Some(Self::Completed), + "cancelled" => Some(Self::Cancelled), + _ => None, + } + } + + /// 是否为终端状态(不可再变更) + fn is_terminal(&self) -> bool { + matches!(self, TodoStatus::Completed | TodoStatus::Cancelled) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TodoPriority { + High, + Medium, + Low, +} + +impl TodoPriority { + fn as_str(&self) -> &'static str { + match self { + TodoPriority::High => "high", + TodoPriority::Medium => "medium", + TodoPriority::Low => "low", + } + } + + fn from_str(value: &str) -> Option { + match value { + "high" => Some(Self::High), + "medium" => Some(Self::Medium), + "low" => Some(Self::Low), + _ => None, + } + } +} + +/// 内存中的 Todo 项 +#[derive(Debug, Clone, Serialize)] +pub(crate) struct TodoItem { + pub id: String, + pub content: String, + pub status: String, + pub priority: String, + pub created_at: i64, + pub updated_at: i64, +} + +/// 变更摘要 +#[derive(Debug, Clone, Serialize)] +struct ChangeSummary { + added: Vec, + updated: Vec, + removed: Vec, // ids of removed items +} + +/// 工具完整返回 +#[derive(Debug, Clone, Serialize)] +struct TodoWriteOutput { + current_todos: Vec, + changes: ChangeSummary, + message: String, +} + +// ── 工具实现 ────────────────────────────────────────────── + +pub struct TodoWriteTool { + /// 内存状态:scope_key → Vec + /// scope_key = topic_id.unwrap_or(session_id) + state: Arc>>>, +} + +impl TodoWriteTool { + pub(crate) fn new(state: Arc>>>) -> Self { + Self { state } + } +} + +#[async_trait] +impl Tool for TodoWriteTool { + fn name(&self) -> &str { + "todo_write" + } + + fn description(&self) -> &str { + "Manage a structured task list for tracking work within the current conversation. \ + Two modes: merge=false (default, full replacement — omitted items are removed); \ + merge=true (incremental — only send the items you want to add/update, \ + previously completed items are preserved). \ + Use when you have 3+ distinct steps to track. \ + Rules: only ONE in_progress at a time, complete work before marking completed, \ + completed/cancelled are terminal states." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "merge": { + "type": "boolean", + "description": "false (default): full replacement — todos not in the list are removed. true: incremental — only send items you want to add or update, existing items not mentioned are preserved." + }, + "todos": { + "type": "array", + "description": "The todo items to add or update. In merge=false mode, this is the complete replacement list. In merge=true mode, only send the items that changed — unreferenced items are kept as-is.", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Brief, actionable description of the task" + }, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed", "cancelled"], + "description": "Current status: pending=not started, in_progress=working on (only ONE at a time), completed=done, cancelled=no longer needed" + }, + "priority": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Task priority. Defaults to medium if not specified." + } + }, + "required": ["content", "status"] + } + } + }, + "required": ["todos"] + }) + } + + fn read_only(&self) -> bool { + false + } + + fn concurrency_safe(&self) -> bool { + false + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(error_result("todo_write requires tool context (session_id)")) + } + + async fn execute_with_context( + &self, + context: &ToolContext, + args: serde_json::Value, + ) -> anyhow::Result { + // 1. 计算 scope_key + let scope_key = match scope_key_from_context(context) { + Some(key) => key, + None => return Ok(error_result("todo_write requires session_id or topic_id in tool context")), + }; + + // 2. 解析入参 + let todos_array = match args.get("todos").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return Ok(error_result("Missing required parameter: todos (must be an array)")), + }; + + let merge_mode = args + .get("merge") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + // 3. 解析并校验每个输入项 + let now = current_timestamp(); + + // 读锁获取旧状态 + let old_items = { + let guard = self.state.read().await; + guard.get(&scope_key).cloned().unwrap_or_default() + }; + + // 构建 id → TodoItem 的旧状态映射 + let old_map: HashMap<&str, &TodoItem> = old_items.iter().map(|item| (item.id.as_str(), item)).collect(); + + let mut processed_items: Vec = Vec::new(); + let mut validation_errors: Vec = Vec::new(); + + for (idx, input) in todos_array.iter().enumerate() { + let content = match input.get("content").and_then(|v| v.as_str()) { + Some(s) if !s.trim().is_empty() => s.trim().to_string(), + _ => { + validation_errors.push(format!("Item {}: missing or empty 'content'", idx)); + continue; + } + }; + + let status_str = input + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("pending"); + + let new_status = match TodoStatus::from_str(status_str) { + Some(s) => s, + None => { + validation_errors.push(format!( + "Item '{}': invalid status '{}'. Valid: pending, in_progress, completed, cancelled", + content, status_str + )); + continue; + } + }; + + let priority_str = input + .get("priority") + .and_then(|v| v.as_str()) + .unwrap_or("medium"); + + let priority = TodoPriority::from_str(priority_str).unwrap_or(TodoPriority::Medium); + + let input_id = input.get("id").and_then(|v| v.as_str()); + + if let Some(id) = input_id { + if let Some(old_item) = old_map.get(id) { + // 已有 item — 校验状态转换 + let old_status = match TodoStatus::from_str(&old_item.status) { + Some(s) => s, + None => { + validation_errors.push(format!("Item '{}': corrupted old status", content)); + continue; + } + }; + + if let Err(err) = validate_transition(&old_status, &new_status) { + validation_errors.push(format!("Item '{}': {}", content, err)); + continue; + } + + processed_items.push(TodoItem { + id: id.to_string(), + content, + status: new_status.as_str().to_string(), + priority: priority.as_str().to_string(), + created_at: old_item.created_at, + updated_at: now, + }); + } else { + // 传入 id 但旧状态中没有 → 按新 item 处理 + if new_status != TodoStatus::Pending { + validation_errors.push(format!( + "Item '{}': new items must start as 'pending', got '{}'", + content, status_str + )); + continue; + } + + processed_items.push(TodoItem { + id: id.to_string(), + content, + status: new_status.as_str().to_string(), + priority: priority.as_str().to_string(), + created_at: now, + updated_at: now, + }); + } + } else { + // 新 item(无 id)— 必须是 pending + if new_status != TodoStatus::Pending { + validation_errors.push(format!( + "Item '{}': new items must start as 'pending', got '{}'", + content, status_str + )); + continue; + } + + processed_items.push(TodoItem { + id: uuid::Uuid::new_v4().to_string(), + content, + status: new_status.as_str().to_string(), + priority: priority.as_str().to_string(), + created_at: now, + updated_at: now, + }); + } + } + + if !validation_errors.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(validation_errors.join("\n")), + }); + } + + // 4. 提前收集 processed ids(在 move 之前) + let processed_ids: std::collections::HashSet<&str> = + processed_items.iter().map(|item| item.id.as_str()).collect(); + let old_ids: std::collections::HashSet<&str> = old_items.iter().map(|item| item.id.as_str()).collect(); + + // 5. 计算 diff(在 processed_items 被 move 之前) + let added: Vec = processed_items + .iter() + .filter(|item| !old_ids.contains(item.id.as_str())) + .cloned() + .collect(); + + let updated: Vec = processed_items + .iter() + .filter(|item| { + old_ids.contains(item.id.as_str()) + && old_map.get(item.id.as_str()).map_or(true, |old| { + old.status != item.status || old.content != item.content + }) + }) + .cloned() + .collect(); + + // 6. 合并模式:将旧列表中未被引用的项保留 + let final_items: Vec = if merge_mode { + let mut merged = Vec::new(); + for item in &processed_items { + merged.push(item.clone()); + } + for old in &old_items { + if !processed_ids.contains(old.id.as_str()) { + merged.push(old.clone()); + } + } + merged.sort_by_key(|item| item.created_at); + merged + } else { + // full replacement: processed_items 就是全部 + processed_items + }; + + // 7. 全局约束:只有一个 in_progress + let in_progress_count = final_items + .iter() + .filter(|item| item.status == "in_progress") + .count(); + if in_progress_count > 1 { + return Ok(error_result(&format!( + "Only one task can be 'in_progress' at a time. Found {} in_progress tasks.", + in_progress_count + ))); + } + + let final_ids: std::collections::HashSet<&str> = final_items.iter().map(|item| item.id.as_str()).collect(); + + // merge 模式下从不删除,只有全量替换模式才会删除 + let removed: Vec = if merge_mode { + Vec::new() + } else { + old_items + .iter() + .filter(|item| !final_ids.contains(item.id.as_str())) + .map(|item| item.id.clone()) + .collect() + }; + + let changes = ChangeSummary { + added, + updated, + removed, + }; + + // 8. 更新内存状态 + { + let mut guard = self.state.write().await; + guard.insert(scope_key.clone(), final_items.clone()); + } + + // 9. 生成友好消息 + let message = build_summary_message(&changes, merge_mode); + + let output = TodoWriteOutput { + current_todos: final_items, + changes, + message, + }; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } +} + +// ── 辅助函数 ────────────────────────────────────────────── + +/// 计算 scope_key:优先 topic_id,否则 session_id +pub(crate) fn scope_key_from_context(context: &ToolContext) -> Option { + let tid = context.topic_id.as_deref().filter(|t| !t.is_empty()); + let sid = context.session_id.as_deref().filter(|s| !s.is_empty()); + tid.or(sid).map(str::to_string) +} + +/// 校验状态转换合法性 +fn validate_transition(old: &TodoStatus, new: &TodoStatus) -> Result<(), String> { + if old.is_terminal() { + return Err(format!( + "Cannot change status of a {} task (terminal state)", + old.as_str() + )); + } + + match (old, new) { + // pending → anything is allowed + (TodoStatus::Pending, _) => Ok(()), + + // in_progress → completed or cancelled + (TodoStatus::InProgress, TodoStatus::Completed) => Ok(()), + (TodoStatus::InProgress, TodoStatus::Cancelled) => Ok(()), + (TodoStatus::InProgress, TodoStatus::Pending) => Err( + "Cannot move an in_progress task back to pending. Use completed or cancelled.".to_string(), + ), + (TodoStatus::InProgress, TodoStatus::InProgress) => Ok(()), // no change + + // completed or cancelled → nothing (handled by is_terminal check above) + _ => Ok(()), + } +} + +fn build_summary_message(changes: &ChangeSummary, merge_mode: bool) -> String { + let mut parts: Vec = Vec::new(); + + if !changes.added.is_empty() { + let names: Vec<&str> = changes.added.iter().map(|item| item.content.as_str()).collect(); + parts.push(format!("新增 {} 项: {}", changes.added.len(), names.join(", "))); + } + + if !changes.updated.is_empty() { + let names: Vec = changes + .updated + .iter() + .map(|item| format!("{}→{}", item.content, item.status)) + .collect(); + parts.push(format!("更新 {} 项: {}", changes.updated.len(), names.join(", "))); + } + + if !changes.removed.is_empty() { + parts.push(format!("移除 {} 项", changes.removed.len())); + } + + if parts.is_empty() { + if merge_mode { + "Todo list unchanged (merge mode — no items were added, updated, or removed).".to_string() + } else { + "Todo list unchanged.".to_string() + } + } else { + parts.join("; ") + } +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64 +} + +fn error_result(message: &str) -> ToolResult { + ToolResult { + success: false, + output: String::new(), + error: Some(message.to_string()), + } +} + +// ── 测试 ────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::tools::traits::ToolContext; + + fn test_context() -> ToolContext { + ToolContext { + channel_name: Some("cli".to_string()), + sender_id: Some("user-1".to_string()), + chat_id: Some("chat-1".to_string()), + session_id: Some("cli:chat-1".to_string()), + topic_id: None, + message_id: Some("msg-1".to_string()), + message_seq: Some(1), + subagent_description: None, + } + } + + fn test_state() -> Arc>>> { + Arc::new(RwLock::new(HashMap::new())) + } + + #[tokio::test] + async fn test_create_initial_todos() { + let tool = TodoWriteTool::new(test_state()); + let context = test_context(); + + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "设计数据库", "status": "pending", "priority": "high"}, + {"content": "实现 API", "status": "pending", "priority": "medium"}, + {"content": "写测试", "status": "pending", "priority": "low"} + ] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["current_todos"].as_array().unwrap().len(), 3); + assert_eq!(output["changes"]["added"].as_array().unwrap().len(), 3); + assert_eq!(output["changes"]["updated"].as_array().unwrap().len(), 0); + assert_eq!(output["changes"]["removed"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn test_single_in_progress_constraint() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 先创建两个 pending 任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"}, + {"content": "任务B", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 获取它们的 id + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let id_a = items[0].id.clone(); + let id_b = items[1].id.clone(); + drop(guard); + + // 尝试将两个都设为 in_progress + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": id_a, "content": "任务A", "status": "in_progress"}, + {"id": id_b, "content": "任务B", "status": "in_progress"} + ] + }), + ) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Only one task can be 'in_progress'")); + } + + #[tokio::test] + async fn test_state_transition_in_progress_to_completed() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 先创建 todo 列表 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 获取创建后的 id + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let task_id = items[0].id.clone(); + drop(guard); + + // 更新为 in_progress → completed + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": task_id, "content": "任务A", "status": "completed"} + ] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + let todos = output["current_todos"].as_array().unwrap(); + assert_eq!(todos[0]["status"], "completed"); + } + + #[tokio::test] + async fn test_terminal_state_cannot_change() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 创建并完成一个任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let task_id = items[0].id.clone(); + drop(guard); + + // 标记为 completed + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": task_id, "content": "任务A", "status": "completed"} + ] + }), + ) + .await + .unwrap(); + + // 尝试从 completed 改回 pending + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": task_id, "content": "任务A", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("terminal state")); + } + + #[tokio::test] + async fn test_new_item_must_be_pending() { + let tool = TodoWriteTool::new(test_state()); + let context = test_context(); + + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "completed"} + ] + }), + ) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("must start as 'pending'")); + } + + #[tokio::test] + async fn test_remove_items_by_omission() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 创建两个任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"}, + {"content": "任务B", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 只传入一个任务(任务B 被移除) + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let task_a_id = items[0].id.clone(); + drop(guard); + + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": task_a_id, "content": "任务A", "status": "in_progress"} + ] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["current_todos"].as_array().unwrap().len(), 1); + assert_eq!(output["changes"]["removed"].as_array().unwrap().len(), 1); + } + + #[tokio::test] + async fn test_topic_isolation() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + + // 在主会话中创建 todo + let main_context = ToolContext { + session_id: Some("cli:chat-1".to_string()), + topic_id: None, + ..ToolContext::default() + }; + + let _ = tool + .execute_with_context( + &main_context, + json!({ + "todos": [ + {"content": "主会话任务", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 在 topic 中创建 todo + let topic_context = ToolContext { + session_id: Some("cli:chat-1".to_string()), + topic_id: Some("topic-xyz".to_string()), + ..ToolContext::default() + }; + + let _ = tool + .execute_with_context( + &topic_context, + json!({ + "todos": [ + {"content": "话题任务", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 验证隔离 + let guard = state.read().await; + let main_items = guard.get("cli:chat-1").unwrap(); + let topic_items = guard.get("topic-xyz").unwrap(); + + assert_eq!(main_items.len(), 1); + assert_eq!(main_items[0].content, "主会话任务"); + assert_eq!(topic_items.len(), 1); + assert_eq!(topic_items[0].content, "话题任务"); + } + + #[tokio::test] + async fn test_empty_list() { + let tool = TodoWriteTool::new(test_state()); + let context = test_context(); + + let result = tool + .execute_with_context( + &context, + json!({"todos": []}), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["current_todos"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn test_missing_todos_param() { + let tool = TodoWriteTool::new(test_state()); + let context = test_context(); + + let result = tool + .execute_with_context(&context, json!({})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Missing required parameter")); + } + + #[tokio::test] + async fn test_no_context() { + let tool = TodoWriteTool::new(test_state()); + + let result = tool.execute(json!({})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("requires tool context")); + } + + #[tokio::test] + async fn test_subagent_isolation() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + + // 父代理 + let parent_ctx = ToolContext { + session_id: Some("cli:chat-1".to_string()), + ..ToolContext::default() + }; + + let _ = tool + .execute_with_context( + &parent_ctx, + json!({ + "todos": [ + {"content": "父代理任务", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 子代理(不同 session_id) + let child_ctx = ToolContext { + session_id: Some("sub:cli:chat-1:task:uuid-abc".to_string()), + ..ToolContext::default() + }; + + let _ = tool + .execute_with_context( + &child_ctx, + json!({ + "todos": [ + {"content": "子代理任务", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 验证隔离 + let guard = state.read().await; + let parent_items = guard.get("cli:chat-1").unwrap(); + let child_items = guard.get("sub:cli:chat-1:task:uuid-abc").unwrap(); + + assert_eq!(parent_items.len(), 1); + assert_eq!(parent_items[0].content, "父代理任务"); + assert_eq!(child_items.len(), 1); + assert_eq!(child_items[0].content, "子代理任务"); + } + + #[tokio::test] + async fn test_in_progress_cannot_revert_to_pending() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 先创建 pending 任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let task_id = items[0].id.clone(); + drop(guard); + + // 更新为 in_progress + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": task_id, "content": "任务A", "status": "in_progress"} + ] + }), + ) + .await + .unwrap(); + + // 尝试从 in_progress 退回 pending + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": task_id, "content": "任务A", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Cannot move an in_progress task back to pending")); + } + + // ── merge 模式测试 ────────────────────────────────────── + + #[tokio::test] + async fn test_merge_mode_preserves_unreferenced_items() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 先创建三个 pending 任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"}, + {"content": "任务B", "status": "pending"}, + {"content": "任务C", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // 获取任务 A 的 id + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let id_a = items[0].id.clone(); + drop(guard); + + // merge: true — 只传任务 A(改为 in_progress),B 和 C 应被保留 + let result = tool + .execute_with_context( + &context, + json!({ + "merge": true, + "todos": [ + {"id": id_a, "content": "任务A", "status": "in_progress"} + ] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + let todos = output["current_todos"].as_array().unwrap(); + // 3 项都应在 + assert_eq!(todos.len(), 3); + // 已更新的一项状态是 in_progress + let task_a = todos.iter().find(|t| t["id"] == id_a).unwrap(); + assert_eq!(task_a["status"], "in_progress"); + // diff 中 updated 1, added 0, removed 0 + assert_eq!(output["changes"]["updated"].as_array().unwrap().len(), 1); + assert_eq!(output["changes"]["added"].as_array().unwrap().len(), 0); + assert_eq!(output["changes"]["removed"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn test_merge_mode_add_new_item_without_id() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 先创建两个任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "已有任务", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // merge: true — 添加一个新的 pending 任务 + let result = tool + .execute_with_context( + &context, + json!({ + "merge": true, + "todos": [ + {"content": "新任务", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + let todos = output["current_todos"].as_array().unwrap(); + // 旧 + 新 = 2 + assert_eq!(todos.len(), 2); + // diff: added 1, updated 0, removed 0 + assert_eq!(output["changes"]["added"].as_array().unwrap().len(), 1); + assert_eq!(output["changes"]["updated"].as_array().unwrap().len(), 0); + assert_eq!(output["changes"]["removed"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn test_merge_mode_never_removes() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 创建两个任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"}, + {"content": "任务B", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + // merge: true — 传入空列表,不应删除任何项 + let result = tool + .execute_with_context( + &context, + json!({ + "merge": true, + "todos": [] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + let todos = output["current_todos"].as_array().unwrap(); + // 2 项都在 + assert_eq!(todos.len(), 2); + // removed 始终为空 + assert_eq!(output["changes"]["removed"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn test_non_merge_still_removes_by_omission() { + let state = test_state(); + let tool = TodoWriteTool::new(state.clone()); + let context = test_context(); + + // 创建两个任务 + let _ = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"content": "任务A", "status": "pending"}, + {"content": "任务B", "status": "pending"} + ] + }), + ) + .await + .unwrap(); + + let guard = state.read().await; + let scope_key = scope_key_from_context(&context).unwrap(); + let items = guard.get(&scope_key).unwrap(); + let id_a = items[0].id.clone(); + drop(guard); + + // merge 未设置(默认 false)— 只传一个,另一个被删 + let result = tool + .execute_with_context( + &context, + json!({ + "todos": [ + {"id": id_a, "content": "任务A", "status": "in_progress"} + ] + }), + ) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + let todos = output["current_todos"].as_array().unwrap(); + assert_eq!(todos.len(), 1); + assert_eq!(output["changes"]["removed"].as_array().unwrap().len(), 1); + } +}