PicoBot/src/gateway/execution.rs
oudecheng 881fcace47 feat: 添加 todo_write 工具,支持全量替换和增量合并两种模式
- Tool: 纯内存实现 (Arc<RwLock<HashMap>>),零 DB 依赖,解耦持久化
- 状态机: pending → in_progress → completed/cancelled,单 in_progress 约束
- merge=false: 全量替换模式(默认)
- merge=true: 增量更新模式,只传变更的项,其余保留
- 隔离: scope_key = topic_id.unwrap_or(session_id),topic 和子代理隔离
- 持久化: TodoRepository trait + SessionStore SQLite 实现,在 Session 拦截器层完成
- 前端推送: WsOutbound::TodoList 事件
- Prompt: TodoPromptProvider 中文指令,子代理模板也包含
- 测试: 16 个单元测试,全部通过

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 14:19:07 +08:00

491 lines
19 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler, PersistingEmittedMessageHandler, SystemPromptContext};
use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDULED_PROMPT};
use crate::config::LLMProviderConfig;
use crate::storage::ConversationRepository;
use tokio::sync::Mutex;
use super::compaction::schedule_background_history_compaction;
use super::message_prepare::enrich_user_content_with_media_refs;
use super::session::Session;
/// 空的 EmittedMessageHandler不转发消息仅配合 PersistingEmittedMessageHandler 做持久化。
struct NoOpEmittedMessageHandler;
#[async_trait]
impl EmittedMessageHandler for NoOpEmittedMessageHandler {
async fn handle(&self, _message: ChatMessage) {}
}
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
match system_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
{
Some(system_prompt) => format!(
"{}\n\n任务专属要求:{}",
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
),
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
}
}
pub(crate) struct AgentExecutionService {
show_tool_results: bool,
}
pub(crate) struct FinalizeAgentResultRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) user_message: &'a ChatMessage,
pub(crate) result: AgentProcessResult,
pub(crate) metadata: &'a HashMap<String, String>,
pub(crate) suppress_live_tool_calls: bool,
pub(crate) execution_kind: &'a str,
pub(crate) original_topic_id: Option<String>,
}
pub(crate) struct FinalizedAgentResult {
pub(crate) outbound_messages: Vec<OutboundMessage>,
pub(crate) should_schedule_compaction: bool,
}
pub(crate) struct MessageExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) content: &'a str,
pub(crate) media: Vec<MediaItem>,
pub(crate) live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
}
pub(crate) struct ScheduledExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) notification_chat_id: Option<&'a str>,
pub(crate) prompt: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) provider_config: LLMProviderConfig,
pub(crate) system_prompt: Option<&'a str>,
pub(crate) metadata: &'a HashMap<String, String>,
pub(crate) fresh_session: bool,
}
impl AgentExecutionService {
pub(crate) fn new(show_tool_results: bool) -> Self {
Self { show_tool_results }
}
pub(crate) fn finalize_result(
&self,
session: &mut Session,
request: FinalizeAgentResultRequest<'_>,
) -> Result<FinalizedAgentResult, AgentError> {
// 检查是否是最新的用户回合
let is_current_turn =
session.matches_current_user_turn(request.chat_id, request.user_message);
if !is_current_turn {
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
session.stale_result_diagnostics(request.chat_id);
tracing::info!(
channel = %request.channel_name,
chat_id = %request.chat_id,
user_message_id = %request.user_message.id,
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
execution_kind = %request.execution_kind,
original_topic_id = ?request.original_topic_id,
"User switched topic during agent execution - saving result to original topic"
);
}
// 确定保存消息的话题 ID
// 如果是最新回合,使用当前话题;否则使用原始话题
let target_topic_id = if is_current_turn {
session.current_topic(request.chat_id)
} else {
request.original_topic_id.as_deref()
};
// 将结果消息保存到确定的话题
if let Some(topic_id) = target_topic_id {
if is_current_turn {
// 如果是最新回合,使用 append_persisted_messages 保存到数据库并更新内存历史
if let Err(err) = session.append_persisted_messages(
request.chat_id,
request.result.emitted_messages.clone(),
) {
tracing::error!(
error = %err,
chat_id = %request.chat_id,
"Failed to append messages to session history"
);
}
} else {
// 如果用户已切换话题,只保存到原始话题(不更新内存历史)
if let Err(err) = session.append_messages_to_topic(
request.chat_id,
topic_id,
&request.result.emitted_messages,
) {
tracing::error!(
error = %err,
topic_id = %topic_id,
"Failed to append messages to topic"
);
}
}
} else if is_current_turn {
// 如果没有话题直接更新内存历史append_persisted_messages 会处理持久化)
if let Err(err) = session.append_persisted_messages(
request.chat_id,
request.result.emitted_messages.clone(),
) {
tracing::error!(
error = %err,
chat_id = %request.chat_id,
"Failed to append messages to session history"
);
}
}
// 只有当是最新回合时才发送 outbound 消息给用户
// 如果用户已经切换到其他话题,只保存结果,不发送消息(避免打扰)
let outbound_messages = if is_current_turn {
request
.result
.emitted_messages
.iter()
.filter(|message| {
// 当存在 live_emitter 时,所有消息已在 loop 中实时广播,不需要 post-loop 发送
!request.suppress_live_tool_calls
&& should_display_message_to_user(self.show_tool_results, message)
})
.flat_map(|message| {
OutboundMessage::from_chat_message(
request.channel_name,
request.chat_id,
None, // session_id
None,
request.metadata,
message,
)
})
.collect()
} else {
Vec::new()
};
// 只有当是最新回合时才触发历史压缩
let should_schedule_compaction = is_current_turn;
// 拦截 todo_write 结果:持久化 + 前端推送
if is_current_turn {
session.intercept_todo_write_results(
&request.result.emitted_messages,
request.chat_id,
);
}
Ok(FinalizedAgentResult {
outbound_messages,
should_schedule_compaction,
})
}
pub(crate) async fn prepare_and_execute_message(
&self,
request: MessageExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message, user_message_count, original_topic_id) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
session_guard.ensure_chat_loaded(request.chat_id)?;
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let media_refs: Vec<String> = request
.media
.iter()
.map(|media| media.path.clone())
.collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %request.media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let enriched_content =
enrich_user_content_with_media_refs(request.content, &media_refs)?;
enrich_user_content_with_media_refs(request.content, &media_refs)?;
enrich_user_content_with_media_refs(request.content, &media_refs)?;
// 先计算 user_message_count在添加新消息之前
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
// 在添加用户消息前,记录当前话题 ID
let original_topic_id = session_guard
.current_topic(request.chat_id)
.map(|s| s.to_string());
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
// 再获取包含新消息的完整历史记录
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let mut agent = session_guard.create_agent(
request.chat_id,
Some(request.sender_id),
Some(&user_message.id),
)?;
if let Some(handler) = request.live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
(history, agent, user_message, user_message_count, original_topic_id)
};
// 构建系统提示词上下文
let system_prompt_context = SystemPromptContext {
session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)),
chat_id: request.chat_id.to_string(),
user_message_count,
};
let result = agent.process(history, Some(&system_prompt_context)).await?;
let metadata = HashMap::new();
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: &metadata,
suppress_live_tool_calls: request.live_emitter.is_some(),
execution_kind: "message",
original_topic_id,
},
)
.await
}
pub(crate) async fn prepare_and_execute_scheduled_task(
&self,
request: ScheduledExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, mut agent, user_message, user_message_count, original_topic_id, store, session_id) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
// 如果 fresh_session 为 true清理历史内存 + 数据库)
if request.fresh_session {
session_guard.clear_chat_history(request.chat_id)?;
tracing::info!(
chat_id = %request.chat_id,
"Fresh session enabled, history cleared"
);
}
session_guard.ensure_chat_loaded(request.chat_id)?;
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let scheduled_system_prompt =
compose_scheduled_task_system_prompt(request.system_prompt);
session_guard.append_persisted_message(
request.chat_id,
ChatMessage::system_with_context(
&scheduled_system_prompt,
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
),
)?;
// 先计算 user_message_count在添加新消息之前
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
// 在添加用户消息前,记录当前话题 ID
let original_topic_id = session_guard
.current_topic(request.chat_id)
.map(|s| s.to_string());
let user_message = session_guard.create_user_message(request.prompt, Vec::new());
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
// 再获取包含新消息的完整历史记录
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let agent = session_guard.create_agent_with_provider_config(
request.chat_id,
request.notification_chat_id, // 传入真实 chat_id
Some(request.sender_id),
Some(&user_message.id),
request.provider_config.clone(),
)?;
// 获取 store 和 session_id用于构造消息持久化 handler
let store = session_guard.store();
let session_id = crate::storage::persistent_session_id(
request.channel_name,
request.chat_id,
);
(history, agent, user_message, user_message_count, original_topic_id, store, session_id)
};
// 定时任务没有 live_emitter需要 PersistingEmittedMessageHandler 来持久化消息
{
let persisting_handler = PersistingEmittedMessageHandler::new(
NoOpEmittedMessageHandler,
store as Arc<dyn ConversationRepository>,
&session_id,
None,
);
agent = agent.with_emitted_message_handler(Arc::new(persisting_handler));
}
// 构建系统提示词上下文
let system_prompt_context = SystemPromptContext {
session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)),
chat_id: request.chat_id.to_string(),
user_message_count,
};
let result = agent.process(history, Some(&system_prompt_context)).await?;
let outbound_messages = self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: request.metadata,
suppress_live_tool_calls: false,
execution_kind: "scheduled_task",
original_topic_id,
},
)
.await?;
// 清理内存历史,释放内存(数据库历史保留)
{
let mut session_guard = request.session.lock().await;
session_guard.remove_history(request.chat_id);
tracing::info!(
chat_id = %request.chat_id,
"Scheduled task completed, memory history released"
);
}
Ok(outbound_messages)
}
pub(crate) async fn finalize_result_and_schedule_compaction(
&self,
session: Arc<Mutex<Session>>,
request: FinalizeAgentResultRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let channel_name = request.channel_name.to_string();
let chat_id = request.chat_id.to_string();
let execution_kind = request.execution_kind.to_string();
let finalized_result = {
let mut session_guard = session.lock().await;
self.finalize_result(&mut session_guard, request)?
};
if finalized_result.should_schedule_compaction {
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone()).await
{
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
execution_kind = %execution_kind,
error = %error,
"Failed to schedule background history compaction"
);
}
}
Ok(finalized_result.outbound_messages)
}
}
pub(crate) fn should_display_message_to_user(
show_tool_results: bool,
message: &ChatMessage,
) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
#[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(prompt.contains("任务专属要求:只汇报异常"));
}
#[test]
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
let prompt = compose_scheduled_task_system_prompt(Some(" "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(!prompt.contains("任务专属要求"));
}
#[test]
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
let message = ChatMessage::tool_with_state(
"call-1",
"approval",
"需要用户确认",
ToolMessageState::PendingUserAction,
);
assert!(should_display_message_to_user(false, &message));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
let message = ChatMessage::tool("call-1", "calculator", "2");
assert!(!should_display_message_to_user(false, &message));
assert!(should_display_message_to_user(true, &message));
}
}