PicoBot/src/agent/agent_loop.rs
oudecheng 06756a4816 fix: 修复消息持久化缺失 topic 关联和 assistant 文本丢失
- PersistingEmittedMessageHandler 新增 topic_id 参数,使用 append_message_with_topic 替代 append_message
- agent_loop 的所有退出路径中为最终 assistant 文本添加 emit_live_tool_call_message
- 更新 finalize_result filter,live_emitter 存在时抑制所有消息的 post-loop 广播

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-29 18:09:00 +08:00

1758 lines
64 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::agent::AgentRuntimeConfig;
use crate::agent::{SystemPromptContext, SystemPromptProvider};
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::storage::ConversationRepository;
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::observability::{
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::io::Read;
use std::sync::Arc;
use std::time::Instant;
/// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
"工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
/// Build content blocks from text and media paths
fn build_content_blocks(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
build_content_blocks_with_image_budget(text, media_paths, budget)
}
fn build_content_blocks_with_image_budget(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
let mut skipped_image_notices = Vec::new();
// Add text block if there's text
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
// Add image blocks for media paths
for path in media_paths {
if supported_image_mime_type(path).is_none() {
tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block");
continue;
}
let Some(target_tokens) = budget.take_next_image_tokens() else {
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
skipped_image_notices.push(format!(
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
};
match encode_image_to_base64_with_budget(path, target_tokens) {
Ok((mime_type, base64_data)) => {
let url = format!("data:{};base64,{}", mime_type, base64_data);
blocks.push(ContentBlock::image_url(url));
}
Err(err) => {
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
skipped_image_notices.push(format!(
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
}
}
}
if !skipped_image_notices.is_empty() {
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
notice.push('\n');
notice.push_str(&skipped_image_notices.join("\n"));
blocks.push(ContentBlock::text(notice));
}
// If nothing, add empty text block
if blocks.is_empty() {
blocks.push(ContentBlock::text(""));
}
blocks
}
fn supported_image_mime_type(path: &str) -> Option<String> {
let mime = mime_guess::from_path(path).first_or_octet_stream();
let essence = mime.essence_str();
if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) {
Some(essence.to_string())
} else {
None
}
}
fn display_media_name(path: &str) -> String {
std::path::Path::new(path)
.file_name()
.and_then(|name| name.to_str())
.map(ToOwned::to_owned)
.unwrap_or_else(|| path.to_string())
}
#[derive(Debug, Clone)]
struct ImageInlineBudget {
remaining_tokens: usize,
remaining_images: usize,
}
impl ImageInlineBudget {
fn new(total_tokens: usize, image_count: usize) -> Self {
Self {
remaining_tokens: total_tokens,
remaining_images: image_count,
}
}
fn take_next_image_tokens(&mut self) -> Option<usize> {
if self.remaining_tokens == 0 || self.remaining_images == 0 {
self.remaining_images = self.remaining_images.saturating_sub(1);
return None;
}
let target = self.remaining_tokens.div_ceil(self.remaining_images);
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
self.remaining_images = self.remaining_images.saturating_sub(1);
Some(target)
}
}
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
let raw_len = serde_json::to_string(value)
.map(|serialized| serialized.len())
.unwrap_or_default();
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
as usize
}
fn image_token_budget_for_request(
runtime_config: &AgentRuntimeConfig,
text_only_messages: &[Message],
tools: Option<&Vec<crate::domain::tools::Tool>>,
) -> usize {
let completion_reserve = runtime_config
.provider
.max_tokens
.map(|tokens| tokens as usize)
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
let input_window = runtime_config
.context_window_tokens
.saturating_sub(completion_reserve);
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
+ tools
.map(estimate_tokens_from_serialized_json)
.unwrap_or_default();
safe_input_window.saturating_sub(text_tokens)
}
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
messages
.iter()
.flat_map(|message| message.media_refs.iter())
.filter(|path| supported_image_mime_type(path).is_some())
.count()
}
/// 过滤超出轮次限制的图片,并限制总图片数量
///
/// # 参数
/// - `messages`: 原始消息列表
/// - `max_age_rounds`: 图片超过多少消息轮次后不再发送(从最新消息向前数,包括所有 role
/// - `max_images`: 最多发送多少张图片(优先保留最近的图片)
///
/// # 返回
/// 过滤后的消息列表,图片被转换为文本提示 "[图片已过期]"
fn filter_images_by_age_and_count(
messages: &[ChatMessage],
max_age_rounds: usize,
max_images: usize,
) -> Vec<ChatMessage> {
if messages.is_empty() || (max_age_rounds == 0 && max_images == 0) {
return messages.to_vec();
}
let total_images = count_supported_image_media_refs(messages);
if total_images == 0 {
return messages.to_vec();
}
// 从最新消息向前遍历,优先保留最新的图片
// 消息列表顺序:[old, ..., new],所以末尾是最新的
let msg_count = messages.len();
// 先从后向前遍历,计算每条消息应该保留多少张图片
// 使用 Vec<usize> 存储每条消息应该保留的图片数量
let mut images_to_keep_per_msg: Vec<usize> = vec![0; msg_count];
let mut images_kept = 0usize;
// 从后向前遍历(从最新到最旧)
for (idx, message) in messages.iter().enumerate().rev() {
// 计算距离最新消息的消息数(末尾消息的 age = 0
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
// 检查是否超出轮次限制
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
if exceeds_age_limit {
continue; // 超出轮次限制的图片不保留
}
// 计算这条消息中的图片数量
let image_count_in_msg = message.media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
if image_count_in_msg == 0 {
continue;
}
// 计算可以保留多少张
let can_keep = std::cmp::min(image_count_in_msg, max_images.saturating_sub(images_kept));
if can_keep > 0 {
images_to_keep_per_msg[idx] = can_keep;
images_kept += can_keep;
}
}
// 然后从前向后遍历,构建过滤后的消息列表
let mut filtered = Vec::with_capacity(msg_count);
for (idx, message) in messages.iter().enumerate() {
// 计算距离最新消息的消息数(末尾消息的 age = 0
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
// 检查是否超出轮次限制
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
let keep_count = images_to_keep_per_msg[idx];
// 过滤图片:保留非图片媒体和指定数量的图片
let mut images_kept_in_msg = 0usize;
let filtered_media_refs: Vec<String> = message.media_refs.iter()
.filter_map(|path| {
if supported_image_mime_type(path).is_some() {
if images_kept_in_msg < keep_count {
images_kept_in_msg += 1;
Some(path.clone())
} else {
None
}
} else {
Some(path.clone()) // 保留非图片媒体
}
})
.collect();
// 如果图片被过滤,添加文本提示
let original_image_count = message.media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
let filtered_image_count = filtered_media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
let content = if original_image_count > filtered_image_count {
let notice = if exceeds_age_limit {
format!("{} [图片已过期:超出 {} 条消息范围]", message.content, max_age_rounds)
} else {
format!("{} [图片已过期:超出最大图片数量限制]", message.content)
};
notice
} else {
message.content.clone()
};
filtered.push(ChatMessage {
id: message.id.clone(),
role: message.role.clone(),
content,
media_refs: filtered_media_refs,
timestamp: message.timestamp,
system_context: message.system_context.clone(),
reasoning_content: message.reasoning_content.clone(),
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_state: message.tool_state.clone(),
tool_calls: message.tool_calls.clone(),
});
}
filtered
}
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
.saturating_mul(3)
/ 4
}
/// Encode an image file to base64 data URL
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let mime = supported_image_mime_type(path).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("unsupported image media type for path: {}", path),
)
})?;
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded = STANDARD.encode(&buffer);
Ok((mime, encoded))
}
fn encode_image_to_base64_with_budget(
path: &str,
target_tokens: usize,
) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let (mime, encoded) = encode_image_to_base64(path)?;
let target_base64_chars = target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
if encoded.len() <= target_base64_chars {
return Ok((mime, encoded));
}
let target_bytes = target_image_bytes_for_tokens(target_tokens);
let image_bytes = std::fs::read(path)?;
let image = image::load_from_memory(&image_bytes)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
}
fn compress_image_to_jpeg_budget(
image: &image::DynamicImage,
target_bytes: usize,
) -> Result<Vec<u8>, std::io::Error> {
use image::codecs::jpeg::JpegEncoder;
use image::imageops::FilterType;
let mut best: Option<Vec<u8>> = None;
let mut max_side = image
.width()
.max(image.height())
.max(MIN_COMPRESSED_IMAGE_SIDE);
loop {
let candidate = if image.width().max(image.height()) > max_side {
image.resize(max_side, max_side, FilterType::Triangle)
} else {
image.clone()
};
for quality in JPEG_QUALITY_STEPS {
let mut encoded = Vec::new();
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
encoder
.encode_image(&candidate)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
if encoded.len() <= target_bytes {
return Ok(encoded);
}
if best
.as_ref()
.map(|current| encoded.len() < current.len())
.unwrap_or(true)
{
best = Some(encoded);
}
}
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
break;
}
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
}
best.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
})
}
/// Truncate tool result if it exceeds the configured limit.
/// Preserves the end of the output as it often contains the conclusion/useful result.
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
let total_chars = char_count(output);
if total_chars <= max_tool_result_chars {
return output.to_string();
}
let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN);
if truncated_start_len > max_tool_result_chars {
// Even after removing suffix, still too long - take from beginning
let head_len = max_tool_result_chars.saturating_sub(100);
let head = take_prefix_chars(output, head_len);
format!(
"{}...\n\n[Output truncated - {} characters removed]",
head,
total_chars - max_tool_result_chars + 100
)
} else {
// Keep most of the end which usually contains the useful result
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
format!(
"...\n\n[Output truncated - {} characters removed]\n\n{}",
truncated_start_len, tail
)
}
}
fn parse_pending_tool_output(output: &str) -> Option<String> {
output
.strip_prefix(PENDING_USER_ACTION_MARKER)
.map(|rest| rest.trim().to_string())
}
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
match arguments {
serde_json::Value::String(raw) => {
serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone())
}
_ => arguments.clone(),
}
}
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
let mut current = error.source();
while let Some(source) = current {
details.push(source.to_string());
current = source.source();
}
details.join("\ncaused by: ")
}
fn is_recoverable_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn recoverable_llm_message(error: &str) -> String {
if is_recoverable_llm_error(error) {
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
} else {
format!("模型请求失败:{}", error)
}
}
/// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult {
/// No warning needed.
Ok,
/// Warning: same tool + args repeated N times.
Warning(String),
}
/// Configuration for loop detector.
#[derive(Debug, Clone)]
struct LoopDetectorConfig {
/// Master switch.
enabled: bool,
/// Warn every N consecutive identical calls.
warn_every: usize,
}
impl Default for LoopDetectorConfig {
fn default() -> Self {
Self {
enabled: true,
warn_every: 5,
}
}
}
/// A single recorded tool invocation in the sliding window.
#[derive(Debug, Clone)]
struct ToolCallRecord {
name: String,
args_hash: u64,
}
/// Stateful loop detector that monitors for repetitive patterns.
struct LoopDetector {
config: LoopDetectorConfig,
window: VecDeque<ToolCallRecord>,
}
impl LoopDetector {
fn new(config: LoopDetectorConfig) -> Self {
Self {
window: VecDeque::with_capacity(config.warn_every * 2),
config,
}
}
/// Record a completed tool call and check for loop patterns.
/// Returns Warning every `warn_every` consecutive identical calls.
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
if !self.config.enabled {
return LoopDetectionResult::Ok;
}
let record = ToolCallRecord {
name: name.to_string(),
args_hash: hash_json_value(args),
};
// Maintain sliding window
if self.window.len() >= self.config.warn_every * 2 {
self.window.pop_front();
}
self.window.push_back(record);
// Count consecutive identical calls
let last = self.window.back().unwrap();
let consecutive: usize = self
.window
.iter()
.rev()
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
.count();
// Warn every warn_every times
if consecutive > 0 && consecutive % self.config.warn_every == 0 {
LoopDetectionResult::Warning(format!(
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
last.name, consecutive
))
} else {
LoopDetectionResult::Ok
}
}
}
/// Hash a JSON value deterministically (key-order independent).
fn hash_json_value(value: &serde_json::Value) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
let canonical = canonicalise_json(value);
canonical.hash(&mut hasher);
hasher.finish()
}
/// Return a clone of value with all object keys sorted recursively.
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
sorted.sort_by_key(|(k, _)| *k);
let new_map: serde_json::Map<String, serde_json::Value> = sorted
.into_iter()
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
.collect();
serde_json::Value::Object(new_map)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
}
other => other.clone(),
}
}
/// Convert ChatMessage to LLM Message format
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs, image_budget)
};
Message {
role: m.role.clone(),
content,
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
Message {
role: m.role.clone(),
content: vec![ContentBlock::text(&m.content)],
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
runtime_config: AgentRuntimeConfig,
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
/// 系统提示词提供者(统一注入 Agent 和 Skill 提示词)
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
/// Skill 提供者(用于匹配错误提示)
skills: Option<Arc<dyn SkillProvider>>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
max_iterations: usize,
}
#[derive(Debug, Clone)]
pub struct AgentProcessResult {
pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>,
}
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
}
/// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB
pub struct PersistingEmittedMessageHandler<H: EmittedMessageHandler> {
inner: H,
conversation_repository: Arc<dyn ConversationRepository>,
session_id: String,
topic_id: Option<String>,
}
impl<H: EmittedMessageHandler> PersistingEmittedMessageHandler<H> {
pub fn new(
inner: H,
conversation_repository: Arc<dyn ConversationRepository>,
session_id: impl Into<String>,
topic_id: Option<String>,
) -> Self {
Self { inner, conversation_repository, session_id: session_id.into(), topic_id }
}
}
#[async_trait]
impl<H: EmittedMessageHandler> EmittedMessageHandler for PersistingEmittedMessageHandler<H> {
async fn handle(&self, message: ChatMessage) {
if let Err(e) = self.conversation_repository
.append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message)
{
tracing::error!(error = %e, session_id = %self.session_id,
"Failed to persist emitted message");
}
self.inner.handle(message).await;
}
}
pub trait SkillProvider: Send + Sync + 'static {
fn system_index_prompt(&self) -> Option<String>;
fn matching_skill_summary(&self, _name: &str) -> Option<String> {
None
}
}
#[derive(Default)]
#[allow(dead_code)]
struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
}
impl AgentLoop {
pub fn new(config: impl Into<AgentRuntimeConfig>) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools: Arc::new(ToolRegistry::new()),
system_prompt_provider: None,
skills: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
pub fn with_tools(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: None,
skills: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
pub fn with_tools_and_skill_provider(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: None,
skills: Some(skills),
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
/// 使用系统提示词提供者创建 AgentLoop
///
/// 这是新的推荐方式,支持统一注入 Agent 和 Skill 提示词。
pub fn with_tools_and_system_prompt_provider(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
skills: Option<Arc<dyn SkillProvider>>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: Some(system_prompt_provider),
skills,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
self.tool_context = context;
self
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
self.emitted_message_handler = Some(handler);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
///
/// This method supports multi-round tool calling: after executing tools,
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
///
/// # 参数
/// - `messages`: 会话历史消息
/// - `system_prompt_context`: 系统提示词上下文(用于动态注入,可选)
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
system_prompt_context: Option<&SystemPromptContext>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
tracing::debug!(iteration, "Agent iteration started");
// Build request
let tool_defs = self.tools.get_definitions();
let tools = if tool_defs.is_empty() {
None
} else {
Some(tool_defs)
};
// 过滤超出轮次和数量限制的图片
let filtered_messages = filter_images_by_age_and_count(
&messages,
self.runtime_config.max_image_age_rounds,
self.runtime_config.max_images_in_context,
);
let image_count = count_supported_image_media_refs(&filtered_messages);
// 构建系统提示词(统一注入 Agent 和 Skill 提示词)
let system_prompt = system_prompt_context.and_then(|ctx| {
self.system_prompt_provider
.as_ref()
.and_then(|provider| provider.build(ctx))
});
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
if let Some(ref prompt) = system_prompt {
text_only_messages.push(Message::system(prompt.content.clone()));
}
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens = image_token_budget_for_request(
&self.runtime_config,
&text_only_messages,
tools.as_ref(),
);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
// 使用相同的系统提示词(已构建)
if let Some(ref prompt) = system_prompt {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
messages_for_llm.extend(
filtered_messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
let response = match (*self.provider).chat(request).await {
Ok(response) => response,
Err(e) => {
tracing::error!(
provider = %self.provider.name(),
model = %self.provider.model_id(),
error = %e,
error_details = %format_error_chain(e.as_ref()),
"LLM request failed"
);
let assistant_message =
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
};
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
// Execute tool calls
tracing::info!(
iteration,
count = response.tool_calls.len(),
"Tool calls detected, executing tools"
);
// Add assistant message with tool calls
let assistant_message =
if let Some(reasoning_content) = response.reasoning_content.clone() {
ChatMessage::assistant_with_tool_calls_and_reasoning(
response.content.clone(),
response.tool_calls.clone(),
reasoning_content,
)
} else {
ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
)
};
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(
emitted_messages
.last()
.expect("assistant message just pushed")
.clone(),
)
.await;
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Truncate tool result if too large
let truncated_output =
truncate_tool_result(&result.output, self.runtime_config.tool_result_max_chars);
// Record tool call and check for loops
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
match loop_result {
LoopDetectionResult::Warning(msg) => {
// Add warning and proceed
tracing::warn!(
tool = %tool_call.name,
"Loop warning: {}",
msg
);
let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(),
tool_call.name.clone(),
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
);
messages.push(tool_message.clone());
emitted_messages.push(tool_message.clone());
self.emit_live_tool_call_message(tool_message).await;
}
LoopDetectionResult::Ok => {
let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(),
tool_call.name.clone(),
truncated_output,
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
);
messages.push(tool_message.clone());
emitted_messages.push(tool_message.clone());
self.emit_live_tool_call_message(tool_message).await;
}
}
}
if let Some((tool_call, pending_result)) = response
.tool_calls
.iter()
.zip(tool_results.iter())
.find(|(_, result)| result.state == ToolExecutionState::PendingUserAction)
{
let assistant_message = ChatMessage::assistant(format!(
"{}\n\n当前等待中的工具: {}",
pending_result
.output
.lines()
.next()
.filter(|line| !line.trim().is_empty())
.unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE),
tool_call.name,
));
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
}
// Max iterations reached - ask LLM for a summary based on completed work
tracing::warn!("Max iterations reached, requesting final summary from LLM");
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far.",
);
messages.push(summary_request);
// 过滤超出轮次和数量限制的图片
let filtered_messages = filter_images_by_age_and_count(
&messages,
self.runtime_config.max_image_age_rounds,
self.runtime_config.max_images_in_context,
);
// Convert messages to LLM format (使用系统提示词提供者)
let image_count = count_supported_image_media_refs(&filtered_messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
text_only_messages.push(Message::system(prompt.content.clone()));
}
}
}
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens =
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
}
}
messages_for_llm.extend(
filtered_messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools: None, // No tools in final summary call
};
match (*self.provider).chat(request).await {
Ok(response) => {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
})
}
Err(e) => {
tracing::error!(
provider = %self.provider.name(),
model = %self.provider.model_id(),
error = %e,
error_details = %format_error_chain(e.as_ref()),
"Failed to get summary from LLM"
);
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(final_message.clone());
self.emit_live_tool_call_message(final_message.clone()).await;
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
})
}
}
}
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
if let Some(handler) = &self.emitted_message_handler {
handler.handle(message).await;
}
}
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:
/// - There are multiple tool calls
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
if tool_calls.len() <= 1 {
return false;
}
// tool_search must run sequentially to avoid MCP activation race conditions
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
return false;
}
// Multiple Task tools can run in parallel
// Task tools create independent subagents with isolated contexts, no shared state
let task_count = tool_calls.iter().filter(|tc| tc.name == "task").count();
if task_count > 1 && task_count == tool_calls.len() {
return true;
}
// When Task is mixed with other tools, keep sequential to avoid complexity
if task_count > 0 {
return false;
}
// All tools must be concurrency-safe to run in parallel
tool_calls.iter().all(|tc| {
self.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false)
})
}
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
if self.should_execute_in_parallel(tool_calls) {
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
self.execute_tools_parallel(tool_calls).await
} else {
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
self.execute_tools_sequential(tool_calls).await
}
}
/// Execute tools in parallel using join_all.
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let futures: Vec<_> = tool_calls
.iter()
.map(|tc| self.execute_one_tool(tc))
.collect();
futures_util::future::join_all(futures).await
}
/// Execute tools sequentially.
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let mut outcomes = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
outcomes.push(self.execute_one_tool(tool_call).await);
}
outcomes
}
/// Execute a single tool and return the outcome with event tracking.
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let start = Instant::now();
let tool_name = tool_call.name.clone();
// Log function call with name and arguments before execution
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
// Record ToolCallStart event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCallStart {
tool: tool_name.clone(),
arguments: Some(truncate_args(&tool_call.arguments, 300)),
});
}
let result = self.execute_tool_internal(tool_call).await;
let duration = start.elapsed();
// Record ToolCall event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCall {
tool: tool_name.clone(),
duration,
success: result.success,
});
}
// Apply duration
ToolExecutionOutcome { duration, ..result }
}
/// Internal tool execution without event tracking.
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let normalized_arguments = normalize_tool_arguments(&tool_call.arguments);
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
let skill_hint = self.skills
.as_ref()
.and_then(|s| s.matching_skill_summary(&tool_call.name));
let error = match skill_hint {
Some(summary) => format!(
"Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.",
tool_call.name, summary, tool_call.name
),
None => format!("Tool '{}' not found", tool_call.name),
};
return ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
);
}
};
match tool
.execute_with_context(&self.tool_context, normalized_arguments.clone())
.await
{
Ok(result) => {
if result.success {
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
ToolExecutionOutcome::pending(pending_output)
} else {
ToolExecutionOutcome::success(result.output)
}
} else {
let error = result.error.unwrap_or_default();
tracing::error!(
tool = %tool_call.name,
args = %truncate_args(&tool_call.arguments, 4_000),
normalized_args = %truncate_args(&normalized_arguments, 4_000),
error = %error,
output = %result.output,
"Tool returned an error result"
);
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
}
}
Err(e) => {
tracing::error!(
tool = %tool_call.name,
args = %truncate_args(&tool_call.arguments, 4_000),
normalized_args = %truncate_args(&normalized_arguments, 4_000),
error = %e,
error_details = %format!("{:#}", e),
"Tool execution failed"
);
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::LLMProviderConfig;
use crate::observability::{MultiObserver, Observer};
use tempfile::tempdir;
struct TestObserver {
events: std::sync::Mutex<Vec<ObserverEvent>>,
}
impl TestObserver {
fn new() -> Self {
Self {
events: std::sync::Mutex::new(Vec::new()),
}
}
}
impl Observer for TestObserver {
fn record_event(&self, event: &ObserverEvent) {
self.events.lock().unwrap().push(event.clone());
}
fn name(&self) -> &str {
"test_observer"
}
}
struct TestSkillProvider;
impl SkillProvider for TestSkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
fn matching_skill_summary(&self, name: &str) -> Option<String> {
(name == "baidu-search").then(|| "用于百度搜索和天气查询的技能".to_string())
}
}
fn test_runtime_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: std::collections::HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: std::collections::HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
#[tokio::test]
async fn test_missing_tool_with_same_name_skill_returns_activation_hint() {
let loop_instance = AgentLoop::with_tools_and_skill_provider(
test_runtime_config(),
Arc::new(ToolRegistry::new()),
Arc::new(TestSkillProvider),
)
.unwrap();
let outcome = loop_instance
.execute_tool_internal(&ToolCall {
id: "call-1".to_string(),
name: "baidu-search".to_string(),
arguments: serde_json::json!({
"queries": "佛山今天几点下雨"
}),
})
.await;
assert_eq!(outcome.state, ToolExecutionState::Completed);
assert!(!outcome.success);
assert!(outcome.output.contains("技能"));
assert!(outcome.output.contains("skill_activate"));
assert!(outcome.output.contains("baidu-search"));
}
#[tokio::test]
async fn test_observer_receives_tool_events() {
// Verify MultiObserver works
let mut multi = MultiObserver::new();
multi.add_observer(Box::new(TestObserver::new()));
let event = ObserverEvent::ToolCallStart {
tool: "test".to_string(),
arguments: Some("{}".to_string()),
};
multi.record_event(&event);
// Just verify the structure works
assert_eq!(multi.len(), 1);
}
#[test]
fn test_should_execute_in_parallel_single_tool() {
// Would need a proper setup with AgentLoop to test fully
// For now, just verify the logic: single tool should return false
let calls = vec![ToolCall {
id: "1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
// If there's only 1 tool, should return false regardless
assert_eq!(calls.len() <= 1, true);
}
#[test]
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
let chat_message = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "2+2" }),
}],
);
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
#[test]
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
let chat_message =
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(
provider_message.reasoning_content.as_deref(),
Some("hidden chain of thought")
);
}
#[test]
fn test_truncate_tool_result_handles_utf8_char_boundaries() {
let input = "".repeat(20_500);
let output = truncate_tool_result(&input, 20_000);
assert!(output.contains("Output truncated"));
assert!(output.is_char_boundary(output.len()));
}
#[test]
fn test_parse_pending_tool_output() {
let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
assert_eq!(output.as_deref(), Some("请完成授权"));
assert!(parse_pending_tool_output("normal output").is_none());
}
#[test]
fn test_normalize_tool_arguments_parses_stringified_json() {
let normalized = normalize_tool_arguments(&serde_json::Value::String(
"{\"command\":\"ls -la\"}".to_string(),
));
assert_eq!(normalized, serde_json::json!({ "command": "ls -la" }));
}
#[test]
fn test_normalize_tool_arguments_keeps_plain_string() {
let normalized =
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
assert_eq!(
normalized,
serde_json::Value::String("plain text".to_string())
);
}
#[test]
fn test_build_content_blocks_skips_non_image_media_refs() {
let temp_dir = tempdir().unwrap();
let pdf_path = temp_dir.path().join("demo.pdf");
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
let mut budget = ImageInlineBudget::new(1_000, 0);
let blocks = build_content_blocks(
"hello",
&[pdf_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
}
#[test]
fn test_build_content_blocks_keeps_supported_images() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(10_000, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_compresses_images_to_budget() {
let temp_dir = tempdir().unwrap();
let png_path = temp_dir.path().join("large.png");
let image = image::DynamicImage::new_rgb8(512, 512);
image.save(&png_path).unwrap();
let mut budget = ImageInlineBudget::new(512, 1);
let blocks = build_content_blocks(
"hello",
&[png_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(0, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(matches!(
&blocks[1],
ContentBlock::Text { text }
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
));
}
#[test]
fn test_filter_images_by_age_and_count_no_images() {
let messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant("hi"),
ChatMessage::user("how are you"),
];
let filtered = filter_images_by_age_and_count(&messages, 10, 5);
assert_eq!(filtered.len(), 3);
assert_eq!(filtered[0].content, "hello");
assert_eq!(filtered[1].content, "hi");
assert_eq!(filtered[2].content, "how are you");
}
#[test]
fn test_filter_images_by_age_limit() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("test.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
// 创建 12 条消息,第 0 条(最旧)有图片
let messages: Vec<ChatMessage> = (0..12)
.map(|i| {
if i == 0 {
ChatMessage::user_with_media(
"first message with image",
vec![jpg_path.to_string_lossy().to_string()],
)
} else if i % 2 == 0 {
ChatMessage::user(format!("user message {}", i))
} else {
ChatMessage::assistant(format!("assistant message {}", i))
}
})
.collect();
// max_age_rounds = 10第一条消息索引 0距离最新消息索引 11有 11 条消息
// 所以第一条消息的图片应该被过滤
let filtered = filter_images_by_age_and_count(&messages, 10, 100);
// 检查第一条消息的图片被过滤
assert_eq!(filtered[0].media_refs.len(), 0);
assert!(filtered[0].content.contains("图片已过期"));
}
#[test]
fn test_filter_images_by_count_limit() {
let temp_dir = tempdir().unwrap();
// 创建 3 张不同的图片
let jpg_paths: Vec<String> = (0..3)
.map(|i| {
let path = temp_dir.path().join(format!("test{}.jpg", i));
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&path).unwrap();
path.to_string_lossy().to_string()
})
.collect();
// 创建 3 条消息,每条都有图片
let messages: Vec<ChatMessage> = (0..3)
.map(|i| {
ChatMessage::user_with_media(
format!("message {}", i),
vec![jpg_paths[i].clone()],
)
})
.collect();
// max_images = 1只保留最新的图片索引 2
let filtered = filter_images_by_age_and_count(&messages, 100, 1);
// 检查最新的消息保留图片,旧消息的图片被过滤
assert_eq!(filtered[2].media_refs.len(), 1); // 最新保留
assert_eq!(filtered[1].media_refs.len(), 0); // 被过滤
assert!(filtered[1].content.contains("超出最大图片数量限制"));
assert_eq!(filtered[0].media_refs.len(), 0); // 被过滤
assert!(filtered[0].content.contains("超出最大图片数量限制"));
}
#[test]
fn test_filter_images_combined_limits() {
let temp_dir = tempdir().unwrap();
// 创建 5 张图片
let jpg_paths: Vec<String> = (0..5)
.map(|i| {
let path = temp_dir.path().join(format!("test{}.jpg", i));
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&path).unwrap();
path.to_string_lossy().to_string()
})
.collect();
// 创建 20 条消息,在特定位置添加图片
let messages: Vec<ChatMessage> = (0..20)
.map(|i| {
// 在索引 0, 5, 10, 15, 19 添加图片
if i == 0 || i == 5 || i == 10 || i == 15 || i == 19 {
let image_idx = i / 5;
ChatMessage::user_with_media(
format!("message {} with image", i),
vec![jpg_paths[image_idx.clamp(0, 4)].clone()],
)
} else if i % 2 == 0 {
ChatMessage::user(format!("user message {}", i))
} else {
ChatMessage::assistant(format!("assistant message {}", i))
}
})
.collect();
// max_age_rounds = 10保留最近 10 条消息内的图片)
// max_images = 3最多 3 张图片)
// 最新消息索引 19 (age=0),索引 15 (age=4),索引 10 (age=9),索引 5 (age=14),索引 0 (age=19)
// 索引 5 和 0 超出 age 限制
// 索引 19, 15, 10 的图片应该保留(共 3 张,不超过 max_images
let filtered = filter_images_by_age_and_count(&messages, 10, 3);
// 检查结果
assert!(filtered[19].media_refs.len() > 0, "最新消息应保留图片");
assert!(filtered[15].media_refs.len() > 0, "age=4 的消息应保留图片");
assert!(filtered[10].media_refs.len() > 0, "age=9 的消息应保留图片");
assert_eq!(filtered[5].media_refs.len(), 0, "age=14 的消息图片应被过滤");
assert!(filtered[5].content.contains("超出 10 条消息范围"));
assert_eq!(filtered[0].media_refs.len(), 0, "age=19 的消息图片应被过滤");
assert!(filtered[0].content.contains("超出 10 条消息范围"));
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
Other(String),
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
AgentError::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for AgentError {}