use async_trait::async_trait; use crate::bus::message::ContentBlock; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; use crate::config::LLMProviderConfig; use crate::observability::{ truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, }; use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; use crate::skills::SkillRuntime; use crate::storage::SessionStore; use crate::tools::{ToolContext, ToolRegistry}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; use std::sync::Arc; use std::time::Instant; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = include_str!("memory_tool_usage_system_prompt.md"); const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[ "image/jpeg", "image/png", "image/gif", "image/webp", ]; /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { let mut blocks = Vec::new(); // Add text block if there's text if !text.is_empty() { blocks.push(ContentBlock::text(text)); } // Add image blocks for media paths for path in media_paths { if supported_image_mime_type(path).is_none() { tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block"); continue; } if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) { let url = format!("data:{};base64,{}", mime_type, base64_data); blocks.push(ContentBlock::image_url(url)); } } // If nothing, add empty text block if blocks.is_empty() { blocks.push(ContentBlock::text("")); } blocks } fn supported_image_mime_type(path: &str) -> Option { let mime = mime_guess::from_path(path).first_or_octet_stream(); let essence = mime.essence_str(); if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) { Some(essence.to_string()) } else { None } } /// Encode an image file to base64 data URL fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; let mime = supported_image_mime_type(path).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("unsupported image media type for path: {}", path), ) })?; let mut file = std::fs::File::open(path)?; let mut buffer = Vec::new(); file.read_to_end(&mut buffer)?; let encoded = STANDARD.encode(&buffer); Ok((mime, encoded)) } /// Truncate tool result if it exceeds the configured limit. /// Preserves the end of the output as it often contains the conclusion/useful result. fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String { let total_chars = char_count(output); if total_chars <= max_tool_result_chars { return output.to_string(); } let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN); if truncated_start_len > max_tool_result_chars { // Even after removing suffix, still too long - take from beginning let head_len = max_tool_result_chars.saturating_sub(100); let head = take_prefix_chars(output, head_len); format!( "{}...\n\n[Output truncated - {} characters removed]", head, total_chars - max_tool_result_chars + 100 ) } else { // Keep most of the end which usually contains the useful result let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len)); format!( "...\n\n[Output truncated - {} characters removed]\n\n{}", truncated_start_len, tail ) } } fn parse_pending_tool_output(output: &str) -> Option { output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string()) } fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value { match arguments { serde_json::Value::String(raw) => { serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()) } _ => arguments.clone(), } } fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; let mut current = error.source(); while let Some(source) = current { details.push(source.to_string()); current = source.source(); } details.join("\ncaused by: ") } fn is_recoverable_llm_error(error: &str) -> bool { let normalized = error.to_ascii_lowercase(); normalized.contains("504") || normalized.contains("gateway timeout") || normalized.contains("stream timeout") || normalized.contains("timed out") || normalized.contains("timeout") } fn recoverable_llm_message(error: &str) -> String { if is_recoverable_llm_error(error) { RECOVERABLE_LLM_ERROR_MESSAGE.to_string() } else { format!("模型请求失败:{}", error) } } /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { /// No warning needed. Ok, /// Warning: same tool + args repeated N times. Warning(String), } /// Configuration for loop detector. #[derive(Debug, Clone)] struct LoopDetectorConfig { /// Master switch. enabled: bool, /// Warn every N consecutive identical calls. warn_every: usize, } impl Default for LoopDetectorConfig { fn default() -> Self { Self { enabled: true, warn_every: 5, } } } /// A single recorded tool invocation in the sliding window. #[derive(Debug, Clone)] struct ToolCallRecord { name: String, args_hash: u64, } /// Stateful loop detector that monitors for repetitive patterns. struct LoopDetector { config: LoopDetectorConfig, window: VecDeque, } impl LoopDetector { fn new(config: LoopDetectorConfig) -> Self { Self { window: VecDeque::with_capacity(config.warn_every * 2), config, } } /// Record a completed tool call and check for loop patterns. /// Returns Warning every `warn_every` consecutive identical calls. fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult { if !self.config.enabled { return LoopDetectionResult::Ok; } let record = ToolCallRecord { name: name.to_string(), args_hash: hash_json_value(args), }; // Maintain sliding window if self.window.len() >= self.config.warn_every * 2 { self.window.pop_front(); } self.window.push_back(record); // Count consecutive identical calls let last = self.window.back().unwrap(); let consecutive: usize = self .window .iter() .rev() .take_while(|r| r.name == last.name && r.args_hash == last.args_hash) .count(); // Warn every warn_every times if consecutive > 0 && consecutive % self.config.warn_every == 0 { LoopDetectionResult::Warning(format!( "注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。", last.name, consecutive )) } else { LoopDetectionResult::Ok } } } /// Hash a JSON value deterministically (key-order independent). fn hash_json_value(value: &serde_json::Value) -> u64 { let mut hasher = std::collections::hash_map::DefaultHasher::new(); let canonical = canonicalise_json(value); canonical.hash(&mut hasher); hasher.finish() } /// Return a clone of value with all object keys sorted recursively. fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value { match value { serde_json::Value::Object(map) => { let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect(); sorted.sort_by_key(|(k, _)| *k); let new_map: serde_json::Map = sorted .into_iter() .map(|(k, v)| (k.clone(), canonicalise_json(v))) .collect(); serde_json::Value::Object(new_map) } serde_json::Value::Array(arr) => { serde_json::Value::Array(arr.iter().map(canonicalise_json).collect()) } other => other.clone(), } } /// Convert ChatMessage to LLM Message format fn chat_message_to_llm_message(m: &ChatMessage) -> Message { let content = if m.media_refs.is_empty() { vec![ContentBlock::text(&m.content)] } else { build_content_blocks(&m.content, &m.media_refs) }; Message { role: m.role.clone(), content, reasoning_content: m.reasoning_content.clone(), tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), tool_calls: m.tool_calls.clone(), } } /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { provider_config: LLMProviderConfig, provider: Box, tools: Arc, skills: Arc, skill_event_store: Option>, skill_event_session_id: Option, tool_context: ToolContext, observer: Option>, emitted_message_handler: Option>, max_iterations: usize, } #[derive(Debug, Clone)] pub struct AgentProcessResult { pub final_response: ChatMessage, pub emitted_messages: Vec, } #[async_trait] pub trait EmittedMessageHandler: Send + Sync + 'static { async fn handle(&self, message: ChatMessage); } impl AgentLoop { pub fn new(provider_config: LLMProviderConfig) -> Result { let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { provider_config, provider, tools: Arc::new(ToolRegistry::new()), skills: Arc::new(SkillRuntime::default()), skill_event_store: None, skill_event_session_id: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, max_iterations, }) } pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { provider_config, provider, tools, skills: Arc::new(SkillRuntime::default()), skill_event_store: None, skill_event_session_id: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, max_iterations, }) } pub fn with_tools_and_skills( provider_config: LLMProviderConfig, tools: Arc, skills: Arc, ) -> Result { let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { provider_config, provider, tools, skills, skill_event_store: None, skill_event_session_id: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, max_iterations, }) } pub fn with_skill_event_store(mut self, store: Arc, session_id: String) -> Self { self.skill_event_store = Some(store); self.skill_event_session_id = Some(session_id); self } pub fn with_tool_context(mut self, context: ToolContext) -> Self { self.tool_context = context; self } /// Set an observer for tracking events. pub fn with_observer(mut self, observer: Arc) -> Self { self.observer = Some(observer); self } pub fn with_emitted_message_handler(mut self, handler: Arc) -> Self { self.emitted_message_handler = Some(handler); self } pub fn tools(&self) -> &Arc { &self.tools } /// Process a message using the provided conversation history. /// History management is handled externally by SessionManager. /// /// This method supports multi-round tool calling: after executing tools, /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached pub async fn process(&self, mut messages: Vec) -> Result { #[cfg(debug_assertions)] tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] tracing::debug!(iteration, "Agent iteration started"); // Convert messages to LLM format let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT)); messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); // Build request let mut tool_defs = self.tools.get_definitions(); if let Some(skill_tool) = self.skills.skill_tool_definition() { tool_defs.push(skill_tool); } let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) }; let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, max_tokens: None, tools, }; // Call LLM let response = match (*self.provider).chat(request).await { Ok(response) => response, Err(e) => { tracing::error!( provider = %self.provider.name(), model = %self.provider.model_id(), error = %e, error_details = %format_error_chain(e.as_ref()), "LLM request failed" ); let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } }; #[cfg(debug_assertions)] tracing::debug!( iteration, response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received" ); // If no tool calls, this is the final response if response.tool_calls.is_empty() { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) }; emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } // Execute tool calls tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); // Add assistant message with tool calls let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() { ChatMessage::assistant_with_tool_calls_and_reasoning( response.content.clone(), response.tool_calls.clone(), reasoning_content, ) } else { ChatMessage::assistant_with_tool_calls( response.content.clone(), response.tool_calls.clone(), ) }; messages.push(assistant_message.clone()); emitted_messages.push(assistant_message); self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await; // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { // Log function call with name and arguments let args_str = match &tool_call.arguments { serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), }; tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); // Truncate tool result if too large let truncated_output = truncate_tool_result( &result.output, self.provider_config.tool_result_max_chars, ); // Record tool call and check for loops let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments); match loop_result { LoopDetectionResult::Warning(msg) => { // Add warning and proceed tracing::warn!( tool = %tool_call.name, "Loop warning: {}", msg ); let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), if result.state == ToolExecutionState::PendingUserAction { ToolMessageState::PendingUserAction } else { ToolMessageState::Completed }, ); messages.push(tool_message.clone()); emitted_messages.push(tool_message); } LoopDetectionResult::Ok => { let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), truncated_output, if result.state == ToolExecutionState::PendingUserAction { ToolMessageState::PendingUserAction } else { ToolMessageState::Completed }, ); messages.push(tool_message.clone()); emitted_messages.push(tool_message); } } } if let Some((tool_call, pending_result)) = response .tool_calls .iter() .zip(tool_results.iter()) .find(|(_, result)| result.state == ToolExecutionState::PendingUserAction) { let assistant_message = ChatMessage::assistant(format!( "{}\n\n当前等待中的工具: {}", pending_result .output .lines() .next() .filter(|line| !line.trim().is_empty()) .unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE), tool_call.name, )); emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }); } // Loop continues to next iteration with updated messages #[cfg(debug_assertions)] tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); } // Max iterations reached - ask LLM for a summary based on completed work tracing::warn!("Max iterations reached, requesting final summary from LLM"); // Add a message asking for summary let summary_request = ChatMessage::user( "You have reached the maximum number of tool call iterations. \ Please provide your best answer based on the work completed so far." ); messages.push(summary_request); // Convert messages to LLM format let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, max_tokens: None, tools: None, // No tools in final summary call }; match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) }; emitted_messages.push(assistant_message.clone()); Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, }) } Err(e) => { tracing::error!( provider = %self.provider.name(), model = %self.provider.model_id(), error = %e, error_details = %format_error_chain(e.as_ref()), "Failed to get summary from LLM" ); let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(final_message.clone()); Ok(AgentProcessResult { final_response: final_message, emitted_messages, }) } } } async fn emit_live_tool_call_message(&self, message: ChatMessage) { if !message.is_assistant_tool_call_message() { return; } if let Some(handler) = &self.emitted_message_handler { handler.handle(message).await; } } /// Determine whether to execute tools in parallel or sequentially. /// /// Returns true if: /// - There are multiple tool calls /// - None of the tools require sequential execution (tool_search, non-concurrency-safe) fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool { if tool_calls.len() <= 1 { return false; } // tool_search must run sequentially to avoid MCP activation race conditions if tool_calls.iter().any(|tc| tc.name == "tool_search") { return false; } // All tools must be concurrency-safe to run in parallel tool_calls.iter().all(|tc| { self.tools .get(&tc.name) .map(|t| t.concurrency_safe()) .unwrap_or(false) }) } /// Execute multiple tool calls, choosing parallel or sequential based on conditions. async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { if self.should_execute_in_parallel(tool_calls) { tracing::debug!("Executing {} tools in parallel", tool_calls.len()); self.execute_tools_parallel(tool_calls).await } else { tracing::debug!("Executing {} tools sequentially", tool_calls.len()); self.execute_tools_sequential(tool_calls).await } } /// Execute tools in parallel using join_all. async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec { let futures: Vec<_> = tool_calls .iter() .map(|tc| self.execute_one_tool(tc)) .collect(); futures_util::future::join_all(futures).await } /// Execute tools sequentially. async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec { let mut outcomes = Vec::with_capacity(tool_calls.len()); for tool_call in tool_calls { outcomes.push(self.execute_one_tool(tool_call).await); } outcomes } /// Execute a single tool and return the outcome with event tracking. async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let start = Instant::now(); let tool_name = tool_call.name.clone(); // Record ToolCallStart event if let Some(ref observer) = self.observer { observer.record_event(&ObserverEvent::ToolCallStart { tool: tool_name.clone(), arguments: Some(truncate_args(&tool_call.arguments, 300)), }); } let result = self.execute_tool_internal(tool_call).await; let duration = start.elapsed(); // Record ToolCall event if let Some(ref observer) = self.observer { observer.record_event(&ObserverEvent::ToolCall { tool: tool_name.clone(), duration, success: result.success, }); } // Apply duration ToolExecutionOutcome { duration, ..result } } /// Internal tool execution without event tracking. async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let normalized_arguments = normalize_tool_arguments(&tool_call.arguments); if tool_call.name == "skill_activate" { let skill_name = match normalized_arguments.get("name").and_then(|v| v.as_str()) { Some(name) if !name.trim().is_empty() => name, _ => { self.record_skill_event( "activation_failed", None, serde_json::json!({ "reason": "missing_name", "arguments": normalized_arguments, }), ); return ToolExecutionOutcome::failure( "Error: Missing required parameter: name".to_string(), Some("Missing required parameter: name".to_string()), ); } }; return match self.skills.activation_payload(skill_name) { Ok(output) => { if let Ok(payload) = self.skills.activation_event_payload(skill_name) { self.record_skill_event("activated", Some(skill_name), payload); } ToolExecutionOutcome::success(output) } Err(err) => { self.record_skill_event( "activation_failed", Some(skill_name), serde_json::json!({ "reason": err, "arguments": normalized_arguments, }), ); ToolExecutionOutcome::failure( format!("Error: {}", err), Some(err), ) } }; } let tool = match self.tools.get(&tool_call.name) { Some(t) => t, None => { tracing::warn!(tool = %tool_call.name, "Tool not found"); return ToolExecutionOutcome::failure( format!("Error: Tool '{}' not found", tool_call.name), Some(format!("Tool '{}' not found", tool_call.name)), ); } }; match tool.execute_with_context(&self.tool_context, normalized_arguments.clone()).await { Ok(result) => { if result.success { if let Some(pending_output) = parse_pending_tool_output(&result.output) { ToolExecutionOutcome::pending(pending_output) } else { ToolExecutionOutcome::success(result.output) } } else { let error = result.error.unwrap_or_default(); tracing::error!( tool = %tool_call.name, args = %truncate_args(&tool_call.arguments, 2_000), normalized_args = %truncate_args(&normalized_arguments, 2_000), error = %error, output = %result.output, "Tool returned an error result" ); ToolExecutionOutcome::failure( format!("Error: {}", error), Some(error), ) } } Err(e) => { tracing::error!( tool = %tool_call.name, args = %truncate_args(&tool_call.arguments, 2_000), normalized_args = %truncate_args(&normalized_arguments, 2_000), error = %e, error_details = %format!("{:#}", e), "Tool execution failed" ); ToolExecutionOutcome::failure( format!("Error: {}", e), Some(e.to_string()), ) } } } fn record_skill_event( &self, event_type: &str, skill_name: Option<&str>, payload: serde_json::Value, ) { let (Some(store), Some(session_id)) = ( self.skill_event_store.as_ref(), self.skill_event_session_id.as_ref(), ) else { return; }; if let Err(err) = store.append_skill_event(Some(session_id), event_type, skill_name, &payload) { tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event"); } } } #[cfg(test)] mod tests { use super::*; use crate::observability::{MultiObserver, Observer}; use tempfile::tempdir; struct TestObserver { events: std::sync::Mutex>, } impl TestObserver { fn new() -> Self { Self { events: std::sync::Mutex::new(Vec::new()), } } } impl Observer for TestObserver { fn record_event(&self, event: &ObserverEvent) { self.events.lock().unwrap().push(event.clone()); } fn name(&self) -> &str { "test_observer" } } #[tokio::test] async fn test_observer_receives_tool_events() { // Verify MultiObserver works let mut multi = MultiObserver::new(); multi.add_observer(Box::new(TestObserver::new())); let event = ObserverEvent::ToolCallStart { tool: "test".to_string(), arguments: Some("{}".to_string()), }; multi.record_event(&event); // Just verify the structure works assert_eq!(multi.len(), 1); } #[test] fn test_should_execute_in_parallel_single_tool() { // Would need a proper setup with AgentLoop to test fully // For now, just verify the logic: single tool should return false let calls = vec![ToolCall { id: "1".to_string(), name: "test".to_string(), arguments: serde_json::json!({}), }]; // If there's only 1 tool, should return false regardless assert_eq!(calls.len() <= 1, true); } #[test] fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() { let chat_message = ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: serde_json::json!({ "expression": "2+2" }), }], ); let provider_message = chat_message_to_llm_message(&chat_message); assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); } #[test] fn test_chat_message_to_llm_message_preserves_reasoning_content() { let chat_message = ChatMessage::assistant_with_reasoning( "final answer", "hidden chain of thought", ); let provider_message = chat_message_to_llm_message(&chat_message); assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought")); } #[test] fn test_memory_prompt_requires_proactive_memory_search() { assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时")); assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search")); assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索")); } #[test] fn test_memory_prompt_allows_parallel_independent_tool_calls() { assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("同一轮同时返回多个 tool calls")); assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("自动并行执行")); assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("只有当后一个工具的参数依赖")); } #[test] fn test_truncate_tool_result_handles_utf8_char_boundaries() { let input = "范".repeat(20_500); let output = truncate_tool_result(&input, 20_000); assert!(output.contains("Output truncated")); assert!(output.is_char_boundary(output.len())); } #[test] fn test_parse_pending_tool_output() { let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权"); assert_eq!(output.as_deref(), Some("请完成授权")); assert!(parse_pending_tool_output("normal output").is_none()); } #[test] fn test_normalize_tool_arguments_parses_stringified_json() { let normalized = normalize_tool_arguments(&serde_json::Value::String( "{\"command\":\"ls -la\"}".to_string(), )); assert_eq!(normalized, serde_json::json!({ "command": "ls -la" })); } #[test] fn test_normalize_tool_arguments_keeps_plain_string() { let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string())); assert_eq!(normalized, serde_json::Value::String("plain text".to_string())); } #[test] fn test_build_content_blocks_skips_non_image_media_refs() { let temp_dir = tempdir().unwrap(); let pdf_path = temp_dir.path().join("demo.pdf"); std::fs::write(&pdf_path, b"%PDF-1.4").unwrap(); let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]); assert_eq!(blocks.len(), 1); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); } #[test] fn test_build_content_blocks_keeps_supported_images() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("demo.jpg"); std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap(); let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))); } } #[derive(Debug)] pub enum AgentError { ProviderCreation(String), LlmError(String), Other(String), } impl std::fmt::Display for AgentError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e), AgentError::LlmError(e) => write!(f, "LLM error: {}", e), AgentError::Other(e) => write!(f, "{}", e), } } } impl std::error::Error for AgentError {}