From fb0a9e06aa4136c8c42e3945a46606d680b3f6a2 Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Sun, 12 Apr 2026 13:18:16 +0800 Subject: [PATCH] feat(agent): add loop detection and result truncation for tool calls --- src/agent/agent_loop.rs | 226 ++++++++++++++++++++++++++++++++++++++-- src/tools/calculator.rs | 4 + 2 files changed, 220 insertions(+), 10 deletions(-) diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 86c43a0..60be481 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -6,10 +6,18 @@ use crate::observability::{ }; use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; use crate::tools::ToolRegistry; +use std::collections::VecDeque; +use std::hash::{Hash, Hasher}; use std::io::Read; use std::sync::Arc; use std::time::Instant; +/// Maximum characters in a tool result before truncation. +/// Prevents context overflow from large tool outputs. +const MAX_TOOL_RESULT_CHARS: usize = 16_000; +/// Minimum characters to keep when truncating +const TRUNCATION_SUFFIX_LEN: usize = 200; + /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { let mut blocks = Vec::new(); @@ -51,6 +59,145 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error Ok((mime, encoded)) } +/// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS. +/// Preserves the end of the output as it often contains the conclusion/useful result. +fn truncate_tool_result(output: &str) -> String { + if output.len() <= MAX_TOOL_RESULT_CHARS { + return output.to_string(); + } + + let truncated_start_len = output.len().saturating_sub(TRUNCATION_SUFFIX_LEN); + if truncated_start_len > MAX_TOOL_RESULT_CHARS { + // Even after removing suffix, still too long - take from beginning + format!( + "{}...\n\n[Output truncated - {} characters removed]", + &output[..MAX_TOOL_RESULT_CHARS - 100], + output.len() - MAX_TOOL_RESULT_CHARS + 100 + ) + } else { + // Keep most of the end which usually contains the useful result + format!( + "...\n\n[Output truncated - {} characters removed]\n\n{}", + truncated_start_len, + &output[truncated_start_len..] + ) + } +} + +/// Loop detection result. +#[derive(Debug, Clone, PartialEq, Eq)] +enum LoopDetectionResult { + /// No warning needed. + Ok, + /// Warning: same tool + args repeated N times. + Warning(String), +} + +/// Configuration for loop detector. +#[derive(Debug, Clone)] +struct LoopDetectorConfig { + /// Master switch. + enabled: bool, + /// Warn every N consecutive identical calls. + warn_every: usize, +} + +impl Default for LoopDetectorConfig { + fn default() -> Self { + Self { + enabled: true, + warn_every: 5, + } + } +} + +/// A single recorded tool invocation in the sliding window. +#[derive(Debug, Clone)] +struct ToolCallRecord { + name: String, + args_hash: u64, +} + +/// Stateful loop detector that monitors for repetitive patterns. +struct LoopDetector { + config: LoopDetectorConfig, + window: VecDeque, +} + +impl LoopDetector { + fn new(config: LoopDetectorConfig) -> Self { + Self { + window: VecDeque::with_capacity(config.warn_every * 2), + config, + } + } + + /// Record a completed tool call and check for loop patterns. + /// Returns Warning every `warn_every` consecutive identical calls. + fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult { + if !self.config.enabled { + return LoopDetectionResult::Ok; + } + + let record = ToolCallRecord { + name: name.to_string(), + args_hash: hash_json_value(args), + }; + + // Maintain sliding window + if self.window.len() >= self.config.warn_every * 2 { + self.window.pop_front(); + } + self.window.push_back(record); + + // Count consecutive identical calls + let last = self.window.back().unwrap(); + let consecutive: usize = self + .window + .iter() + .rev() + .take_while(|r| r.name == last.name && r.args_hash == last.args_hash) + .count(); + + // Warn every warn_every times + if consecutive > 0 && consecutive % self.config.warn_every == 0 { + LoopDetectionResult::Warning(format!( + "注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。", + last.name, consecutive + )) + } else { + LoopDetectionResult::Ok + } + } +} + +/// Hash a JSON value deterministically (key-order independent). +fn hash_json_value(value: &serde_json::Value) -> u64 { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + let canonical = canonicalise_json(value); + canonical.hash(&mut hasher); + hasher.finish() +} + +/// Return a clone of value with all object keys sorted recursively. +fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value { + match value { + serde_json::Value::Object(map) => { + let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect(); + sorted.sort_by_key(|(k, _)| *k); + let new_map: serde_json::Map = sorted + .into_iter() + .map(|(k, v)| (k.clone(), canonicalise_json(v))) + .collect(); + serde_json::Value::Object(new_map) + } + serde_json::Value::Array(arr) => { + serde_json::Value::Array(arr.iter().map(canonicalise_json).collect()) + } + other => other.clone(), + } +} + /// Convert ChatMessage to LLM Message format fn chat_message_to_llm_message(m: &ChatMessage) -> Message { let content = if m.media_refs.is_empty() { @@ -124,6 +271,9 @@ impl AgentLoop { #[cfg(debug_assertions)] tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); + // Track tool calls for loop detection + let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); + for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] tracing::debug!(iteration, "Agent iteration started"); @@ -180,12 +330,36 @@ impl AgentLoop { let tool_results = self.execute_tools(&response.tool_calls).await; for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { - let tool_message = ChatMessage::tool( - tool_call.id.clone(), - tool_call.name.clone(), - result.output.clone(), - ); - messages.push(tool_message); + // Truncate tool result if too large + let truncated_output = truncate_tool_result(&result.output); + + // Record tool call and check for loops + let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments); + + match loop_result { + LoopDetectionResult::Warning(msg) => { + // Add warning and proceed + tracing::warn!( + tool = %tool_call.name, + "Loop warning: {}", + msg + ); + let tool_message = ChatMessage::tool( + tool_call.id.clone(), + tool_call.name.clone(), + format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), + ); + messages.push(tool_message); + } + LoopDetectionResult::Ok => { + let tool_message = ChatMessage::tool( + tool_call.id.clone(), + tool_call.name.clone(), + truncated_output, + ); + messages.push(tool_message); + } + } } // Loop continues to next iteration with updated messages @@ -193,11 +367,43 @@ impl AgentLoop { tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); } - // Max iterations reached - let final_message = ChatMessage::assistant( - format!("I reached the maximum number of tool call iterations ({}) without completing the task. Please try breaking the task into smaller steps.", self.max_iterations) + // Max iterations reached - ask LLM for a summary based on completed work + tracing::warn!("Max iterations reached, requesting final summary from LLM"); + + // Add a message asking for summary + let summary_request = ChatMessage::user( + "You have reached the maximum number of tool call iterations. \ + Please provide your best answer based on the work completed so far." ); - Ok(final_message) + messages.push(summary_request); + + // Convert messages to LLM format + let messages_for_llm: Vec = messages + .iter() + .map(chat_message_to_llm_message) + .collect(); + + let request = ChatCompletionRequest { + messages: messages_for_llm, + temperature: None, + max_tokens: None, + tools: None, // No tools in final summary call + }; + + match (*self.provider).chat(request).await { + Ok(response) => { + let assistant_message = ChatMessage::assistant(response.content); + Ok(assistant_message) + } + Err(e) => { + // Fallback if summary call fails + tracing::error!(error = %e, "Failed to get summary from LLM"); + let final_message = ChatMessage::assistant( + format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) + ); + Ok(final_message) + } + } } /// Determine whether to execute tools in parallel or sequentially. diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index de29b73..cff93f9 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -30,6 +30,10 @@ impl Tool for CalculatorTool { Use this tool whenever you need to compute a numeric result instead of guessing." } + fn read_only(&self) -> bool { + true + } + fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object",