use crate::agent::AgentRuntimeConfig; use crate::agent::{SystemPromptContext, SystemPromptProvider}; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; use crate::storage::ConversationRepository; use crate::domain::messages::{ContentBlock, ToolCall}; use crate::observability::{ Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, }; use crate::providers::{ChatCompletionRequest, LLMProvider, Message, StreamDelta, StreamCallback, create_provider}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; use crate::tools::{ToolContext, ToolRegistry}; use async_trait::async_trait; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; use std::sync::Arc; use std::time::Instant; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"]; const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4; const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2; const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9; const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048; const DATA_URL_OVERHEAD_TOKENS: usize = 16; const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36]; const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64; const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:"; /// Build content blocks from text and media paths fn build_content_blocks( text: &str, media_paths: &[String], budget: &mut ImageInlineBudget, ) -> Vec { build_content_blocks_with_image_budget(text, media_paths, budget) } fn build_content_blocks_with_image_budget( text: &str, media_paths: &[String], budget: &mut ImageInlineBudget, ) -> Vec { let mut blocks = Vec::new(); let mut skipped_image_notices = Vec::new(); // Add text block if there's text if !text.is_empty() { blocks.push(ContentBlock::text(text)); } // Add image blocks for media paths for path in media_paths { if supported_image_mime_type(path).is_none() { tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block"); continue; } let Some(target_tokens) = budget.take_next_image_tokens() else { tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains"); skipped_image_notices.push(format!( "- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。", display_media_name(path) )); continue; }; match encode_image_to_base64_with_budget(path, target_tokens) { Ok((mime_type, base64_data)) => { let url = format!("data:{};base64,{}", mime_type, base64_data); blocks.push(ContentBlock::image_url(url)); } Err(err) => { tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed"); skipped_image_notices.push(format!( "- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。", display_media_name(path) )); continue; } } } if !skipped_image_notices.is_empty() { let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX); notice.push('\n'); notice.push_str(&skipped_image_notices.join("\n")); blocks.push(ContentBlock::text(notice)); } // If nothing, add empty text block if blocks.is_empty() { blocks.push(ContentBlock::text("")); } blocks } fn supported_image_mime_type(path: &str) -> Option { let mime = mime_guess::from_path(path).first_or_octet_stream(); let essence = mime.essence_str(); if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) { Some(essence.to_string()) } else { None } } fn display_media_name(path: &str) -> String { std::path::Path::new(path) .file_name() .and_then(|name| name.to_str()) .map(ToOwned::to_owned) .unwrap_or_else(|| path.to_string()) } #[derive(Debug, Clone)] struct ImageInlineBudget { remaining_tokens: usize, remaining_images: usize, } impl ImageInlineBudget { fn new(total_tokens: usize, image_count: usize) -> Self { Self { remaining_tokens: total_tokens, remaining_images: image_count, } } fn take_next_image_tokens(&mut self) -> Option { if self.remaining_tokens == 0 || self.remaining_images == 0 { self.remaining_images = self.remaining_images.saturating_sub(1); return None; } let target = self.remaining_tokens.div_ceil(self.remaining_images); self.remaining_tokens = self.remaining_tokens.saturating_sub(target); self.remaining_images = self.remaining_images.saturating_sub(1); Some(target) } } fn estimate_tokens_from_serialized_json(value: &T) -> usize { let raw_len = serde_json::to_string(value) .map(|serialized| serialized.len()) .unwrap_or_default(); ((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize } fn image_token_budget_for_request( runtime_config: &AgentRuntimeConfig, text_only_messages: &[Message], tools: Option<&Vec>, ) -> usize { let completion_reserve = runtime_config .provider .max_tokens .map(|tokens| tokens as usize) .unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE); let input_window = runtime_config .context_window_tokens .saturating_sub(completion_reserve); let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize; let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages) + tools .map(estimate_tokens_from_serialized_json) .unwrap_or_default(); safe_input_window.saturating_sub(text_tokens) } fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize { messages .iter() .flat_map(|message| message.media_refs.iter()) .filter(|path| supported_image_mime_type(path).is_some()) .count() } /// 过滤超出轮次限制的图片,并限制总图片数量 /// /// # 参数 /// - `messages`: 原始消息列表 /// - `max_age_rounds`: 图片超过多少消息轮次后不再发送(从最新消息向前数,包括所有 role) /// - `max_images`: 最多发送多少张图片(优先保留最近的图片) /// /// # 返回 /// 过滤后的消息列表,图片被转换为文本提示 "[图片已过期]" fn filter_images_by_age_and_count( messages: &[ChatMessage], max_age_rounds: usize, max_images: usize, ) -> Vec { if messages.is_empty() || (max_age_rounds == 0 && max_images == 0) { return messages.to_vec(); } let total_images = count_supported_image_media_refs(messages); if total_images == 0 { return messages.to_vec(); } // 从最新消息向前遍历,优先保留最新的图片 // 消息列表顺序:[old, ..., new],所以末尾是最新的 let msg_count = messages.len(); // 先从后向前遍历,计算每条消息应该保留多少张图片 // 使用 Vec 存储每条消息应该保留的图片数量 let mut images_to_keep_per_msg: Vec = vec![0; msg_count]; let mut images_kept = 0usize; // 从后向前遍历(从最新到最旧) for (idx, message) in messages.iter().enumerate().rev() { // 计算距离最新消息的消息数(末尾消息的 age = 0) let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1); // 检查是否超出轮次限制 let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds; if exceeds_age_limit { continue; // 超出轮次限制的图片不保留 } // 计算这条消息中的图片数量 let image_count_in_msg = message.media_refs.iter() .filter(|p| supported_image_mime_type(p).is_some()) .count(); if image_count_in_msg == 0 { continue; } // 计算可以保留多少张 let can_keep = std::cmp::min(image_count_in_msg, max_images.saturating_sub(images_kept)); if can_keep > 0 { images_to_keep_per_msg[idx] = can_keep; images_kept += can_keep; } } // 然后从前向后遍历,构建过滤后的消息列表 let mut filtered = Vec::with_capacity(msg_count); for (idx, message) in messages.iter().enumerate() { // 计算距离最新消息的消息数(末尾消息的 age = 0) let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1); // 检查是否超出轮次限制 let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds; let keep_count = images_to_keep_per_msg[idx]; // 过滤图片:保留非图片媒体和指定数量的图片 let mut images_kept_in_msg = 0usize; let filtered_media_refs: Vec = message.media_refs.iter() .filter_map(|path| { if supported_image_mime_type(path).is_some() { if images_kept_in_msg < keep_count { images_kept_in_msg += 1; Some(path.clone()) } else { None } } else { Some(path.clone()) // 保留非图片媒体 } }) .collect(); // 如果图片被过滤,添加文本提示 let original_image_count = message.media_refs.iter() .filter(|p| supported_image_mime_type(p).is_some()) .count(); let filtered_image_count = filtered_media_refs.iter() .filter(|p| supported_image_mime_type(p).is_some()) .count(); let content = if original_image_count > filtered_image_count { let notice = if exceeds_age_limit { format!("{} [图片已过期:超出 {} 条消息范围]", message.content, max_age_rounds) } else { format!("{} [图片已过期:超出最大图片数量限制]", message.content) }; notice } else { message.content.clone() }; filtered.push(ChatMessage { id: message.id.clone(), role: message.role.clone(), content, media_refs: filtered_media_refs, timestamp: message.timestamp, system_context: message.system_context.clone(), reasoning_content: message.reasoning_content.clone(), tool_call_id: message.tool_call_id.clone(), tool_name: message.tool_name.clone(), tool_state: message.tool_state.clone(), tool_duration_ms: message.tool_duration_ms, tool_calls: message.tool_calls.clone(), }); } filtered } fn target_image_bytes_for_tokens(target_tokens: usize) -> usize { target_tokens .saturating_sub(DATA_URL_OVERHEAD_TOKENS) .saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN) .saturating_mul(3) / 4 } /// Encode an image file to base64 data URL fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; let mime = supported_image_mime_type(path).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("unsupported image media type for path: {}", path), ) })?; let mut file = std::fs::File::open(path)?; let mut buffer = Vec::new(); file.read_to_end(&mut buffer)?; let encoded = STANDARD.encode(&buffer); Ok((mime, encoded)) } fn encode_image_to_base64_with_budget( path: &str, target_tokens: usize, ) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; let (mime, encoded) = encode_image_to_base64(path)?; let target_base64_chars = target_tokens .saturating_sub(DATA_URL_OVERHEAD_TOKENS) .saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN); if encoded.len() <= target_base64_chars { return Ok((mime, encoded)); } let target_bytes = target_image_bytes_for_tokens(target_tokens); let image_bytes = std::fs::read(path)?; let image = image::load_from_memory(&image_bytes) .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?; Ok(("image/jpeg".to_string(), STANDARD.encode(compressed))) } fn compress_image_to_jpeg_budget( image: &image::DynamicImage, target_bytes: usize, ) -> Result, std::io::Error> { use image::codecs::jpeg::JpegEncoder; use image::imageops::FilterType; let mut best: Option> = None; let mut max_side = image .width() .max(image.height()) .max(MIN_COMPRESSED_IMAGE_SIDE); loop { let candidate = if image.width().max(image.height()) > max_side { image.resize(max_side, max_side, FilterType::Triangle) } else { image.clone() }; for quality in JPEG_QUALITY_STEPS { let mut encoded = Vec::new(); let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality); encoder .encode_image(&candidate) .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; if encoded.len() <= target_bytes { return Ok(encoded); } if best .as_ref() .map(|current| encoded.len() < current.len()) .unwrap_or(true) { best = Some(encoded); } } if max_side <= MIN_COMPRESSED_IMAGE_SIDE { break; } max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE); } best.ok_or_else(|| { std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed") }) } /// Truncate tool result if it exceeds the configured limit. /// Preserves the end of the output as it often contains the conclusion/useful result. fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String { let total_chars = char_count(output); if total_chars <= max_tool_result_chars { return output.to_string(); } let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN); if truncated_start_len > max_tool_result_chars { // Even after removing suffix, still too long - take from beginning let head_len = max_tool_result_chars.saturating_sub(100); let head = take_prefix_chars(output, head_len); format!( "{}...\n\n[Output truncated - {} characters removed]", head, total_chars - max_tool_result_chars + 100 ) } else { // Keep most of the end which usually contains the useful result let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len)); format!( "...\n\n[Output truncated - {} characters removed]\n\n{}", truncated_start_len, tail ) } } fn parse_pending_tool_output(output: &str) -> Option { output .strip_prefix(PENDING_USER_ACTION_MARKER) .map(|rest| rest.trim().to_string()) } fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value { match arguments { serde_json::Value::String(raw) => { serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()) } _ => arguments.clone(), } } fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; let mut current = error.source(); while let Some(source) = current { details.push(source.to_string()); current = source.source(); } details.join("\ncaused by: ") } fn is_recoverable_llm_error(error: &str) -> bool { let normalized = error.to_ascii_lowercase(); normalized.contains("504") || normalized.contains("gateway timeout") || normalized.contains("stream timeout") || normalized.contains("timed out") || normalized.contains("timeout") } fn recoverable_llm_message(error: &str) -> String { if is_recoverable_llm_error(error) { RECOVERABLE_LLM_ERROR_MESSAGE.to_string() } else { format!("模型请求失败:{}", error) } } /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { /// No warning needed. Ok, /// Warning: same tool + args repeated N times. Warning(String), } /// Configuration for loop detector. #[derive(Debug, Clone)] struct LoopDetectorConfig { /// Master switch. enabled: bool, /// Warn every N consecutive identical calls. warn_every: usize, } impl Default for LoopDetectorConfig { fn default() -> Self { Self { enabled: true, warn_every: 5, } } } /// A single recorded tool invocation in the sliding window. #[derive(Debug, Clone)] struct ToolCallRecord { name: String, args_hash: u64, } /// Stateful loop detector that monitors for repetitive patterns. struct LoopDetector { config: LoopDetectorConfig, window: VecDeque, } impl LoopDetector { fn new(config: LoopDetectorConfig) -> Self { Self { window: VecDeque::with_capacity(config.warn_every * 2), config, } } /// Record a completed tool call and check for loop patterns. /// Returns Warning every `warn_every` consecutive identical calls. fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult { if !self.config.enabled { return LoopDetectionResult::Ok; } let record = ToolCallRecord { name: name.to_string(), args_hash: hash_json_value(args), }; // Maintain sliding window if self.window.len() >= self.config.warn_every * 2 { self.window.pop_front(); } self.window.push_back(record); // Count consecutive identical calls let last = self.window.back().unwrap(); let consecutive: usize = self .window .iter() .rev() .take_while(|r| r.name == last.name && r.args_hash == last.args_hash) .count(); // Warn every warn_every times if consecutive > 0 && consecutive % self.config.warn_every == 0 { LoopDetectionResult::Warning(format!( "注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。", last.name, consecutive )) } else { LoopDetectionResult::Ok } } } /// Hash a JSON value deterministically (key-order independent). fn hash_json_value(value: &serde_json::Value) -> u64 { let mut hasher = std::collections::hash_map::DefaultHasher::new(); let canonical = canonicalise_json(value); canonical.hash(&mut hasher); hasher.finish() } /// Return a clone of value with all object keys sorted recursively. fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value { match value { serde_json::Value::Object(map) => { let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect(); sorted.sort_by_key(|(k, _)| *k); let new_map: serde_json::Map = sorted .into_iter() .map(|(k, v)| (k.clone(), canonicalise_json(v))) .collect(); serde_json::Value::Object(new_map) } serde_json::Value::Array(arr) => { serde_json::Value::Array(arr.iter().map(canonicalise_json).collect()) } other => other.clone(), } } /// Convert ChatMessage to LLM Message format fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message { let content = if m.media_refs.is_empty() { vec![ContentBlock::text(&m.content)] } else { build_content_blocks(&m.content, &m.media_refs, image_budget) }; Message { role: m.role.clone(), content, reasoning_content: m.reasoning_content.clone(), tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), tool_calls: m.tool_calls.clone(), } } fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message { Message { role: m.role.clone(), content: vec![ContentBlock::text(&m.content)], reasoning_content: m.reasoning_content.clone(), tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), tool_calls: m.tool_calls.clone(), } } /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { runtime_config: AgentRuntimeConfig, provider: Box, tools: Arc, /// 系统提示词提供者(统一注入 Agent 和 Skill 提示词) system_prompt_provider: Option>, /// Skill 提供者(用于匹配错误提示) skills: Option>, tool_context: ToolContext, observer: Option>, emitted_message_handler: Option>, max_iterations: usize, /// 取消信号接收端:Agent 在每次迭代开始时检查是否被取消 cancel_token: Option>, } #[derive(Debug, Clone)] pub struct AgentProcessResult { pub final_response: ChatMessage, pub emitted_messages: Vec, } #[async_trait] pub trait EmittedMessageHandler: Send + Sync + 'static { async fn handle(&self, message: ChatMessage); /// Handle a tool result message with optional execution timing. /// Default implementation delegates to `handle()`, ignoring timing. async fn handle_tool_result(&self, message: ChatMessage, _duration_ms: Option) { self.handle(message).await; } /// Handle a streaming delta. Default is no-op. async fn handle_stream_delta(&self, _delta: &StreamDelta) { // Non-streaming channels ignore this } /// Set the message ID to use for stream deltas (so the final assistant message /// can share the same ID, enabling front-end replacement). async fn set_stream_message_id(&self, _id: &str) { // Default: no-op for handlers that don't stream } } /// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB pub struct PersistingEmittedMessageHandler { inner: H, conversation_repository: Arc, session_id: String, topic_id: Option, } impl PersistingEmittedMessageHandler { pub fn new( inner: H, conversation_repository: Arc, session_id: impl Into, topic_id: Option, ) -> Self { Self { inner, conversation_repository, session_id: session_id.into(), topic_id } } } #[async_trait] impl EmittedMessageHandler for PersistingEmittedMessageHandler { async fn handle(&self, message: ChatMessage) { if let Err(e) = self.conversation_repository .append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message) { tracing::error!(error = %e, session_id = %self.session_id, "Failed to persist emitted message"); } self.inner.handle(message).await; } async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option) { // Persist the ChatMessage first (no duration field, same as before) if let Err(e) = self.conversation_repository .append_message_with_topic(&self.session_id, self.topic_id.as_deref(), &message) { tracing::error!(error = %e, session_id = %self.session_id, "Failed to persist emitted message"); } self.inner.handle_tool_result(message, duration_ms).await; } async fn set_stream_message_id(&self, id: &str) { self.inner.set_stream_message_id(id).await; } async fn handle_stream_delta(&self, delta: &StreamDelta) { // Deltas are transient — do NOT persist, just forward to inner handler self.inner.handle_stream_delta(delta).await; } } pub trait SkillProvider: Send + Sync + 'static { fn system_index_prompt(&self) -> Option; fn matching_skill_summary(&self, _name: &str) -> Option { None } } #[derive(Default)] #[allow(dead_code)] struct EmptySkillProvider; impl SkillProvider for EmptySkillProvider { fn system_index_prompt(&self) -> Option { None } } impl AgentLoop { pub fn new(config: impl Into) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools: Arc::new(ToolRegistry::new()), system_prompt_provider: None, skills: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, cancel_token: None, max_iterations, }) } pub fn with_tools( config: impl Into, tools: Arc, ) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools, system_prompt_provider: None, skills: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, cancel_token: None, max_iterations, }) } pub fn with_tools_and_skill_provider( config: impl Into, tools: Arc, skills: Arc, ) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools, system_prompt_provider: None, skills: Some(skills), tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, cancel_token: None, max_iterations, }) } /// 使用系统提示词提供者创建 AgentLoop /// /// 这是新的推荐方式,支持统一注入 Agent 和 Skill 提示词。 pub fn with_tools_and_system_prompt_provider( config: impl Into, tools: Arc, system_prompt_provider: Arc, skills: Option>, ) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools, system_prompt_provider: Some(system_prompt_provider), skills, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, cancel_token: None, max_iterations, }) } pub fn with_tool_context(mut self, context: ToolContext) -> Self { self.tool_context = context; self } /// Set an observer for tracking events. pub fn with_observer(mut self, observer: Arc) -> Self { self.observer = Some(observer); self } pub fn with_emitted_message_handler(mut self, handler: Arc) -> Self { self.emitted_message_handler = Some(handler); self } /// 设置取消信号接收端。 /// /// Agent 在每次迭代开始时检查 `cancel_token.has_changed()`, /// 如果已收到取消信号则提前返回。 pub fn with_cancel_token(mut self, token: tokio::sync::watch::Receiver<()>) -> Self { self.cancel_token = Some(token); self } pub fn tools(&self) -> &Arc { &self.tools } /// Process a message using the provided conversation history. /// History management is handled externally by SessionManager. /// /// This method supports multi-round tool calling: after executing tools, /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached /// /// # 参数 /// - `messages`: 会话历史消息 /// - `system_prompt_context`: 系统提示词上下文(用于动态注入,可选) pub async fn process( &self, mut messages: Vec, system_prompt_context: Option<&SystemPromptContext>, ) -> Result { #[cfg(debug_assertions)] tracing::debug!( history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process" ); // Sanitize: remove any trailing incomplete tool call sequences // that may have been persisted before a process interruption. crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] tracing::debug!(iteration, "Agent iteration started"); // 检查取消信号 if let Some(ref token) = self.cancel_token { if token.has_changed().unwrap_or(false) { tracing::info!(iteration, "Agent execution cancelled by user"); let cancel_message = format!( "\n\n[用户已取消执行。已迭代 {} 次,取消前共生成了 {} 条消息。]", iteration, emitted_messages.len() ); let assistant_message = ChatMessage::assistant(cancel_message); emitted_messages.push(assistant_message.clone()); self.emit_live_tool_call_message(assistant_message.clone()).await; return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } } // Build request let tool_defs = self.tools.get_definitions(); let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) }; // 过滤超出轮次和数量限制的图片 let filtered_messages = filter_images_by_age_and_count( &messages, self.runtime_config.max_image_age_rounds, self.runtime_config.max_images_in_context, ); let image_count = count_supported_image_media_refs(&filtered_messages); // 构建系统提示词(统一注入 Agent 和 Skill 提示词) let system_prompt = system_prompt_context.and_then(|ctx| { self.system_prompt_provider .as_ref() .and_then(|provider| provider.build(ctx)) }); let mut text_only_messages: Vec = Vec::with_capacity(filtered_messages.len() + 2); if let Some(ref prompt) = system_prompt { text_only_messages.push(Message::system(prompt.content.clone())); } text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request( &self.runtime_config, &text_only_messages, tools.as_ref(), ); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(filtered_messages.len() + 2); // 使用相同的系统提示词(已构建) if let Some(ref prompt) = system_prompt { messages_for_llm.push(Message::system(prompt.content.clone())); } messages_for_llm.extend( filtered_messages .iter() .map(|message| chat_message_to_llm_message(message, &mut image_budget)), ); let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, max_tokens: None, tools, }; // Set up streaming delta consumer // Pre-generate the message ID so stream deltas and the final assistant // message share the same ID — this lets the front-end replace the // streamed message with the authoritative response. let streaming_message_id = uuid::Uuid::new_v4().to_string(); if let Some(handler) = &self.emitted_message_handler { handler.set_stream_message_id(&streaming_message_id).await; } let (delta_tx, mut delta_rx) = tokio::sync::mpsc::channel::(256); let consumer_handler = self.emitted_message_handler.clone(); let consumer_task = tokio::spawn(async move { while let Some(delta) = delta_rx.recv().await { if let Some(ref handler) = consumer_handler { handler.handle_stream_delta(&delta).await; } } }); let stream_callback: StreamCallback = std::sync::Arc::new(move |delta: StreamDelta| { // try_send is non-blocking and safe to call from within a tokio runtime let _ = delta_tx.try_send(delta); }); let response = match (*self.provider).chat_with_streaming(request, stream_callback).await { Ok(response) => response, Err(e) => { // delta_tx is dropped with the callback; await consumer to finish let _ = consumer_task.await; tracing::error!( provider = %self.provider.name(), model = %self.provider.model_id(), error = %e, error_details = %format_error_chain(e.as_ref()), "LLM request failed" ); let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(assistant_message.clone()); self.emit_live_tool_call_message(assistant_message.clone()).await; return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } }; // Close delta channel and wait for consumer to finish processing // (delta_tx is dropped when the callback closure is dropped) let _ = consumer_task.await; // Signal stream end if handler exists let had_streaming = self.emitted_message_handler.is_some(); if had_streaming { let end_delta = StreamDelta { content: String::new(), reasoning_content: None, }; if let Some(handler) = &self.emitted_message_handler { handler.handle_stream_delta(&end_delta).await; } } #[cfg(debug_assertions)] tracing::debug!( iteration, response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received" ); // If no tool calls, this is the final response if response.tool_calls.is_empty() { let mut assistant_message = if let Some(reasoning_content) = response.reasoning_content { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) }; // Use the same ID as the stream deltas so the front-end can replace // the streamed message with this authoritative response. if had_streaming { assistant_message.id = streaming_message_id; } emitted_messages.push(assistant_message.clone()); self.emit_live_tool_call_message(assistant_message.clone()).await; return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } // Execute tool calls tracing::info!( iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools" ); // Add assistant message with tool calls let mut assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() { ChatMessage::assistant_with_tool_calls_and_reasoning( response.content.clone(), response.tool_calls.clone(), reasoning_content, ) } else { ChatMessage::assistant_with_tool_calls( response.content.clone(), response.tool_calls.clone(), ) }; // Use the same ID as stream deltas so the front-end can replace // the streamed message with this authoritative response. if had_streaming { assistant_message.id = streaming_message_id; } messages.push(assistant_message.clone()); emitted_messages.push(assistant_message); self.emit_live_tool_call_message( emitted_messages .last() .expect("assistant message just pushed") .clone(), ) .await; // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { // Truncate tool result if too large let truncated_output = truncate_tool_result(&result.output, self.runtime_config.tool_result_max_chars); // Record tool call and check for loops let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments); match loop_result { LoopDetectionResult::Warning(msg) => { // Add warning and proceed tracing::warn!( tool = %tool_call.name, "Loop warning: {}", msg ); let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), if result.state == ToolExecutionState::PendingUserAction { ToolMessageState::PendingUserAction } else { ToolMessageState::Completed }, ) .with_tool_duration(result.duration.as_millis() as u64); messages.push(tool_message.clone()); emitted_messages.push(tool_message.clone()); let duration_ms = Some(result.duration.as_millis() as u64); self.emit_tool_result(tool_message, duration_ms).await; } LoopDetectionResult::Ok => { let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), truncated_output, if result.state == ToolExecutionState::PendingUserAction { ToolMessageState::PendingUserAction } else { ToolMessageState::Completed }, ) .with_tool_duration(result.duration.as_millis() as u64); messages.push(tool_message.clone()); emitted_messages.push(tool_message.clone()); let duration_ms = Some(result.duration.as_millis() as u64); self.emit_tool_result(tool_message, duration_ms).await; } } } // Loop continues to next iteration with updated messages // PendingUserAction 工具的结果已在上方加入 messages, // 模型将在下一轮看到完整的终端输出并生成智能回复 #[cfg(debug_assertions)] tracing::debug!( iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration" ); } // Max iterations reached - ask LLM for a summary based on completed work tracing::warn!("Max iterations reached, requesting final summary from LLM"); // Add a message asking for summary let summary_request = ChatMessage::user( "You have reached the maximum number of tool call iterations. \ Please provide your best answer based on the work completed so far.", ); messages.push(summary_request); // 过滤超出轮次和数量限制的图片 let filtered_messages = filter_images_by_age_and_count( &messages, self.runtime_config.max_image_age_rounds, self.runtime_config.max_images_in_context, ); // Convert messages to LLM format (使用系统提示词提供者) let image_count = count_supported_image_media_refs(&filtered_messages); let mut text_only_messages: Vec = Vec::with_capacity(filtered_messages.len() + 1); if let Some(ref provider) = self.system_prompt_provider { if let Some(ctx) = system_prompt_context { if let Some(prompt) = provider.build(ctx) { text_only_messages.push(Message::system(prompt.content.clone())); } } } text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request(&self.runtime_config, &text_only_messages, None); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(filtered_messages.len() + 1); if let Some(ref provider) = self.system_prompt_provider { if let Some(ctx) = system_prompt_context { if let Some(prompt) = provider.build(ctx) { messages_for_llm.push(Message::system(prompt.content.clone())); } } } messages_for_llm.extend( filtered_messages .iter() .map(|message| chat_message_to_llm_message(message, &mut image_budget)), ); let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, max_tokens: None, tools: None, // No tools in final summary call }; match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) }; emitted_messages.push(assistant_message.clone()); self.emit_live_tool_call_message(assistant_message.clone()).await; Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }) } Err(e) => { tracing::error!( provider = %self.provider.name(), model = %self.provider.model_id(), error = %e, error_details = %format_error_chain(e.as_ref()), "Failed to get summary from LLM" ); let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(final_message.clone()); self.emit_live_tool_call_message(final_message.clone()).await; Ok(AgentProcessResult { final_response: final_message, emitted_messages, }) } } } async fn emit_live_tool_call_message(&self, message: ChatMessage) { if let Some(handler) = &self.emitted_message_handler { handler.handle(message).await; } } async fn emit_tool_result(&self, message: ChatMessage, duration_ms: Option) { if let Some(handler) = &self.emitted_message_handler { handler.handle_tool_result(message, duration_ms).await; } } /// Determine whether to execute tools in parallel or sequentially. /// /// Returns true if: /// - There are multiple tool calls /// - None of the tools require sequential execution (tool_search, non-concurrency-safe) fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool { if tool_calls.len() <= 1 { return false; } // tool_search must run sequentially to avoid MCP activation race conditions if tool_calls.iter().any(|tc| tc.name == "tool_search") { return false; } // Multiple Task tools can run in parallel // Task tools create independent subagents with isolated contexts, no shared state let task_count = tool_calls.iter().filter(|tc| tc.name == "task").count(); if task_count > 1 && task_count == tool_calls.len() { return true; } // When Task is mixed with other tools, keep sequential to avoid complexity if task_count > 0 { return false; } // All tools must be concurrency-safe to run in parallel tool_calls.iter().all(|tc| { self.tools .get(&tc.name) .map(|t| t.concurrency_safe()) .unwrap_or(false) }) } /// Execute multiple tool calls, choosing parallel or sequential based on conditions. async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { if self.should_execute_in_parallel(tool_calls) { tracing::debug!("Executing {} tools in parallel", tool_calls.len()); self.execute_tools_parallel(tool_calls).await } else { tracing::debug!("Executing {} tools sequentially", tool_calls.len()); self.execute_tools_sequential(tool_calls).await } } /// Execute tools in parallel using join_all. async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec { let futures: Vec<_> = tool_calls .iter() .map(|tc| self.execute_one_tool(tc)) .collect(); futures_util::future::join_all(futures).await } /// Execute tools sequentially. async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec { let mut outcomes = Vec::with_capacity(tool_calls.len()); for tool_call in tool_calls { outcomes.push(self.execute_one_tool(tool_call).await); } outcomes } /// Execute a single tool and return the outcome with event tracking. async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let start = Instant::now(); let tool_name = tool_call.name.clone(); // Log function call with name and arguments before execution let args_str = match &tool_call.arguments { serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), other => { serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()) } }; tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); // Record ToolCallStart event if let Some(ref observer) = self.observer { observer.record_event(&ObserverEvent::ToolCallStart { tool: tool_name.clone(), arguments: Some(truncate_args(&tool_call.arguments, 300)), }); } let result = self.execute_tool_internal(tool_call).await; let duration = start.elapsed(); // Record ToolCall event if let Some(ref observer) = self.observer { observer.record_event(&ObserverEvent::ToolCall { tool: tool_name.clone(), duration, success: result.success, }); } // Apply duration ToolExecutionOutcome { duration, ..result } } /// Internal tool execution without event tracking. async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let normalized_arguments = normalize_tool_arguments(&tool_call.arguments); let tool = match self.tools.get(&tool_call.name) { Some(t) => t, None => { tracing::warn!(tool = %tool_call.name, "Tool not found"); let skill_hint = self.skills .as_ref() .and_then(|s| s.matching_skill_summary(&tool_call.name)); let error = match skill_hint { Some(summary) => format!( "Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.", tool_call.name, summary, tool_call.name ), None => format!("Tool '{}' not found", tool_call.name), }; return ToolExecutionOutcome::failure( format!("Error: {}", error), Some(error), ); } }; match tool .execute_with_context(&self.tool_context, normalized_arguments.clone()) .await { Ok(result) => { if result.success { if let Some(pending_output) = parse_pending_tool_output(&result.output) { ToolExecutionOutcome::pending(pending_output) } else { ToolExecutionOutcome::success(result.output) } } else { let error = result.error.unwrap_or_default(); tracing::error!( tool = %tool_call.name, args = %truncate_args(&tool_call.arguments, 4_000), normalized_args = %truncate_args(&normalized_arguments, 4_000), error = %error, output = %result.output, "Tool returned an error result" ); ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error)) } } Err(e) => { tracing::error!( tool = %tool_call.name, args = %truncate_args(&tool_call.arguments, 4_000), normalized_args = %truncate_args(&normalized_arguments, 4_000), error = %e, error_details = %format!("{:#}", e), "Tool execution failed" ); ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string())) } } } } #[cfg(test)] mod tests { use super::*; use crate::config::LLMProviderConfig; use crate::observability::{MultiObserver, Observer}; use tempfile::tempdir; struct TestObserver { events: std::sync::Mutex>, } impl TestObserver { fn new() -> Self { Self { events: std::sync::Mutex::new(Vec::new()), } } } impl Observer for TestObserver { fn record_event(&self, event: &ObserverEvent) { self.events.lock().unwrap().push(event.clone()); } fn name(&self) -> &str { "test_observer" } } struct TestSkillProvider; impl SkillProvider for TestSkillProvider { fn system_index_prompt(&self) -> Option { None } fn matching_skill_summary(&self, name: &str) -> Option { (name == "baidu-search").then(|| "用于百度搜索和天气查询的技能".to_string()) } } fn test_runtime_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: std::collections::HashMap::new(), llm_timeout_secs: 120, memory_maintenance_timeout_secs: 600, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: std::collections::HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 100_000, context_tool_result_trim_chars: 100_000, max_images_in_context: 1, max_image_age_rounds: 10, } } #[tokio::test] async fn test_missing_tool_with_same_name_skill_returns_activation_hint() { let loop_instance = AgentLoop::with_tools_and_skill_provider( test_runtime_config(), Arc::new(ToolRegistry::new()), Arc::new(TestSkillProvider), ) .unwrap(); let outcome = loop_instance .execute_tool_internal(&ToolCall { id: "call-1".to_string(), name: "baidu-search".to_string(), arguments: serde_json::json!({ "queries": "佛山今天几点下雨" }), }) .await; assert_eq!(outcome.state, ToolExecutionState::Completed); assert!(!outcome.success); assert!(outcome.output.contains("技能")); assert!(outcome.output.contains("skill_activate")); assert!(outcome.output.contains("baidu-search")); } #[tokio::test] async fn test_observer_receives_tool_events() { // Verify MultiObserver works let mut multi = MultiObserver::new(); multi.add_observer(Box::new(TestObserver::new())); let event = ObserverEvent::ToolCallStart { tool: "test".to_string(), arguments: Some("{}".to_string()), }; multi.record_event(&event); // Just verify the structure works assert_eq!(multi.len(), 1); } #[test] fn test_should_execute_in_parallel_single_tool() { // Would need a proper setup with AgentLoop to test fully // For now, just verify the logic: single tool should return false let calls = vec![ToolCall { id: "1".to_string(), name: "test".to_string(), arguments: serde_json::json!({}), }]; // If there's only 1 tool, should return false regardless assert_eq!(calls.len() <= 1, true); } #[test] fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() { let chat_message = ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({ "expression": "2+2" }), }], ); let mut image_budget = ImageInlineBudget::new(0, 0); let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget); assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); assert_eq!( provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1" ); assert_eq!( provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator" ); } #[test] fn test_chat_message_to_llm_message_preserves_reasoning_content() { let chat_message = ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought"); let mut image_budget = ImageInlineBudget::new(0, 0); let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget); assert_eq!(provider_message.role, "assistant"); assert_eq!( provider_message.reasoning_content.as_deref(), Some("hidden chain of thought") ); } #[test] fn test_truncate_tool_result_handles_utf8_char_boundaries() { let input = "范".repeat(100_500); let output = truncate_tool_result(&input, 100_000); assert!(output.contains("Output truncated")); assert!(output.is_char_boundary(output.len())); } #[test] fn test_parse_pending_tool_output() { let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权"); assert_eq!(output.as_deref(), Some("请完成授权")); assert!(parse_pending_tool_output("normal output").is_none()); } #[test] fn test_normalize_tool_arguments_parses_stringified_json() { let normalized = normalize_tool_arguments(&serde_json::Value::String( "{\"command\":\"ls -la\"}".to_string(), )); assert_eq!(normalized, serde_json::json!({ "command": "ls -la" })); } #[test] fn test_normalize_tool_arguments_keeps_plain_string() { let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string())); assert_eq!( normalized, serde_json::Value::String("plain text".to_string()) ); } #[test] fn test_build_content_blocks_skips_non_image_media_refs() { let temp_dir = tempdir().unwrap(); let pdf_path = temp_dir.path().join("demo.pdf"); std::fs::write(&pdf_path, b"%PDF-1.4").unwrap(); let mut budget = ImageInlineBudget::new(1_000, 0); let blocks = build_content_blocks( "hello", &[pdf_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 1); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); } #[test] fn test_build_content_blocks_keeps_supported_images() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("demo.jpg"); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&jpg_path).unwrap(); let mut budget = ImageInlineBudget::new(10_000, 1); let blocks = build_content_blocks( "hello", &[jpg_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!( matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) ); } #[test] fn test_build_content_blocks_compresses_images_to_budget() { let temp_dir = tempdir().unwrap(); let png_path = temp_dir.path().join("large.png"); let image = image::DynamicImage::new_rgb8(512, 512); image.save(&png_path).unwrap(); let mut budget = ImageInlineBudget::new(512, 1); let blocks = build_content_blocks( "hello", &[png_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!( matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) ); } #[test] fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("demo.jpg"); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&jpg_path).unwrap(); let mut budget = ImageInlineBudget::new(0, 1); let blocks = build_content_blocks( "hello", &[jpg_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!(matches!( &blocks[1], ContentBlock::Text { text } if text.contains("图片未能成功入模") && text.contains("demo.jpg") )); } #[test] fn test_filter_images_by_age_and_count_no_images() { let messages = vec![ ChatMessage::user("hello"), ChatMessage::assistant("hi"), ChatMessage::user("how are you"), ]; let filtered = filter_images_by_age_and_count(&messages, 10, 5); assert_eq!(filtered.len(), 3); assert_eq!(filtered[0].content, "hello"); assert_eq!(filtered[1].content, "hi"); assert_eq!(filtered[2].content, "how are you"); } #[test] fn test_filter_images_by_age_limit() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("test.jpg"); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&jpg_path).unwrap(); // 创建 12 条消息,第 0 条(最旧)有图片 let messages: Vec = (0..12) .map(|i| { if i == 0 { ChatMessage::user_with_media( "first message with image", vec![jpg_path.to_string_lossy().to_string()], ) } else if i % 2 == 0 { ChatMessage::user(format!("user message {}", i)) } else { ChatMessage::assistant(format!("assistant message {}", i)) } }) .collect(); // max_age_rounds = 10,第一条消息(索引 0)距离最新消息(索引 11)有 11 条消息 // 所以第一条消息的图片应该被过滤 let filtered = filter_images_by_age_and_count(&messages, 10, 100); // 检查第一条消息的图片被过滤 assert_eq!(filtered[0].media_refs.len(), 0); assert!(filtered[0].content.contains("图片已过期")); } #[test] fn test_filter_images_by_count_limit() { let temp_dir = tempdir().unwrap(); // 创建 3 张不同的图片 let jpg_paths: Vec = (0..3) .map(|i| { let path = temp_dir.path().join(format!("test{}.jpg", i)); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&path).unwrap(); path.to_string_lossy().to_string() }) .collect(); // 创建 3 条消息,每条都有图片 let messages: Vec = (0..3) .map(|i| { ChatMessage::user_with_media( format!("message {}", i), vec![jpg_paths[i].clone()], ) }) .collect(); // max_images = 1,只保留最新的图片(索引 2) let filtered = filter_images_by_age_and_count(&messages, 100, 1); // 检查最新的消息保留图片,旧消息的图片被过滤 assert_eq!(filtered[2].media_refs.len(), 1); // 最新保留 assert_eq!(filtered[1].media_refs.len(), 0); // 被过滤 assert!(filtered[1].content.contains("超出最大图片数量限制")); assert_eq!(filtered[0].media_refs.len(), 0); // 被过滤 assert!(filtered[0].content.contains("超出最大图片数量限制")); } #[test] fn test_filter_images_combined_limits() { let temp_dir = tempdir().unwrap(); // 创建 5 张图片 let jpg_paths: Vec = (0..5) .map(|i| { let path = temp_dir.path().join(format!("test{}.jpg", i)); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&path).unwrap(); path.to_string_lossy().to_string() }) .collect(); // 创建 20 条消息,在特定位置添加图片 let messages: Vec = (0..20) .map(|i| { // 在索引 0, 5, 10, 15, 19 添加图片 if i == 0 || i == 5 || i == 10 || i == 15 || i == 19 { let image_idx = i / 5; ChatMessage::user_with_media( format!("message {} with image", i), vec![jpg_paths[image_idx.clamp(0, 4)].clone()], ) } else if i % 2 == 0 { ChatMessage::user(format!("user message {}", i)) } else { ChatMessage::assistant(format!("assistant message {}", i)) } }) .collect(); // max_age_rounds = 10(保留最近 10 条消息内的图片) // max_images = 3(最多 3 张图片) // 最新消息索引 19 (age=0),索引 15 (age=4),索引 10 (age=9),索引 5 (age=14),索引 0 (age=19) // 索引 5 和 0 超出 age 限制 // 索引 19, 15, 10 的图片应该保留(共 3 张,不超过 max_images) let filtered = filter_images_by_age_and_count(&messages, 10, 3); // 检查结果 assert!(filtered[19].media_refs.len() > 0, "最新消息应保留图片"); assert!(filtered[15].media_refs.len() > 0, "age=4 的消息应保留图片"); assert!(filtered[10].media_refs.len() > 0, "age=9 的消息应保留图片"); assert_eq!(filtered[5].media_refs.len(), 0, "age=14 的消息图片应被过滤"); assert!(filtered[5].content.contains("超出 10 条消息范围")); assert_eq!(filtered[0].media_refs.len(), 0, "age=19 的消息图片应被过滤"); assert!(filtered[0].content.contains("超出 10 条消息范围")); } // ==================== // sanitize_incomplete_tool_call_sequences tests // ==================== #[test] fn test_sanitize_removes_trailing_incomplete_tool_call_sequence() { let mut messages = vec![ ChatMessage::user("hello"), ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }], ), // Tool result for call_1 is MISSING — incomplete sequence ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 1); assert_eq!(messages.len(), 1); assert_eq!(messages[0].role, "user"); } #[test] fn test_sanitize_preserves_complete_tool_call_sequence() { let mut messages = vec![ ChatMessage::user("hello"), ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }], ), ChatMessage::tool("call_1", "calculator", "2"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 0); assert_eq!(messages.len(), 3); } #[test] fn test_sanitize_removes_multiple_incomplete_sequences() { let mut messages = vec![ ChatMessage::user("hello"), ChatMessage::assistant_with_tool_calls( "first tool call", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }], ), // Missing tool result for call_1 ChatMessage::user("second question"), ChatMessage::assistant_with_tool_calls( "second tool call", vec![ToolCall { id: "call_2".to_string(), name: "read".to_string(), arguments: serde_json::json!({"path": "README.md"}), }], ), // Also missing tool result for call_2 ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); // Should remove both trailing assistant messages with incomplete tool calls assert_eq!(removed, 2); assert_eq!(messages.len(), 2); assert_eq!(messages[0].role, "user"); assert_eq!(messages[0].content, "hello"); assert_eq!(messages[1].role, "user"); assert_eq!(messages[1].content, "second question"); } #[test] fn test_sanitize_removes_assistant_when_partial_tool_results() { // Assistant makes 2 tool calls, but only 1 tool result exists let mut messages = vec![ ChatMessage::user("hello"), ChatMessage::assistant_with_tool_calls( "calling two tools", vec![ ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }, ToolCall { id: "call_2".to_string(), name: "read".to_string(), arguments: serde_json::json!({"path": "README.md"}), }, ], ), ChatMessage::tool("call_1", "calculator", "2"), // Missing tool result for call_2 ]; let removed_count = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); // Phase 1 removes the assistant message (call_2 has no result). // Phase 2 removes the orphaned tool result for call_1 (its parent // assistant was removed). assert_eq!(removed_count, 2); assert_eq!(messages.len(), 1); assert_eq!(messages[0].role, "user"); } #[test] fn test_sanitize_preserves_messages_without_tool_calls() { let mut messages = vec![ ChatMessage::user("hello"), ChatMessage::assistant("hi there"), ChatMessage::user("how are you"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 0); assert_eq!(messages.len(), 3); } #[test] fn test_sanitize_handles_empty_messages() { let mut messages: Vec = vec![]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 0); } #[test] fn test_sanitize_removes_orphaned_tool_messages() { // A lone tool message without a preceding assistant tool_calls // is orphaned and should be removed. let mut messages = vec![ ChatMessage::tool("call_1", "calculator", "2"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 1); assert!(messages.is_empty()); } #[test] fn test_sanitize_preserves_complete_sequence_with_multiple_tool_calls() { let mut messages = vec![ ChatMessage::user("do two things"), ChatMessage::assistant_with_tool_calls( "calling two tools", vec![ ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }, ToolCall { id: "call_2".to_string(), name: "read".to_string(), arguments: serde_json::json!({"path": "README.md"}), }, ], ), ChatMessage::tool("call_1", "calculator", "2"), ChatMessage::tool("call_2", "read", "contents of README"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 0); assert_eq!(messages.len(), 4); } #[test] fn test_sanitize_only_trims_trailing_incomplete_sequence() { // Complete sequence followed by an incomplete one — only the // trailing incomplete one should be removed let mut messages = vec![ ChatMessage::user("first question"), ChatMessage::assistant_with_tool_calls( "first tool call", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }], ), ChatMessage::tool("call_1", "calculator", "2"), ChatMessage::assistant("the answer is 2"), ChatMessage::user("second question"), ChatMessage::assistant_with_tool_calls( "second tool call", vec![ToolCall { id: "call_2".to_string(), name: "read".to_string(), arguments: serde_json::json!({"path": "README.md"}), }], ), // Missing tool result for call_2 — only THIS sequence should be trimmed ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 1); // First complete sequence preserved (5 messages), user message for second // question preserved assert_eq!(messages.len(), 5); assert_eq!(messages[0].content, "first question"); assert_eq!(messages[3].content, "the answer is 2"); assert_eq!(messages[4].content, "second question"); } #[test] fn test_sanitize_removes_mid_history_orphaned_tool_calls() { // Bug scenario: orphaned tool_calls in the MIDDLE of history, // followed by a complete sequence. The old trailing-only sanitizer // would stop at the complete sequence and never remove the orphan. let mut messages = vec![ ChatMessage::user("first question"), ChatMessage::assistant_with_tool_calls( "orphaned tool call", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "1+1"}), }], ), // Missing tool result for call_1 — ORPHAN in the middle ChatMessage::user("second question"), ChatMessage::assistant_with_tool_calls( "valid tool call", vec![ToolCall { id: "call_2".to_string(), name: "read".to_string(), arguments: serde_json::json!({"path": "README.md"}), }], ), ChatMessage::tool("call_2", "read", "file contents"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); // call_1 assistant removed (1), rest preserved (4) assert_eq!(removed, 1); assert_eq!(messages.len(), 4); assert_eq!(messages[0].role, "user"); assert_eq!(messages[0].content, "first question"); assert_eq!(messages[1].role, "user"); assert_eq!(messages[1].content, "second question"); // The complete call_2 sequence is preserved assert_eq!(messages[2].role, "assistant"); assert_eq!(messages[3].role, "tool"); } #[test] fn test_sanitize_removes_multiple_mid_history_orphans() { // Multiple orphaned tool_calls scattered throughout history let mut messages = vec![ ChatMessage::user("first"), ChatMessage::assistant_with_tool_calls( "orphan 1", vec![ToolCall { id: "orphan_1".to_string(), name: "tool_a".to_string(), arguments: serde_json::json!({}), }], ), // Missing result for orphan_1 ChatMessage::user("second"), ChatMessage::assistant_with_tool_calls( "orphan 2", vec![ToolCall { id: "orphan_2".to_string(), name: "tool_b".to_string(), arguments: serde_json::json!({}), }], ), // Missing result for orphan_2 ChatMessage::user("third"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); assert_eq!(removed, 2); assert_eq!(messages.len(), 3); assert_eq!(messages[0].content, "first"); assert_eq!(messages[1].content, "second"); assert_eq!(messages[2].content, "third"); } #[test] fn test_sanitize_removes_orphaned_tool_results_for_removed_mid_assistant() { // When removing a mid-history assistant with partial tool results, // both the assistant AND its orphaned tool results must be removed. // Assistant has 2 tool calls, only 1 has a result → assistant is // incomplete → both assistant and its lone tool result removed. let mut messages = vec![ ChatMessage::user("first question"), ChatMessage::assistant_with_tool_calls( "two tool calls, only one has result", vec![ ToolCall { id: "call_has_result".to_string(), name: "tool_a".to_string(), arguments: serde_json::json!({}), }, ToolCall { id: "call_no_result".to_string(), name: "tool_b".to_string(), arguments: serde_json::json!({}), }, ], ), ChatMessage::tool("call_has_result", "tool_a", "some result"), // Missing tool result for call_no_result → incomplete sequence ChatMessage::user("second question"), ChatMessage::assistant_with_tool_calls( "valid tool call", vec![ToolCall { id: "call_valid".to_string(), name: "good_tool".to_string(), arguments: serde_json::json!({}), }], ), ChatMessage::tool("call_valid", "good_tool", "valid result"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); // call_has_result/call_no_result assistant (1) + its orphaned tool result (1) = 2 removed assert_eq!(removed, 2); assert_eq!(messages.len(), 4); assert_eq!(messages[0].content, "first question"); assert_eq!(messages[1].content, "second question"); assert_eq!(messages[2].role, "assistant"); assert_eq!(messages[3].role, "tool"); // Verify the remaining tool belongs to call_valid assert_eq!(messages[3].tool_call_id.as_deref(), Some("call_valid")); } #[test] fn test_sanitize_handles_complex_interleaved_history() { // Complete → Orphaned → Complete: a realistic scenario after // history compaction let mut messages = vec![ ChatMessage::user("task 1"), ChatMessage::assistant_with_tool_calls( "doing task 1", vec![ToolCall { id: "t1_call".to_string(), name: "read".to_string(), arguments: serde_json::json!({"path": "a.txt"}), }], ), ChatMessage::tool("t1_call", "read", "content A"), ChatMessage::assistant("task 1 is done"), // End of task 1 — complete sequence ChatMessage::user("task 2"), ChatMessage::assistant_with_tool_calls( "doing task 2 — this got interrupted", vec![ ToolCall { id: "t2_call_1".to_string(), name: "write".to_string(), arguments: serde_json::json!({"path": "b.txt"}), }, ToolCall { id: "t2_call_2".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({"expression": "2+2"}), }, ], ), // Missing BOTH tool results — process was killed here // End of task 2 — orphaned sequence in the middle ChatMessage::user("task 3"), ChatMessage::assistant_with_tool_calls( "doing task 3", vec![ToolCall { id: "t3_call".to_string(), name: "search".to_string(), arguments: serde_json::json!({"query": "hello"}), }], ), ChatMessage::tool("t3_call", "search", "found results"), ChatMessage::assistant("task 3 is done"), ]; let removed = crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut messages); // Removed: assistant with t2_call_1/t2_call_2 (1 message) assert_eq!(removed, 1); // Original 10 messages - 1 = 9 assert_eq!(messages.len(), 9); assert_eq!(messages[4].content, "task 2"); assert_eq!(messages[5].content, "task 3"); } } #[derive(Debug)] pub enum AgentError { ProviderCreation(String), LlmError(String), Other(String), } impl std::fmt::Display for AgentError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e), AgentError::LlmError(e) => write!(f, "LLM error: {}", e), AgentError::Other(e) => write!(f, "{}", e), } } } impl std::error::Error for AgentError {}