use crate::agent::AgentRuntimeConfig; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; use crate::domain::messages::{ContentBlock, ToolCall}; use crate::observability::{ Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, }; use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; use crate::tools::{ToolContext, ToolRegistry}; use async_trait::async_trait; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; use std::sync::Arc; use std::time::Instant; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"]; const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4; const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2; const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9; const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048; const DATA_URL_OVERHEAD_TOKENS: usize = 16; const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36]; const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64; const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:"; /// Build content blocks from text and media paths fn build_content_blocks( text: &str, media_paths: &[String], budget: &mut ImageInlineBudget, ) -> Vec { build_content_blocks_with_image_budget(text, media_paths, budget) } fn build_content_blocks_with_image_budget( text: &str, media_paths: &[String], budget: &mut ImageInlineBudget, ) -> Vec { let mut blocks = Vec::new(); let mut skipped_image_notices = Vec::new(); // Add text block if there's text if !text.is_empty() { blocks.push(ContentBlock::text(text)); } // Add image blocks for media paths for path in media_paths { if supported_image_mime_type(path).is_none() { tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block"); continue; } let Some(target_tokens) = budget.take_next_image_tokens() else { tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains"); skipped_image_notices.push(format!( "- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。", display_media_name(path) )); continue; }; match encode_image_to_base64_with_budget(path, target_tokens) { Ok((mime_type, base64_data)) => { let url = format!("data:{};base64,{}", mime_type, base64_data); blocks.push(ContentBlock::image_url(url)); } Err(err) => { tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed"); skipped_image_notices.push(format!( "- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。", display_media_name(path) )); continue; } } } if !skipped_image_notices.is_empty() { let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX); notice.push('\n'); notice.push_str(&skipped_image_notices.join("\n")); blocks.push(ContentBlock::text(notice)); } // If nothing, add empty text block if blocks.is_empty() { blocks.push(ContentBlock::text("")); } blocks } fn supported_image_mime_type(path: &str) -> Option { let mime = mime_guess::from_path(path).first_or_octet_stream(); let essence = mime.essence_str(); if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) { Some(essence.to_string()) } else { None } } fn display_media_name(path: &str) -> String { std::path::Path::new(path) .file_name() .and_then(|name| name.to_str()) .map(ToOwned::to_owned) .unwrap_or_else(|| path.to_string()) } #[derive(Debug, Clone)] struct ImageInlineBudget { remaining_tokens: usize, remaining_images: usize, } impl ImageInlineBudget { fn new(total_tokens: usize, image_count: usize) -> Self { Self { remaining_tokens: total_tokens, remaining_images: image_count, } } fn take_next_image_tokens(&mut self) -> Option { if self.remaining_tokens == 0 || self.remaining_images == 0 { self.remaining_images = self.remaining_images.saturating_sub(1); return None; } let target = self.remaining_tokens.div_ceil(self.remaining_images); self.remaining_tokens = self.remaining_tokens.saturating_sub(target); self.remaining_images = self.remaining_images.saturating_sub(1); Some(target) } } fn estimate_tokens_from_serialized_json(value: &T) -> usize { let raw_len = serde_json::to_string(value) .map(|serialized| serialized.len()) .unwrap_or_default(); ((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize } fn image_token_budget_for_request( runtime_config: &AgentRuntimeConfig, text_only_messages: &[Message], tools: Option<&Vec>, ) -> usize { let completion_reserve = runtime_config .provider .max_tokens .map(|tokens| tokens as usize) .unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE); let input_window = runtime_config .context_window_tokens .saturating_sub(completion_reserve); let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize; let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages) + tools .map(estimate_tokens_from_serialized_json) .unwrap_or_default(); safe_input_window.saturating_sub(text_tokens) } fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize { messages .iter() .flat_map(|message| message.media_refs.iter()) .filter(|path| supported_image_mime_type(path).is_some()) .count() } fn target_image_bytes_for_tokens(target_tokens: usize) -> usize { target_tokens .saturating_sub(DATA_URL_OVERHEAD_TOKENS) .saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN) .saturating_mul(3) / 4 } /// Encode an image file to base64 data URL fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; let mime = supported_image_mime_type(path).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("unsupported image media type for path: {}", path), ) })?; let mut file = std::fs::File::open(path)?; let mut buffer = Vec::new(); file.read_to_end(&mut buffer)?; let encoded = STANDARD.encode(&buffer); Ok((mime, encoded)) } fn encode_image_to_base64_with_budget( path: &str, target_tokens: usize, ) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; let (mime, encoded) = encode_image_to_base64(path)?; let target_base64_chars = target_tokens .saturating_sub(DATA_URL_OVERHEAD_TOKENS) .saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN); if encoded.len() <= target_base64_chars { return Ok((mime, encoded)); } let target_bytes = target_image_bytes_for_tokens(target_tokens); let image_bytes = std::fs::read(path)?; let image = image::load_from_memory(&image_bytes) .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?; Ok(("image/jpeg".to_string(), STANDARD.encode(compressed))) } fn compress_image_to_jpeg_budget( image: &image::DynamicImage, target_bytes: usize, ) -> Result, std::io::Error> { use image::codecs::jpeg::JpegEncoder; use image::imageops::FilterType; let mut best: Option> = None; let mut max_side = image .width() .max(image.height()) .max(MIN_COMPRESSED_IMAGE_SIDE); loop { let candidate = if image.width().max(image.height()) > max_side { image.resize(max_side, max_side, FilterType::Triangle) } else { image.clone() }; for quality in JPEG_QUALITY_STEPS { let mut encoded = Vec::new(); let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality); encoder .encode_image(&candidate) .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; if encoded.len() <= target_bytes { return Ok(encoded); } if best .as_ref() .map(|current| encoded.len() < current.len()) .unwrap_or(true) { best = Some(encoded); } } if max_side <= MIN_COMPRESSED_IMAGE_SIDE { break; } max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE); } best.ok_or_else(|| { std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed") }) } /// Truncate tool result if it exceeds the configured limit. /// Preserves the end of the output as it often contains the conclusion/useful result. fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String { let total_chars = char_count(output); if total_chars <= max_tool_result_chars { return output.to_string(); } let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN); if truncated_start_len > max_tool_result_chars { // Even after removing suffix, still too long - take from beginning let head_len = max_tool_result_chars.saturating_sub(100); let head = take_prefix_chars(output, head_len); format!( "{}...\n\n[Output truncated - {} characters removed]", head, total_chars - max_tool_result_chars + 100 ) } else { // Keep most of the end which usually contains the useful result let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len)); format!( "...\n\n[Output truncated - {} characters removed]\n\n{}", truncated_start_len, tail ) } } fn parse_pending_tool_output(output: &str) -> Option { output .strip_prefix(PENDING_USER_ACTION_MARKER) .map(|rest| rest.trim().to_string()) } fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value { match arguments { serde_json::Value::String(raw) => { serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()) } _ => arguments.clone(), } } fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; let mut current = error.source(); while let Some(source) = current { details.push(source.to_string()); current = source.source(); } details.join("\ncaused by: ") } fn is_recoverable_llm_error(error: &str) -> bool { let normalized = error.to_ascii_lowercase(); normalized.contains("504") || normalized.contains("gateway timeout") || normalized.contains("stream timeout") || normalized.contains("timed out") || normalized.contains("timeout") } fn recoverable_llm_message(error: &str) -> String { if is_recoverable_llm_error(error) { RECOVERABLE_LLM_ERROR_MESSAGE.to_string() } else { format!("模型请求失败:{}", error) } } /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { /// No warning needed. Ok, /// Warning: same tool + args repeated N times. Warning(String), } /// Configuration for loop detector. #[derive(Debug, Clone)] struct LoopDetectorConfig { /// Master switch. enabled: bool, /// Warn every N consecutive identical calls. warn_every: usize, } impl Default for LoopDetectorConfig { fn default() -> Self { Self { enabled: true, warn_every: 5, } } } /// A single recorded tool invocation in the sliding window. #[derive(Debug, Clone)] struct ToolCallRecord { name: String, args_hash: u64, } /// Stateful loop detector that monitors for repetitive patterns. struct LoopDetector { config: LoopDetectorConfig, window: VecDeque, } impl LoopDetector { fn new(config: LoopDetectorConfig) -> Self { Self { window: VecDeque::with_capacity(config.warn_every * 2), config, } } /// Record a completed tool call and check for loop patterns. /// Returns Warning every `warn_every` consecutive identical calls. fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult { if !self.config.enabled { return LoopDetectionResult::Ok; } let record = ToolCallRecord { name: name.to_string(), args_hash: hash_json_value(args), }; // Maintain sliding window if self.window.len() >= self.config.warn_every * 2 { self.window.pop_front(); } self.window.push_back(record); // Count consecutive identical calls let last = self.window.back().unwrap(); let consecutive: usize = self .window .iter() .rev() .take_while(|r| r.name == last.name && r.args_hash == last.args_hash) .count(); // Warn every warn_every times if consecutive > 0 && consecutive % self.config.warn_every == 0 { LoopDetectionResult::Warning(format!( "注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。", last.name, consecutive )) } else { LoopDetectionResult::Ok } } } /// Hash a JSON value deterministically (key-order independent). fn hash_json_value(value: &serde_json::Value) -> u64 { let mut hasher = std::collections::hash_map::DefaultHasher::new(); let canonical = canonicalise_json(value); canonical.hash(&mut hasher); hasher.finish() } /// Return a clone of value with all object keys sorted recursively. fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value { match value { serde_json::Value::Object(map) => { let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect(); sorted.sort_by_key(|(k, _)| *k); let new_map: serde_json::Map = sorted .into_iter() .map(|(k, v)| (k.clone(), canonicalise_json(v))) .collect(); serde_json::Value::Object(new_map) } serde_json::Value::Array(arr) => { serde_json::Value::Array(arr.iter().map(canonicalise_json).collect()) } other => other.clone(), } } /// Convert ChatMessage to LLM Message format fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message { let content = if m.media_refs.is_empty() { vec![ContentBlock::text(&m.content)] } else { build_content_blocks(&m.content, &m.media_refs, image_budget) }; Message { role: m.role.clone(), content, reasoning_content: m.reasoning_content.clone(), tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), tool_calls: m.tool_calls.clone(), } } fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message { Message { role: m.role.clone(), content: vec![ContentBlock::text(&m.content)], reasoning_content: m.reasoning_content.clone(), tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), tool_calls: m.tool_calls.clone(), } } /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { runtime_config: AgentRuntimeConfig, provider: Box, tools: Arc, skills: Arc, tool_context: ToolContext, observer: Option>, emitted_message_handler: Option>, max_iterations: usize, } #[derive(Debug, Clone)] pub struct AgentProcessResult { pub final_response: ChatMessage, pub emitted_messages: Vec, } #[async_trait] pub trait EmittedMessageHandler: Send + Sync + 'static { async fn handle(&self, message: ChatMessage); } pub trait SkillProvider: Send + Sync + 'static { fn system_index_prompt(&self) -> Option; fn matching_skill_summary(&self, _name: &str) -> Option { None } } #[derive(Default)] struct EmptySkillProvider; impl SkillProvider for EmptySkillProvider { fn system_index_prompt(&self) -> Option { None } } impl AgentLoop { pub fn new(config: impl Into) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools: Arc::new(ToolRegistry::new()), skills: Arc::new(EmptySkillProvider), tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, max_iterations, }) } pub fn with_tools( config: impl Into, tools: Arc, ) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools, skills: Arc::new(EmptySkillProvider), tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, max_iterations, }) } pub fn with_tools_and_skill_provider( config: impl Into, tools: Arc, skills: Arc, ) -> Result { let runtime_config = config.into(); let max_iterations = runtime_config.max_tool_iterations; let provider = create_provider(runtime_config.provider.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { runtime_config, provider, tools, skills, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, max_iterations, }) } pub fn with_tool_context(mut self, context: ToolContext) -> Self { self.tool_context = context; self } /// Set an observer for tracking events. pub fn with_observer(mut self, observer: Arc) -> Self { self.observer = Some(observer); self } pub fn with_emitted_message_handler(mut self, handler: Arc) -> Self { self.emitted_message_handler = Some(handler); self } pub fn tools(&self) -> &Arc { &self.tools } /// Process a message using the provided conversation history. /// History management is handled externally by SessionManager. /// /// This method supports multi-round tool calling: after executing tools, /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached pub async fn process( &self, mut messages: Vec, ) -> Result { #[cfg(debug_assertions)] tracing::debug!( history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process" ); // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] tracing::debug!(iteration, "Agent iteration started"); // Build request let tool_defs = self.tools.get_definitions(); let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) }; let image_count = count_supported_image_media_refs(&messages); let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 2); if let Some(skill_prompt) = self.skills.system_index_prompt() { text_only_messages.push(Message::system(skill_prompt.clone())); } text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request( &self.runtime_config, &text_only_messages, tools.as_ref(), ); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 2); if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } messages_for_llm.extend( messages .iter() .map(|message| chat_message_to_llm_message(message, &mut image_budget)), ); let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, max_tokens: None, tools, }; let response = match (*self.provider).chat(request).await { Ok(response) => response, Err(e) => { tracing::error!( provider = %self.provider.name(), model = %self.provider.model_id(), error = %e, error_details = %format_error_chain(e.as_ref()), "LLM request failed" ); let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } }; #[cfg(debug_assertions)] tracing::debug!( iteration, response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received" ); // If no tool calls, this is the final response if response.tool_calls.is_empty() { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) }; emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } // Execute tool calls tracing::info!( iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools" ); // Add assistant message with tool calls let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() { ChatMessage::assistant_with_tool_calls_and_reasoning( response.content.clone(), response.tool_calls.clone(), reasoning_content, ) } else { ChatMessage::assistant_with_tool_calls( response.content.clone(), response.tool_calls.clone(), ) }; messages.push(assistant_message.clone()); emitted_messages.push(assistant_message); self.emit_live_tool_call_message( emitted_messages .last() .expect("assistant message just pushed") .clone(), ) .await; // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { // Log function call with name and arguments let args_str = match &tool_call.arguments { serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), other => { serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()) } }; tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); // Truncate tool result if too large let truncated_output = truncate_tool_result(&result.output, self.runtime_config.tool_result_max_chars); // Record tool call and check for loops let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments); match loop_result { LoopDetectionResult::Warning(msg) => { // Add warning and proceed tracing::warn!( tool = %tool_call.name, "Loop warning: {}", msg ); let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), if result.state == ToolExecutionState::PendingUserAction { ToolMessageState::PendingUserAction } else { ToolMessageState::Completed }, ); messages.push(tool_message.clone()); emitted_messages.push(tool_message); } LoopDetectionResult::Ok => { let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), truncated_output, if result.state == ToolExecutionState::PendingUserAction { ToolMessageState::PendingUserAction } else { ToolMessageState::Completed }, ); messages.push(tool_message.clone()); emitted_messages.push(tool_message); } } } if let Some((tool_call, pending_result)) = response .tool_calls .iter() .zip(tool_results.iter()) .find(|(_, result)| result.state == ToolExecutionState::PendingUserAction) { let assistant_message = ChatMessage::assistant(format!( "{}\n\n当前等待中的工具: {}", pending_result .output .lines() .next() .filter(|line| !line.trim().is_empty()) .unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE), tool_call.name, )); emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } // Loop continues to next iteration with updated messages #[cfg(debug_assertions)] tracing::debug!( iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration" ); } // Max iterations reached - ask LLM for a summary based on completed work tracing::warn!("Max iterations reached, requesting final summary from LLM"); // Add a message asking for summary let summary_request = ChatMessage::user( "You have reached the maximum number of tool call iterations. \ Please provide your best answer based on the work completed so far.", ); messages.push(summary_request); // Convert messages to LLM format let image_count = count_supported_image_media_refs(&messages); let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 1); if let Some(skill_prompt) = self.skills.system_index_prompt() { text_only_messages.push(Message::system(skill_prompt)); } text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request(&self.runtime_config, &text_only_messages, None); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } messages_for_llm.extend( messages .iter() .map(|message| chat_message_to_llm_message(message, &mut image_budget)), ); let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, max_tokens: None, tools: None, // No tools in final summary call }; match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) }; emitted_messages.push(assistant_message.clone()); Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }) } Err(e) => { tracing::error!( provider = %self.provider.name(), model = %self.provider.model_id(), error = %e, error_details = %format_error_chain(e.as_ref()), "Failed to get summary from LLM" ); let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(final_message.clone()); Ok(AgentProcessResult { final_response: final_message, emitted_messages, }) } } } async fn emit_live_tool_call_message(&self, message: ChatMessage) { if !message.is_assistant_tool_call_message() { return; } if let Some(handler) = &self.emitted_message_handler { handler.handle(message).await; } } /// Determine whether to execute tools in parallel or sequentially. /// /// Returns true if: /// - There are multiple tool calls /// - None of the tools require sequential execution (tool_search, non-concurrency-safe) fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool { if tool_calls.len() <= 1 { return false; } // tool_search must run sequentially to avoid MCP activation race conditions if tool_calls.iter().any(|tc| tc.name == "tool_search") { return false; } // All tools must be concurrency-safe to run in parallel tool_calls.iter().all(|tc| { self.tools .get(&tc.name) .map(|t| t.concurrency_safe()) .unwrap_or(false) }) } /// Execute multiple tool calls, choosing parallel or sequential based on conditions. async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { if self.should_execute_in_parallel(tool_calls) { tracing::debug!("Executing {} tools in parallel", tool_calls.len()); self.execute_tools_parallel(tool_calls).await } else { tracing::debug!("Executing {} tools sequentially", tool_calls.len()); self.execute_tools_sequential(tool_calls).await } } /// Execute tools in parallel using join_all. async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec { let futures: Vec<_> = tool_calls .iter() .map(|tc| self.execute_one_tool(tc)) .collect(); futures_util::future::join_all(futures).await } /// Execute tools sequentially. async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec { let mut outcomes = Vec::with_capacity(tool_calls.len()); for tool_call in tool_calls { outcomes.push(self.execute_one_tool(tool_call).await); } outcomes } /// Execute a single tool and return the outcome with event tracking. async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let start = Instant::now(); let tool_name = tool_call.name.clone(); // Record ToolCallStart event if let Some(ref observer) = self.observer { observer.record_event(&ObserverEvent::ToolCallStart { tool: tool_name.clone(), arguments: Some(truncate_args(&tool_call.arguments, 300)), }); } let result = self.execute_tool_internal(tool_call).await; let duration = start.elapsed(); // Record ToolCall event if let Some(ref observer) = self.observer { observer.record_event(&ObserverEvent::ToolCall { tool: tool_name.clone(), duration, success: result.success, }); } // Apply duration ToolExecutionOutcome { duration, ..result } } /// Internal tool execution without event tracking. async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let normalized_arguments = normalize_tool_arguments(&tool_call.arguments); let tool = match self.tools.get(&tool_call.name) { Some(t) => t, None => { tracing::warn!(tool = %tool_call.name, "Tool not found"); let skill_hint = self.skills.matching_skill_summary(&tool_call.name); let error = match skill_hint { Some(summary) => format!( "Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.", tool_call.name, summary, tool_call.name ), None => format!("Tool '{}' not found", tool_call.name), }; return ToolExecutionOutcome::failure( format!("Error: {}", error), Some(error), ); } }; match tool .execute_with_context(&self.tool_context, normalized_arguments.clone()) .await { Ok(result) => { if result.success { if let Some(pending_output) = parse_pending_tool_output(&result.output) { ToolExecutionOutcome::pending(pending_output) } else { ToolExecutionOutcome::success(result.output) } } else { let error = result.error.unwrap_or_default(); tracing::error!( tool = %tool_call.name, args = %truncate_args(&tool_call.arguments, 4_000), normalized_args = %truncate_args(&normalized_arguments, 4_000), error = %error, output = %result.output, "Tool returned an error result" ); ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error)) } } Err(e) => { tracing::error!( tool = %tool_call.name, args = %truncate_args(&tool_call.arguments, 4_000), normalized_args = %truncate_args(&normalized_arguments, 4_000), error = %e, error_details = %format!("{:#}", e), "Tool execution failed" ); ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string())) } } } } #[cfg(test)] mod tests { use super::*; use crate::config::LLMProviderConfig; use crate::observability::{MultiObserver, Observer}; use tempfile::tempdir; struct TestObserver { events: std::sync::Mutex>, } impl TestObserver { fn new() -> Self { Self { events: std::sync::Mutex::new(Vec::new()), } } } impl Observer for TestObserver { fn record_event(&self, event: &ObserverEvent) { self.events.lock().unwrap().push(event.clone()); } fn name(&self) -> &str { "test_observer" } } struct TestSkillProvider; impl SkillProvider for TestSkillProvider { fn system_index_prompt(&self) -> Option { None } fn matching_skill_summary(&self, name: &str) -> Option { (name == "baidu-search").then(|| "用于百度搜索和天气查询的技能".to_string()) } } fn test_runtime_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: std::collections::HashMap::new(), llm_timeout_secs: 120, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: std::collections::HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, } } #[tokio::test] async fn test_missing_tool_with_same_name_skill_returns_activation_hint() { let loop_instance = AgentLoop::with_tools_and_skill_provider( test_runtime_config(), Arc::new(ToolRegistry::new()), Arc::new(TestSkillProvider), ) .unwrap(); let outcome = loop_instance .execute_tool_internal(&ToolCall { id: "call-1".to_string(), name: "baidu-search".to_string(), arguments: serde_json::json!({ "queries": "佛山今天几点下雨" }), }) .await; assert_eq!(outcome.state, ToolExecutionState::Completed); assert!(!outcome.success); assert!(outcome.output.contains("技能")); assert!(outcome.output.contains("skill_activate")); assert!(outcome.output.contains("baidu-search")); } #[tokio::test] async fn test_observer_receives_tool_events() { // Verify MultiObserver works let mut multi = MultiObserver::new(); multi.add_observer(Box::new(TestObserver::new())); let event = ObserverEvent::ToolCallStart { tool: "test".to_string(), arguments: Some("{}".to_string()), }; multi.record_event(&event); // Just verify the structure works assert_eq!(multi.len(), 1); } #[test] fn test_should_execute_in_parallel_single_tool() { // Would need a proper setup with AgentLoop to test fully // For now, just verify the logic: single tool should return false let calls = vec![ToolCall { id: "1".to_string(), name: "test".to_string(), arguments: serde_json::json!({}), }]; // If there's only 1 tool, should return false regardless assert_eq!(calls.len() <= 1, true); } #[test] fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() { let chat_message = ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({ "expression": "2+2" }), }], ); let mut image_budget = ImageInlineBudget::new(0, 0); let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget); assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); assert_eq!( provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1" ); assert_eq!( provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator" ); } #[test] fn test_chat_message_to_llm_message_preserves_reasoning_content() { let chat_message = ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought"); let mut image_budget = ImageInlineBudget::new(0, 0); let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget); assert_eq!(provider_message.role, "assistant"); assert_eq!( provider_message.reasoning_content.as_deref(), Some("hidden chain of thought") ); } #[test] fn test_truncate_tool_result_handles_utf8_char_boundaries() { let input = "范".repeat(20_500); let output = truncate_tool_result(&input, 20_000); assert!(output.contains("Output truncated")); assert!(output.is_char_boundary(output.len())); } #[test] fn test_parse_pending_tool_output() { let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权"); assert_eq!(output.as_deref(), Some("请完成授权")); assert!(parse_pending_tool_output("normal output").is_none()); } #[test] fn test_normalize_tool_arguments_parses_stringified_json() { let normalized = normalize_tool_arguments(&serde_json::Value::String( "{\"command\":\"ls -la\"}".to_string(), )); assert_eq!(normalized, serde_json::json!({ "command": "ls -la" })); } #[test] fn test_normalize_tool_arguments_keeps_plain_string() { let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string())); assert_eq!( normalized, serde_json::Value::String("plain text".to_string()) ); } #[test] fn test_build_content_blocks_skips_non_image_media_refs() { let temp_dir = tempdir().unwrap(); let pdf_path = temp_dir.path().join("demo.pdf"); std::fs::write(&pdf_path, b"%PDF-1.4").unwrap(); let mut budget = ImageInlineBudget::new(1_000, 0); let blocks = build_content_blocks( "hello", &[pdf_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 1); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); } #[test] fn test_build_content_blocks_keeps_supported_images() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("demo.jpg"); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&jpg_path).unwrap(); let mut budget = ImageInlineBudget::new(10_000, 1); let blocks = build_content_blocks( "hello", &[jpg_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!( matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) ); } #[test] fn test_build_content_blocks_compresses_images_to_budget() { let temp_dir = tempdir().unwrap(); let png_path = temp_dir.path().join("large.png"); let image = image::DynamicImage::new_rgb8(512, 512); image.save(&png_path).unwrap(); let mut budget = ImageInlineBudget::new(512, 1); let blocks = build_content_blocks( "hello", &[png_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!( matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) ); } #[test] fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("demo.jpg"); let image = image::DynamicImage::new_rgb8(8, 8); image.save(&jpg_path).unwrap(); let mut budget = ImageInlineBudget::new(0, 1); let blocks = build_content_blocks( "hello", &[jpg_path.to_string_lossy().to_string()], &mut budget, ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!(matches!( &blocks[1], ContentBlock::Text { text } if text.contains("图片未能成功入模") && text.contains("demo.jpg") )); } } #[derive(Debug)] pub enum AgentError { ProviderCreation(String), LlmError(String), Other(String), } impl std::fmt::Display for AgentError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e), AgentError::LlmError(e) => write!(f, "LLM error: {}", e), AgentError::Other(e) => write!(f, "{}", e), } } } impl std::error::Error for AgentError {}