PicoBot/src/agent/agent_loop.rs

1377 lines
48 KiB
Rust

use crate::agent::AgentRuntimeConfig;
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::observability::{
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::io::Read;
use std::sync::Arc;
use std::time::Instant;
/// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
"工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
/// Build content blocks from text and media paths
fn build_content_blocks(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
build_content_blocks_with_image_budget(text, media_paths, budget)
}
fn build_content_blocks_with_image_budget(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
let mut skipped_image_notices = Vec::new();
// Add text block if there's text
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
// Add image blocks for media paths
for path in media_paths {
if supported_image_mime_type(path).is_none() {
tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block");
continue;
}
let Some(target_tokens) = budget.take_next_image_tokens() else {
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
skipped_image_notices.push(format!(
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
};
match encode_image_to_base64_with_budget(path, target_tokens) {
Ok((mime_type, base64_data)) => {
let url = format!("data:{};base64,{}", mime_type, base64_data);
blocks.push(ContentBlock::image_url(url));
}
Err(err) => {
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
skipped_image_notices.push(format!(
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
}
}
}
if !skipped_image_notices.is_empty() {
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
notice.push('\n');
notice.push_str(&skipped_image_notices.join("\n"));
blocks.push(ContentBlock::text(notice));
}
// If nothing, add empty text block
if blocks.is_empty() {
blocks.push(ContentBlock::text(""));
}
blocks
}
fn supported_image_mime_type(path: &str) -> Option<String> {
let mime = mime_guess::from_path(path).first_or_octet_stream();
let essence = mime.essence_str();
if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) {
Some(essence.to_string())
} else {
None
}
}
fn display_media_name(path: &str) -> String {
std::path::Path::new(path)
.file_name()
.and_then(|name| name.to_str())
.map(ToOwned::to_owned)
.unwrap_or_else(|| path.to_string())
}
#[derive(Debug, Clone)]
struct ImageInlineBudget {
remaining_tokens: usize,
remaining_images: usize,
}
impl ImageInlineBudget {
fn new(total_tokens: usize, image_count: usize) -> Self {
Self {
remaining_tokens: total_tokens,
remaining_images: image_count,
}
}
fn take_next_image_tokens(&mut self) -> Option<usize> {
if self.remaining_tokens == 0 || self.remaining_images == 0 {
self.remaining_images = self.remaining_images.saturating_sub(1);
return None;
}
let target = self.remaining_tokens.div_ceil(self.remaining_images);
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
self.remaining_images = self.remaining_images.saturating_sub(1);
Some(target)
}
}
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
let raw_len = serde_json::to_string(value)
.map(|serialized| serialized.len())
.unwrap_or_default();
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
as usize
}
fn image_token_budget_for_request(
runtime_config: &AgentRuntimeConfig,
text_only_messages: &[Message],
tools: Option<&Vec<crate::domain::tools::Tool>>,
) -> usize {
let completion_reserve = runtime_config
.provider
.max_tokens
.map(|tokens| tokens as usize)
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
let input_window = runtime_config
.context_window_tokens
.saturating_sub(completion_reserve);
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
+ tools
.map(estimate_tokens_from_serialized_json)
.unwrap_or_default();
safe_input_window.saturating_sub(text_tokens)
}
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
messages
.iter()
.flat_map(|message| message.media_refs.iter())
.filter(|path| supported_image_mime_type(path).is_some())
.count()
}
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
.saturating_mul(3)
/ 4
}
/// Encode an image file to base64 data URL
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let mime = supported_image_mime_type(path).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("unsupported image media type for path: {}", path),
)
})?;
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded = STANDARD.encode(&buffer);
Ok((mime, encoded))
}
fn encode_image_to_base64_with_budget(
path: &str,
target_tokens: usize,
) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let (mime, encoded) = encode_image_to_base64(path)?;
let target_base64_chars = target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
if encoded.len() <= target_base64_chars {
return Ok((mime, encoded));
}
let target_bytes = target_image_bytes_for_tokens(target_tokens);
let image_bytes = std::fs::read(path)?;
let image = image::load_from_memory(&image_bytes)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
}
fn compress_image_to_jpeg_budget(
image: &image::DynamicImage,
target_bytes: usize,
) -> Result<Vec<u8>, std::io::Error> {
use image::codecs::jpeg::JpegEncoder;
use image::imageops::FilterType;
let mut best: Option<Vec<u8>> = None;
let mut max_side = image
.width()
.max(image.height())
.max(MIN_COMPRESSED_IMAGE_SIDE);
loop {
let candidate = if image.width().max(image.height()) > max_side {
image.resize(max_side, max_side, FilterType::Triangle)
} else {
image.clone()
};
for quality in JPEG_QUALITY_STEPS {
let mut encoded = Vec::new();
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
encoder
.encode_image(&candidate)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
if encoded.len() <= target_bytes {
return Ok(encoded);
}
if best
.as_ref()
.map(|current| encoded.len() < current.len())
.unwrap_or(true)
{
best = Some(encoded);
}
}
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
break;
}
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
}
best.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
})
}
/// Truncate tool result if it exceeds the configured limit.
/// Preserves the end of the output as it often contains the conclusion/useful result.
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
let total_chars = char_count(output);
if total_chars <= max_tool_result_chars {
return output.to_string();
}
let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN);
if truncated_start_len > max_tool_result_chars {
// Even after removing suffix, still too long - take from beginning
let head_len = max_tool_result_chars.saturating_sub(100);
let head = take_prefix_chars(output, head_len);
format!(
"{}...\n\n[Output truncated - {} characters removed]",
head,
total_chars - max_tool_result_chars + 100
)
} else {
// Keep most of the end which usually contains the useful result
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
format!(
"...\n\n[Output truncated - {} characters removed]\n\n{}",
truncated_start_len, tail
)
}
}
fn parse_pending_tool_output(output: &str) -> Option<String> {
output
.strip_prefix(PENDING_USER_ACTION_MARKER)
.map(|rest| rest.trim().to_string())
}
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
match arguments {
serde_json::Value::String(raw) => {
serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone())
}
_ => arguments.clone(),
}
}
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
let mut current = error.source();
while let Some(source) = current {
details.push(source.to_string());
current = source.source();
}
details.join("\ncaused by: ")
}
fn is_recoverable_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn recoverable_llm_message(error: &str) -> String {
if is_recoverable_llm_error(error) {
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
} else {
format!("模型请求失败:{}", error)
}
}
/// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult {
/// No warning needed.
Ok,
/// Warning: same tool + args repeated N times.
Warning(String),
}
/// Configuration for loop detector.
#[derive(Debug, Clone)]
struct LoopDetectorConfig {
/// Master switch.
enabled: bool,
/// Warn every N consecutive identical calls.
warn_every: usize,
}
impl Default for LoopDetectorConfig {
fn default() -> Self {
Self {
enabled: true,
warn_every: 5,
}
}
}
/// A single recorded tool invocation in the sliding window.
#[derive(Debug, Clone)]
struct ToolCallRecord {
name: String,
args_hash: u64,
}
/// Stateful loop detector that monitors for repetitive patterns.
struct LoopDetector {
config: LoopDetectorConfig,
window: VecDeque<ToolCallRecord>,
}
impl LoopDetector {
fn new(config: LoopDetectorConfig) -> Self {
Self {
window: VecDeque::with_capacity(config.warn_every * 2),
config,
}
}
/// Record a completed tool call and check for loop patterns.
/// Returns Warning every `warn_every` consecutive identical calls.
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
if !self.config.enabled {
return LoopDetectionResult::Ok;
}
let record = ToolCallRecord {
name: name.to_string(),
args_hash: hash_json_value(args),
};
// Maintain sliding window
if self.window.len() >= self.config.warn_every * 2 {
self.window.pop_front();
}
self.window.push_back(record);
// Count consecutive identical calls
let last = self.window.back().unwrap();
let consecutive: usize = self
.window
.iter()
.rev()
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
.count();
// Warn every warn_every times
if consecutive > 0 && consecutive % self.config.warn_every == 0 {
LoopDetectionResult::Warning(format!(
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
last.name, consecutive
))
} else {
LoopDetectionResult::Ok
}
}
}
/// Hash a JSON value deterministically (key-order independent).
fn hash_json_value(value: &serde_json::Value) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
let canonical = canonicalise_json(value);
canonical.hash(&mut hasher);
hasher.finish()
}
/// Return a clone of value with all object keys sorted recursively.
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
sorted.sort_by_key(|(k, _)| *k);
let new_map: serde_json::Map<String, serde_json::Value> = sorted
.into_iter()
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
.collect();
serde_json::Value::Object(new_map)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
}
other => other.clone(),
}
}
/// Convert ChatMessage to LLM Message format
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs, image_budget)
};
Message {
role: m.role.clone(),
content,
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
Message {
role: m.role.clone(),
content: vec![ContentBlock::text(&m.content)],
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
runtime_config: AgentRuntimeConfig,
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
max_iterations: usize,
}
#[derive(Debug, Clone)]
pub struct AgentProcessResult {
pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>,
}
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
}
pub trait SkillProvider: Send + Sync + 'static {
fn system_index_prompt(&self) -> Option<String>;
fn matching_skill_summary(&self, _name: &str) -> Option<String> {
None
}
}
#[derive(Default)]
struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
}
impl AgentLoop {
pub fn new(config: impl Into<AgentRuntimeConfig>) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools: Arc::new(ToolRegistry::new()),
skills: Arc::new(EmptySkillProvider),
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
pub fn with_tools(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
skills: Arc::new(EmptySkillProvider),
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
pub fn with_tools_and_skill_provider(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
skills,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
self.tool_context = context;
self
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
self.emitted_message_handler = Some(handler);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
///
/// This method supports multi-round tool calling: after executing tools,
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
tracing::debug!(iteration, "Agent iteration started");
// Build request
let tool_defs = self.tools.get_definitions();
let tools = if tool_defs.is_empty() {
None
} else {
Some(tool_defs)
};
let image_count = count_supported_image_media_refs(&messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 2);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
text_only_messages.push(Message::system(skill_prompt.clone()));
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens = image_token_budget_for_request(
&self.runtime_config,
&text_only_messages,
tools.as_ref(),
);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 2);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.extend(
messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
let response = match (*self.provider).chat(request).await {
Ok(response) => response,
Err(e) => {
tracing::error!(
provider = %self.provider.name(),
model = %self.provider.model_id(),
error = %e,
error_details = %format_error_chain(e.as_ref()),
"LLM request failed"
);
let assistant_message =
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
};
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
// Execute tool calls
tracing::info!(
iteration,
count = response.tool_calls.len(),
"Tool calls detected, executing tools"
);
// Add assistant message with tool calls
let assistant_message =
if let Some(reasoning_content) = response.reasoning_content.clone() {
ChatMessage::assistant_with_tool_calls_and_reasoning(
response.content.clone(),
response.tool_calls.clone(),
reasoning_content,
)
} else {
ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
)
};
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(
emitted_messages
.last()
.expect("assistant message just pushed")
.clone(),
)
.await;
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Log function call with name and arguments
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
// Truncate tool result if too large
let truncated_output =
truncate_tool_result(&result.output, self.runtime_config.tool_result_max_chars);
// Record tool call and check for loops
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
match loop_result {
LoopDetectionResult::Warning(msg) => {
// Add warning and proceed
tracing::warn!(
tool = %tool_call.name,
"Loop warning: {}",
msg
);
let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(),
tool_call.name.clone(),
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
);
messages.push(tool_message.clone());
emitted_messages.push(tool_message);
}
LoopDetectionResult::Ok => {
let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(),
tool_call.name.clone(),
truncated_output,
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
);
messages.push(tool_message.clone());
emitted_messages.push(tool_message);
}
}
}
if let Some((tool_call, pending_result)) = response
.tool_calls
.iter()
.zip(tool_results.iter())
.find(|(_, result)| result.state == ToolExecutionState::PendingUserAction)
{
let assistant_message = ChatMessage::assistant(format!(
"{}\n\n当前等待中的工具: {}",
pending_result
.output
.lines()
.next()
.filter(|line| !line.trim().is_empty())
.unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE),
tool_call.name,
));
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
}
// Max iterations reached - ask LLM for a summary based on completed work
tracing::warn!("Max iterations reached, requesting final summary from LLM");
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far.",
);
messages.push(summary_request);
// Convert messages to LLM format
let image_count = count_supported_image_media_refs(&messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
text_only_messages.push(Message::system(skill_prompt));
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens =
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.extend(
messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools: None, // No tools in final summary call
};
match (*self.provider).chat(request).await {
Ok(response) => {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
})
}
Err(e) => {
tracing::error!(
provider = %self.provider.name(),
model = %self.provider.model_id(),
error = %e,
error_details = %format_error_chain(e.as_ref()),
"Failed to get summary from LLM"
);
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(final_message.clone());
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
})
}
}
}
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
if !message.is_assistant_tool_call_message() {
return;
}
if let Some(handler) = &self.emitted_message_handler {
handler.handle(message).await;
}
}
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:
/// - There are multiple tool calls
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
if tool_calls.len() <= 1 {
return false;
}
// tool_search must run sequentially to avoid MCP activation race conditions
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
return false;
}
// All tools must be concurrency-safe to run in parallel
tool_calls.iter().all(|tc| {
self.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false)
})
}
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
if self.should_execute_in_parallel(tool_calls) {
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
self.execute_tools_parallel(tool_calls).await
} else {
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
self.execute_tools_sequential(tool_calls).await
}
}
/// Execute tools in parallel using join_all.
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let futures: Vec<_> = tool_calls
.iter()
.map(|tc| self.execute_one_tool(tc))
.collect();
futures_util::future::join_all(futures).await
}
/// Execute tools sequentially.
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let mut outcomes = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
outcomes.push(self.execute_one_tool(tool_call).await);
}
outcomes
}
/// Execute a single tool and return the outcome with event tracking.
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let start = Instant::now();
let tool_name = tool_call.name.clone();
// Record ToolCallStart event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCallStart {
tool: tool_name.clone(),
arguments: Some(truncate_args(&tool_call.arguments, 300)),
});
}
let result = self.execute_tool_internal(tool_call).await;
let duration = start.elapsed();
// Record ToolCall event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCall {
tool: tool_name.clone(),
duration,
success: result.success,
});
}
// Apply duration
ToolExecutionOutcome { duration, ..result }
}
/// Internal tool execution without event tracking.
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let normalized_arguments = normalize_tool_arguments(&tool_call.arguments);
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
let skill_hint = self.skills.matching_skill_summary(&tool_call.name);
let error = match skill_hint {
Some(summary) => format!(
"Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.",
tool_call.name, summary, tool_call.name
),
None => format!("Tool '{}' not found", tool_call.name),
};
return ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
);
}
};
match tool
.execute_with_context(&self.tool_context, normalized_arguments.clone())
.await
{
Ok(result) => {
if result.success {
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
ToolExecutionOutcome::pending(pending_output)
} else {
ToolExecutionOutcome::success(result.output)
}
} else {
let error = result.error.unwrap_or_default();
tracing::error!(
tool = %tool_call.name,
args = %truncate_args(&tool_call.arguments, 4_000),
normalized_args = %truncate_args(&normalized_arguments, 4_000),
error = %error,
output = %result.output,
"Tool returned an error result"
);
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
}
}
Err(e) => {
tracing::error!(
tool = %tool_call.name,
args = %truncate_args(&tool_call.arguments, 4_000),
normalized_args = %truncate_args(&normalized_arguments, 4_000),
error = %e,
error_details = %format!("{:#}", e),
"Tool execution failed"
);
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::LLMProviderConfig;
use crate::observability::{MultiObserver, Observer};
use tempfile::tempdir;
struct TestObserver {
events: std::sync::Mutex<Vec<ObserverEvent>>,
}
impl TestObserver {
fn new() -> Self {
Self {
events: std::sync::Mutex::new(Vec::new()),
}
}
}
impl Observer for TestObserver {
fn record_event(&self, event: &ObserverEvent) {
self.events.lock().unwrap().push(event.clone());
}
fn name(&self) -> &str {
"test_observer"
}
}
struct TestSkillProvider;
impl SkillProvider for TestSkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
fn matching_skill_summary(&self, name: &str) -> Option<String> {
(name == "baidu-search").then(|| "用于百度搜索和天气查询的技能".to_string())
}
}
fn test_runtime_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: std::collections::HashMap::new(),
llm_timeout_secs: 120,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: std::collections::HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[tokio::test]
async fn test_missing_tool_with_same_name_skill_returns_activation_hint() {
let loop_instance = AgentLoop::with_tools_and_skill_provider(
test_runtime_config(),
Arc::new(ToolRegistry::new()),
Arc::new(TestSkillProvider),
)
.unwrap();
let outcome = loop_instance
.execute_tool_internal(&ToolCall {
id: "call-1".to_string(),
name: "baidu-search".to_string(),
arguments: serde_json::json!({
"queries": "佛山今天几点下雨"
}),
})
.await;
assert_eq!(outcome.state, ToolExecutionState::Completed);
assert!(!outcome.success);
assert!(outcome.output.contains("技能"));
assert!(outcome.output.contains("skill_activate"));
assert!(outcome.output.contains("baidu-search"));
}
#[tokio::test]
async fn test_observer_receives_tool_events() {
// Verify MultiObserver works
let mut multi = MultiObserver::new();
multi.add_observer(Box::new(TestObserver::new()));
let event = ObserverEvent::ToolCallStart {
tool: "test".to_string(),
arguments: Some("{}".to_string()),
};
multi.record_event(&event);
// Just verify the structure works
assert_eq!(multi.len(), 1);
}
#[test]
fn test_should_execute_in_parallel_single_tool() {
// Would need a proper setup with AgentLoop to test fully
// For now, just verify the logic: single tool should return false
let calls = vec![ToolCall {
id: "1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
// If there's only 1 tool, should return false regardless
assert_eq!(calls.len() <= 1, true);
}
#[test]
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
let chat_message = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "2+2" }),
}],
);
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
#[test]
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
let chat_message =
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(
provider_message.reasoning_content.as_deref(),
Some("hidden chain of thought")
);
}
#[test]
fn test_truncate_tool_result_handles_utf8_char_boundaries() {
let input = "".repeat(20_500);
let output = truncate_tool_result(&input, 20_000);
assert!(output.contains("Output truncated"));
assert!(output.is_char_boundary(output.len()));
}
#[test]
fn test_parse_pending_tool_output() {
let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
assert_eq!(output.as_deref(), Some("请完成授权"));
assert!(parse_pending_tool_output("normal output").is_none());
}
#[test]
fn test_normalize_tool_arguments_parses_stringified_json() {
let normalized = normalize_tool_arguments(&serde_json::Value::String(
"{\"command\":\"ls -la\"}".to_string(),
));
assert_eq!(normalized, serde_json::json!({ "command": "ls -la" }));
}
#[test]
fn test_normalize_tool_arguments_keeps_plain_string() {
let normalized =
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
assert_eq!(
normalized,
serde_json::Value::String("plain text".to_string())
);
}
#[test]
fn test_build_content_blocks_skips_non_image_media_refs() {
let temp_dir = tempdir().unwrap();
let pdf_path = temp_dir.path().join("demo.pdf");
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
let mut budget = ImageInlineBudget::new(1_000, 0);
let blocks = build_content_blocks(
"hello",
&[pdf_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
}
#[test]
fn test_build_content_blocks_keeps_supported_images() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(10_000, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_compresses_images_to_budget() {
let temp_dir = tempdir().unwrap();
let png_path = temp_dir.path().join("large.png");
let image = image::DynamicImage::new_rgb8(512, 512);
image.save(&png_path).unwrap();
let mut budget = ImageInlineBudget::new(512, 1);
let blocks = build_content_blocks(
"hello",
&[png_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(0, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(matches!(
&blocks[1],
ContentBlock::Text { text }
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
));
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
Other(String),
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
AgentError::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for AgentError {}