PicoBot/src/agent/agent_loop.rs
ooodc e37dea886b refactor(tools): 优化交互式进程输出捕获逻辑
- 删除了 PendingUserAction 相关的冗余辅助消息发送代码
- 引入自适应 drain_until_stable 函数循环读取输出直到稳定
- 用 drain_until_stable 替代固定延时等待以捕获最终提示内容
- 确保进程等待 stdin 时完整且及时地捕获所有输出数据
- 移除过时的常量和注释,简化代码逻辑
- 保持对最大循环次数和间隔时间的限制防止死循环
2026-06-13 17:40:44 +08:00

2196 lines
82 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::agent::AgentRuntimeConfig;
use crate::agent::{SystemPromptContext, SystemPromptProvider};
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::storage::ConversationRepository;
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::observability::{
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::io::Read;
use std::sync::Arc;
use std::time::Instant;
/// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
/// Build content blocks from text and media paths
fn build_content_blocks(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
build_content_blocks_with_image_budget(text, media_paths, budget)
}
fn build_content_blocks_with_image_budget(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
let mut skipped_image_notices = Vec::new();
// Add text block if there's text
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
// Add image blocks for media paths
for path in media_paths {
if supported_image_mime_type(path).is_none() {
tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block");
continue;
}
let Some(target_tokens) = budget.take_next_image_tokens() else {
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
skipped_image_notices.push(format!(
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
};
match encode_image_to_base64_with_budget(path, target_tokens) {
Ok((mime_type, base64_data)) => {
let url = format!("data:{};base64,{}", mime_type, base64_data);
blocks.push(ContentBlock::image_url(url));
}
Err(err) => {
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
skipped_image_notices.push(format!(
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
}
}
}
if !skipped_image_notices.is_empty() {
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
notice.push('\n');
notice.push_str(&skipped_image_notices.join("\n"));
blocks.push(ContentBlock::text(notice));
}
// If nothing, add empty text block
if blocks.is_empty() {
blocks.push(ContentBlock::text(""));
}
blocks
}
fn supported_image_mime_type(path: &str) -> Option<String> {
let mime = mime_guess::from_path(path).first_or_octet_stream();
let essence = mime.essence_str();
if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) {
Some(essence.to_string())
} else {
None
}
}
fn display_media_name(path: &str) -> String {
std::path::Path::new(path)
.file_name()
.and_then(|name| name.to_str())
.map(ToOwned::to_owned)
.unwrap_or_else(|| path.to_string())
}
#[derive(Debug, Clone)]
struct ImageInlineBudget {
remaining_tokens: usize,
remaining_images: usize,
}
impl ImageInlineBudget {
fn new(total_tokens: usize, image_count: usize) -> Self {
Self {
remaining_tokens: total_tokens,
remaining_images: image_count,
}
}
fn take_next_image_tokens(&mut self) -> Option<usize> {
if self.remaining_tokens == 0 || self.remaining_images == 0 {
self.remaining_images = self.remaining_images.saturating_sub(1);
return None;
}
let target = self.remaining_tokens.div_ceil(self.remaining_images);
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
self.remaining_images = self.remaining_images.saturating_sub(1);
Some(target)
}
}
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
let raw_len = serde_json::to_string(value)
.map(|serialized| serialized.len())
.unwrap_or_default();
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
as usize
}
fn image_token_budget_for_request(
runtime_config: &AgentRuntimeConfig,
text_only_messages: &[Message],
tools: Option<&Vec<crate::domain::tools::Tool>>,
) -> usize {
let completion_reserve = runtime_config
.provider
.max_tokens
.map(|tokens| tokens as usize)
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
let input_window = runtime_config
.context_window_tokens
.saturating_sub(completion_reserve);
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
+ tools
.map(estimate_tokens_from_serialized_json)
.unwrap_or_default();
safe_input_window.saturating_sub(text_tokens)
}
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
messages
.iter()
.flat_map(|message| message.media_refs.iter())
.filter(|path| supported_image_mime_type(path).is_some())
.count()
}
/// 过滤超出轮次限制的图片,并限制总图片数量
///
/// # 参数
/// - `messages`: 原始消息列表
/// - `max_age_rounds`: 图片超过多少消息轮次后不再发送(从最新消息向前数,包括所有 role
/// - `max_images`: 最多发送多少张图片(优先保留最近的图片)
///
/// # 返回
/// 过滤后的消息列表,图片被转换为文本提示 "[图片已过期]"
fn filter_images_by_age_and_count(
messages: &[ChatMessage],
max_age_rounds: usize,
max_images: usize,
) -> Vec<ChatMessage> {
if messages.is_empty() || (max_age_rounds == 0 && max_images == 0) {
return messages.to_vec();
}
let total_images = count_supported_image_media_refs(messages);
if total_images == 0 {
return messages.to_vec();
}
// 从最新消息向前遍历,优先保留最新的图片
// 消息列表顺序:[old, ..., new],所以末尾是最新的
let msg_count = messages.len();
// 先从后向前遍历,计算每条消息应该保留多少张图片
// 使用 Vec<usize> 存储每条消息应该保留的图片数量
let mut images_to_keep_per_msg: Vec<usize> = vec![0; msg_count];
let mut images_kept = 0usize;
// 从后向前遍历(从最新到最旧)
for (idx, message) in messages.iter().enumerate().rev() {
// 计算距离最新消息的消息数(末尾消息的 age = 0
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
// 检查是否超出轮次限制
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
if exceeds_age_limit {
continue; // 超出轮次限制的图片不保留
}
// 计算这条消息中的图片数量
let image_count_in_msg = message.media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
if image_count_in_msg == 0 {
continue;
}
// 计算可以保留多少张
let can_keep = std::cmp::min(image_count_in_msg, max_images.saturating_sub(images_kept));
if can_keep > 0 {
images_to_keep_per_msg[idx] = can_keep;
images_kept += can_keep;
}
}
// 然后从前向后遍历,构建过滤后的消息列表
let mut filtered = Vec::with_capacity(msg_count);
for (idx, message) in messages.iter().enumerate() {
// 计算距离最新消息的消息数(末尾消息的 age = 0
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
// 检查是否超出轮次限制
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
let keep_count = images_to_keep_per_msg[idx];
// 过滤图片:保留非图片媒体和指定数量的图片
let mut images_kept_in_msg = 0usize;
let filtered_media_refs: Vec<String> = message.media_refs.iter()
.filter_map(|path| {
if supported_image_mime_type(path).is_some() {
if images_kept_in_msg < keep_count {
images_kept_in_msg += 1;
Some(path.clone())
} else {
None
}
} else {
Some(path.clone()) // 保留非图片媒体
}
})
.collect();
// 如果图片被过滤,添加文本提示
let original_image_count = message.media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
let filtered_image_count = filtered_media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
let content = if original_image_count > filtered_image_count {
let notice = if exceeds_age_limit {
format!("{} [图片已过期:超出 {} 条消息范围]", message.content, max_age_rounds)
} else {
format!("{} [图片已过期:超出最大图片数量限制]", message.content)
};
notice
} else {
message.content.clone()
};
filtered.push(ChatMessage {
id: message.id.clone(),
role: message.role.clone(),
content,
media_refs: filtered_media_refs,
timestamp: message.timestamp,
system_context: message.system_context.clone(),
reasoning_content: message.reasoning_content.clone(),
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_state: message.tool_state.clone(),
tool_duration_ms: message.tool_duration_ms,
tool_calls: message.tool_calls.clone(),
});
}
filtered
}
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
.saturating_mul(3)
/ 4
}
/// Encode an image file to base64 data URL
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let mime = supported_image_mime_type(path).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("unsupported image media type for path: {}", path),
)
})?;
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded = STANDARD.encode(&buffer);
Ok((mime, encoded))
}
fn encode_image_to_base64_with_budget(
path: &str,
target_tokens: usize,
) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let (mime, encoded) = encode_image_to_base64(path)?;
let target_base64_chars = target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
if encoded.len() <= target_base64_chars {
return Ok((mime, encoded));
}
let target_bytes = target_image_bytes_for_tokens(target_tokens);
let image_bytes = std::fs::read(path)?;
let image = image::load_from_memory(&image_bytes)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
}
fn compress_image_to_jpeg_budget(
image: &image::DynamicImage,
target_bytes: usize,
) -> Result<Vec<u8>, std::io::Error> {
use image::codecs::jpeg::JpegEncoder;
use image::imageops::FilterType;
let mut best: Option<Vec<u8>> = None;
let mut max_side = image
.width()
.max(image.height())
.max(MIN_COMPRESSED_IMAGE_SIDE);
loop {
let candidate = if image.width().max(image.height()) > max_side {
image.resize(max_side, max_side, FilterType::Triangle)
} else {
image.clone()
};
for quality in JPEG_QUALITY_STEPS {
let mut encoded = Vec::new();
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
encoder
.encode_image(&candidate)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
if encoded.len() <= target_bytes {
return Ok(encoded);
}
if best
.as_ref()
.map(|current| encoded.len() < current.len())
.unwrap_or(true)
{
best = Some(encoded);
}
}
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
break;
}
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
}
best.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
})
}
/// Truncate tool result if it exceeds the configured limit.
/// Preserves the end of the output as it often contains the conclusion/useful result.
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
let total_chars = char_count(output);
if total_chars <= max_tool_result_chars {
return output.to_string();
}
let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN);
if truncated_start_len > max_tool_result_chars {
// Even after removing suffix, still too long - take from beginning
let head_len = max_tool_result_chars.saturating_sub(100);
let head = take_prefix_chars(output, head_len);
format!(
"{}...\n\n[Output truncated - {} characters removed]",
head,
total_chars - max_tool_result_chars + 100
)
} else {
// Keep most of the end which usually contains the useful result
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
format!(
"...\n\n[Output truncated - {} characters removed]\n\n{}",
truncated_start_len, tail
)
}
}
fn parse_pending_tool_output(output: &str) -> Option<String> {
output
.strip_prefix(PENDING_USER_ACTION_MARKER)
.map(|rest| rest.trim().to_string())
}
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
match arguments {
serde_json::Value::String(raw) => {
serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone())
}
_ => arguments.clone(),
}
}
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
let mut current = error.source();
while let Some(source) = current {
details.push(source.to_string());
current = source.source();
}
details.join("\ncaused by: ")
}
fn is_recoverable_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn recoverable_llm_message(error: &str) -> String {
if is_recoverable_llm_error(error) {
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
} else {
format!("模型请求失败:{}", error)
}
}
/// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult {
/// No warning needed.
Ok,
/// Warning: same tool + args repeated N times.
Warning(String),
}
/// Configuration for loop detector.
#[derive(Debug, Clone)]
struct LoopDetectorConfig {
/// Master switch.
enabled: bool,
/// Warn every N consecutive identical calls.
warn_every: usize,
}
impl Default for LoopDetectorConfig {
fn default() -> Self {
Self {
enabled: true,
warn_every: 5,
}
}
}
/// A single recorded tool invocation in the sliding window.
#[derive(Debug, Clone)]
struct ToolCallRecord {
name: String,
args_hash: u64,
}
/// Stateful loop detector that monitors for repetitive patterns.
struct LoopDetector {
config: LoopDetectorConfig,
window: VecDeque<ToolCallRecord>,
}
impl LoopDetector {
fn new(config: LoopDetectorConfig) -> Self {
Self {
window: VecDeque::with_capacity(config.warn_every * 2),
config,
}
}
/// Record a completed tool call and check for loop patterns.
/// Returns Warning every `warn_every` consecutive identical calls.
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
if !self.config.enabled {
return LoopDetectionResult::Ok;
}
let record = ToolCallRecord {
name: name.to_string(),
args_hash: hash_json_value(args),
};
// Maintain sliding window
if self.window.len() >= self.config.warn_every * 2 {
self.window.pop_front();
}
self.window.push_back(record);
// Count consecutive identical calls
let last = self.window.back().unwrap();
let consecutive: usize = self
.window
.iter()
.rev()
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
.count();
// Warn every warn_every times
if consecutive > 0 && consecutive % self.config.warn_every == 0 {
LoopDetectionResult::Warning(format!(
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
last.name, consecutive
))
} else {
LoopDetectionResult::Ok
}
}
}
/// Hash a JSON value deterministically (key-order independent).
fn hash_json_value(value: &serde_json::Value) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
let canonical = canonicalise_json(value);
canonical.hash(&mut hasher);
hasher.finish()
}
/// Return a clone of value with all object keys sorted recursively.
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
sorted.sort_by_key(|(k, _)| *k);
let new_map: serde_json::Map<String, serde_json::Value> = sorted
.into_iter()
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
.collect();
serde_json::Value::Object(new_map)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
}
other => other.clone(),
}
}
/// Convert ChatMessage to LLM Message format
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs, image_budget)
};
Message {
role: m.role.clone(),
content,
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
Message {
role: m.role.clone(),
content: vec![ContentBlock::text(&m.content)],
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
runtime_config: AgentRuntimeConfig,
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
/// 系统提示词提供者(统一注入 Agent 和 Skill 提示词)
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
/// Skill 提供者(用于匹配错误提示)
skills: Option<Arc<dyn SkillProvider>>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
max_iterations: usize,
/// 取消信号接收端Agent 在每次迭代开始时检查是否被取消
cancel_token: Option<tokio::sync::watch::Receiver<()>>,
}
#[derive(Debug, Clone)]
pub struct AgentProcessResult {
pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>,
}
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
/// Handle a tool result message with optional execution timing.
/// Default implementation delegates to `handle()`, ignoring timing.
async fn handle_tool_result(&self, message: ChatMessage, _duration_ms: Option<u64>) {
self.handle(message).await;
}
}
/// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB
pub struct PersistingEmittedMessageHandler<H: EmittedMessageHandler> {
inner: H,
conversation_repository: Arc<dyn ConversationRepository>,
session_id: String,
topic_id: Option<String>,
}
impl<H: EmittedMessageHandler> PersistingEmittedMessageHandler<H> {
pub fn new(
inner: H,
conversation_repository: Arc<dyn ConversationRepository>,
session_id: impl Into<String>,
topic_id: Option<String>,
) -> Self {
Self { inner, conversation_repository, session_id: session_id.into(), topic_id }
}
}
#[async_trait]
impl<H: EmittedMessageHandler> EmittedMessageHandler for PersistingEmittedMessageHandler<H> {
async fn handle(&self, message: ChatMessage) {
if let Err(e) = self.conversation_repository
.append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message)
{
tracing::error!(error = %e, session_id = %self.session_id,
"Failed to persist emitted message");
}
self.inner.handle(message).await;
}
async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
// Persist the ChatMessage first (no duration field, same as before)
if let Err(e) = self.conversation_repository
.append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message)
{
tracing::error!(error = %e, session_id = %self.session_id,
"Failed to persist emitted message");
}
self.inner.handle_tool_result(message, duration_ms).await;
}
}
pub trait SkillProvider: Send + Sync + 'static {
fn system_index_prompt(&self) -> Option<String>;
fn matching_skill_summary(&self, _name: &str) -> Option<String> {
None
}
}
#[derive(Default)]
#[allow(dead_code)]
struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
}
impl AgentLoop {
pub fn new(config: impl Into<AgentRuntimeConfig>) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools: Arc::new(ToolRegistry::new()),
system_prompt_provider: None,
skills: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
cancel_token: None,
max_iterations,
})
}
pub fn with_tools(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: None,
skills: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
cancel_token: None,
max_iterations,
})
}
pub fn with_tools_and_skill_provider(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: None,
skills: Some(skills),
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
cancel_token: None,
max_iterations,
})
}
/// 使用系统提示词提供者创建 AgentLoop
///
/// 这是新的推荐方式,支持统一注入 Agent 和 Skill 提示词。
pub fn with_tools_and_system_prompt_provider(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
skills: Option<Arc<dyn SkillProvider>>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: Some(system_prompt_provider),
skills,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
cancel_token: None,
max_iterations,
})
}
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
self.tool_context = context;
self
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
self.emitted_message_handler = Some(handler);
self
}
/// 设置取消信号接收端。
///
/// Agent 在每次迭代开始时检查 `cancel_token.has_changed()`
/// 如果已收到取消信号则提前返回。
pub fn with_cancel_token(mut self, token: tokio::sync::watch::Receiver<()>) -> Self {
self.cancel_token = Some(token);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
///
/// This method supports multi-round tool calling: after executing tools,
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
///
/// # 参数
/// - `messages`: 会话历史消息
/// - `system_prompt_context`: 系统提示词上下文(用于动态注入,可选)
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
system_prompt_context: Option<&SystemPromptContext>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Sanitize: remove any trailing incomplete tool call sequences
// that may have been persisted before a process interruption.
crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
tracing::debug!(iteration, "Agent iteration started");
// 检查取消信号
if let Some(ref token) = self.cancel_token {
if token.has_changed().unwrap_or(false) {
tracing::info!(iteration, "Agent execution cancelled by user");
let cancel_message = format!(
"\n\n[用户已取消执行。已迭代 {} 次,取消前共生成了 {} 条消息。]",
iteration,
emitted_messages.len()
);
let assistant_message = ChatMessage::assistant(cancel_message);
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
}
// Build request
let tool_defs = self.tools.get_definitions();
let tools = if tool_defs.is_empty() {
None
} else {
Some(tool_defs)
};
// 过滤超出轮次和数量限制的图片
let filtered_messages = filter_images_by_age_and_count(
&messages,
self.runtime_config.max_image_age_rounds,
self.runtime_config.max_images_in_context,
);
let image_count = count_supported_image_media_refs(&filtered_messages);
// 构建系统提示词(统一注入 Agent 和 Skill 提示词)
let system_prompt = system_prompt_context.and_then(|ctx| {
self.system_prompt_provider
.as_ref()
.and_then(|provider| provider.build(ctx))
});
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
if let Some(ref prompt) = system_prompt {
text_only_messages.push(Message::system(prompt.content.clone()));
}
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens = image_token_budget_for_request(
&self.runtime_config,
&text_only_messages,
tools.as_ref(),
);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
// 使用相同的系统提示词(已构建)
if let Some(ref prompt) = system_prompt {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
messages_for_llm.extend(
filtered_messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
let response = match (*self.provider).chat(request).await {
Ok(response) => response,
Err(e) => {
tracing::error!(
provider = %self.provider.name(),
model = %self.provider.model_id(),
error = %e,
error_details = %format_error_chain(e.as_ref()),
"LLM request failed"
);
let assistant_message =
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
};
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
// Execute tool calls
tracing::info!(
iteration,
count = response.tool_calls.len(),
"Tool calls detected, executing tools"
);
// Add assistant message with tool calls
let assistant_message =
if let Some(reasoning_content) = response.reasoning_content.clone() {
ChatMessage::assistant_with_tool_calls_and_reasoning(
response.content.clone(),
response.tool_calls.clone(),
reasoning_content,
)
} else {
ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
)
};
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(
emitted_messages
.last()
.expect("assistant message just pushed")
.clone(),
)
.await;
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Truncate tool result if too large
let truncated_output =
truncate_tool_result(&result.output, self.runtime_config.tool_result_max_chars);
// Record tool call and check for loops
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
match loop_result {
LoopDetectionResult::Warning(msg) => {
// Add warning and proceed
tracing::warn!(
tool = %tool_call.name,
"Loop warning: {}",
msg
);
let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(),
tool_call.name.clone(),
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
)
.with_tool_duration(result.duration.as_millis() as u64);
messages.push(tool_message.clone());
emitted_messages.push(tool_message.clone());
let duration_ms = Some(result.duration.as_millis() as u64);
self.emit_tool_result(tool_message, duration_ms).await;
}
LoopDetectionResult::Ok => {
let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(),
tool_call.name.clone(),
truncated_output,
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
)
.with_tool_duration(result.duration.as_millis() as u64);
messages.push(tool_message.clone());
emitted_messages.push(tool_message.clone());
let duration_ms = Some(result.duration.as_millis() as u64);
self.emit_tool_result(tool_message, duration_ms).await;
}
}
}
// Loop continues to next iteration with updated messages
// PendingUserAction 工具的结果已在上方加入 messages
// 模型将在下一轮看到完整的终端输出并生成智能回复
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
}
// Max iterations reached - ask LLM for a summary based on completed work
tracing::warn!("Max iterations reached, requesting final summary from LLM");
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far.",
);
messages.push(summary_request);
// 过滤超出轮次和数量限制的图片
let filtered_messages = filter_images_by_age_and_count(
&messages,
self.runtime_config.max_image_age_rounds,
self.runtime_config.max_images_in_context,
);
// Convert messages to LLM format (使用系统提示词提供者)
let image_count = count_supported_image_media_refs(&filtered_messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
text_only_messages.push(Message::system(prompt.content.clone()));
}
}
}
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens =
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
}
}
messages_for_llm.extend(
filtered_messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools: None, // No tools in final summary call
};
match (*self.provider).chat(request).await {
Ok(response) => {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone());
self.emit_live_tool_call_message(assistant_message.clone()).await;
Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
})
}
Err(e) => {
tracing::error!(
provider = %self.provider.name(),
model = %self.provider.model_id(),
error = %e,
error_details = %format_error_chain(e.as_ref()),
"Failed to get summary from LLM"
);
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(final_message.clone());
self.emit_live_tool_call_message(final_message.clone()).await;
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
})
}
}
}
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
if let Some(handler) = &self.emitted_message_handler {
handler.handle(message).await;
}
}
async fn emit_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
if let Some(handler) = &self.emitted_message_handler {
handler.handle_tool_result(message, duration_ms).await;
}
}
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:
/// - There are multiple tool calls
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
if tool_calls.len() <= 1 {
return false;
}
// tool_search must run sequentially to avoid MCP activation race conditions
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
return false;
}
// Multiple Task tools can run in parallel
// Task tools create independent subagents with isolated contexts, no shared state
let task_count = tool_calls.iter().filter(|tc| tc.name == "task").count();
if task_count > 1 && task_count == tool_calls.len() {
return true;
}
// When Task is mixed with other tools, keep sequential to avoid complexity
if task_count > 0 {
return false;
}
// All tools must be concurrency-safe to run in parallel
tool_calls.iter().all(|tc| {
self.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false)
})
}
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
if self.should_execute_in_parallel(tool_calls) {
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
self.execute_tools_parallel(tool_calls).await
} else {
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
self.execute_tools_sequential(tool_calls).await
}
}
/// Execute tools in parallel using join_all.
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let futures: Vec<_> = tool_calls
.iter()
.map(|tc| self.execute_one_tool(tc))
.collect();
futures_util::future::join_all(futures).await
}
/// Execute tools sequentially.
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let mut outcomes = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
outcomes.push(self.execute_one_tool(tool_call).await);
}
outcomes
}
/// Execute a single tool and return the outcome with event tracking.
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let start = Instant::now();
let tool_name = tool_call.name.clone();
// Log function call with name and arguments before execution
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
// Record ToolCallStart event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCallStart {
tool: tool_name.clone(),
arguments: Some(truncate_args(&tool_call.arguments, 300)),
});
}
let result = self.execute_tool_internal(tool_call).await;
let duration = start.elapsed();
// Record ToolCall event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCall {
tool: tool_name.clone(),
duration,
success: result.success,
});
}
// Apply duration
ToolExecutionOutcome { duration, ..result }
}
/// Internal tool execution without event tracking.
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let normalized_arguments = normalize_tool_arguments(&tool_call.arguments);
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
let skill_hint = self.skills
.as_ref()
.and_then(|s| s.matching_skill_summary(&tool_call.name));
let error = match skill_hint {
Some(summary) => format!(
"Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.",
tool_call.name, summary, tool_call.name
),
None => format!("Tool '{}' not found", tool_call.name),
};
return ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
);
}
};
match tool
.execute_with_context(&self.tool_context, normalized_arguments.clone())
.await
{
Ok(result) => {
if result.success {
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
ToolExecutionOutcome::pending(pending_output)
} else {
ToolExecutionOutcome::success(result.output)
}
} else {
let error = result.error.unwrap_or_default();
tracing::error!(
tool = %tool_call.name,
args = %truncate_args(&tool_call.arguments, 4_000),
normalized_args = %truncate_args(&normalized_arguments, 4_000),
error = %error,
output = %result.output,
"Tool returned an error result"
);
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
}
}
Err(e) => {
tracing::error!(
tool = %tool_call.name,
args = %truncate_args(&tool_call.arguments, 4_000),
normalized_args = %truncate_args(&normalized_arguments, 4_000),
error = %e,
error_details = %format!("{:#}", e),
"Tool execution failed"
);
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::LLMProviderConfig;
use crate::observability::{MultiObserver, Observer};
use tempfile::tempdir;
struct TestObserver {
events: std::sync::Mutex<Vec<ObserverEvent>>,
}
impl TestObserver {
fn new() -> Self {
Self {
events: std::sync::Mutex::new(Vec::new()),
}
}
}
impl Observer for TestObserver {
fn record_event(&self, event: &ObserverEvent) {
self.events.lock().unwrap().push(event.clone());
}
fn name(&self) -> &str {
"test_observer"
}
}
struct TestSkillProvider;
impl SkillProvider for TestSkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
fn matching_skill_summary(&self, name: &str) -> Option<String> {
(name == "baidu-search").then(|| "用于百度搜索和天气查询的技能".to_string())
}
}
fn test_runtime_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: std::collections::HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: std::collections::HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 100_000,
context_tool_result_trim_chars: 100_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
#[tokio::test]
async fn test_missing_tool_with_same_name_skill_returns_activation_hint() {
let loop_instance = AgentLoop::with_tools_and_skill_provider(
test_runtime_config(),
Arc::new(ToolRegistry::new()),
Arc::new(TestSkillProvider),
)
.unwrap();
let outcome = loop_instance
.execute_tool_internal(&ToolCall {
id: "call-1".to_string(),
name: "baidu-search".to_string(),
arguments: serde_json::json!({
"queries": "佛山今天几点下雨"
}),
})
.await;
assert_eq!(outcome.state, ToolExecutionState::Completed);
assert!(!outcome.success);
assert!(outcome.output.contains("技能"));
assert!(outcome.output.contains("skill_activate"));
assert!(outcome.output.contains("baidu-search"));
}
#[tokio::test]
async fn test_observer_receives_tool_events() {
// Verify MultiObserver works
let mut multi = MultiObserver::new();
multi.add_observer(Box::new(TestObserver::new()));
let event = ObserverEvent::ToolCallStart {
tool: "test".to_string(),
arguments: Some("{}".to_string()),
};
multi.record_event(&event);
// Just verify the structure works
assert_eq!(multi.len(), 1);
}
#[test]
fn test_should_execute_in_parallel_single_tool() {
// Would need a proper setup with AgentLoop to test fully
// For now, just verify the logic: single tool should return false
let calls = vec![ToolCall {
id: "1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
// If there's only 1 tool, should return false regardless
assert_eq!(calls.len() <= 1, true);
}
#[test]
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
let chat_message = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "2+2" }),
}],
);
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
#[test]
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
let chat_message =
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(
provider_message.reasoning_content.as_deref(),
Some("hidden chain of thought")
);
}
#[test]
fn test_truncate_tool_result_handles_utf8_char_boundaries() {
let input = "".repeat(100_500);
let output = truncate_tool_result(&input, 100_000);
assert!(output.contains("Output truncated"));
assert!(output.is_char_boundary(output.len()));
}
#[test]
fn test_parse_pending_tool_output() {
let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
assert_eq!(output.as_deref(), Some("请完成授权"));
assert!(parse_pending_tool_output("normal output").is_none());
}
#[test]
fn test_normalize_tool_arguments_parses_stringified_json() {
let normalized = normalize_tool_arguments(&serde_json::Value::String(
"{\"command\":\"ls -la\"}".to_string(),
));
assert_eq!(normalized, serde_json::json!({ "command": "ls -la" }));
}
#[test]
fn test_normalize_tool_arguments_keeps_plain_string() {
let normalized =
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
assert_eq!(
normalized,
serde_json::Value::String("plain text".to_string())
);
}
#[test]
fn test_build_content_blocks_skips_non_image_media_refs() {
let temp_dir = tempdir().unwrap();
let pdf_path = temp_dir.path().join("demo.pdf");
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
let mut budget = ImageInlineBudget::new(1_000, 0);
let blocks = build_content_blocks(
"hello",
&[pdf_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
}
#[test]
fn test_build_content_blocks_keeps_supported_images() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(10_000, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_compresses_images_to_budget() {
let temp_dir = tempdir().unwrap();
let png_path = temp_dir.path().join("large.png");
let image = image::DynamicImage::new_rgb8(512, 512);
image.save(&png_path).unwrap();
let mut budget = ImageInlineBudget::new(512, 1);
let blocks = build_content_blocks(
"hello",
&[png_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(0, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(matches!(
&blocks[1],
ContentBlock::Text { text }
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
));
}
#[test]
fn test_filter_images_by_age_and_count_no_images() {
let messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant("hi"),
ChatMessage::user("how are you"),
];
let filtered = filter_images_by_age_and_count(&messages, 10, 5);
assert_eq!(filtered.len(), 3);
assert_eq!(filtered[0].content, "hello");
assert_eq!(filtered[1].content, "hi");
assert_eq!(filtered[2].content, "how are you");
}
#[test]
fn test_filter_images_by_age_limit() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("test.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
// 创建 12 条消息,第 0 条(最旧)有图片
let messages: Vec<ChatMessage> = (0..12)
.map(|i| {
if i == 0 {
ChatMessage::user_with_media(
"first message with image",
vec![jpg_path.to_string_lossy().to_string()],
)
} else if i % 2 == 0 {
ChatMessage::user(format!("user message {}", i))
} else {
ChatMessage::assistant(format!("assistant message {}", i))
}
})
.collect();
// max_age_rounds = 10第一条消息索引 0距离最新消息索引 11有 11 条消息
// 所以第一条消息的图片应该被过滤
let filtered = filter_images_by_age_and_count(&messages, 10, 100);
// 检查第一条消息的图片被过滤
assert_eq!(filtered[0].media_refs.len(), 0);
assert!(filtered[0].content.contains("图片已过期"));
}
#[test]
fn test_filter_images_by_count_limit() {
let temp_dir = tempdir().unwrap();
// 创建 3 张不同的图片
let jpg_paths: Vec<String> = (0..3)
.map(|i| {
let path = temp_dir.path().join(format!("test{}.jpg", i));
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&path).unwrap();
path.to_string_lossy().to_string()
})
.collect();
// 创建 3 条消息,每条都有图片
let messages: Vec<ChatMessage> = (0..3)
.map(|i| {
ChatMessage::user_with_media(
format!("message {}", i),
vec![jpg_paths[i].clone()],
)
})
.collect();
// max_images = 1只保留最新的图片索引 2
let filtered = filter_images_by_age_and_count(&messages, 100, 1);
// 检查最新的消息保留图片,旧消息的图片被过滤
assert_eq!(filtered[2].media_refs.len(), 1); // 最新保留
assert_eq!(filtered[1].media_refs.len(), 0); // 被过滤
assert!(filtered[1].content.contains("超出最大图片数量限制"));
assert_eq!(filtered[0].media_refs.len(), 0); // 被过滤
assert!(filtered[0].content.contains("超出最大图片数量限制"));
}
#[test]
fn test_filter_images_combined_limits() {
let temp_dir = tempdir().unwrap();
// 创建 5 张图片
let jpg_paths: Vec<String> = (0..5)
.map(|i| {
let path = temp_dir.path().join(format!("test{}.jpg", i));
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&path).unwrap();
path.to_string_lossy().to_string()
})
.collect();
// 创建 20 条消息,在特定位置添加图片
let messages: Vec<ChatMessage> = (0..20)
.map(|i| {
// 在索引 0, 5, 10, 15, 19 添加图片
if i == 0 || i == 5 || i == 10 || i == 15 || i == 19 {
let image_idx = i / 5;
ChatMessage::user_with_media(
format!("message {} with image", i),
vec![jpg_paths[image_idx.clamp(0, 4)].clone()],
)
} else if i % 2 == 0 {
ChatMessage::user(format!("user message {}", i))
} else {
ChatMessage::assistant(format!("assistant message {}", i))
}
})
.collect();
// max_age_rounds = 10保留最近 10 条消息内的图片)
// max_images = 3最多 3 张图片)
// 最新消息索引 19 (age=0),索引 15 (age=4),索引 10 (age=9),索引 5 (age=14),索引 0 (age=19)
// 索引 5 和 0 超出 age 限制
// 索引 19, 15, 10 的图片应该保留(共 3 张,不超过 max_images
let filtered = filter_images_by_age_and_count(&messages, 10, 3);
// 检查结果
assert!(filtered[19].media_refs.len() > 0, "最新消息应保留图片");
assert!(filtered[15].media_refs.len() > 0, "age=4 的消息应保留图片");
assert!(filtered[10].media_refs.len() > 0, "age=9 的消息应保留图片");
assert_eq!(filtered[5].media_refs.len(), 0, "age=14 的消息图片应被过滤");
assert!(filtered[5].content.contains("超出 10 条消息范围"));
assert_eq!(filtered[0].media_refs.len(), 0, "age=19 的消息图片应被过滤");
assert!(filtered[0].content.contains("超出 10 条消息范围"));
}
// ====================
// sanitize_incomplete_tool_call_sequences tests
// ====================
#[test]
fn test_sanitize_removes_trailing_incomplete_tool_call_sequence() {
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
}],
),
// Tool result for call_1 is MISSING — incomplete sequence
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 1);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "user");
}
#[test]
fn test_sanitize_preserves_complete_tool_call_sequence() {
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
}],
),
ChatMessage::tool("call_1", "calculator", "2"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 0);
assert_eq!(messages.len(), 3);
}
#[test]
fn test_sanitize_removes_multiple_incomplete_sequences() {
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant_with_tool_calls(
"first tool call",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
}],
),
// Missing tool result for call_1
ChatMessage::user("second question"),
ChatMessage::assistant_with_tool_calls(
"second tool call",
vec![ToolCall {
id: "call_2".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "README.md"}),
}],
),
// Also missing tool result for call_2
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
// Should remove both trailing assistant messages with incomplete tool calls
assert_eq!(removed, 2);
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "hello");
assert_eq!(messages[1].role, "user");
assert_eq!(messages[1].content, "second question");
}
#[test]
fn test_sanitize_removes_assistant_when_partial_tool_results() {
// Assistant makes 2 tool calls, but only 1 tool result exists
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant_with_tool_calls(
"calling two tools",
vec![
ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
},
ToolCall {
id: "call_2".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "README.md"}),
},
],
),
ChatMessage::tool("call_1", "calculator", "2"),
// Missing tool result for call_2
];
let removed_count = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
// Phase 1 removes the assistant message (call_2 has no result).
// Phase 2 removes the orphaned tool result for call_1 (its parent
// assistant was removed).
assert_eq!(removed_count, 2);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "user");
}
#[test]
fn test_sanitize_preserves_messages_without_tool_calls() {
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant("hi there"),
ChatMessage::user("how are you"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 0);
assert_eq!(messages.len(), 3);
}
#[test]
fn test_sanitize_handles_empty_messages() {
let mut messages: Vec<ChatMessage> = vec![];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 0);
}
#[test]
fn test_sanitize_removes_orphaned_tool_messages() {
// A lone tool message without a preceding assistant tool_calls
// is orphaned and should be removed.
let mut messages = vec![
ChatMessage::tool("call_1", "calculator", "2"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 1);
assert!(messages.is_empty());
}
#[test]
fn test_sanitize_preserves_complete_sequence_with_multiple_tool_calls() {
let mut messages = vec![
ChatMessage::user("do two things"),
ChatMessage::assistant_with_tool_calls(
"calling two tools",
vec![
ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
},
ToolCall {
id: "call_2".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "README.md"}),
},
],
),
ChatMessage::tool("call_1", "calculator", "2"),
ChatMessage::tool("call_2", "read", "contents of README"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 0);
assert_eq!(messages.len(), 4);
}
#[test]
fn test_sanitize_only_trims_trailing_incomplete_sequence() {
// Complete sequence followed by an incomplete one — only the
// trailing incomplete one should be removed
let mut messages = vec![
ChatMessage::user("first question"),
ChatMessage::assistant_with_tool_calls(
"first tool call",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
}],
),
ChatMessage::tool("call_1", "calculator", "2"),
ChatMessage::assistant("the answer is 2"),
ChatMessage::user("second question"),
ChatMessage::assistant_with_tool_calls(
"second tool call",
vec![ToolCall {
id: "call_2".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "README.md"}),
}],
),
// Missing tool result for call_2 — only THIS sequence should be trimmed
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 1);
// First complete sequence preserved (5 messages), user message for second
// question preserved
assert_eq!(messages.len(), 5);
assert_eq!(messages[0].content, "first question");
assert_eq!(messages[3].content, "the answer is 2");
assert_eq!(messages[4].content, "second question");
}
#[test]
fn test_sanitize_removes_mid_history_orphaned_tool_calls() {
// Bug scenario: orphaned tool_calls in the MIDDLE of history,
// followed by a complete sequence. The old trailing-only sanitizer
// would stop at the complete sequence and never remove the orphan.
let mut messages = vec![
ChatMessage::user("first question"),
ChatMessage::assistant_with_tool_calls(
"orphaned tool call",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1+1"}),
}],
),
// Missing tool result for call_1 — ORPHAN in the middle
ChatMessage::user("second question"),
ChatMessage::assistant_with_tool_calls(
"valid tool call",
vec![ToolCall {
id: "call_2".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "README.md"}),
}],
),
ChatMessage::tool("call_2", "read", "file contents"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
// call_1 assistant removed (1), rest preserved (4)
assert_eq!(removed, 1);
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "first question");
assert_eq!(messages[1].role, "user");
assert_eq!(messages[1].content, "second question");
// The complete call_2 sequence is preserved
assert_eq!(messages[2].role, "assistant");
assert_eq!(messages[3].role, "tool");
}
#[test]
fn test_sanitize_removes_multiple_mid_history_orphans() {
// Multiple orphaned tool_calls scattered throughout history
let mut messages = vec![
ChatMessage::user("first"),
ChatMessage::assistant_with_tool_calls(
"orphan 1",
vec![ToolCall {
id: "orphan_1".to_string(),
name: "tool_a".to_string(),
arguments: serde_json::json!({}),
}],
),
// Missing result for orphan_1
ChatMessage::user("second"),
ChatMessage::assistant_with_tool_calls(
"orphan 2",
vec![ToolCall {
id: "orphan_2".to_string(),
name: "tool_b".to_string(),
arguments: serde_json::json!({}),
}],
),
// Missing result for orphan_2
ChatMessage::user("third"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
assert_eq!(removed, 2);
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].content, "first");
assert_eq!(messages[1].content, "second");
assert_eq!(messages[2].content, "third");
}
#[test]
fn test_sanitize_removes_orphaned_tool_results_for_removed_mid_assistant() {
// When removing a mid-history assistant with partial tool results,
// both the assistant AND its orphaned tool results must be removed.
// Assistant has 2 tool calls, only 1 has a result → assistant is
// incomplete → both assistant and its lone tool result removed.
let mut messages = vec![
ChatMessage::user("first question"),
ChatMessage::assistant_with_tool_calls(
"two tool calls, only one has result",
vec![
ToolCall {
id: "call_has_result".to_string(),
name: "tool_a".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "call_no_result".to_string(),
name: "tool_b".to_string(),
arguments: serde_json::json!({}),
},
],
),
ChatMessage::tool("call_has_result", "tool_a", "some result"),
// Missing tool result for call_no_result → incomplete sequence
ChatMessage::user("second question"),
ChatMessage::assistant_with_tool_calls(
"valid tool call",
vec![ToolCall {
id: "call_valid".to_string(),
name: "good_tool".to_string(),
arguments: serde_json::json!({}),
}],
),
ChatMessage::tool("call_valid", "good_tool", "valid result"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
// call_has_result/call_no_result assistant (1) + its orphaned tool result (1) = 2 removed
assert_eq!(removed, 2);
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].content, "first question");
assert_eq!(messages[1].content, "second question");
assert_eq!(messages[2].role, "assistant");
assert_eq!(messages[3].role, "tool");
// Verify the remaining tool belongs to call_valid
assert_eq!(messages[3].tool_call_id.as_deref(), Some("call_valid"));
}
#[test]
fn test_sanitize_handles_complex_interleaved_history() {
// Complete → Orphaned → Complete: a realistic scenario after
// history compaction
let mut messages = vec![
ChatMessage::user("task 1"),
ChatMessage::assistant_with_tool_calls(
"doing task 1",
vec![ToolCall {
id: "t1_call".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "a.txt"}),
}],
),
ChatMessage::tool("t1_call", "read", "content A"),
ChatMessage::assistant("task 1 is done"),
// End of task 1 — complete sequence
ChatMessage::user("task 2"),
ChatMessage::assistant_with_tool_calls(
"doing task 2 — this got interrupted",
vec![
ToolCall {
id: "t2_call_1".to_string(),
name: "write".to_string(),
arguments: serde_json::json!({"path": "b.txt"}),
},
ToolCall {
id: "t2_call_2".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "2+2"}),
},
],
),
// Missing BOTH tool results — process was killed here
// End of task 2 — orphaned sequence in the middle
ChatMessage::user("task 3"),
ChatMessage::assistant_with_tool_calls(
"doing task 3",
vec![ToolCall {
id: "t3_call".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"query": "hello"}),
}],
),
ChatMessage::tool("t3_call", "search", "found results"),
ChatMessage::assistant("task 3 is done"),
];
let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages);
// Removed: assistant with t2_call_1/t2_call_2 (1 message)
assert_eq!(removed, 1);
// Original 10 messages - 1 = 9
assert_eq!(messages.len(), 9);
assert_eq!(messages[4].content, "task 2");
assert_eq!(messages[5].content, "task 3");
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
Other(String),
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
AgentError::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for AgentError {}