- 在待办相关数据结构和存储中新增 created_by_message_id 字段 - 记录待办项创建时对应的消息ID,支持追溯来源 - 在前端待办列表项增加点击事件,点击后滚动并高亮对应消息 - 在消息列表组件中实现高亮动画及自动滚动功能 - 更新相关工具、协议和数据库查询,确保新字段正确传递和存储 - 增加 CSS 动画实现待办对应消息的高亮闪烁效果 - 优化前端状态管理,支持设置与获取高亮消息ID
1144 lines
38 KiB
Rust
1144 lines
38 KiB
Rust
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
|
||
use async_trait::async_trait;
|
||
use serde::Serialize;
|
||
use serde_json::json;
|
||
use tokio::sync::RwLock;
|
||
|
||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||
|
||
// ── 数据模型 ──────────────────────────────────────────────
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
enum TodoStatus {
|
||
Pending,
|
||
InProgress,
|
||
Completed,
|
||
Cancelled,
|
||
}
|
||
|
||
impl TodoStatus {
|
||
fn as_str(&self) -> &'static str {
|
||
match self {
|
||
TodoStatus::Pending => "pending",
|
||
TodoStatus::InProgress => "in_progress",
|
||
TodoStatus::Completed => "completed",
|
||
TodoStatus::Cancelled => "cancelled",
|
||
}
|
||
}
|
||
|
||
fn from_str(value: &str) -> Option<Self> {
|
||
match value {
|
||
"pending" => Some(Self::Pending),
|
||
"in_progress" => Some(Self::InProgress),
|
||
"completed" => Some(Self::Completed),
|
||
"cancelled" => Some(Self::Cancelled),
|
||
_ => None,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 内存中的 Todo 项
|
||
#[derive(Debug, Clone, Serialize)]
|
||
pub(crate) struct TodoItem {
|
||
pub id: String,
|
||
pub content: String,
|
||
pub status: String,
|
||
pub created_by_message_id: Option<String>,
|
||
}
|
||
|
||
/// 工具完整返回
|
||
#[derive(Debug, Clone, Serialize)]
|
||
struct TodoWriteOutput {
|
||
current_todos: Vec<TodoItem>,
|
||
message: String,
|
||
}
|
||
|
||
// ── 工具实现 ──────────────────────────────────────────────
|
||
|
||
pub struct TodoWriteTool {
|
||
/// 内存状态:scope_key → Vec<TodoItem>
|
||
/// scope_key = topic_id.unwrap_or(session_id)
|
||
state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>,
|
||
}
|
||
|
||
impl TodoWriteTool {
|
||
pub(crate) fn new(state: Arc<RwLock<HashMap<String, Vec<TodoItem>>>>) -> Self {
|
||
Self { state }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl Tool for TodoWriteTool {
|
||
fn name(&self) -> &str {
|
||
"todo_write"
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
"Manage a structured task list for tracking work within the current conversation. \
|
||
Two modes: merge=false (default, full replacement — omitted items are removed); \
|
||
merge=true (incremental — only send the items you want to add/update, \
|
||
previously existing items are preserved). \
|
||
Use when you have 3+ distinct steps to track. \
|
||
Rules: only ONE in_progress at a time, complete work before marking completed, \
|
||
every item requires an id (generate a short random string for new items)."
|
||
}
|
||
|
||
fn parameters_schema(&self) -> serde_json::Value {
|
||
json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"merge": {
|
||
"type": "boolean",
|
||
"description": "false (default): full replacement — todos not in the list are removed. true: incremental — only send items you want to add or update, existing items not mentioned are preserved."
|
||
},
|
||
"todos": {
|
||
"type": "array",
|
||
"description": "The todo items to add or update. In merge=false mode, this is the complete replacement list. In merge=true mode, only send the items that changed — unreferenced items are kept as-is.",
|
||
"items": {
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {
|
||
"type": "string",
|
||
"description": "Unique identifier for the todo item. Generate a short random string (e.g. 'r9Tg8Kq2pLm7') for new items. Use the existing id to update."
|
||
},
|
||
"content": {
|
||
"type": "string",
|
||
"description": "Brief, actionable description of the task"
|
||
},
|
||
"status": {
|
||
"type": "string",
|
||
"enum": ["pending", "in_progress", "completed", "cancelled"],
|
||
"description": "Current status: pending=not started, in_progress=working on (only ONE at a time), completed=done, cancelled=no longer needed"
|
||
}
|
||
},
|
||
"required": ["id", "content", "status"]
|
||
}
|
||
}
|
||
},
|
||
"required": ["todos"]
|
||
})
|
||
}
|
||
|
||
fn read_only(&self) -> bool {
|
||
false
|
||
}
|
||
|
||
fn concurrency_safe(&self) -> bool {
|
||
false
|
||
}
|
||
|
||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||
Ok(error_result("todo_write requires tool context (session_id)"))
|
||
}
|
||
|
||
async fn execute_with_context(
|
||
&self,
|
||
context: &ToolContext,
|
||
args: serde_json::Value,
|
||
) -> anyhow::Result<ToolResult> {
|
||
// 1. 计算 scope_key
|
||
let scope_key = match scope_key_from_context(context) {
|
||
Some(key) => key,
|
||
None => return Ok(error_result("todo_write requires session_id or topic_id in tool context")),
|
||
};
|
||
|
||
// 2. 提取当前消息 ID(用于记录待办的创建来源)
|
||
let message_id = context.message_id.clone();
|
||
|
||
// 3. 解析入参
|
||
let todos_array = match args.get("todos").and_then(|v| v.as_array()) {
|
||
Some(arr) => arr,
|
||
None => return Ok(error_result("Missing required parameter: todos (must be an array)")),
|
||
};
|
||
|
||
let merge_mode = args
|
||
.get("merge")
|
||
.and_then(|v| v.as_bool())
|
||
.unwrap_or(false);
|
||
|
||
// 3. 读锁获取旧状态
|
||
let old_items = {
|
||
let guard = self.state.read().await;
|
||
guard.get(&scope_key).cloned().unwrap_or_default()
|
||
};
|
||
|
||
// 构建 id → TodoItem 的旧状态映射
|
||
let old_map: HashMap<&str, &TodoItem> = old_items.iter().map(|item| (item.id.as_str(), item)).collect();
|
||
|
||
// 4. 解析并校验每个输入项
|
||
let mut processed_items: Vec<TodoItem> = Vec::new();
|
||
let mut validation_errors: Vec<String> = Vec::new();
|
||
|
||
for (idx, input) in todos_array.iter().enumerate() {
|
||
let id = match input.get("id").and_then(|v| v.as_str()) {
|
||
Some(s) if !s.trim().is_empty() => s.trim().to_string(),
|
||
_ => {
|
||
validation_errors.push(format!("Item {}: missing or empty 'id'", idx));
|
||
continue;
|
||
}
|
||
};
|
||
|
||
let content = match input.get("content").and_then(|v| v.as_str()) {
|
||
Some(s) if !s.trim().is_empty() => s.trim().to_string(),
|
||
_ => {
|
||
validation_errors.push(format!("Item {}: missing or empty 'content'", idx));
|
||
continue;
|
||
}
|
||
};
|
||
|
||
let status_str = input
|
||
.get("status")
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or("pending");
|
||
|
||
let new_status = match TodoStatus::from_str(status_str) {
|
||
Some(s) => s,
|
||
None => {
|
||
validation_errors.push(format!(
|
||
"Item '{}': invalid status '{}'. Valid: pending, in_progress, completed, cancelled",
|
||
content, status_str
|
||
));
|
||
continue;
|
||
}
|
||
};
|
||
|
||
if let Some(old_item) = old_map.get(id.as_str()) {
|
||
// id 匹配旧项 → 更新,校验状态转换
|
||
let old_status = match TodoStatus::from_str(&old_item.status) {
|
||
Some(s) => s,
|
||
None => {
|
||
validation_errors.push(format!("Item '{}': corrupted old status", content));
|
||
continue;
|
||
}
|
||
};
|
||
|
||
if let Err(err) = validate_transition(&old_status, &new_status) {
|
||
validation_errors.push(format!("Item '{}': {}", content, err));
|
||
continue;
|
||
}
|
||
|
||
processed_items.push(TodoItem {
|
||
id,
|
||
content,
|
||
status: new_status.as_str().to_string(),
|
||
created_by_message_id: message_id.clone(),
|
||
});
|
||
} else if merge_mode {
|
||
// merge 模式:id 不匹配,尝试 content fallback
|
||
if let Some(old_item) = old_items.iter().find(|oi| oi.content == content) {
|
||
let old_status = match TodoStatus::from_str(&old_item.status) {
|
||
Some(s) => s,
|
||
None => {
|
||
validation_errors.push(format!("Item '{}': corrupted old status", content));
|
||
continue;
|
||
}
|
||
};
|
||
if let Err(err) = validate_transition(&old_status, &new_status) {
|
||
validation_errors.push(format!("Item '{}': {}", content, err));
|
||
continue;
|
||
}
|
||
processed_items.push(TodoItem {
|
||
id: old_item.id.clone(),
|
||
content,
|
||
status: new_status.as_str().to_string(),
|
||
created_by_message_id: message_id.clone(),
|
||
});
|
||
} else {
|
||
// 全新项
|
||
processed_items.push(TodoItem {
|
||
id,
|
||
content,
|
||
status: new_status.as_str().to_string(),
|
||
created_by_message_id: message_id.clone(),
|
||
});
|
||
}
|
||
} else {
|
||
// 全量替换模式:id 不匹配 → 全新项
|
||
processed_items.push(TodoItem {
|
||
id,
|
||
content,
|
||
status: new_status.as_str().to_string(),
|
||
created_by_message_id: message_id.clone(),
|
||
});
|
||
}
|
||
}
|
||
|
||
if !validation_errors.is_empty() {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(validation_errors.join("\n")),
|
||
});
|
||
}
|
||
|
||
// 5. 合并模式:将旧列表中未被引用的项保留
|
||
let processed_ids: std::collections::HashSet<&str> =
|
||
processed_items.iter().map(|item| item.id.as_str()).collect();
|
||
|
||
let final_items: Vec<TodoItem> = if merge_mode {
|
||
let mut merged = processed_items.clone();
|
||
for old in &old_items {
|
||
if !processed_ids.contains(old.id.as_str()) {
|
||
merged.push(old.clone());
|
||
}
|
||
}
|
||
merged
|
||
} else {
|
||
processed_items
|
||
};
|
||
|
||
// 6. 全局约束:只有一个 in_progress
|
||
let in_progress_count = final_items
|
||
.iter()
|
||
.filter(|item| item.status == "in_progress")
|
||
.count();
|
||
if in_progress_count > 1 {
|
||
return Ok(error_result(&format!(
|
||
"Only one task can be 'in_progress' at a time. Found {} in_progress tasks.",
|
||
in_progress_count
|
||
)));
|
||
}
|
||
|
||
// 7. 计算 removed 数量(仅全量替换模式)
|
||
let final_ids: std::collections::HashSet<&str> = final_items.iter().map(|item| item.id.as_str()).collect();
|
||
let removed_count = if merge_mode {
|
||
0
|
||
} else {
|
||
old_items
|
||
.iter()
|
||
.filter(|item| !final_ids.contains(item.id.as_str()))
|
||
.count()
|
||
};
|
||
|
||
// 8. 更新内存状态
|
||
{
|
||
let mut guard = self.state.write().await;
|
||
guard.insert(scope_key.clone(), final_items.clone());
|
||
}
|
||
|
||
// 9. 生成友好消息
|
||
let message = build_message(final_items.len(), removed_count, merge_mode);
|
||
|
||
let output = TodoWriteOutput {
|
||
current_todos: final_items,
|
||
message,
|
||
};
|
||
|
||
Ok(ToolResult {
|
||
success: true,
|
||
output: serde_json::to_string_pretty(&output)?,
|
||
error: None,
|
||
})
|
||
}
|
||
}
|
||
|
||
// ── 辅助函数 ──────────────────────────────────────────────
|
||
|
||
/// 计算 scope_key:
|
||
/// - 主代理 (nesting_depth == 0):优先 topic_id,否则 session_id
|
||
/// - 子/孙代理 (nesting_depth > 0):使用 task_id 隔离(全局唯一,与 list_todos 保持一致)
|
||
pub(crate) fn scope_key_from_context(context: &ToolContext) -> Option<String> {
|
||
if context.nesting_depth > 0 {
|
||
// 使用 task_id 而不是 session_id 作为 scope_key。
|
||
// session_id 对于孙智能体包含父链(如 sub:sub:root:parent:task),
|
||
// 而 list_todos handler 用根 session + task_id 拼接,两者不匹配。
|
||
// task_id 是全局唯一的 UUID(task:xxx),直接使用可避免层级不一致。
|
||
context.task_id.clone().filter(|s| !s.is_empty())
|
||
} else {
|
||
let tid = context.topic_id.as_deref().filter(|t| !t.is_empty());
|
||
let sid = context.session_id.as_deref().filter(|s| !s.is_empty());
|
||
tid.or(sid).map(str::to_string)
|
||
}
|
||
}
|
||
|
||
/// 校验状态转换合法性
|
||
fn validate_transition(old: &TodoStatus, new: &TodoStatus) -> Result<(), String> {
|
||
match (old, new) {
|
||
// pending → anything is allowed
|
||
(TodoStatus::Pending, _) => Ok(()),
|
||
|
||
// in_progress → completed, cancelled, or same
|
||
(TodoStatus::InProgress, TodoStatus::Completed) => Ok(()),
|
||
(TodoStatus::InProgress, TodoStatus::Cancelled) => Ok(()),
|
||
(TodoStatus::InProgress, TodoStatus::InProgress) => Ok(()),
|
||
(TodoStatus::InProgress, TodoStatus::Pending) => Err(
|
||
"Cannot move an in_progress task back to pending. Use completed or cancelled.".to_string(),
|
||
),
|
||
|
||
// completed → can reactivate to in_progress or pending
|
||
(TodoStatus::Completed, TodoStatus::InProgress) => Ok(()),
|
||
(TodoStatus::Completed, TodoStatus::Pending) => Ok(()),
|
||
(TodoStatus::Completed, TodoStatus::Completed) => Ok(()),
|
||
(TodoStatus::Completed, TodoStatus::Cancelled) => Err(
|
||
"Cannot cancel a completed task. Move it to pending first if needed.".to_string(),
|
||
),
|
||
|
||
// cancelled → can reactivate to pending or in_progress
|
||
(TodoStatus::Cancelled, TodoStatus::Pending) => Ok(()),
|
||
(TodoStatus::Cancelled, TodoStatus::InProgress) => Ok(()),
|
||
(TodoStatus::Cancelled, TodoStatus::Cancelled) => Ok(()),
|
||
(TodoStatus::Cancelled, TodoStatus::Completed) => Err(
|
||
"Cannot complete a cancelled task. Move it to pending first if needed.".to_string(),
|
||
),
|
||
}
|
||
}
|
||
|
||
fn build_message(total: usize, removed: usize, merge_mode: bool) -> String {
|
||
if merge_mode {
|
||
format!("Todo list updated: {} items total", total)
|
||
} else if removed > 0 {
|
||
format!("Replaced todo list: {} items ({} removed)", total, removed)
|
||
} else {
|
||
format!("Todo list set: {} items", total)
|
||
}
|
||
}
|
||
|
||
fn error_result(message: &str) -> ToolResult {
|
||
ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(message.to_string()),
|
||
}
|
||
}
|
||
|
||
// ── 测试 ──────────────────────────────────────────────────
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::tools::traits::ToolContext;
|
||
|
||
fn test_context() -> ToolContext {
|
||
ToolContext {
|
||
channel_name: Some("cli".to_string()),
|
||
sender_id: Some("user-1".to_string()),
|
||
chat_id: Some("chat-1".to_string()),
|
||
session_id: Some("cli:chat-1".to_string()),
|
||
topic_id: None,
|
||
message_id: Some("msg-1".to_string()),
|
||
message_seq: Some(1),
|
||
subagent_description: None,
|
||
nesting_depth: 0,
|
||
task_id: None,
|
||
parent_task_id: None,
|
||
tool_call_id: None,
|
||
}
|
||
}
|
||
|
||
fn test_state() -> Arc<RwLock<HashMap<String, Vec<TodoItem>>>> {
|
||
Arc::new(RwLock::new(HashMap::new()))
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_create_initial_todos() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
let context = test_context();
|
||
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "a1", "content": "设计数据库", "status": "pending"},
|
||
{"id": "a2", "content": "实现 API", "status": "pending"},
|
||
{"id": "a3", "content": "写测试", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
assert_eq!(output["current_todos"].as_array().unwrap().len(), 3);
|
||
// 不应有 changes 字段
|
||
assert!(output.get("changes").is_none());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_single_in_progress_constraint() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "b1", "content": "任务A", "status": "pending"},
|
||
{"id": "b2", "content": "任务B", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// 尝试将两个都设为 in_progress
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "b1", "content": "任务A", "status": "in_progress"},
|
||
{"id": "b2", "content": "任务B", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(!result.success);
|
||
assert!(result.error.unwrap().contains("Only one task can be 'in_progress'"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_state_transition_in_progress_to_completed() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "c1", "content": "任务A", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// pending → in_progress
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "c1", "content": "任务A", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// in_progress → completed
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "c1", "content": "任务A", "status": "completed"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos[0]["status"], "completed");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_completed_can_revert_to_in_progress() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
// 创建并完成一个任务
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "d1", "content": "任务A", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "d1", "content": "任务A", "status": "completed"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// completed → in_progress(返工)
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "d1", "content": "任务A", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
assert_eq!(output["current_todos"][0]["status"], "in_progress");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_cancelled_can_revert_to_pending() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "e1", "content": "任务A", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "e1", "content": "任务A", "status": "cancelled"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// cancelled → pending(恢复)
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "e1", "content": "任务A", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
assert_eq!(output["current_todos"][0]["status"], "pending");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_in_progress_cannot_revert_to_pending() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "f1", "content": "任务A", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "f1", "content": "任务A", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// in_progress → pending(禁止)
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "f1", "content": "任务A", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(!result.success);
|
||
assert!(result.error.unwrap().contains("Cannot move an in_progress task back to pending"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_new_item_can_be_any_status() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
let context = test_context();
|
||
|
||
// 新项直接 completed — 应该允许(id 必填后不再限制初始状态)
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "g1", "content": "任务A", "status": "completed"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_new_item_can_start_as_in_progress() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
let context = test_context();
|
||
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "h1", "content": "第一个任务", "status": "in_progress"},
|
||
{"id": "h2", "content": "第二个任务", "status": "pending"},
|
||
{"id": "h3", "content": "第三个任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos.len(), 3);
|
||
assert_eq!(todos[0]["status"], "in_progress");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_remove_items_by_omission() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "i1", "content": "任务A", "status": "pending"},
|
||
{"id": "i2", "content": "任务B", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// 只传入一个任务(任务B 被移除)
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "i1", "content": "任务A", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
assert_eq!(output["current_todos"].as_array().unwrap().len(), 1);
|
||
// message 应包含 removed 计数
|
||
assert!(output["message"].as_str().unwrap().contains("1 removed"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_topic_isolation() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
|
||
let main_context = ToolContext {
|
||
session_id: Some("cli:chat-1".to_string()),
|
||
topic_id: None,
|
||
..ToolContext::default()
|
||
};
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&main_context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "j1", "content": "主会话任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let topic_context = ToolContext {
|
||
session_id: Some("cli:chat-1".to_string()),
|
||
topic_id: Some("topic-xyz".to_string()),
|
||
..ToolContext::default()
|
||
};
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&topic_context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "j2", "content": "话题任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let guard = state.read().await;
|
||
let main_items = guard.get("cli:chat-1").unwrap();
|
||
let topic_items = guard.get("topic-xyz").unwrap();
|
||
|
||
assert_eq!(main_items.len(), 1);
|
||
assert_eq!(main_items[0].content, "主会话任务");
|
||
assert_eq!(topic_items.len(), 1);
|
||
assert_eq!(topic_items[0].content, "话题任务");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_empty_list() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
let context = test_context();
|
||
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({"todos": []}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
assert_eq!(output["current_todos"].as_array().unwrap().len(), 0);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_missing_todos_param() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
let context = test_context();
|
||
|
||
let result = tool
|
||
.execute_with_context(&context, json!({}))
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(!result.success);
|
||
assert!(result.error.unwrap().contains("Missing required parameter"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_no_context() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
|
||
let result = tool.execute(json!({})).await.unwrap();
|
||
|
||
assert!(!result.success);
|
||
assert!(result.error.unwrap().contains("requires tool context"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_subagent_isolation() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
|
||
let parent_ctx = ToolContext {
|
||
session_id: Some("cli:chat-1".to_string()),
|
||
..ToolContext::default()
|
||
};
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&parent_ctx,
|
||
json!({
|
||
"todos": [
|
||
{"id": "k1", "content": "父代理任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let child_ctx = ToolContext {
|
||
session_id: Some("sub:cli:chat-1:task:uuid-abc".to_string()),
|
||
..ToolContext::default()
|
||
};
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&child_ctx,
|
||
json!({
|
||
"todos": [
|
||
{"id": "k2", "content": "子代理任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let guard = state.read().await;
|
||
let parent_items = guard.get("cli:chat-1").unwrap();
|
||
let child_items = guard.get("sub:cli:chat-1:task:uuid-abc").unwrap();
|
||
|
||
assert_eq!(parent_items.len(), 1);
|
||
assert_eq!(parent_items[0].content, "父代理任务");
|
||
assert_eq!(child_items.len(), 1);
|
||
assert_eq!(child_items[0].content, "子代理任务");
|
||
}
|
||
|
||
// ── merge 模式测试 ──────────────────────────────────────
|
||
|
||
#[tokio::test]
|
||
async fn test_merge_mode_preserves_unreferenced_items() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "m1", "content": "任务A", "status": "pending"},
|
||
{"id": "m2", "content": "任务B", "status": "pending"},
|
||
{"id": "m3", "content": "任务C", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// merge: true — 只传任务 A(改为 in_progress),B 和 C 应被保留
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"merge": true,
|
||
"todos": [
|
||
{"id": "m1", "content": "任务A", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos.len(), 3);
|
||
let task_a = todos.iter().find(|t| t["id"] == "m1").unwrap();
|
||
assert_eq!(task_a["status"], "in_progress");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_merge_mode_add_new_item() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "n1", "content": "已有任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// merge: true — 添加一个新的 pending 任务
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"merge": true,
|
||
"todos": [
|
||
{"id": "n2", "content": "新任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos.len(), 2);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_merge_mode_never_removes() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "o1", "content": "任务A", "status": "pending"},
|
||
{"id": "o2", "content": "任务B", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// merge: true — 传入空列表,不应删除任何项
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"merge": true,
|
||
"todos": []
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos.len(), 2);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_non_merge_still_removes_by_omission() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "p1", "content": "任务A", "status": "pending"},
|
||
{"id": "p2", "content": "任务B", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// merge=false(默认)— 只传一个,另一个被删
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "p1", "content": "任务A", "status": "in_progress"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos.len(), 1);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_merge_match_by_content_fallback() {
|
||
let state = test_state();
|
||
let tool = TodoWriteTool::new(state.clone());
|
||
let context = test_context();
|
||
|
||
let _ = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"id": "q1", "content": "任务1", "status": "pending"},
|
||
{"id": "q2", "content": "任务2", "status": "pending"},
|
||
{"id": "q3", "content": "任务3", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
// merge: true — 传了不同的 id 但相同 content,应通过 content fallback 匹配
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"merge": true,
|
||
"todos": [
|
||
{"id": "wrong-id-1", "content": "任务1", "status": "completed"},
|
||
{"id": "wrong-id-2", "content": "任务2", "status": "completed"},
|
||
{"id": "wrong-id-3", "content": "任务3", "status": "cancelled"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.success);
|
||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||
let todos = output["current_todos"].as_array().unwrap();
|
||
assert_eq!(todos.len(), 3);
|
||
// content fallback 匹配后应使用旧 id
|
||
assert_eq!(todos.iter().find(|t| t["content"] == "任务1").unwrap()["id"], "q1");
|
||
assert_eq!(todos.iter().find(|t| t["content"] == "任务1").unwrap()["status"], "completed");
|
||
assert_eq!(todos.iter().find(|t| t["content"] == "任务2").unwrap()["status"], "completed");
|
||
assert_eq!(todos.iter().find(|t| t["content"] == "任务3").unwrap()["status"], "cancelled");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_missing_id_validation_error() {
|
||
let tool = TodoWriteTool::new(test_state());
|
||
let context = test_context();
|
||
|
||
let result = tool
|
||
.execute_with_context(
|
||
&context,
|
||
json!({
|
||
"todos": [
|
||
{"content": "缺少 id 的任务", "status": "pending"}
|
||
]
|
||
}),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(!result.success);
|
||
assert!(result.error.unwrap().contains("missing or empty 'id'"));
|
||
}
|
||
}
|