- 跳过标记行、session_id元数据和空行 - 跳过提示行,提取提示行之后的实际内容 - 限制提取内容最多20行,防止消息过长 - 当提取内容为空时,使用默认提示消息 - 改善助手消息的显示内容格式
2234 lines
83 KiB
Rust
2234 lines
83 KiB
Rust
use crate::agent::AgentRuntimeConfig;
|
||
use crate::agent::{SystemPromptContext, SystemPromptProvider};
|
||
use crate::bus::ChatMessage;
|
||
use crate::bus::message::ToolMessageState;
|
||
use crate::storage::ConversationRepository;
|
||
use crate::domain::messages::{ContentBlock, ToolCall};
|
||
use crate::observability::{
|
||
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
||
};
|
||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
|
||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||
use crate::tools::{ToolContext, ToolRegistry};
|
||
use async_trait::async_trait;
|
||
use std::collections::VecDeque;
|
||
use std::hash::{Hash, Hasher};
|
||
use std::io::Read;
|
||
use std::sync::Arc;
|
||
use std::time::Instant;
|
||
|
||
/// Minimum characters to keep when truncating
|
||
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
|
||
"工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||
|
||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
||
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
|
||
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
|
||
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
|
||
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
|
||
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
|
||
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
|
||
|
||
/// Build content blocks from text and media paths
|
||
fn build_content_blocks(
|
||
text: &str,
|
||
media_paths: &[String],
|
||
budget: &mut ImageInlineBudget,
|
||
) -> Vec<ContentBlock> {
|
||
build_content_blocks_with_image_budget(text, media_paths, budget)
|
||
}
|
||
|
||
fn build_content_blocks_with_image_budget(
|
||
text: &str,
|
||
media_paths: &[String],
|
||
budget: &mut ImageInlineBudget,
|
||
) -> Vec<ContentBlock> {
|
||
let mut blocks = Vec::new();
|
||
let mut skipped_image_notices = Vec::new();
|
||
|
||
// Add text block if there's text
|
||
if !text.is_empty() {
|
||
blocks.push(ContentBlock::text(text));
|
||
}
|
||
|
||
// Add image blocks for media paths
|
||
for path in media_paths {
|
||
if supported_image_mime_type(path).is_none() {
|
||
tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block");
|
||
continue;
|
||
}
|
||
|
||
let Some(target_tokens) = budget.take_next_image_tokens() else {
|
||
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
|
||
skipped_image_notices.push(format!(
|
||
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
|
||
display_media_name(path)
|
||
));
|
||
continue;
|
||
};
|
||
|
||
match encode_image_to_base64_with_budget(path, target_tokens) {
|
||
Ok((mime_type, base64_data)) => {
|
||
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||
blocks.push(ContentBlock::image_url(url));
|
||
}
|
||
Err(err) => {
|
||
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
|
||
skipped_image_notices.push(format!(
|
||
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
|
||
display_media_name(path)
|
||
));
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
if !skipped_image_notices.is_empty() {
|
||
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
|
||
notice.push('\n');
|
||
notice.push_str(&skipped_image_notices.join("\n"));
|
||
blocks.push(ContentBlock::text(notice));
|
||
}
|
||
|
||
// If nothing, add empty text block
|
||
if blocks.is_empty() {
|
||
blocks.push(ContentBlock::text(""));
|
||
}
|
||
|
||
blocks
|
||
}
|
||
|
||
fn supported_image_mime_type(path: &str) -> Option<String> {
|
||
let mime = mime_guess::from_path(path).first_or_octet_stream();
|
||
let essence = mime.essence_str();
|
||
|
||
if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) {
|
||
Some(essence.to_string())
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
fn display_media_name(path: &str) -> String {
|
||
std::path::Path::new(path)
|
||
.file_name()
|
||
.and_then(|name| name.to_str())
|
||
.map(ToOwned::to_owned)
|
||
.unwrap_or_else(|| path.to_string())
|
||
}
|
||
|
||
#[derive(Debug, Clone)]
|
||
struct ImageInlineBudget {
|
||
remaining_tokens: usize,
|
||
remaining_images: usize,
|
||
}
|
||
|
||
impl ImageInlineBudget {
|
||
fn new(total_tokens: usize, image_count: usize) -> Self {
|
||
Self {
|
||
remaining_tokens: total_tokens,
|
||
remaining_images: image_count,
|
||
}
|
||
}
|
||
|
||
fn take_next_image_tokens(&mut self) -> Option<usize> {
|
||
if self.remaining_tokens == 0 || self.remaining_images == 0 {
|
||
self.remaining_images = self.remaining_images.saturating_sub(1);
|
||
return None;
|
||
}
|
||
|
||
let target = self.remaining_tokens.div_ceil(self.remaining_images);
|
||
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
|
||
self.remaining_images = self.remaining_images.saturating_sub(1);
|
||
Some(target)
|
||
}
|
||
}
|
||
|
||
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
|
||
let raw_len = serde_json::to_string(value)
|
||
.map(|serialized| serialized.len())
|
||
.unwrap_or_default();
|
||
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
|
||
as usize
|
||
}
|
||
|
||
fn image_token_budget_for_request(
|
||
runtime_config: &AgentRuntimeConfig,
|
||
text_only_messages: &[Message],
|
||
tools: Option<&Vec<crate::domain::tools::Tool>>,
|
||
) -> usize {
|
||
let completion_reserve = runtime_config
|
||
.provider
|
||
.max_tokens
|
||
.map(|tokens| tokens as usize)
|
||
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
|
||
let input_window = runtime_config
|
||
.context_window_tokens
|
||
.saturating_sub(completion_reserve);
|
||
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
|
||
|
||
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
|
||
+ tools
|
||
.map(estimate_tokens_from_serialized_json)
|
||
.unwrap_or_default();
|
||
|
||
safe_input_window.saturating_sub(text_tokens)
|
||
}
|
||
|
||
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
|
||
messages
|
||
.iter()
|
||
.flat_map(|message| message.media_refs.iter())
|
||
.filter(|path| supported_image_mime_type(path).is_some())
|
||
.count()
|
||
}
|
||
|
||
/// 过滤超出轮次限制的图片,并限制总图片数量
|
||
///
|
||
/// # 参数
|
||
/// - `messages`: 原始消息列表
|
||
/// - `max_age_rounds`: 图片超过多少消息轮次后不再发送(从最新消息向前数,包括所有 role)
|
||
/// - `max_images`: 最多发送多少张图片(优先保留最近的图片)
|
||
///
|
||
/// # 返回
|
||
/// 过滤后的消息列表,图片被转换为文本提示 "[图片已过期]"
|
||
fn filter_images_by_age_and_count(
|
||
messages: &[ChatMessage],
|
||
max_age_rounds: usize,
|
||
max_images: usize,
|
||
) -> Vec<ChatMessage> {
|
||
if messages.is_empty() || (max_age_rounds == 0 && max_images == 0) {
|
||
return messages.to_vec();
|
||
}
|
||
|
||
let total_images = count_supported_image_media_refs(messages);
|
||
if total_images == 0 {
|
||
return messages.to_vec();
|
||
}
|
||
|
||
// 从最新消息向前遍历,优先保留最新的图片
|
||
// 消息列表顺序:[old, ..., new],所以末尾是最新的
|
||
let msg_count = messages.len();
|
||
|
||
// 先从后向前遍历,计算每条消息应该保留多少张图片
|
||
// 使用 Vec<usize> 存储每条消息应该保留的图片数量
|
||
let mut images_to_keep_per_msg: Vec<usize> = vec![0; msg_count];
|
||
let mut images_kept = 0usize;
|
||
|
||
// 从后向前遍历(从最新到最旧)
|
||
for (idx, message) in messages.iter().enumerate().rev() {
|
||
// 计算距离最新消息的消息数(末尾消息的 age = 0)
|
||
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
|
||
|
||
// 检查是否超出轮次限制
|
||
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
|
||
|
||
if exceeds_age_limit {
|
||
continue; // 超出轮次限制的图片不保留
|
||
}
|
||
|
||
// 计算这条消息中的图片数量
|
||
let image_count_in_msg = message.media_refs.iter()
|
||
.filter(|p| supported_image_mime_type(p).is_some())
|
||
.count();
|
||
|
||
if image_count_in_msg == 0 {
|
||
continue;
|
||
}
|
||
|
||
// 计算可以保留多少张
|
||
let can_keep = std::cmp::min(image_count_in_msg, max_images.saturating_sub(images_kept));
|
||
if can_keep > 0 {
|
||
images_to_keep_per_msg[idx] = can_keep;
|
||
images_kept += can_keep;
|
||
}
|
||
}
|
||
|
||
// 然后从前向后遍历,构建过滤后的消息列表
|
||
let mut filtered = Vec::with_capacity(msg_count);
|
||
|
||
for (idx, message) in messages.iter().enumerate() {
|
||
// 计算距离最新消息的消息数(末尾消息的 age = 0)
|
||
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
|
||
|
||
// 检查是否超出轮次限制
|
||
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
|
||
|
||
let keep_count = images_to_keep_per_msg[idx];
|
||
|
||
// 过滤图片:保留非图片媒体和指定数量的图片
|
||
let mut images_kept_in_msg = 0usize;
|
||
let filtered_media_refs: Vec<String> = message.media_refs.iter()
|
||
.filter_map(|path| {
|
||
if supported_image_mime_type(path).is_some() {
|
||
if images_kept_in_msg < keep_count {
|
||
images_kept_in_msg += 1;
|
||
Some(path.clone())
|
||
} else {
|
||
None
|
||
}
|
||
} else {
|
||
Some(path.clone()) // 保留非图片媒体
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// 如果图片被过滤,添加文本提示
|
||
let original_image_count = message.media_refs.iter()
|
||
.filter(|p| supported_image_mime_type(p).is_some())
|
||
.count();
|
||
let filtered_image_count = filtered_media_refs.iter()
|
||
.filter(|p| supported_image_mime_type(p).is_some())
|
||
.count();
|
||
|
||
let content = if original_image_count > filtered_image_count {
|
||
let notice = if exceeds_age_limit {
|
||
format!("{} [图片已过期:超出 {} 条消息范围]", message.content, max_age_rounds)
|
||
} else {
|
||
format!("{} [图片已过期:超出最大图片数量限制]", message.content)
|
||
};
|
||
notice
|
||
} else {
|
||
message.content.clone()
|
||
};
|
||
|
||
filtered.push(ChatMessage {
|
||
id: message.id.clone(),
|
||
role: message.role.clone(),
|
||
content,
|
||
media_refs: filtered_media_refs,
|
||
timestamp: message.timestamp,
|
||
system_context: message.system_context.clone(),
|
||
reasoning_content: message.reasoning_content.clone(),
|
||
tool_call_id: message.tool_call_id.clone(),
|
||
tool_name: message.tool_name.clone(),
|
||
tool_state: message.tool_state.clone(),
|
||
tool_duration_ms: message.tool_duration_ms,
|
||
tool_calls: message.tool_calls.clone(),
|
||
});
|
||
}
|
||
|
||
filtered
|
||
}
|
||
|
||
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
|
||
target_tokens
|
||
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
|
||
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||
.saturating_mul(3)
|
||
/ 4
|
||
}
|
||
|
||
/// Encode an image file to base64 data URL
|
||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||
|
||
let mime = supported_image_mime_type(path).ok_or_else(|| {
|
||
std::io::Error::new(
|
||
std::io::ErrorKind::InvalidInput,
|
||
format!("unsupported image media type for path: {}", path),
|
||
)
|
||
})?;
|
||
|
||
let mut file = std::fs::File::open(path)?;
|
||
let mut buffer = Vec::new();
|
||
file.read_to_end(&mut buffer)?;
|
||
|
||
let encoded = STANDARD.encode(&buffer);
|
||
Ok((mime, encoded))
|
||
}
|
||
|
||
fn encode_image_to_base64_with_budget(
|
||
path: &str,
|
||
target_tokens: usize,
|
||
) -> Result<(String, String), std::io::Error> {
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||
|
||
let (mime, encoded) = encode_image_to_base64(path)?;
|
||
let target_base64_chars = target_tokens
|
||
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
|
||
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
|
||
if encoded.len() <= target_base64_chars {
|
||
return Ok((mime, encoded));
|
||
}
|
||
|
||
let target_bytes = target_image_bytes_for_tokens(target_tokens);
|
||
let image_bytes = std::fs::read(path)?;
|
||
let image = image::load_from_memory(&image_bytes)
|
||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
|
||
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
|
||
|
||
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
|
||
}
|
||
|
||
fn compress_image_to_jpeg_budget(
|
||
image: &image::DynamicImage,
|
||
target_bytes: usize,
|
||
) -> Result<Vec<u8>, std::io::Error> {
|
||
use image::codecs::jpeg::JpegEncoder;
|
||
use image::imageops::FilterType;
|
||
|
||
let mut best: Option<Vec<u8>> = None;
|
||
let mut max_side = image
|
||
.width()
|
||
.max(image.height())
|
||
.max(MIN_COMPRESSED_IMAGE_SIDE);
|
||
|
||
loop {
|
||
let candidate = if image.width().max(image.height()) > max_side {
|
||
image.resize(max_side, max_side, FilterType::Triangle)
|
||
} else {
|
||
image.clone()
|
||
};
|
||
|
||
for quality in JPEG_QUALITY_STEPS {
|
||
let mut encoded = Vec::new();
|
||
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
|
||
encoder
|
||
.encode_image(&candidate)
|
||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
|
||
|
||
if encoded.len() <= target_bytes {
|
||
return Ok(encoded);
|
||
}
|
||
|
||
if best
|
||
.as_ref()
|
||
.map(|current| encoded.len() < current.len())
|
||
.unwrap_or(true)
|
||
{
|
||
best = Some(encoded);
|
||
}
|
||
}
|
||
|
||
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
|
||
break;
|
||
}
|
||
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
|
||
}
|
||
|
||
best.ok_or_else(|| {
|
||
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
|
||
})
|
||
}
|
||
|
||
/// Truncate tool result if it exceeds the configured limit.
|
||
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
||
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
||
let total_chars = char_count(output);
|
||
if total_chars <= max_tool_result_chars {
|
||
return output.to_string();
|
||
}
|
||
|
||
let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN);
|
||
if truncated_start_len > max_tool_result_chars {
|
||
// Even after removing suffix, still too long - take from beginning
|
||
let head_len = max_tool_result_chars.saturating_sub(100);
|
||
let head = take_prefix_chars(output, head_len);
|
||
format!(
|
||
"{}...\n\n[Output truncated - {} characters removed]",
|
||
head,
|
||
total_chars - max_tool_result_chars + 100
|
||
)
|
||
} else {
|
||
// Keep most of the end which usually contains the useful result
|
||
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
|
||
format!(
|
||
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
||
truncated_start_len, tail
|
||
)
|
||
}
|
||
}
|
||
|
||
fn parse_pending_tool_output(output: &str) -> Option<String> {
|
||
output
|
||
.strip_prefix(PENDING_USER_ACTION_MARKER)
|
||
.map(|rest| rest.trim().to_string())
|
||
}
|
||
|
||
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
|
||
match arguments {
|
||
serde_json::Value::String(raw) => {
|
||
serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone())
|
||
}
|
||
_ => arguments.clone(),
|
||
}
|
||
}
|
||
|
||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||
let mut details = vec![error.to_string()];
|
||
let mut current = error.source();
|
||
|
||
while let Some(source) = current {
|
||
details.push(source.to_string());
|
||
current = source.source();
|
||
}
|
||
|
||
details.join("\ncaused by: ")
|
||
}
|
||
|
||
fn is_recoverable_llm_error(error: &str) -> bool {
|
||
let normalized = error.to_ascii_lowercase();
|
||
normalized.contains("504")
|
||
|| normalized.contains("gateway timeout")
|
||
|| normalized.contains("stream timeout")
|
||
|| normalized.contains("timed out")
|
||
|| normalized.contains("timeout")
|
||
}
|
||
|
||
fn recoverable_llm_message(error: &str) -> String {
|
||
if is_recoverable_llm_error(error) {
|
||
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
|
||
} else {
|
||
format!("模型请求失败:{}", error)
|
||
}
|
||
}
|
||
|
||
/// Loop detection result.
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
enum LoopDetectionResult {
|
||
/// No warning needed.
|
||
Ok,
|
||
/// Warning: same tool + args repeated N times.
|
||
Warning(String),
|
||
}
|
||
|
||
/// Configuration for loop detector.
|
||
#[derive(Debug, Clone)]
|
||
struct LoopDetectorConfig {
|
||
/// Master switch.
|
||
enabled: bool,
|
||
/// Warn every N consecutive identical calls.
|
||
warn_every: usize,
|
||
}
|
||
|
||
impl Default for LoopDetectorConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
enabled: true,
|
||
warn_every: 5,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// A single recorded tool invocation in the sliding window.
|
||
#[derive(Debug, Clone)]
|
||
struct ToolCallRecord {
|
||
name: String,
|
||
args_hash: u64,
|
||
}
|
||
|
||
/// Stateful loop detector that monitors for repetitive patterns.
|
||
struct LoopDetector {
|
||
config: LoopDetectorConfig,
|
||
window: VecDeque<ToolCallRecord>,
|
||
}
|
||
|
||
impl LoopDetector {
|
||
fn new(config: LoopDetectorConfig) -> Self {
|
||
Self {
|
||
window: VecDeque::with_capacity(config.warn_every * 2),
|
||
config,
|
||
}
|
||
}
|
||
|
||
/// Record a completed tool call and check for loop patterns.
|
||
/// Returns Warning every `warn_every` consecutive identical calls.
|
||
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
|
||
if !self.config.enabled {
|
||
return LoopDetectionResult::Ok;
|
||
}
|
||
|
||
let record = ToolCallRecord {
|
||
name: name.to_string(),
|
||
args_hash: hash_json_value(args),
|
||
};
|
||
|
||
// Maintain sliding window
|
||
if self.window.len() >= self.config.warn_every * 2 {
|
||
self.window.pop_front();
|
||
}
|
||
self.window.push_back(record);
|
||
|
||
// Count consecutive identical calls
|
||
let last = self.window.back().unwrap();
|
||
let consecutive: usize = self
|
||
.window
|
||
.iter()
|
||
.rev()
|
||
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
|
||
.count();
|
||
|
||
// Warn every warn_every times
|
||
if consecutive > 0 && consecutive % self.config.warn_every == 0 {
|
||
LoopDetectionResult::Warning(format!(
|
||
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
|
||
last.name, consecutive
|
||
))
|
||
} else {
|
||
LoopDetectionResult::Ok
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Hash a JSON value deterministically (key-order independent).
|
||
fn hash_json_value(value: &serde_json::Value) -> u64 {
|
||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||
let canonical = canonicalise_json(value);
|
||
canonical.hash(&mut hasher);
|
||
hasher.finish()
|
||
}
|
||
|
||
/// Return a clone of value with all object keys sorted recursively.
|
||
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
|
||
match value {
|
||
serde_json::Value::Object(map) => {
|
||
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
|
||
sorted.sort_by_key(|(k, _)| *k);
|
||
let new_map: serde_json::Map<String, serde_json::Value> = sorted
|
||
.into_iter()
|
||
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
|
||
.collect();
|
||
serde_json::Value::Object(new_map)
|
||
}
|
||
serde_json::Value::Array(arr) => {
|
||
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
|
||
}
|
||
other => other.clone(),
|
||
}
|
||
}
|
||
|
||
/// Convert ChatMessage to LLM Message format
|
||
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
|
||
let content = if m.media_refs.is_empty() {
|
||
vec![ContentBlock::text(&m.content)]
|
||
} else {
|
||
build_content_blocks(&m.content, &m.media_refs, image_budget)
|
||
};
|
||
|
||
Message {
|
||
role: m.role.clone(),
|
||
content,
|
||
reasoning_content: m.reasoning_content.clone(),
|
||
tool_call_id: m.tool_call_id.clone(),
|
||
name: m.tool_name.clone(),
|
||
tool_calls: m.tool_calls.clone(),
|
||
}
|
||
}
|
||
|
||
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
|
||
Message {
|
||
role: m.role.clone(),
|
||
content: vec![ContentBlock::text(&m.content)],
|
||
reasoning_content: m.reasoning_content.clone(),
|
||
tool_call_id: m.tool_call_id.clone(),
|
||
name: m.tool_name.clone(),
|
||
tool_calls: m.tool_calls.clone(),
|
||
}
|
||
}
|
||
|
||
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
||
/// History is managed externally by SessionManager.
|
||
pub struct AgentLoop {
|
||
runtime_config: AgentRuntimeConfig,
|
||
provider: Box<dyn LLMProvider>,
|
||
tools: Arc<ToolRegistry>,
|
||
/// 系统提示词提供者(统一注入 Agent 和 Skill 提示词)
|
||
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
|
||
/// Skill 提供者(用于匹配错误提示)
|
||
skills: Option<Arc<dyn SkillProvider>>,
|
||
tool_context: ToolContext,
|
||
observer: Option<Arc<dyn Observer>>,
|
||
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
|
||
max_iterations: usize,
|
||
/// 取消信号接收端:Agent 在每次迭代开始时检查是否被取消
|
||
cancel_token: Option<tokio::sync::watch::Receiver<()>>,
|
||
}
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct AgentProcessResult {
|
||
pub final_response: ChatMessage,
|
||
pub emitted_messages: Vec<ChatMessage>,
|
||
}
|
||
|
||
#[async_trait]
|
||
pub trait EmittedMessageHandler: Send + Sync + 'static {
|
||
async fn handle(&self, message: ChatMessage);
|
||
|
||
/// Handle a tool result message with optional execution timing.
|
||
/// Default implementation delegates to `handle()`, ignoring timing.
|
||
async fn handle_tool_result(&self, message: ChatMessage, _duration_ms: Option<u64>) {
|
||
self.handle(message).await;
|
||
}
|
||
}
|
||
|
||
/// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB
|
||
pub struct PersistingEmittedMessageHandler<H: EmittedMessageHandler> {
|
||
inner: H,
|
||
conversation_repository: Arc<dyn ConversationRepository>,
|
||
session_id: String,
|
||
topic_id: Option<String>,
|
||
}
|
||
|
||
impl<H: EmittedMessageHandler> PersistingEmittedMessageHandler<H> {
|
||
pub fn new(
|
||
inner: H,
|
||
conversation_repository: Arc<dyn ConversationRepository>,
|
||
session_id: impl Into<String>,
|
||
topic_id: Option<String>,
|
||
) -> Self {
|
||
Self { inner, conversation_repository, session_id: session_id.into(), topic_id }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl<H: EmittedMessageHandler> EmittedMessageHandler for PersistingEmittedMessageHandler<H> {
|
||
async fn handle(&self, message: ChatMessage) {
|
||
if let Err(e) = self.conversation_repository
|
||
.append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message)
|
||
{
|
||
tracing::error!(error = %e, session_id = %self.session_id,
|
||
"Failed to persist emitted message");
|
||
}
|
||
self.inner.handle(message).await;
|
||
}
|
||
|
||
async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
|
||
// Persist the ChatMessage first (no duration field, same as before)
|
||
if let Err(e) = self.conversation_repository
|
||
.append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message)
|
||
{
|
||
tracing::error!(error = %e, session_id = %self.session_id,
|
||
"Failed to persist emitted message");
|
||
}
|
||
self.inner.handle_tool_result(message, duration_ms).await;
|
||
}
|
||
}
|
||
|
||
pub trait SkillProvider: Send + Sync + 'static {
|
||
fn system_index_prompt(&self) -> Option<String>;
|
||
|
||
fn matching_skill_summary(&self, _name: &str) -> Option<String> {
|
||
None
|
||
}
|
||
}
|
||
|
||
#[derive(Default)]
|
||
#[allow(dead_code)]
|
||
struct EmptySkillProvider;
|
||
|
||
impl SkillProvider for EmptySkillProvider {
|
||
fn system_index_prompt(&self) -> Option<String> {
|
||
None
|
||
}
|
||
}
|
||
|
||
impl AgentLoop {
|
||
pub fn new(config: impl Into<AgentRuntimeConfig>) -> Result<Self, AgentError> {
|
||
let runtime_config = config.into();
|
||
let max_iterations = runtime_config.max_tool_iterations;
|
||
let provider = create_provider(runtime_config.provider.clone())
|
||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||
|
||
Ok(Self {
|
||
runtime_config,
|
||
provider,
|
||
tools: Arc::new(ToolRegistry::new()),
|
||
system_prompt_provider: None,
|
||
skills: None,
|
||
tool_context: ToolContext::default(),
|
||
observer: None,
|
||
emitted_message_handler: None,
|
||
cancel_token: None,
|
||
max_iterations,
|
||
})
|
||
}
|
||
|
||
pub fn with_tools(
|
||
config: impl Into<AgentRuntimeConfig>,
|
||
tools: Arc<ToolRegistry>,
|
||
) -> Result<Self, AgentError> {
|
||
let runtime_config = config.into();
|
||
let max_iterations = runtime_config.max_tool_iterations;
|
||
let provider = create_provider(runtime_config.provider.clone())
|
||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||
|
||
Ok(Self {
|
||
runtime_config,
|
||
provider,
|
||
tools,
|
||
system_prompt_provider: None,
|
||
skills: None,
|
||
tool_context: ToolContext::default(),
|
||
observer: None,
|
||
emitted_message_handler: None,
|
||
cancel_token: None,
|
||
max_iterations,
|
||
})
|
||
}
|
||
|
||
pub fn with_tools_and_skill_provider(
|
||
config: impl Into<AgentRuntimeConfig>,
|
||
tools: Arc<ToolRegistry>,
|
||
skills: Arc<dyn SkillProvider>,
|
||
) -> Result<Self, AgentError> {
|
||
let runtime_config = config.into();
|
||
let max_iterations = runtime_config.max_tool_iterations;
|
||
let provider = create_provider(runtime_config.provider.clone())
|
||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||
|
||
Ok(Self {
|
||
runtime_config,
|
||
provider,
|
||
tools,
|
||
system_prompt_provider: None,
|
||
skills: Some(skills),
|
||
tool_context: ToolContext::default(),
|
||
observer: None,
|
||
emitted_message_handler: None,
|
||
cancel_token: None,
|
||
max_iterations,
|
||
})
|
||
}
|
||
|
||
/// 使用系统提示词提供者创建 AgentLoop
|
||
///
|
||
/// 这是新的推荐方式,支持统一注入 Agent 和 Skill 提示词。
|
||
pub fn with_tools_and_system_prompt_provider(
|
||
config: impl Into<AgentRuntimeConfig>,
|
||
tools: Arc<ToolRegistry>,
|
||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||
skills: Option<Arc<dyn SkillProvider>>,
|
||
) -> Result<Self, AgentError> {
|
||
let runtime_config = config.into();
|
||
let max_iterations = runtime_config.max_tool_iterations;
|
||
let provider = create_provider(runtime_config.provider.clone())
|
||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||
|
||
Ok(Self {
|
||
runtime_config,
|
||
provider,
|
||
tools,
|
||
system_prompt_provider: Some(system_prompt_provider),
|
||
skills,
|
||
tool_context: ToolContext::default(),
|
||
observer: None,
|
||
emitted_message_handler: None,
|
||
cancel_token: None,
|
||
max_iterations,
|
||
})
|
||
}
|
||
|
||
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
|
||
self.tool_context = context;
|
||
self
|
||
}
|
||
|
||
/// Set an observer for tracking events.
|
||
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
||
self.observer = Some(observer);
|
||
self
|
||
}
|
||
|
||
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
|
||
self.emitted_message_handler = Some(handler);
|
||
self
|
||
}
|
||
|
||
/// 设置取消信号接收端。
|
||
///
|
||
/// Agent 在每次迭代开始时检查 `cancel_token.has_changed()`,
|
||
/// 如果已收到取消信号则提前返回。
|
||
pub fn with_cancel_token(mut self, token: tokio::sync::watch::Receiver<()>) -> Self {
|
||
self.cancel_token = Some(token);
|
||
self
|
||
}
|
||
|
||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||
&self.tools
|
||
}
|
||
|
||
/// Process a message using the provided conversation history.
|
||
/// History management is handled externally by SessionManager.
|
||
///
|
||
/// This method supports multi-round tool calling: after executing tools,
|
||
/// it loops back to the LLM with the tool results until either:
|
||
/// - The LLM returns no more tool calls (final response)
|
||
/// - Maximum iterations are reached
|
||
///
|
||
/// # 参数
|
||
/// - `messages`: 会话历史消息
|
||
/// - `system_prompt_context`: 系统提示词上下文(用于动态注入,可选)
|
||
pub async fn process(
|
||
&self,
|
||
mut messages: Vec<ChatMessage>,
|
||
system_prompt_context: Option<&SystemPromptContext>,
|
||
) -> Result<AgentProcessResult, AgentError> {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(
|
||
history_len = messages.len(),
|
||
max_iterations = self.max_iterations,
|
||
"Starting agent process"
|
||
);
|
||
|
||
// Sanitize: remove any trailing incomplete tool call sequences
|
||
// that may have been persisted before a process interruption.
|
||
crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
|
||
// Track tool calls for loop detection
|
||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||
let mut emitted_messages = Vec::new();
|
||
|
||
for iteration in 0..self.max_iterations {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(iteration, "Agent iteration started");
|
||
|
||
// 检查取消信号
|
||
if let Some(ref token) = self.cancel_token {
|
||
if token.has_changed().unwrap_or(false) {
|
||
tracing::info!(iteration, "Agent execution cancelled by user");
|
||
let cancel_message = format!(
|
||
"\n\n[用户已取消执行。已迭代 {} 次,取消前共生成了 {} 条消息。]",
|
||
iteration,
|
||
emitted_messages.len()
|
||
);
|
||
let assistant_message = ChatMessage::assistant(cancel_message);
|
||
emitted_messages.push(assistant_message.clone());
|
||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||
return Ok(AgentProcessResult {
|
||
final_response: assistant_message,
|
||
emitted_messages,
|
||
});
|
||
}
|
||
}
|
||
|
||
// Build request
|
||
let tool_defs = self.tools.get_definitions();
|
||
let tools = if tool_defs.is_empty() {
|
||
None
|
||
} else {
|
||
Some(tool_defs)
|
||
};
|
||
|
||
// 过滤超出轮次和数量限制的图片
|
||
let filtered_messages = filter_images_by_age_and_count(
|
||
&messages,
|
||
self.runtime_config.max_image_age_rounds,
|
||
self.runtime_config.max_images_in_context,
|
||
);
|
||
let image_count = count_supported_image_media_refs(&filtered_messages);
|
||
|
||
// 构建系统提示词(统一注入 Agent 和 Skill 提示词)
|
||
let system_prompt = system_prompt_context.and_then(|ctx| {
|
||
self.system_prompt_provider
|
||
.as_ref()
|
||
.and_then(|provider| provider.build(ctx))
|
||
});
|
||
|
||
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
|
||
if let Some(ref prompt) = system_prompt {
|
||
text_only_messages.push(Message::system(prompt.content.clone()));
|
||
}
|
||
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
|
||
|
||
let image_tokens = image_token_budget_for_request(
|
||
&self.runtime_config,
|
||
&text_only_messages,
|
||
tools.as_ref(),
|
||
);
|
||
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
|
||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
|
||
// 使用相同的系统提示词(已构建)
|
||
if let Some(ref prompt) = system_prompt {
|
||
messages_for_llm.push(Message::system(prompt.content.clone()));
|
||
}
|
||
messages_for_llm.extend(
|
||
filtered_messages
|
||
.iter()
|
||
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: messages_for_llm,
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools,
|
||
};
|
||
|
||
let response = match (*self.provider).chat(request).await {
|
||
Ok(response) => response,
|
||
Err(e) => {
|
||
tracing::error!(
|
||
provider = %self.provider.name(),
|
||
model = %self.provider.model_id(),
|
||
error = %e,
|
||
error_details = %format_error_chain(e.as_ref()),
|
||
"LLM request failed"
|
||
);
|
||
let assistant_message =
|
||
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||
emitted_messages.push(assistant_message.clone());
|
||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||
return Ok(AgentProcessResult {
|
||
final_response: assistant_message,
|
||
emitted_messages,
|
||
});
|
||
}
|
||
};
|
||
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(
|
||
iteration,
|
||
response_len = response.content.len(),
|
||
tool_calls_len = response.tool_calls.len(),
|
||
"LLM response received"
|
||
);
|
||
|
||
// If no tool calls, this is the final response
|
||
if response.tool_calls.is_empty() {
|
||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||
{
|
||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||
} else {
|
||
ChatMessage::assistant(response.content)
|
||
};
|
||
emitted_messages.push(assistant_message.clone());
|
||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||
return Ok(AgentProcessResult {
|
||
final_response: assistant_message,
|
||
emitted_messages,
|
||
});
|
||
}
|
||
|
||
// Execute tool calls
|
||
tracing::info!(
|
||
iteration,
|
||
count = response.tool_calls.len(),
|
||
"Tool calls detected, executing tools"
|
||
);
|
||
|
||
// Add assistant message with tool calls
|
||
let assistant_message =
|
||
if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
||
response.content.clone(),
|
||
response.tool_calls.clone(),
|
||
reasoning_content,
|
||
)
|
||
} else {
|
||
ChatMessage::assistant_with_tool_calls(
|
||
response.content.clone(),
|
||
response.tool_calls.clone(),
|
||
)
|
||
};
|
||
messages.push(assistant_message.clone());
|
||
emitted_messages.push(assistant_message);
|
||
self.emit_live_tool_call_message(
|
||
emitted_messages
|
||
.last()
|
||
.expect("assistant message just pushed")
|
||
.clone(),
|
||
)
|
||
.await;
|
||
|
||
// Execute tools and add results to messages
|
||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||
|
||
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
||
// Truncate tool result if too large
|
||
let truncated_output =
|
||
truncate_tool_result(&result.output, self.runtime_config.tool_result_max_chars);
|
||
|
||
// Record tool call and check for loops
|
||
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
|
||
|
||
match loop_result {
|
||
LoopDetectionResult::Warning(msg) => {
|
||
// Add warning and proceed
|
||
tracing::warn!(
|
||
tool = %tool_call.name,
|
||
"Loop warning: {}",
|
||
msg
|
||
);
|
||
let tool_message = ChatMessage::tool_with_state(
|
||
tool_call.id.clone(),
|
||
tool_call.name.clone(),
|
||
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
||
if result.state == ToolExecutionState::PendingUserAction {
|
||
ToolMessageState::PendingUserAction
|
||
} else {
|
||
ToolMessageState::Completed
|
||
},
|
||
)
|
||
.with_tool_duration(result.duration.as_millis() as u64);
|
||
messages.push(tool_message.clone());
|
||
emitted_messages.push(tool_message.clone());
|
||
let duration_ms = Some(result.duration.as_millis() as u64);
|
||
self.emit_tool_result(tool_message, duration_ms).await;
|
||
}
|
||
LoopDetectionResult::Ok => {
|
||
let tool_message = ChatMessage::tool_with_state(
|
||
tool_call.id.clone(),
|
||
tool_call.name.clone(),
|
||
truncated_output,
|
||
if result.state == ToolExecutionState::PendingUserAction {
|
||
ToolMessageState::PendingUserAction
|
||
} else {
|
||
ToolMessageState::Completed
|
||
},
|
||
)
|
||
.with_tool_duration(result.duration.as_millis() as u64);
|
||
messages.push(tool_message.clone());
|
||
emitted_messages.push(tool_message.clone());
|
||
let duration_ms = Some(result.duration.as_millis() as u64);
|
||
self.emit_tool_result(tool_message, duration_ms).await;
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some((tool_call, pending_result)) = response
|
||
.tool_calls
|
||
.iter()
|
||
.zip(tool_results.iter())
|
||
.find(|(_, result)| result.state == ToolExecutionState::PendingUserAction)
|
||
{
|
||
// 从工具输出中提取有意义的终端内容
|
||
// 跳过:标记行、session_id 元数据、空行、以及提示行(取提示行之后的实际内容)
|
||
let content: String = pending_result
|
||
.output
|
||
.lines()
|
||
.map(|l| l.trim())
|
||
.filter(|line| {
|
||
!line.is_empty()
|
||
&& !line.starts_with("__PICOBOT_")
|
||
&& !line.starts_with("[session_id:")
|
||
})
|
||
.skip(1) // 跳过第一行(提示行,如"进程正在等待输入...")
|
||
.take(20) // 最多取 20 行避免过长
|
||
.collect::<Vec<_>>()
|
||
.join("\n");
|
||
let display_content = if content.is_empty() {
|
||
DEFAULT_PENDING_ASSISTANT_MESSAGE
|
||
} else {
|
||
&content
|
||
};
|
||
let assistant_message = ChatMessage::assistant(format!(
|
||
"{}\n\n当前等待中的工具: {}",
|
||
display_content, tool_call.name,
|
||
));
|
||
emitted_messages.push(assistant_message.clone());
|
||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||
return Ok(AgentProcessResult {
|
||
final_response: assistant_message,
|
||
emitted_messages,
|
||
});
|
||
}
|
||
|
||
// Loop continues to next iteration with updated messages
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(
|
||
iteration,
|
||
message_count = messages.len(),
|
||
"Tool execution complete, continuing to next iteration"
|
||
);
|
||
}
|
||
|
||
// Max iterations reached - ask LLM for a summary based on completed work
|
||
tracing::warn!("Max iterations reached, requesting final summary from LLM");
|
||
|
||
// Add a message asking for summary
|
||
let summary_request = ChatMessage::user(
|
||
"You have reached the maximum number of tool call iterations. \
|
||
Please provide your best answer based on the work completed so far.",
|
||
);
|
||
messages.push(summary_request);
|
||
|
||
// 过滤超出轮次和数量限制的图片
|
||
let filtered_messages = filter_images_by_age_and_count(
|
||
&messages,
|
||
self.runtime_config.max_image_age_rounds,
|
||
self.runtime_config.max_images_in_context,
|
||
);
|
||
|
||
// Convert messages to LLM format (使用系统提示词提供者)
|
||
let image_count = count_supported_image_media_refs(&filtered_messages);
|
||
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
|
||
if let Some(ref provider) = self.system_prompt_provider {
|
||
if let Some(ctx) = system_prompt_context {
|
||
if let Some(prompt) = provider.build(ctx) {
|
||
text_only_messages.push(Message::system(prompt.content.clone()));
|
||
}
|
||
}
|
||
}
|
||
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
|
||
let image_tokens =
|
||
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
|
||
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
|
||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
|
||
if let Some(ref provider) = self.system_prompt_provider {
|
||
if let Some(ctx) = system_prompt_context {
|
||
if let Some(prompt) = provider.build(ctx) {
|
||
messages_for_llm.push(Message::system(prompt.content.clone()));
|
||
}
|
||
}
|
||
}
|
||
messages_for_llm.extend(
|
||
filtered_messages
|
||
.iter()
|
||
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: messages_for_llm,
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None, // No tools in final summary call
|
||
};
|
||
|
||
match (*self.provider).chat(request).await {
|
||
Ok(response) => {
|
||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||
{
|
||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||
} else {
|
||
ChatMessage::assistant(response.content)
|
||
};
|
||
emitted_messages.push(assistant_message.clone());
|
||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||
Ok(AgentProcessResult {
|
||
final_response: assistant_message,
|
||
emitted_messages,
|
||
})
|
||
}
|
||
Err(e) => {
|
||
tracing::error!(
|
||
provider = %self.provider.name(),
|
||
model = %self.provider.model_id(),
|
||
error = %e,
|
||
error_details = %format_error_chain(e.as_ref()),
|
||
"Failed to get summary from LLM"
|
||
);
|
||
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||
emitted_messages.push(final_message.clone());
|
||
self.emit_live_tool_call_message(final_message.clone()).await;
|
||
Ok(AgentProcessResult {
|
||
final_response: final_message,
|
||
emitted_messages,
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
|
||
if let Some(handler) = &self.emitted_message_handler {
|
||
handler.handle(message).await;
|
||
}
|
||
}
|
||
|
||
async fn emit_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
|
||
if let Some(handler) = &self.emitted_message_handler {
|
||
handler.handle_tool_result(message, duration_ms).await;
|
||
}
|
||
}
|
||
|
||
/// Determine whether to execute tools in parallel or sequentially.
|
||
///
|
||
/// Returns true if:
|
||
/// - There are multiple tool calls
|
||
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
|
||
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
|
||
if tool_calls.len() <= 1 {
|
||
return false;
|
||
}
|
||
|
||
// tool_search must run sequentially to avoid MCP activation race conditions
|
||
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
|
||
return false;
|
||
}
|
||
|
||
// Multiple Task tools can run in parallel
|
||
// Task tools create independent subagents with isolated contexts, no shared state
|
||
let task_count = tool_calls.iter().filter(|tc| tc.name == "task").count();
|
||
if task_count > 1 && task_count == tool_calls.len() {
|
||
return true;
|
||
}
|
||
|
||
// When Task is mixed with other tools, keep sequential to avoid complexity
|
||
if task_count > 0 {
|
||
return false;
|
||
}
|
||
|
||
// All tools must be concurrency-safe to run in parallel
|
||
tool_calls.iter().all(|tc| {
|
||
self.tools
|
||
.get(&tc.name)
|
||
.map(|t| t.concurrency_safe())
|
||
.unwrap_or(false)
|
||
})
|
||
}
|
||
|
||
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
|
||
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
||
if self.should_execute_in_parallel(tool_calls) {
|
||
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
|
||
self.execute_tools_parallel(tool_calls).await
|
||
} else {
|
||
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
|
||
self.execute_tools_sequential(tool_calls).await
|
||
}
|
||
}
|
||
|
||
/// Execute tools in parallel using join_all.
|
||
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
||
let futures: Vec<_> = tool_calls
|
||
.iter()
|
||
.map(|tc| self.execute_one_tool(tc))
|
||
.collect();
|
||
|
||
futures_util::future::join_all(futures).await
|
||
}
|
||
|
||
/// Execute tools sequentially.
|
||
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
||
let mut outcomes = Vec::with_capacity(tool_calls.len());
|
||
|
||
for tool_call in tool_calls {
|
||
outcomes.push(self.execute_one_tool(tool_call).await);
|
||
}
|
||
|
||
outcomes
|
||
}
|
||
|
||
/// Execute a single tool and return the outcome with event tracking.
|
||
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
||
let start = Instant::now();
|
||
let tool_name = tool_call.name.clone();
|
||
|
||
// Log function call with name and arguments before execution
|
||
let args_str = match &tool_call.arguments {
|
||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||
other => {
|
||
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
||
}
|
||
};
|
||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||
|
||
// Record ToolCallStart event
|
||
if let Some(ref observer) = self.observer {
|
||
observer.record_event(&ObserverEvent::ToolCallStart {
|
||
tool: tool_name.clone(),
|
||
arguments: Some(truncate_args(&tool_call.arguments, 300)),
|
||
});
|
||
}
|
||
|
||
let result = self.execute_tool_internal(tool_call).await;
|
||
let duration = start.elapsed();
|
||
|
||
// Record ToolCall event
|
||
if let Some(ref observer) = self.observer {
|
||
observer.record_event(&ObserverEvent::ToolCall {
|
||
tool: tool_name.clone(),
|
||
duration,
|
||
success: result.success,
|
||
});
|
||
}
|
||
|
||
// Apply duration
|
||
ToolExecutionOutcome { duration, ..result }
|
||
}
|
||
|
||
/// Internal tool execution without event tracking.
|
||
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
||
let normalized_arguments = normalize_tool_arguments(&tool_call.arguments);
|
||
|
||
let tool = match self.tools.get(&tool_call.name) {
|
||
Some(t) => t,
|
||
None => {
|
||
tracing::warn!(tool = %tool_call.name, "Tool not found");
|
||
let skill_hint = self.skills
|
||
.as_ref()
|
||
.and_then(|s| s.matching_skill_summary(&tool_call.name));
|
||
let error = match skill_hint {
|
||
Some(summary) => format!(
|
||
"Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.",
|
||
tool_call.name, summary, tool_call.name
|
||
),
|
||
None => format!("Tool '{}' not found", tool_call.name),
|
||
};
|
||
return ToolExecutionOutcome::failure(
|
||
format!("Error: {}", error),
|
||
Some(error),
|
||
);
|
||
}
|
||
};
|
||
|
||
match tool
|
||
.execute_with_context(&self.tool_context, normalized_arguments.clone())
|
||
.await
|
||
{
|
||
Ok(result) => {
|
||
if result.success {
|
||
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
|
||
ToolExecutionOutcome::pending(pending_output)
|
||
} else {
|
||
ToolExecutionOutcome::success(result.output)
|
||
}
|
||
} else {
|
||
let error = result.error.unwrap_or_default();
|
||
tracing::error!(
|
||
tool = %tool_call.name,
|
||
args = %truncate_args(&tool_call.arguments, 4_000),
|
||
normalized_args = %truncate_args(&normalized_arguments, 4_000),
|
||
error = %error,
|
||
output = %result.output,
|
||
"Tool returned an error result"
|
||
);
|
||
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::error!(
|
||
tool = %tool_call.name,
|
||
args = %truncate_args(&tool_call.arguments, 4_000),
|
||
normalized_args = %truncate_args(&normalized_arguments, 4_000),
|
||
error = %e,
|
||
error_details = %format!("{:#}", e),
|
||
"Tool execution failed"
|
||
);
|
||
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::observability::{MultiObserver, Observer};
|
||
use tempfile::tempdir;
|
||
|
||
struct TestObserver {
|
||
events: std::sync::Mutex<Vec<ObserverEvent>>,
|
||
}
|
||
|
||
impl TestObserver {
|
||
fn new() -> Self {
|
||
Self {
|
||
events: std::sync::Mutex::new(Vec::new()),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Observer for TestObserver {
|
||
fn record_event(&self, event: &ObserverEvent) {
|
||
self.events.lock().unwrap().push(event.clone());
|
||
}
|
||
|
||
fn name(&self) -> &str {
|
||
"test_observer"
|
||
}
|
||
}
|
||
|
||
struct TestSkillProvider;
|
||
|
||
impl SkillProvider for TestSkillProvider {
|
||
fn system_index_prompt(&self) -> Option<String> {
|
||
None
|
||
}
|
||
|
||
fn matching_skill_summary(&self, name: &str) -> Option<String> {
|
||
(name == "baidu-search").then(|| "用于百度搜索和天气查询的技能".to_string())
|
||
}
|
||
}
|
||
|
||
fn test_runtime_config() -> LLMProviderConfig {
|
||
LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "test".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: std::collections::HashMap::new(),
|
||
llm_timeout_secs: 120,
|
||
memory_maintenance_timeout_secs: 600,
|
||
model_id: "test-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: std::collections::HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_missing_tool_with_same_name_skill_returns_activation_hint() {
|
||
let loop_instance = AgentLoop::with_tools_and_skill_provider(
|
||
test_runtime_config(),
|
||
Arc::new(ToolRegistry::new()),
|
||
Arc::new(TestSkillProvider),
|
||
)
|
||
.unwrap();
|
||
|
||
let outcome = loop_instance
|
||
.execute_tool_internal(&ToolCall {
|
||
id: "call-1".to_string(),
|
||
name: "baidu-search".to_string(),
|
||
arguments: serde_json::json!({
|
||
"queries": "佛山今天几点下雨"
|
||
}),
|
||
})
|
||
.await;
|
||
|
||
assert_eq!(outcome.state, ToolExecutionState::Completed);
|
||
assert!(!outcome.success);
|
||
assert!(outcome.output.contains("技能"));
|
||
assert!(outcome.output.contains("skill_activate"));
|
||
assert!(outcome.output.contains("baidu-search"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_observer_receives_tool_events() {
|
||
// Verify MultiObserver works
|
||
let mut multi = MultiObserver::new();
|
||
multi.add_observer(Box::new(TestObserver::new()));
|
||
|
||
let event = ObserverEvent::ToolCallStart {
|
||
tool: "test".to_string(),
|
||
arguments: Some("{}".to_string()),
|
||
};
|
||
multi.record_event(&event);
|
||
|
||
// Just verify the structure works
|
||
assert_eq!(multi.len(), 1);
|
||
}
|
||
|
||
#[test]
|
||
fn test_should_execute_in_parallel_single_tool() {
|
||
// Would need a proper setup with AgentLoop to test fully
|
||
// For now, just verify the logic: single tool should return false
|
||
let calls = vec![ToolCall {
|
||
id: "1".to_string(),
|
||
name: "test".to_string(),
|
||
arguments: serde_json::json!({}),
|
||
}];
|
||
|
||
// If there's only 1 tool, should return false regardless
|
||
assert_eq!(calls.len() <= 1, true);
|
||
}
|
||
|
||
#[test]
|
||
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
|
||
let chat_message = ChatMessage::assistant_with_tool_calls(
|
||
"calling tool",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({ "expression": "2+2" }),
|
||
}],
|
||
);
|
||
|
||
let mut image_budget = ImageInlineBudget::new(0, 0);
|
||
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
|
||
|
||
assert_eq!(provider_message.role, "assistant");
|
||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||
assert_eq!(
|
||
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
||
"call_1"
|
||
);
|
||
assert_eq!(
|
||
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
||
"calculator"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
|
||
let chat_message =
|
||
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
|
||
|
||
let mut image_budget = ImageInlineBudget::new(0, 0);
|
||
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
|
||
|
||
assert_eq!(provider_message.role, "assistant");
|
||
assert_eq!(
|
||
provider_message.reasoning_content.as_deref(),
|
||
Some("hidden chain of thought")
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_truncate_tool_result_handles_utf8_char_boundaries() {
|
||
let input = "范".repeat(100_500);
|
||
|
||
let output = truncate_tool_result(&input, 100_000);
|
||
|
||
assert!(output.contains("Output truncated"));
|
||
assert!(output.is_char_boundary(output.len()));
|
||
}
|
||
|
||
#[test]
|
||
fn test_parse_pending_tool_output() {
|
||
let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
|
||
assert_eq!(output.as_deref(), Some("请完成授权"));
|
||
assert!(parse_pending_tool_output("normal output").is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalize_tool_arguments_parses_stringified_json() {
|
||
let normalized = normalize_tool_arguments(&serde_json::Value::String(
|
||
"{\"command\":\"ls -la\"}".to_string(),
|
||
));
|
||
|
||
assert_eq!(normalized, serde_json::json!({ "command": "ls -la" }));
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalize_tool_arguments_keeps_plain_string() {
|
||
let normalized =
|
||
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
|
||
|
||
assert_eq!(
|
||
normalized,
|
||
serde_json::Value::String("plain text".to_string())
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_content_blocks_skips_non_image_media_refs() {
|
||
let temp_dir = tempdir().unwrap();
|
||
let pdf_path = temp_dir.path().join("demo.pdf");
|
||
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
|
||
|
||
let mut budget = ImageInlineBudget::new(1_000, 0);
|
||
let blocks = build_content_blocks(
|
||
"hello",
|
||
&[pdf_path.to_string_lossy().to_string()],
|
||
&mut budget,
|
||
);
|
||
|
||
assert_eq!(blocks.len(), 1);
|
||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_content_blocks_keeps_supported_images() {
|
||
let temp_dir = tempdir().unwrap();
|
||
let jpg_path = temp_dir.path().join("demo.jpg");
|
||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||
image.save(&jpg_path).unwrap();
|
||
|
||
let mut budget = ImageInlineBudget::new(10_000, 1);
|
||
let blocks = build_content_blocks(
|
||
"hello",
|
||
&[jpg_path.to_string_lossy().to_string()],
|
||
&mut budget,
|
||
);
|
||
|
||
assert_eq!(blocks.len(), 2);
|
||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||
assert!(
|
||
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_content_blocks_compresses_images_to_budget() {
|
||
let temp_dir = tempdir().unwrap();
|
||
let png_path = temp_dir.path().join("large.png");
|
||
let image = image::DynamicImage::new_rgb8(512, 512);
|
||
image.save(&png_path).unwrap();
|
||
|
||
let mut budget = ImageInlineBudget::new(512, 1);
|
||
let blocks = build_content_blocks(
|
||
"hello",
|
||
&[png_path.to_string_lossy().to_string()],
|
||
&mut budget,
|
||
);
|
||
|
||
assert_eq!(blocks.len(), 2);
|
||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||
assert!(
|
||
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
|
||
let temp_dir = tempdir().unwrap();
|
||
let jpg_path = temp_dir.path().join("demo.jpg");
|
||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||
image.save(&jpg_path).unwrap();
|
||
|
||
let mut budget = ImageInlineBudget::new(0, 1);
|
||
let blocks = build_content_blocks(
|
||
"hello",
|
||
&[jpg_path.to_string_lossy().to_string()],
|
||
&mut budget,
|
||
);
|
||
|
||
assert_eq!(blocks.len(), 2);
|
||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||
assert!(matches!(
|
||
&blocks[1],
|
||
ContentBlock::Text { text }
|
||
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn test_filter_images_by_age_and_count_no_images() {
|
||
let messages = vec![
|
||
ChatMessage::user("hello"),
|
||
ChatMessage::assistant("hi"),
|
||
ChatMessage::user("how are you"),
|
||
];
|
||
|
||
let filtered = filter_images_by_age_and_count(&messages, 10, 5);
|
||
assert_eq!(filtered.len(), 3);
|
||
assert_eq!(filtered[0].content, "hello");
|
||
assert_eq!(filtered[1].content, "hi");
|
||
assert_eq!(filtered[2].content, "how are you");
|
||
}
|
||
|
||
#[test]
|
||
fn test_filter_images_by_age_limit() {
|
||
let temp_dir = tempdir().unwrap();
|
||
let jpg_path = temp_dir.path().join("test.jpg");
|
||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||
image.save(&jpg_path).unwrap();
|
||
|
||
// 创建 12 条消息,第 0 条(最旧)有图片
|
||
let messages: Vec<ChatMessage> = (0..12)
|
||
.map(|i| {
|
||
if i == 0 {
|
||
ChatMessage::user_with_media(
|
||
"first message with image",
|
||
vec![jpg_path.to_string_lossy().to_string()],
|
||
)
|
||
} else if i % 2 == 0 {
|
||
ChatMessage::user(format!("user message {}", i))
|
||
} else {
|
||
ChatMessage::assistant(format!("assistant message {}", i))
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// max_age_rounds = 10,第一条消息(索引 0)距离最新消息(索引 11)有 11 条消息
|
||
// 所以第一条消息的图片应该被过滤
|
||
let filtered = filter_images_by_age_and_count(&messages, 10, 100);
|
||
|
||
// 检查第一条消息的图片被过滤
|
||
assert_eq!(filtered[0].media_refs.len(), 0);
|
||
assert!(filtered[0].content.contains("图片已过期"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_filter_images_by_count_limit() {
|
||
let temp_dir = tempdir().unwrap();
|
||
|
||
// 创建 3 张不同的图片
|
||
let jpg_paths: Vec<String> = (0..3)
|
||
.map(|i| {
|
||
let path = temp_dir.path().join(format!("test{}.jpg", i));
|
||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||
image.save(&path).unwrap();
|
||
path.to_string_lossy().to_string()
|
||
})
|
||
.collect();
|
||
|
||
// 创建 3 条消息,每条都有图片
|
||
let messages: Vec<ChatMessage> = (0..3)
|
||
.map(|i| {
|
||
ChatMessage::user_with_media(
|
||
format!("message {}", i),
|
||
vec![jpg_paths[i].clone()],
|
||
)
|
||
})
|
||
.collect();
|
||
|
||
// max_images = 1,只保留最新的图片(索引 2)
|
||
let filtered = filter_images_by_age_and_count(&messages, 100, 1);
|
||
|
||
// 检查最新的消息保留图片,旧消息的图片被过滤
|
||
assert_eq!(filtered[2].media_refs.len(), 1); // 最新保留
|
||
assert_eq!(filtered[1].media_refs.len(), 0); // 被过滤
|
||
assert!(filtered[1].content.contains("超出最大图片数量限制"));
|
||
assert_eq!(filtered[0].media_refs.len(), 0); // 被过滤
|
||
assert!(filtered[0].content.contains("超出最大图片数量限制"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_filter_images_combined_limits() {
|
||
let temp_dir = tempdir().unwrap();
|
||
|
||
// 创建 5 张图片
|
||
let jpg_paths: Vec<String> = (0..5)
|
||
.map(|i| {
|
||
let path = temp_dir.path().join(format!("test{}.jpg", i));
|
||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||
image.save(&path).unwrap();
|
||
path.to_string_lossy().to_string()
|
||
})
|
||
.collect();
|
||
|
||
// 创建 20 条消息,在特定位置添加图片
|
||
let messages: Vec<ChatMessage> = (0..20)
|
||
.map(|i| {
|
||
// 在索引 0, 5, 10, 15, 19 添加图片
|
||
if i == 0 || i == 5 || i == 10 || i == 15 || i == 19 {
|
||
let image_idx = i / 5;
|
||
ChatMessage::user_with_media(
|
||
format!("message {} with image", i),
|
||
vec![jpg_paths[image_idx.clamp(0, 4)].clone()],
|
||
)
|
||
} else if i % 2 == 0 {
|
||
ChatMessage::user(format!("user message {}", i))
|
||
} else {
|
||
ChatMessage::assistant(format!("assistant message {}", i))
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// max_age_rounds = 10(保留最近 10 条消息内的图片)
|
||
// max_images = 3(最多 3 张图片)
|
||
// 最新消息索引 19 (age=0),索引 15 (age=4),索引 10 (age=9),索引 5 (age=14),索引 0 (age=19)
|
||
// 索引 5 和 0 超出 age 限制
|
||
// 索引 19, 15, 10 的图片应该保留(共 3 张,不超过 max_images)
|
||
let filtered = filter_images_by_age_and_count(&messages, 10, 3);
|
||
|
||
// 检查结果
|
||
assert!(filtered[19].media_refs.len() > 0, "最新消息应保留图片");
|
||
assert!(filtered[15].media_refs.len() > 0, "age=4 的消息应保留图片");
|
||
assert!(filtered[10].media_refs.len() > 0, "age=9 的消息应保留图片");
|
||
assert_eq!(filtered[5].media_refs.len(), 0, "age=14 的消息图片应被过滤");
|
||
assert!(filtered[5].content.contains("超出 10 条消息范围"));
|
||
assert_eq!(filtered[0].media_refs.len(), 0, "age=19 的消息图片应被过滤");
|
||
assert!(filtered[0].content.contains("超出 10 条消息范围"));
|
||
}
|
||
|
||
// ====================
|
||
// sanitize_incomplete_tool_call_sequences tests
|
||
// ====================
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_trailing_incomplete_tool_call_sequence() {
|
||
let mut messages = vec![
|
||
ChatMessage::user("hello"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"calling tool",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
}],
|
||
),
|
||
// Tool result for call_1 is MISSING — incomplete sequence
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 1);
|
||
assert_eq!(messages.len(), 1);
|
||
assert_eq!(messages[0].role, "user");
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_preserves_complete_tool_call_sequence() {
|
||
let mut messages = vec![
|
||
ChatMessage::user("hello"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"calling tool",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
}],
|
||
),
|
||
ChatMessage::tool("call_1", "calculator", "2"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 0);
|
||
assert_eq!(messages.len(), 3);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_multiple_incomplete_sequences() {
|
||
let mut messages = vec![
|
||
ChatMessage::user("hello"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"first tool call",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
}],
|
||
),
|
||
// Missing tool result for call_1
|
||
ChatMessage::user("second question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"second tool call",
|
||
vec![ToolCall {
|
||
id: "call_2".to_string(),
|
||
name: "read".to_string(),
|
||
arguments: serde_json::json!({"path": "README.md"}),
|
||
}],
|
||
),
|
||
// Also missing tool result for call_2
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
// Should remove both trailing assistant messages with incomplete tool calls
|
||
assert_eq!(removed, 2);
|
||
assert_eq!(messages.len(), 2);
|
||
assert_eq!(messages[0].role, "user");
|
||
assert_eq!(messages[0].content, "hello");
|
||
assert_eq!(messages[1].role, "user");
|
||
assert_eq!(messages[1].content, "second question");
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_assistant_when_partial_tool_results() {
|
||
// Assistant makes 2 tool calls, but only 1 tool result exists
|
||
let mut messages = vec![
|
||
ChatMessage::user("hello"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"calling two tools",
|
||
vec![
|
||
ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
},
|
||
ToolCall {
|
||
id: "call_2".to_string(),
|
||
name: "read".to_string(),
|
||
arguments: serde_json::json!({"path": "README.md"}),
|
||
},
|
||
],
|
||
),
|
||
ChatMessage::tool("call_1", "calculator", "2"),
|
||
// Missing tool result for call_2
|
||
];
|
||
|
||
let removed_count = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
// Phase 1 removes the assistant message (call_2 has no result).
|
||
// Phase 2 removes the orphaned tool result for call_1 (its parent
|
||
// assistant was removed).
|
||
assert_eq!(removed_count, 2);
|
||
assert_eq!(messages.len(), 1);
|
||
assert_eq!(messages[0].role, "user");
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_preserves_messages_without_tool_calls() {
|
||
let mut messages = vec![
|
||
ChatMessage::user("hello"),
|
||
ChatMessage::assistant("hi there"),
|
||
ChatMessage::user("how are you"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 0);
|
||
assert_eq!(messages.len(), 3);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_handles_empty_messages() {
|
||
let mut messages: Vec<ChatMessage> = vec![];
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_orphaned_tool_messages() {
|
||
// A lone tool message without a preceding assistant tool_calls
|
||
// is orphaned and should be removed.
|
||
let mut messages = vec![
|
||
ChatMessage::tool("call_1", "calculator", "2"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 1);
|
||
assert!(messages.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_preserves_complete_sequence_with_multiple_tool_calls() {
|
||
let mut messages = vec![
|
||
ChatMessage::user("do two things"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"calling two tools",
|
||
vec![
|
||
ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
},
|
||
ToolCall {
|
||
id: "call_2".to_string(),
|
||
name: "read".to_string(),
|
||
arguments: serde_json::json!({"path": "README.md"}),
|
||
},
|
||
],
|
||
),
|
||
ChatMessage::tool("call_1", "calculator", "2"),
|
||
ChatMessage::tool("call_2", "read", "contents of README"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 0);
|
||
assert_eq!(messages.len(), 4);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_only_trims_trailing_incomplete_sequence() {
|
||
// Complete sequence followed by an incomplete one — only the
|
||
// trailing incomplete one should be removed
|
||
let mut messages = vec![
|
||
ChatMessage::user("first question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"first tool call",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
}],
|
||
),
|
||
ChatMessage::tool("call_1", "calculator", "2"),
|
||
ChatMessage::assistant("the answer is 2"),
|
||
ChatMessage::user("second question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"second tool call",
|
||
vec![ToolCall {
|
||
id: "call_2".to_string(),
|
||
name: "read".to_string(),
|
||
arguments: serde_json::json!({"path": "README.md"}),
|
||
}],
|
||
),
|
||
// Missing tool result for call_2 — only THIS sequence should be trimmed
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 1);
|
||
// First complete sequence preserved (5 messages), user message for second
|
||
// question preserved
|
||
assert_eq!(messages.len(), 5);
|
||
assert_eq!(messages[0].content, "first question");
|
||
assert_eq!(messages[3].content, "the answer is 2");
|
||
assert_eq!(messages[4].content, "second question");
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_mid_history_orphaned_tool_calls() {
|
||
// Bug scenario: orphaned tool_calls in the MIDDLE of history,
|
||
// followed by a complete sequence. The old trailing-only sanitizer
|
||
// would stop at the complete sequence and never remove the orphan.
|
||
let mut messages = vec![
|
||
ChatMessage::user("first question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"orphaned tool call",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "1+1"}),
|
||
}],
|
||
),
|
||
// Missing tool result for call_1 — ORPHAN in the middle
|
||
ChatMessage::user("second question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"valid tool call",
|
||
vec![ToolCall {
|
||
id: "call_2".to_string(),
|
||
name: "read".to_string(),
|
||
arguments: serde_json::json!({"path": "README.md"}),
|
||
}],
|
||
),
|
||
ChatMessage::tool("call_2", "read", "file contents"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
// call_1 assistant removed (1), rest preserved (4)
|
||
assert_eq!(removed, 1);
|
||
assert_eq!(messages.len(), 4);
|
||
assert_eq!(messages[0].role, "user");
|
||
assert_eq!(messages[0].content, "first question");
|
||
assert_eq!(messages[1].role, "user");
|
||
assert_eq!(messages[1].content, "second question");
|
||
// The complete call_2 sequence is preserved
|
||
assert_eq!(messages[2].role, "assistant");
|
||
assert_eq!(messages[3].role, "tool");
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_multiple_mid_history_orphans() {
|
||
// Multiple orphaned tool_calls scattered throughout history
|
||
let mut messages = vec![
|
||
ChatMessage::user("first"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"orphan 1",
|
||
vec![ToolCall {
|
||
id: "orphan_1".to_string(),
|
||
name: "tool_a".to_string(),
|
||
arguments: serde_json::json!({}),
|
||
}],
|
||
),
|
||
// Missing result for orphan_1
|
||
ChatMessage::user("second"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"orphan 2",
|
||
vec![ToolCall {
|
||
id: "orphan_2".to_string(),
|
||
name: "tool_b".to_string(),
|
||
arguments: serde_json::json!({}),
|
||
}],
|
||
),
|
||
// Missing result for orphan_2
|
||
ChatMessage::user("third"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
assert_eq!(removed, 2);
|
||
assert_eq!(messages.len(), 3);
|
||
assert_eq!(messages[0].content, "first");
|
||
assert_eq!(messages[1].content, "second");
|
||
assert_eq!(messages[2].content, "third");
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_removes_orphaned_tool_results_for_removed_mid_assistant() {
|
||
// When removing a mid-history assistant with partial tool results,
|
||
// both the assistant AND its orphaned tool results must be removed.
|
||
// Assistant has 2 tool calls, only 1 has a result → assistant is
|
||
// incomplete → both assistant and its lone tool result removed.
|
||
let mut messages = vec![
|
||
ChatMessage::user("first question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"two tool calls, only one has result",
|
||
vec![
|
||
ToolCall {
|
||
id: "call_has_result".to_string(),
|
||
name: "tool_a".to_string(),
|
||
arguments: serde_json::json!({}),
|
||
},
|
||
ToolCall {
|
||
id: "call_no_result".to_string(),
|
||
name: "tool_b".to_string(),
|
||
arguments: serde_json::json!({}),
|
||
},
|
||
],
|
||
),
|
||
ChatMessage::tool("call_has_result", "tool_a", "some result"),
|
||
// Missing tool result for call_no_result → incomplete sequence
|
||
ChatMessage::user("second question"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"valid tool call",
|
||
vec![ToolCall {
|
||
id: "call_valid".to_string(),
|
||
name: "good_tool".to_string(),
|
||
arguments: serde_json::json!({}),
|
||
}],
|
||
),
|
||
ChatMessage::tool("call_valid", "good_tool", "valid result"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
// call_has_result/call_no_result assistant (1) + its orphaned tool result (1) = 2 removed
|
||
assert_eq!(removed, 2);
|
||
assert_eq!(messages.len(), 4);
|
||
assert_eq!(messages[0].content, "first question");
|
||
assert_eq!(messages[1].content, "second question");
|
||
assert_eq!(messages[2].role, "assistant");
|
||
assert_eq!(messages[3].role, "tool");
|
||
// Verify the remaining tool belongs to call_valid
|
||
assert_eq!(messages[3].tool_call_id.as_deref(), Some("call_valid"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_sanitize_handles_complex_interleaved_history() {
|
||
// Complete → Orphaned → Complete: a realistic scenario after
|
||
// history compaction
|
||
let mut messages = vec![
|
||
ChatMessage::user("task 1"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"doing task 1",
|
||
vec![ToolCall {
|
||
id: "t1_call".to_string(),
|
||
name: "read".to_string(),
|
||
arguments: serde_json::json!({"path": "a.txt"}),
|
||
}],
|
||
),
|
||
ChatMessage::tool("t1_call", "read", "content A"),
|
||
ChatMessage::assistant("task 1 is done"),
|
||
// End of task 1 — complete sequence
|
||
|
||
ChatMessage::user("task 2"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"doing task 2 — this got interrupted",
|
||
vec![
|
||
ToolCall {
|
||
id: "t2_call_1".to_string(),
|
||
name: "write".to_string(),
|
||
arguments: serde_json::json!({"path": "b.txt"}),
|
||
},
|
||
ToolCall {
|
||
id: "t2_call_2".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({"expression": "2+2"}),
|
||
},
|
||
],
|
||
),
|
||
// Missing BOTH tool results — process was killed here
|
||
// End of task 2 — orphaned sequence in the middle
|
||
|
||
ChatMessage::user("task 3"),
|
||
ChatMessage::assistant_with_tool_calls(
|
||
"doing task 3",
|
||
vec![ToolCall {
|
||
id: "t3_call".to_string(),
|
||
name: "search".to_string(),
|
||
arguments: serde_json::json!({"query": "hello"}),
|
||
}],
|
||
),
|
||
ChatMessage::tool("t3_call", "search", "found results"),
|
||
ChatMessage::assistant("task 3 is done"),
|
||
];
|
||
|
||
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
|
||
// Removed: assistant with t2_call_1/t2_call_2 (1 message)
|
||
assert_eq!(removed, 1);
|
||
// Original 10 messages - 1 = 9
|
||
assert_eq!(messages.len(), 9);
|
||
assert_eq!(messages[4].content, "task 2");
|
||
assert_eq!(messages[5].content, "task 3");
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
pub enum AgentError {
|
||
ProviderCreation(String),
|
||
LlmError(String),
|
||
Other(String),
|
||
}
|
||
|
||
impl std::fmt::Display for AgentError {
|
||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
match self {
|
||
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
||
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
||
AgentError::Other(e) => write!(f, "{}", e),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl std::error::Error for AgentError {}
|