feat(agent): add loop detection and result truncation for tool calls
This commit is contained in:
parent
3d72f3dfa8
commit
fb0a9e06aa
@ -6,10 +6,18 @@ use crate::observability::{
|
|||||||
};
|
};
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
/// Maximum characters in a tool result before truncation.
|
||||||
|
/// Prevents context overflow from large tool outputs.
|
||||||
|
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
|
||||||
|
/// Minimum characters to keep when truncating
|
||||||
|
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||||||
|
|
||||||
/// Build content blocks from text and media paths
|
/// Build content blocks from text and media paths
|
||||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||||
let mut blocks = Vec::new();
|
let mut blocks = Vec::new();
|
||||||
@ -51,6 +59,145 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
|
|||||||
Ok((mime, encoded))
|
Ok((mime, encoded))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS.
|
||||||
|
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
||||||
|
fn truncate_tool_result(output: &str) -> String {
|
||||||
|
if output.len() <= MAX_TOOL_RESULT_CHARS {
|
||||||
|
return output.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let truncated_start_len = output.len().saturating_sub(TRUNCATION_SUFFIX_LEN);
|
||||||
|
if truncated_start_len > MAX_TOOL_RESULT_CHARS {
|
||||||
|
// Even after removing suffix, still too long - take from beginning
|
||||||
|
format!(
|
||||||
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
|
&output[..MAX_TOOL_RESULT_CHARS - 100],
|
||||||
|
output.len() - MAX_TOOL_RESULT_CHARS + 100
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Keep most of the end which usually contains the useful result
|
||||||
|
format!(
|
||||||
|
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
||||||
|
truncated_start_len,
|
||||||
|
&output[truncated_start_len..]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loop detection result.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
enum LoopDetectionResult {
|
||||||
|
/// No warning needed.
|
||||||
|
Ok,
|
||||||
|
/// Warning: same tool + args repeated N times.
|
||||||
|
Warning(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for loop detector.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LoopDetectorConfig {
|
||||||
|
/// Master switch.
|
||||||
|
enabled: bool,
|
||||||
|
/// Warn every N consecutive identical calls.
|
||||||
|
warn_every: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LoopDetectorConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: true,
|
||||||
|
warn_every: 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single recorded tool invocation in the sliding window.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct ToolCallRecord {
|
||||||
|
name: String,
|
||||||
|
args_hash: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stateful loop detector that monitors for repetitive patterns.
|
||||||
|
struct LoopDetector {
|
||||||
|
config: LoopDetectorConfig,
|
||||||
|
window: VecDeque<ToolCallRecord>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoopDetector {
|
||||||
|
fn new(config: LoopDetectorConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
window: VecDeque::with_capacity(config.warn_every * 2),
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a completed tool call and check for loop patterns.
|
||||||
|
/// Returns Warning every `warn_every` consecutive identical calls.
|
||||||
|
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
|
||||||
|
if !self.config.enabled {
|
||||||
|
return LoopDetectionResult::Ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
let record = ToolCallRecord {
|
||||||
|
name: name.to_string(),
|
||||||
|
args_hash: hash_json_value(args),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Maintain sliding window
|
||||||
|
if self.window.len() >= self.config.warn_every * 2 {
|
||||||
|
self.window.pop_front();
|
||||||
|
}
|
||||||
|
self.window.push_back(record);
|
||||||
|
|
||||||
|
// Count consecutive identical calls
|
||||||
|
let last = self.window.back().unwrap();
|
||||||
|
let consecutive: usize = self
|
||||||
|
.window
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
// Warn every warn_every times
|
||||||
|
if consecutive > 0 && consecutive % self.config.warn_every == 0 {
|
||||||
|
LoopDetectionResult::Warning(format!(
|
||||||
|
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
|
||||||
|
last.name, consecutive
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
LoopDetectionResult::Ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hash a JSON value deterministically (key-order independent).
|
||||||
|
fn hash_json_value(value: &serde_json::Value) -> u64 {
|
||||||
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||||
|
let canonical = canonicalise_json(value);
|
||||||
|
canonical.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a clone of value with all object keys sorted recursively.
|
||||||
|
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
|
||||||
|
match value {
|
||||||
|
serde_json::Value::Object(map) => {
|
||||||
|
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
|
||||||
|
sorted.sort_by_key(|(k, _)| *k);
|
||||||
|
let new_map: serde_json::Map<String, serde_json::Value> = sorted
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
|
||||||
|
.collect();
|
||||||
|
serde_json::Value::Object(new_map)
|
||||||
|
}
|
||||||
|
serde_json::Value::Array(arr) => {
|
||||||
|
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
|
||||||
|
}
|
||||||
|
other => other.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert ChatMessage to LLM Message format
|
/// Convert ChatMessage to LLM Message format
|
||||||
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
||||||
let content = if m.media_refs.is_empty() {
|
let content = if m.media_refs.is_empty() {
|
||||||
@ -124,6 +271,9 @@ impl AgentLoop {
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||||
|
|
||||||
|
// Track tool calls for loop detection
|
||||||
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||||
|
|
||||||
for iteration in 0..self.max_iterations {
|
for iteration in 0..self.max_iterations {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, "Agent iteration started");
|
tracing::debug!(iteration, "Agent iteration started");
|
||||||
@ -180,25 +330,81 @@ impl AgentLoop {
|
|||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
|
|
||||||
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
||||||
|
// Truncate tool result if too large
|
||||||
|
let truncated_output = truncate_tool_result(&result.output);
|
||||||
|
|
||||||
|
// Record tool call and check for loops
|
||||||
|
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
|
||||||
|
|
||||||
|
match loop_result {
|
||||||
|
LoopDetectionResult::Warning(msg) => {
|
||||||
|
// Add warning and proceed
|
||||||
|
tracing::warn!(
|
||||||
|
tool = %tool_call.name,
|
||||||
|
"Loop warning: {}",
|
||||||
|
msg
|
||||||
|
);
|
||||||
let tool_message = ChatMessage::tool(
|
let tool_message = ChatMessage::tool(
|
||||||
tool_call.id.clone(),
|
tool_call.id.clone(),
|
||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
result.output.clone(),
|
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
||||||
);
|
);
|
||||||
messages.push(tool_message);
|
messages.push(tool_message);
|
||||||
}
|
}
|
||||||
|
LoopDetectionResult::Ok => {
|
||||||
|
let tool_message = ChatMessage::tool(
|
||||||
|
tool_call.id.clone(),
|
||||||
|
tool_call.name.clone(),
|
||||||
|
truncated_output,
|
||||||
|
);
|
||||||
|
messages.push(tool_message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Loop continues to next iteration with updated messages
|
// Loop continues to next iteration with updated messages
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max iterations reached
|
// Max iterations reached - ask LLM for a summary based on completed work
|
||||||
|
tracing::warn!("Max iterations reached, requesting final summary from LLM");
|
||||||
|
|
||||||
|
// Add a message asking for summary
|
||||||
|
let summary_request = ChatMessage::user(
|
||||||
|
"You have reached the maximum number of tool call iterations. \
|
||||||
|
Please provide your best answer based on the work completed so far."
|
||||||
|
);
|
||||||
|
messages.push(summary_request);
|
||||||
|
|
||||||
|
// Convert messages to LLM format
|
||||||
|
let messages_for_llm: Vec<Message> = messages
|
||||||
|
.iter()
|
||||||
|
.map(chat_message_to_llm_message)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: messages_for_llm,
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None, // No tools in final summary call
|
||||||
|
};
|
||||||
|
|
||||||
|
match (*self.provider).chat(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
|
Ok(assistant_message)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Fallback if summary call fails
|
||||||
|
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||||
let final_message = ChatMessage::assistant(
|
let final_message = ChatMessage::assistant(
|
||||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. Please try breaking the task into smaller steps.", self.max_iterations)
|
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||||
);
|
);
|
||||||
Ok(final_message)
|
Ok(final_message)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Determine whether to execute tools in parallel or sequentially.
|
/// Determine whether to execute tools in parallel or sequentially.
|
||||||
///
|
///
|
||||||
|
|||||||
@ -30,6 +30,10 @@ impl Tool for CalculatorTool {
|
|||||||
Use this tool whenever you need to compute a numeric result instead of guessing."
|
Use this tool whenever you need to compute a numeric result instead of guessing."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
json!({
|
json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user