PicoBot/src/gateway/execution.rs

366 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler, SystemPromptContext};
use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDULED_PROMPT};
use crate::config::LLMProviderConfig;
use tokio::sync::Mutex;
use super::compaction::schedule_background_history_compaction;
use super::message_prepare::enrich_user_content_with_media_refs;
use super::session::Session;
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
match system_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
{
Some(system_prompt) => format!(
"{}\n\n任务专属要求:{}",
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
),
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
}
}
pub(crate) struct AgentExecutionService {
show_tool_results: bool,
}
pub(crate) struct FinalizeAgentResultRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) user_message: &'a ChatMessage,
pub(crate) result: AgentProcessResult,
pub(crate) metadata: &'a HashMap<String, String>,
pub(crate) suppress_live_tool_calls: bool,
pub(crate) execution_kind: &'a str,
}
pub(crate) struct FinalizedAgentResult {
pub(crate) outbound_messages: Vec<OutboundMessage>,
pub(crate) should_schedule_compaction: bool,
}
pub(crate) struct MessageExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) content: &'a str,
pub(crate) media: Vec<MediaItem>,
pub(crate) live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
}
pub(crate) struct ScheduledExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) notification_chat_id: Option<&'a str>,
pub(crate) prompt: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) provider_config: LLMProviderConfig,
pub(crate) fresh_session: bool,
pub(crate) system_prompt: Option<&'a str>,
pub(crate) metadata: &'a HashMap<String, String>,
}
impl AgentExecutionService {
pub(crate) fn new(show_tool_results: bool) -> Self {
Self { show_tool_results }
}
pub(crate) fn finalize_result(
&self,
session: &mut Session,
request: FinalizeAgentResultRequest<'_>,
) -> Result<FinalizedAgentResult, AgentError> {
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
session.stale_result_diagnostics(request.chat_id);
tracing::warn!(
channel = %request.channel_name,
chat_id = %request.chat_id,
user_message_id = %request.user_message.id,
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
execution_kind = %request.execution_kind,
"Skipping stale agent result because a newer user message is already present"
);
return Ok(FinalizedAgentResult {
outbound_messages: Vec::new(),
should_schedule_compaction: false,
});
}
session
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
let outbound_messages = request
.result
.emitted_messages
.iter()
.filter(|message| {
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
&& should_display_message_to_user(self.show_tool_results, message)
})
.flat_map(|message| {
OutboundMessage::from_chat_message(
request.channel_name,
request.chat_id,
None, // session_id
None,
request.metadata,
message,
)
})
.collect();
Ok(FinalizedAgentResult {
outbound_messages,
should_schedule_compaction: true,
})
}
pub(crate) async fn prepare_and_execute_message(
&self,
request: MessageExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message, user_message_count) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
session_guard.ensure_chat_loaded(request.chat_id)?;
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let media_refs: Vec<String> = request
.media
.iter()
.map(|media| media.path.clone())
.collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %request.media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let enriched_content =
enrich_user_content_with_media_refs(request.content, &media_refs)?;
// 先计算 user_message_count在添加新消息之前
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
// 再获取包含新消息的完整历史记录
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let mut agent = session_guard.create_agent(
request.chat_id,
Some(request.sender_id),
Some(&user_message.id),
)?;
if let Some(handler) = request.live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
(history, agent, user_message, user_message_count)
};
// 构建系统提示词上下文
let system_prompt_context = SystemPromptContext {
session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)),
chat_id: request.chat_id.to_string(),
user_message_count,
};
let result = agent.process(history, Some(&system_prompt_context)).await?;
let metadata = HashMap::new();
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: &metadata,
suppress_live_tool_calls: request.live_emitter.is_some(),
execution_kind: "message",
},
)
.await
}
pub(crate) async fn prepare_and_execute_scheduled_task(
&self,
request: ScheduledExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message, user_message_count) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
if request.fresh_session {
session_guard.reset_chat_context(request.chat_id)?;
}
session_guard.ensure_chat_loaded(request.chat_id)?;
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let scheduled_system_prompt =
compose_scheduled_task_system_prompt(request.system_prompt);
session_guard.append_persisted_message(
request.chat_id,
ChatMessage::system_with_context(
&scheduled_system_prompt,
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
),
)?;
// 先计算 user_message_count在添加新消息之前
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
let user_message = session_guard.create_user_message(request.prompt, Vec::new());
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
// 再获取包含新消息的完整历史记录
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let agent = session_guard.create_agent_with_provider_config(
request.chat_id,
request.notification_chat_id, // 传入真实 chat_id
Some(request.sender_id),
Some(&user_message.id),
request.provider_config.clone(),
)?;
(history, agent, user_message, user_message_count)
};
// 构建系统提示词上下文
let system_prompt_context = SystemPromptContext {
session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)),
chat_id: request.chat_id.to_string(),
user_message_count,
};
let result = agent.process(history, Some(&system_prompt_context)).await?;
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: request.metadata,
suppress_live_tool_calls: false,
execution_kind: "scheduled_task",
},
)
.await
}
pub(crate) async fn finalize_result_and_schedule_compaction(
&self,
session: Arc<Mutex<Session>>,
request: FinalizeAgentResultRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let channel_name = request.channel_name.to_string();
let chat_id = request.chat_id.to_string();
let execution_kind = request.execution_kind.to_string();
let finalized_result = {
let mut session_guard = session.lock().await;
self.finalize_result(&mut session_guard, request)?
};
if finalized_result.should_schedule_compaction {
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone()).await
{
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
execution_kind = %execution_kind,
error = %error,
"Failed to schedule background history compaction"
);
}
}
Ok(finalized_result.outbound_messages)
}
}
pub(crate) fn should_display_message_to_user(
show_tool_results: bool,
message: &ChatMessage,
) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
#[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(prompt.contains("任务专属要求:只汇报异常"));
}
#[test]
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
let prompt = compose_scheduled_task_system_prompt(Some(" "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(!prompt.contains("任务专属要求"));
}
#[test]
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
let message = ChatMessage::tool_with_state(
"call-1",
"approval",
"需要用户确认",
ToolMessageState::PendingUserAction,
);
assert!(should_display_message_to_user(false, &message));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
let message = ChatMessage::tool("call-1", "calculator", "2");
assert!(!should_display_message_to_user(false, &message));
assert!(should_display_message_to_user(true, &message));
}
}