feat: 添加 todo_write 工具,支持全量替换和增量合并两种模式

- Tool: 纯内存实现 (Arc<RwLock<HashMap>>),零 DB 依赖,解耦持久化
- 状态机: pending → in_progress → completed/cancelled,单 in_progress 约束
- merge=false: 全量替换模式(默认)
- merge=true: 增量更新模式,只传变更的项,其余保留
- 隔离: scope_key = topic_id.unwrap_or(session_id),topic 和子代理隔离
- 持久化: TodoRepository trait + SessionStore SQLite 实现,在 Session 拦截器层完成
- 前端推送: WsOutbound::TodoList 事件
- Prompt: TodoPromptProvider 中文指令,子代理模板也包含
- 测试: 16 个单元测试,全部通过

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
oudecheng 2026-06-12 14:19:07 +08:00
parent cedd8b2a69
commit 881fcace47
15 changed files with 1508 additions and 7 deletions

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider}; use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::gateway::todo_prompt_provider::TodoPromptProvider;
use crate::skills::{SkillPromptProvider, SkillRuntime}; use crate::skills::{SkillPromptProvider, SkillRuntime};
use crate::storage::persistent_session_id; use crate::storage::persistent_session_id;
use crate::storage::PromptInjectionRepository; use crate::storage::PromptInjectionRepository;
@ -53,6 +54,7 @@ impl AgentFactory {
self.prompt_repository.clone(), self.prompt_repository.clone(),
)), )),
Box::new(SkillPromptProvider::new(self.skills.clone())), Box::new(SkillPromptProvider::new(self.skills.clone())),
Box::new(TodoPromptProvider::new()),
])); ]));
AgentLoop::with_tools_and_system_prompt_provider( AgentLoop::with_tools_and_system_prompt_provider(

View File

@ -190,6 +190,14 @@ impl AgentExecutionService {
// 只有当是最新回合时才触发历史压缩 // 只有当是最新回合时才触发历史压缩
let should_schedule_compaction = is_current_turn; let should_schedule_compaction = is_current_turn;
// 拦截 todo_write 结果:持久化 + 前端推送
if is_current_turn {
session.intercept_todo_write_results(
&request.result.emitted_messages,
request.chat_id,
);
}
Ok(FinalizedAgentResult { Ok(FinalizedAgentResult {
outbound_messages, outbound_messages,
should_schedule_compaction, should_schedule_compaction,

View File

@ -25,6 +25,7 @@ pub mod session_message_service;
pub mod session_pool; pub mod session_pool;
pub mod static_files; pub mod static_files;
pub mod tool_registry_factory; pub mod tool_registry_factory;
pub mod todo_prompt_provider;
pub mod ws; pub mod ws;
use axum::{Router, routing}; use axum::{Router, routing};

View File

@ -3,6 +3,8 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::bus::MessageBus; use crate::bus::MessageBus;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, SubagentsConfig, TaskConfig}; use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, SubagentsConfig, TaskConfig};
@ -18,6 +20,7 @@ use crate::tools::{
SessionMessageSender, SubAgentRuntimeConfig, SubagentCatalog, ToolRegistry, SessionMessageSender, SubAgentRuntimeConfig, SubagentCatalog, ToolRegistry,
}; };
use crate::tools::task::repository::TaskRepository; use crate::tools::task::repository::TaskRepository;
use crate::tools::todo_write::TodoItem;
use super::agent_factory::AgentFactory; use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService; use super::cli_session::CliSessionService;
@ -117,6 +120,11 @@ pub(crate) fn build_session_manager_with_sender(
task_config.clone(), task_config.clone(),
); );
// Create shared todo state for TodoWriteTool
let todo_state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>> =
Arc::new(RwLock::new(HashMap::new()));
let factory = factory.with_todo_state(todo_state);
// Create MCP Initializer (async, non-blocking) // Create MCP Initializer (async, non-blocking)
// MCP servers connect in background task // MCP servers connect in background task
let mut mcp_initializer = McpInitializer::with_config(mcp_config); let mut mcp_initializer = McpInitializer::with_config(mcp_config);

View File

@ -385,6 +385,84 @@ impl Session {
let _ = self.user_tx.send(msg).await; let _ = self.user_tx.send(msg).await;
} }
/// 扫描 agent 结果中的 todo_write 工具消息,
/// 提取 todos 并做持久化 + 前端推送(同步版本)。
pub(crate) fn intercept_todo_write_results(
&self,
emitted_messages: &[ChatMessage],
chat_id: &str,
) {
for msg in emitted_messages {
if msg.role != "tool" {
continue;
}
if msg.tool_name.as_deref() != Some("todo_write") {
continue;
}
// 解析工具返回的 JSON
let parsed: serde_json::Value = match serde_json::from_str(&msg.content) {
Ok(v) => v,
Err(_) => continue,
};
let Some(todos_array) = parsed
.get("current_todos")
.and_then(|v| v.as_array())
else {
continue;
};
// 计算持久化所需的 key
let session_id = crate::storage::persistent_session_id(&self.channel_name, chat_id);
let topic_id = self.current_topic(chat_id);
let scope_key = topic_id.map(|t| t.to_string()).unwrap_or_else(|| session_id.clone());
// 转换为 TodoRecord 并持久化
let records: Vec<crate::storage::TodoRecord> = todos_array
.iter()
.filter_map(|item| {
Some(crate::storage::TodoRecord {
id: item.get("id")?.as_str()?.to_string(),
scope_key: scope_key.clone(),
session_id: session_id.clone(),
topic_id: topic_id.map(|t| t.to_string()),
content: item.get("content")?.as_str()?.to_string(),
status: item.get("status")?.as_str()?.to_string(),
priority: item.get("priority")?.as_str()?.to_string(),
created_at: item.get("created_at")?.as_i64()?,
updated_at: item.get("updated_at")?.as_i64()?,
})
})
.collect();
// 持久化到 SQLite
if let Err(e) = self.store.replace_todos(&scope_key, &records) {
tracing::warn!(error = %e, scope_key = %scope_key, "Failed to persist todo list");
}
// 推送到前端(使用 try_send 避免异步)
let summaries: Vec<crate::protocol::TodoItemSummary> = records
.iter()
.map(|r| crate::protocol::TodoItemSummary {
id: r.id.clone(),
content: r.content.clone(),
status: r.status.clone(),
priority: r.priority.clone(),
created_at: r.created_at,
updated_at: r.updated_at,
})
.collect();
let _ = self.user_tx.try_send(crate::protocol::WsOutbound::TodoList {
todos: summaries,
scope_key: scope_key.clone(),
});
break; // 只处理第一个成功的 todo_write
}
}
/// 获取 provider_config 引用 /// 获取 provider_config 引用
pub fn provider_config(&self) -> &LLMProviderConfig { pub fn provider_config(&self) -> &LLMProviderConfig {
&self.provider_config &self.provider_config

View File

@ -0,0 +1,62 @@
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
pub struct TodoPromptProvider;
impl TodoPromptProvider {
pub fn new() -> Self {
Self
}
}
impl SystemPromptProvider for TodoPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
Some(SystemPrompt {
content: TODO_WRITE_INSTRUCTIONS.to_string(),
context: Some("todo_write".to_string()),
})
}
}
const TODO_WRITE_INSTRUCTIONS: &str = r#"
## TodoWrite
使 `todo_write`
### 使
- 3 使 todo_write
- todo
### merge
- `merge: false` todo
- `merge: true` **使 merge=true id**
###
- `pending`
- `in_progress`
- `completed`
- `cancelled`
###
1. `in_progress`
2. `in_progress`
3. `completed` `cancelled`
4. completed
5. `content`
### 使
merge
```json
{"merge": true, "todos": [{"content": "修复登录 bug", "status": "in_progress"}]}
```
```json
{"merge": true, "todos": [{"content": "补充测试", "status": "pending"}]}
```
id +
```json
{"merge": true, "todos": [{"id": "xxx", "content": "修复登录 bug", "status": "completed"}]}
```
"#;

View File

@ -1,16 +1,19 @@
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock;
use crate::config::TaskConfig; use crate::config::TaskConfig;
use crate::mcp::McpClientManager; use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
use crate::tools::todo_write::TodoItem;
use crate::tools::{ use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool, HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool,
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
ToolRegistry, WebFetchTool, TodoWriteTool, ToolRegistry, WebFetchTool,
}; };
pub(crate) struct ToolRegistryFactory { pub(crate) struct ToolRegistryFactory {
@ -25,6 +28,7 @@ pub(crate) struct ToolRegistryFactory {
task_config: TaskConfig, task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>, subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>, mcp_manager: Option<Arc<McpClientManager>>,
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
} }
impl ToolRegistryFactory { impl ToolRegistryFactory {
@ -51,9 +55,18 @@ impl ToolRegistryFactory {
task_config, task_config,
subagent_runtime: None, subagent_runtime: None,
mcp_manager: None, mcp_manager: None,
todo_state: None,
} }
} }
pub(crate) fn with_todo_state(
mut self,
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
) -> Self {
self.todo_state = Some(state);
self
}
pub(crate) fn with_subagent_runtime( pub(crate) fn with_subagent_runtime(
mut self, mut self,
runtime: Arc<dyn SubAgentRuntime>, runtime: Arc<dyn SubAgentRuntime>,
@ -98,6 +111,11 @@ impl ToolRegistryFactory {
if self.is_enabled("memory_manage") { if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone())); registry.register(MemoryManageTool::new(self.memories.clone()));
} }
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
}
}
if self.is_enabled("session_send") { if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone())); registry.register(SessionSendTool::new(self.session_message_sender.clone()));
} }
@ -198,6 +216,13 @@ impl ToolRegistryFactory {
registry.register(SessionSendTool::new(self.session_message_sender.clone())); registry.register(SessionSendTool::new(self.session_message_sender.clone()));
} }
// Todo 追踪工具
if self.is_enabled("todo_write") {
if let Some(ref state) = self.todo_state {
registry.register(TodoWriteTool::new(state.clone()));
}
}
// 注册 MCP 工具(如果提供) // 注册 MCP 工具(如果提供)
if let Some(mcp_tools) = mcp_tools { if let Some(mcp_tools) = mcp_tools {
for tool in mcp_tools { for tool in mcp_tools {

View File

@ -82,6 +82,17 @@ pub struct SkillSummary {
pub source: String, pub source: String,
} }
/// Todo item 摘要(发送给前端)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TodoItemSummary {
pub id: String,
pub content: String,
pub status: String,
pub priority: String,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerJobSummary { pub struct SchedulerJobSummary {
pub id: String, pub id: String,
@ -257,6 +268,11 @@ pub enum WsOutbound {
}, },
#[serde(rename = "execution_cancelled")] #[serde(rename = "execution_cancelled")]
ExecutionCancelled { message: String }, ExecutionCancelled { message: String },
#[serde(rename = "todo_list")]
TodoList {
todos: Vec<TodoItemSummary>,
scope_key: String,
},
#[serde(rename = "pong")] #[serde(rename = "pong")]
Pong, Pong,
} }

View File

@ -14,13 +14,13 @@ pub mod records;
pub use error::StorageError; pub use error::StorageError;
pub use ports::{ pub use ports::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SkillEventRepository, SkillEventRepository, TodoRepository,
}; };
pub use records::{ pub use records::{
allowed_namespace_names, get_namespace_description, is_valid_namespace, allowed_namespace_names, get_namespace_description, is_valid_namespace,
ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord, ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord,
SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord,
TopicRecord, TodoRecord, TopicRecord,
}; };
#[derive(Clone)] #[derive(Clone)]
@ -217,6 +217,7 @@ impl SessionStore {
ensure_messages_schema(&conn)?; ensure_messages_schema(&conn)?;
ensure_scheduler_schema(&conn)?; ensure_scheduler_schema(&conn)?;
ensure_memory_scope_key_migration(&conn)?; ensure_memory_scope_key_migration(&conn)?;
ensure_todos_schema(&conn)?;
drop(conn); drop(conn);
@ -1491,6 +1492,74 @@ impl SessionStore {
) )
.map_err(StorageError::from) .map_err(StorageError::from)
} }
pub fn replace_todos(
&self,
scope_key: &str,
items: &[TodoRecord],
) -> Result<Vec<TodoRecord>, StorageError> {
let conn = self.pool.get()?;
let now = current_timestamp();
// Delete existing todos for this scope_key
conn.execute(
"DELETE FROM todos WHERE scope_key = ?1",
params![scope_key],
)?;
// Insert new todos
for item in items {
conn.execute(
"INSERT INTO todos (id, scope_key, session_id, topic_id, content, status, priority, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
params![
item.id,
scope_key,
item.session_id,
item.topic_id,
item.content,
item.status,
item.priority,
item.created_at,
now,
],
)?;
}
drop(conn);
self.list_todos(scope_key)
}
pub fn list_todos(&self, scope_key: &str) -> Result<Vec<TodoRecord>, StorageError> {
let conn = self.pool.get()?;
let mut stmt = conn.prepare(
"SELECT id, scope_key, session_id, topic_id, content, status, priority, created_at, updated_at
FROM todos
WHERE scope_key = ?1
ORDER BY created_at ASC",
)?;
let rows = stmt.query_map(params![scope_key], |row| {
Ok(TodoRecord {
id: row.get(0)?,
scope_key: row.get(1)?,
session_id: row.get(2)?,
topic_id: row.get(3)?,
content: row.get(4)?,
status: row.get(5)?,
priority: row.get(6)?,
created_at: row.get(7)?,
updated_at: row.get(8)?,
})
})?;
let mut todos = Vec::new();
for row in rows {
todos.push(row?);
}
Ok(todos)
}
} }
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String { pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
@ -1800,6 +1869,42 @@ fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageErr
Ok(()) Ok(())
} }
fn ensure_todos_schema(conn: &Connection) -> Result<(), StorageError> {
let table_exists: bool = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='todos'",
[],
|row| row.get::<_, i64>(0),
)
.map(|count| count > 0)?;
if !table_exists {
conn.execute_batch(
"
CREATE TABLE IF NOT EXISTS todos (
id TEXT PRIMARY KEY,
scope_key TEXT NOT NULL,
session_id TEXT NOT NULL,
topic_id TEXT,
content TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
priority TEXT NOT NULL DEFAULT 'medium',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_todos_scope
ON todos(scope_key, created_at ASC);
CREATE INDEX IF NOT EXISTS idx_todos_session
ON todos(session_id);
",
)?;
}
Ok(())
}
fn has_column( fn has_column(
conn: &Connection, conn: &Connection,
table_name: &str, table_name: &str,
@ -2009,7 +2114,7 @@ fn load_messages_after(
messages.push(row?); messages.push(row?);
} }
Ok(messages) Ok(messages)
} }
fn current_timestamp() -> i64 { fn current_timestamp() -> i64 {
std::time::SystemTime::now() std::time::SystemTime::now()

View File

@ -1,6 +1,6 @@
use super::{ use super::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError, SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError, TodoRecord,
}; };
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
@ -145,6 +145,18 @@ pub trait SkillEventRepository: Send + Sync + 'static {
) -> Result<Vec<SkillEventRecord>, StorageError>; ) -> Result<Vec<SkillEventRecord>, StorageError>;
} }
pub trait TodoRepository: Send + Sync + 'static {
/// Replace all todos for a scope (full replacement pattern).
fn replace_todos(
&self,
scope_key: &str,
todo_records: &[TodoRecord],
) -> Result<Vec<TodoRecord>, StorageError>;
/// Load all todos for a scope, ordered by created_at.
fn list_todos(&self, scope_key: &str) -> Result<Vec<TodoRecord>, StorageError>;
}
impl ConversationRepository for super::SessionStore { impl ConversationRepository for super::SessionStore {
fn ensure_channel_session( fn ensure_channel_session(
&self, &self,
@ -356,3 +368,17 @@ impl SkillEventRepository for super::SessionStore {
super::SessionStore::list_skill_events(self, session_id) super::SessionStore::list_skill_events(self, session_id)
} }
} }
impl TodoRepository for super::SessionStore {
fn replace_todos(
&self,
scope_key: &str,
todo_records: &[TodoRecord],
) -> Result<Vec<TodoRecord>, StorageError> {
super::SessionStore::replace_todos(self, scope_key, todo_records)
}
fn list_todos(&self, scope_key: &str) -> Result<Vec<TodoRecord>, StorageError> {
super::SessionStore::list_todos(self, scope_key)
}
}

View File

@ -35,6 +35,19 @@ pub fn allowed_namespace_names() -> Vec<&'static str> {
ALLOWED_MEMORY_NAMESPACES.iter().map(|(name, _)| *name).collect() ALLOWED_MEMORY_NAMESPACES.iter().map(|(name, _)| *name).collect()
} }
#[derive(Debug, Clone)]
pub struct TodoRecord {
pub id: String,
pub scope_key: String,
pub session_id: String,
pub topic_id: Option<String>,
pub content: String,
pub status: String,
pub priority: String,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillEventRecord { pub struct SkillEventRecord {
pub id: String, pub id: String,

View File

@ -14,6 +14,7 @@ pub mod skill_activate;
pub mod skill_manage; pub mod skill_manage;
pub mod task; pub mod task;
pub mod time; pub mod time;
pub mod todo_write;
pub mod traits; pub mod traits;
pub mod web_fetch; pub mod web_fetch;
@ -39,6 +40,7 @@ pub use task::{
SubagentCatalog, TaskError, TaskRepository, TaskTool, SubagentCatalog, TaskError, TaskRepository, TaskTool,
}; };
pub use time::TimeTool; pub use time::TimeTool;
pub use todo_write::TodoWriteTool;
pub use traits::{Tool, ToolContext, ToolResult}; pub use traits::{Tool, ToolContext, ToolResult};
pub use web_fetch::WebFetchTool; pub use web_fetch::WebFetchTool;

View File

@ -52,6 +52,8 @@ impl SubagentPromptBuilder {
2. 使\n\ 2. 使\n\
3. \n\ 3. \n\
4. \n\n\ 4. \n\n\
:\n\
使 `todo_write` in_progress3使\n\n\
: 访" : 访"
} else { } else {
&def.prompt_template &def.prompt_template

View File

@ -61,7 +61,7 @@ impl SubagentDef {
Self { Self {
name: "general".to_string(), name: "general".to_string(),
description: "通用型子代理 - 处理复杂多步骤任务".to_string(), description: "通用型子代理 - 处理复杂多步骤任务".to_string(),
prompt_template: "你是一个专注的子代理,正在执行一个独立任务。\n\n任务描述: {{description}}\n\n你应该:\n1. 专注于完成任务,不要偏离目标\n2. 使用可用的工具进行必要操作\n3. 完成后给出简洁的总结\n4. 不要尝试创建新的子代理任务\n\n注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。".to_string(), prompt_template: "你是一个专注的子代理,正在执行一个独立任务。\n\n任务描述: {{description}}\n\n你应该:\n1. 专注于完成任务,不要偏离目标\n2. 使用可用的工具进行必要操作\n3. 完成后给出简洁的总结\n4. 不要尝试创建新的子代理任务\n\n任务追踪:\n你可以使用 `todo_write` 工具追踪子任务进度。规则:同一时间只有一个 in_progress完成后再标记下一个3步以上才使用。\n\n注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。".to_string(),
body: None, body: None,
allowed_tools: None, allowed_tools: None,
max_execution_secs: None, max_execution_secs: None,

1153
src/tools/todo_write.rs Normal file

File diff suppressed because it is too large Load Diff