987 lines
36 KiB
Rust
987 lines
36 KiB
Rust
use async_trait::async_trait;
|
|
use crate::bus::message::ContentBlock;
|
|
use crate::bus::ChatMessage;
|
|
use crate::bus::message::ToolMessageState;
|
|
use crate::config::LLMProviderConfig;
|
|
use crate::observability::{
|
|
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState,
|
|
};
|
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
|
use crate::skills::SkillRuntime;
|
|
use crate::storage::SessionStore;
|
|
use crate::tools::{ToolContext, ToolRegistry};
|
|
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
|
use std::collections::VecDeque;
|
|
use std::hash::{Hash, Hasher};
|
|
use std::io::Read;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
/// Minimum characters to keep when truncating
|
|
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
|
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str =
|
|
include_str!("memory_tool_usage_system_prompt.md");
|
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
|
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
|
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
|
|
|
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[
|
|
"image/jpeg",
|
|
"image/png",
|
|
"image/gif",
|
|
"image/webp",
|
|
];
|
|
|
|
/// Build content blocks from text and media paths
|
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
|
let mut blocks = Vec::new();
|
|
|
|
// Add text block if there's text
|
|
if !text.is_empty() {
|
|
blocks.push(ContentBlock::text(text));
|
|
}
|
|
|
|
// Add image blocks for media paths
|
|
for path in media_paths {
|
|
if supported_image_mime_type(path).is_none() {
|
|
tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block");
|
|
continue;
|
|
}
|
|
|
|
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
|
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
blocks.push(ContentBlock::image_url(url));
|
|
}
|
|
}
|
|
|
|
// If nothing, add empty text block
|
|
if blocks.is_empty() {
|
|
blocks.push(ContentBlock::text(""));
|
|
}
|
|
|
|
blocks
|
|
}
|
|
|
|
fn supported_image_mime_type(path: &str) -> Option<String> {
|
|
let mime = mime_guess::from_path(path).first_or_octet_stream();
|
|
let essence = mime.essence_str();
|
|
|
|
if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) {
|
|
Some(essence.to_string())
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Encode an image file to base64 data URL
|
|
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
|
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
|
|
|
let mime = supported_image_mime_type(path).ok_or_else(|| {
|
|
std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
format!("unsupported image media type for path: {}", path),
|
|
)
|
|
})?;
|
|
|
|
let mut file = std::fs::File::open(path)?;
|
|
let mut buffer = Vec::new();
|
|
file.read_to_end(&mut buffer)?;
|
|
|
|
let encoded = STANDARD.encode(&buffer);
|
|
Ok((mime, encoded))
|
|
}
|
|
|
|
/// Truncate tool result if it exceeds the configured limit.
|
|
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
|
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
|
let total_chars = char_count(output);
|
|
if total_chars <= max_tool_result_chars {
|
|
return output.to_string();
|
|
}
|
|
|
|
let truncated_start_len = total_chars.saturating_sub(TRUNCATION_SUFFIX_LEN);
|
|
if truncated_start_len > max_tool_result_chars {
|
|
// Even after removing suffix, still too long - take from beginning
|
|
let head_len = max_tool_result_chars.saturating_sub(100);
|
|
let head = take_prefix_chars(output, head_len);
|
|
format!(
|
|
"{}...\n\n[Output truncated - {} characters removed]",
|
|
head,
|
|
total_chars - max_tool_result_chars + 100
|
|
)
|
|
} else {
|
|
// Keep most of the end which usually contains the useful result
|
|
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
|
|
format!(
|
|
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
|
truncated_start_len,
|
|
tail
|
|
)
|
|
}
|
|
}
|
|
|
|
fn parse_pending_tool_output(output: &str) -> Option<String> {
|
|
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
|
}
|
|
|
|
fn is_recoverable_llm_error(error: &str) -> bool {
|
|
let normalized = error.to_ascii_lowercase();
|
|
normalized.contains("504")
|
|
|| normalized.contains("gateway timeout")
|
|
|| normalized.contains("stream timeout")
|
|
|| normalized.contains("timed out")
|
|
|| normalized.contains("timeout")
|
|
}
|
|
|
|
fn recoverable_llm_message(error: &str) -> String {
|
|
if is_recoverable_llm_error(error) {
|
|
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
|
|
} else {
|
|
format!("模型请求失败:{}", error)
|
|
}
|
|
}
|
|
|
|
/// Loop detection result.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
enum LoopDetectionResult {
|
|
/// No warning needed.
|
|
Ok,
|
|
/// Warning: same tool + args repeated N times.
|
|
Warning(String),
|
|
}
|
|
|
|
/// Configuration for loop detector.
|
|
#[derive(Debug, Clone)]
|
|
struct LoopDetectorConfig {
|
|
/// Master switch.
|
|
enabled: bool,
|
|
/// Warn every N consecutive identical calls.
|
|
warn_every: usize,
|
|
}
|
|
|
|
impl Default for LoopDetectorConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
enabled: true,
|
|
warn_every: 5,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A single recorded tool invocation in the sliding window.
|
|
#[derive(Debug, Clone)]
|
|
struct ToolCallRecord {
|
|
name: String,
|
|
args_hash: u64,
|
|
}
|
|
|
|
/// Stateful loop detector that monitors for repetitive patterns.
|
|
struct LoopDetector {
|
|
config: LoopDetectorConfig,
|
|
window: VecDeque<ToolCallRecord>,
|
|
}
|
|
|
|
impl LoopDetector {
|
|
fn new(config: LoopDetectorConfig) -> Self {
|
|
Self {
|
|
window: VecDeque::with_capacity(config.warn_every * 2),
|
|
config,
|
|
}
|
|
}
|
|
|
|
/// Record a completed tool call and check for loop patterns.
|
|
/// Returns Warning every `warn_every` consecutive identical calls.
|
|
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
|
|
if !self.config.enabled {
|
|
return LoopDetectionResult::Ok;
|
|
}
|
|
|
|
let record = ToolCallRecord {
|
|
name: name.to_string(),
|
|
args_hash: hash_json_value(args),
|
|
};
|
|
|
|
// Maintain sliding window
|
|
if self.window.len() >= self.config.warn_every * 2 {
|
|
self.window.pop_front();
|
|
}
|
|
self.window.push_back(record);
|
|
|
|
// Count consecutive identical calls
|
|
let last = self.window.back().unwrap();
|
|
let consecutive: usize = self
|
|
.window
|
|
.iter()
|
|
.rev()
|
|
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
|
|
.count();
|
|
|
|
// Warn every warn_every times
|
|
if consecutive > 0 && consecutive % self.config.warn_every == 0 {
|
|
LoopDetectionResult::Warning(format!(
|
|
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
|
|
last.name, consecutive
|
|
))
|
|
} else {
|
|
LoopDetectionResult::Ok
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Hash a JSON value deterministically (key-order independent).
|
|
fn hash_json_value(value: &serde_json::Value) -> u64 {
|
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
|
let canonical = canonicalise_json(value);
|
|
canonical.hash(&mut hasher);
|
|
hasher.finish()
|
|
}
|
|
|
|
/// Return a clone of value with all object keys sorted recursively.
|
|
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
|
|
match value {
|
|
serde_json::Value::Object(map) => {
|
|
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
|
|
sorted.sort_by_key(|(k, _)| *k);
|
|
let new_map: serde_json::Map<String, serde_json::Value> = sorted
|
|
.into_iter()
|
|
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
|
|
.collect();
|
|
serde_json::Value::Object(new_map)
|
|
}
|
|
serde_json::Value::Array(arr) => {
|
|
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
|
|
}
|
|
other => other.clone(),
|
|
}
|
|
}
|
|
|
|
/// Convert ChatMessage to LLM Message format
|
|
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|
let content = if m.media_refs.is_empty() {
|
|
vec![ContentBlock::text(&m.content)]
|
|
} else {
|
|
build_content_blocks(&m.content, &m.media_refs)
|
|
};
|
|
|
|
Message {
|
|
role: m.role.clone(),
|
|
content,
|
|
reasoning_content: m.reasoning_content.clone(),
|
|
tool_call_id: m.tool_call_id.clone(),
|
|
name: m.tool_name.clone(),
|
|
tool_calls: m.tool_calls.clone(),
|
|
}
|
|
}
|
|
|
|
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
|
/// History is managed externally by SessionManager.
|
|
pub struct AgentLoop {
|
|
provider_config: LLMProviderConfig,
|
|
provider: Box<dyn LLMProvider>,
|
|
tools: Arc<ToolRegistry>,
|
|
skills: Arc<SkillRuntime>,
|
|
skill_event_store: Option<Arc<SessionStore>>,
|
|
skill_event_session_id: Option<String>,
|
|
tool_context: ToolContext,
|
|
observer: Option<Arc<dyn Observer>>,
|
|
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
|
|
max_iterations: usize,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct AgentProcessResult {
|
|
pub final_response: ChatMessage,
|
|
pub emitted_messages: Vec<ChatMessage>,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait EmittedMessageHandler: Send + Sync + 'static {
|
|
async fn handle(&self, message: ChatMessage);
|
|
}
|
|
|
|
impl AgentLoop {
|
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
|
let max_iterations = provider_config.max_tool_iterations;
|
|
let provider = create_provider(provider_config.clone())
|
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
provider_config,
|
|
provider,
|
|
tools: Arc::new(ToolRegistry::new()),
|
|
skills: Arc::new(SkillRuntime::default()),
|
|
skill_event_store: None,
|
|
skill_event_session_id: None,
|
|
tool_context: ToolContext::default(),
|
|
observer: None,
|
|
emitted_message_handler: None,
|
|
max_iterations,
|
|
})
|
|
}
|
|
|
|
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
|
let max_iterations = provider_config.max_tool_iterations;
|
|
let provider = create_provider(provider_config.clone())
|
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
provider_config,
|
|
provider,
|
|
tools,
|
|
skills: Arc::new(SkillRuntime::default()),
|
|
skill_event_store: None,
|
|
skill_event_session_id: None,
|
|
tool_context: ToolContext::default(),
|
|
observer: None,
|
|
emitted_message_handler: None,
|
|
max_iterations,
|
|
})
|
|
}
|
|
|
|
pub fn with_tools_and_skills(
|
|
provider_config: LLMProviderConfig,
|
|
tools: Arc<ToolRegistry>,
|
|
skills: Arc<SkillRuntime>,
|
|
) -> Result<Self, AgentError> {
|
|
let max_iterations = provider_config.max_tool_iterations;
|
|
let provider = create_provider(provider_config.clone())
|
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
provider_config,
|
|
provider,
|
|
tools,
|
|
skills,
|
|
skill_event_store: None,
|
|
skill_event_session_id: None,
|
|
tool_context: ToolContext::default(),
|
|
observer: None,
|
|
emitted_message_handler: None,
|
|
max_iterations,
|
|
})
|
|
}
|
|
|
|
pub fn with_skill_event_store(mut self, store: Arc<SessionStore>, session_id: String) -> Self {
|
|
self.skill_event_store = Some(store);
|
|
self.skill_event_session_id = Some(session_id);
|
|
self
|
|
}
|
|
|
|
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
|
|
self.tool_context = context;
|
|
self
|
|
}
|
|
|
|
/// Set an observer for tracking events.
|
|
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
|
self.observer = Some(observer);
|
|
self
|
|
}
|
|
|
|
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
|
|
self.emitted_message_handler = Some(handler);
|
|
self
|
|
}
|
|
|
|
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
|
&self.tools
|
|
}
|
|
|
|
/// Process a message using the provided conversation history.
|
|
/// History management is handled externally by SessionManager.
|
|
///
|
|
/// This method supports multi-round tool calling: after executing tools,
|
|
/// it loops back to the LLM with the tool results until either:
|
|
/// - The LLM returns no more tool calls (final response)
|
|
/// - Maximum iterations are reached
|
|
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
|
|
|
// Track tool calls for loop detection
|
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
|
let mut emitted_messages = Vec::new();
|
|
|
|
for iteration in 0..self.max_iterations {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(iteration, "Agent iteration started");
|
|
|
|
// Convert messages to LLM format
|
|
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
|
messages_for_llm.push(Message::system(skill_prompt));
|
|
}
|
|
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
|
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
|
|
|
// Build request
|
|
let mut tool_defs = self.tools.get_definitions();
|
|
if let Some(skill_tool) = self.skills.skill_tool_definition() {
|
|
tool_defs.push(skill_tool);
|
|
}
|
|
let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) };
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: messages_for_llm,
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools,
|
|
};
|
|
|
|
// Call LLM
|
|
let response = match (*self.provider).chat(request).await {
|
|
Ok(response) => response,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "LLM request failed");
|
|
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
|
emitted_messages.push(assistant_message.clone());
|
|
return Ok(AgentProcessResult {
|
|
final_response: assistant_message,
|
|
emitted_messages,
|
|
});
|
|
}
|
|
};
|
|
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(
|
|
iteration,
|
|
response_len = response.content.len(),
|
|
tool_calls_len = response.tool_calls.len(),
|
|
"LLM response received"
|
|
);
|
|
|
|
// If no tool calls, this is the final response
|
|
if response.tool_calls.is_empty() {
|
|
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
|
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
|
} else {
|
|
ChatMessage::assistant(response.content)
|
|
};
|
|
emitted_messages.push(assistant_message.clone());
|
|
return Ok(AgentProcessResult {
|
|
final_response: assistant_message,
|
|
emitted_messages,
|
|
});
|
|
}
|
|
|
|
// Execute tool calls
|
|
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
|
|
|
// Add assistant message with tool calls
|
|
let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() {
|
|
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
|
response.content.clone(),
|
|
response.tool_calls.clone(),
|
|
reasoning_content,
|
|
)
|
|
} else {
|
|
ChatMessage::assistant_with_tool_calls(
|
|
response.content.clone(),
|
|
response.tool_calls.clone(),
|
|
)
|
|
};
|
|
messages.push(assistant_message.clone());
|
|
emitted_messages.push(assistant_message);
|
|
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
|
|
|
|
// Execute tools and add results to messages
|
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
|
|
|
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
|
// Log function call with name and arguments
|
|
let args_str = match &tool_call.arguments {
|
|
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
|
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
|
};
|
|
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
|
|
|
// Truncate tool result if too large
|
|
let truncated_output = truncate_tool_result(
|
|
&result.output,
|
|
self.provider_config.tool_result_max_chars,
|
|
);
|
|
|
|
// Record tool call and check for loops
|
|
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
|
|
|
|
match loop_result {
|
|
LoopDetectionResult::Warning(msg) => {
|
|
// Add warning and proceed
|
|
tracing::warn!(
|
|
tool = %tool_call.name,
|
|
"Loop warning: {}",
|
|
msg
|
|
);
|
|
let tool_message = ChatMessage::tool_with_state(
|
|
tool_call.id.clone(),
|
|
tool_call.name.clone(),
|
|
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
|
if result.state == ToolExecutionState::PendingUserAction {
|
|
ToolMessageState::PendingUserAction
|
|
} else {
|
|
ToolMessageState::Completed
|
|
},
|
|
);
|
|
messages.push(tool_message.clone());
|
|
emitted_messages.push(tool_message);
|
|
}
|
|
LoopDetectionResult::Ok => {
|
|
let tool_message = ChatMessage::tool_with_state(
|
|
tool_call.id.clone(),
|
|
tool_call.name.clone(),
|
|
truncated_output,
|
|
if result.state == ToolExecutionState::PendingUserAction {
|
|
ToolMessageState::PendingUserAction
|
|
} else {
|
|
ToolMessageState::Completed
|
|
},
|
|
);
|
|
messages.push(tool_message.clone());
|
|
emitted_messages.push(tool_message);
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some((tool_call, pending_result)) = response
|
|
.tool_calls
|
|
.iter()
|
|
.zip(tool_results.iter())
|
|
.find(|(_, result)| result.state == ToolExecutionState::PendingUserAction)
|
|
{
|
|
let assistant_message = ChatMessage::assistant(format!(
|
|
"{}\n\n当前等待中的工具: {}",
|
|
pending_result
|
|
.output
|
|
.lines()
|
|
.next()
|
|
.filter(|line| !line.trim().is_empty())
|
|
.unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE),
|
|
tool_call.name,
|
|
));
|
|
emitted_messages.push(assistant_message.clone());
|
|
return Ok(AgentProcessResult {
|
|
final_response: assistant_message,
|
|
emitted_messages,
|
|
});
|
|
}
|
|
|
|
// Loop continues to next iteration with updated messages
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
|
}
|
|
|
|
// Max iterations reached - ask LLM for a summary based on completed work
|
|
tracing::warn!("Max iterations reached, requesting final summary from LLM");
|
|
|
|
// Add a message asking for summary
|
|
let summary_request = ChatMessage::user(
|
|
"You have reached the maximum number of tool call iterations. \
|
|
Please provide your best answer based on the work completed so far."
|
|
);
|
|
messages.push(summary_request);
|
|
|
|
// Convert messages to LLM format
|
|
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
|
messages_for_llm.push(Message::system(skill_prompt));
|
|
}
|
|
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: messages_for_llm,
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools: None, // No tools in final summary call
|
|
};
|
|
|
|
match (*self.provider).chat(request).await {
|
|
Ok(response) => {
|
|
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
|
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
|
} else {
|
|
ChatMessage::assistant(response.content)
|
|
};
|
|
emitted_messages.push(assistant_message.clone());
|
|
Ok(AgentProcessResult {
|
|
final_response: assistant_message,
|
|
emitted_messages,
|
|
})
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to get summary from LLM");
|
|
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
|
emitted_messages.push(final_message.clone());
|
|
Ok(AgentProcessResult {
|
|
final_response: final_message,
|
|
emitted_messages,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
|
|
if !message.is_assistant_tool_call_message() {
|
|
return;
|
|
}
|
|
|
|
if let Some(handler) = &self.emitted_message_handler {
|
|
handler.handle(message).await;
|
|
}
|
|
}
|
|
|
|
/// Determine whether to execute tools in parallel or sequentially.
|
|
///
|
|
/// Returns true if:
|
|
/// - There are multiple tool calls
|
|
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
|
|
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
|
|
if tool_calls.len() <= 1 {
|
|
return false;
|
|
}
|
|
|
|
// tool_search must run sequentially to avoid MCP activation race conditions
|
|
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
|
|
return false;
|
|
}
|
|
|
|
// All tools must be concurrency-safe to run in parallel
|
|
tool_calls.iter().all(|tc| {
|
|
self.tools
|
|
.get(&tc.name)
|
|
.map(|t| t.concurrency_safe())
|
|
.unwrap_or(false)
|
|
})
|
|
}
|
|
|
|
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
|
|
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
if self.should_execute_in_parallel(tool_calls) {
|
|
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
|
|
self.execute_tools_parallel(tool_calls).await
|
|
} else {
|
|
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
|
|
self.execute_tools_sequential(tool_calls).await
|
|
}
|
|
}
|
|
|
|
/// Execute tools in parallel using join_all.
|
|
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
let futures: Vec<_> = tool_calls
|
|
.iter()
|
|
.map(|tc| self.execute_one_tool(tc))
|
|
.collect();
|
|
|
|
futures_util::future::join_all(futures).await
|
|
}
|
|
|
|
/// Execute tools sequentially.
|
|
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
let mut outcomes = Vec::with_capacity(tool_calls.len());
|
|
|
|
for tool_call in tool_calls {
|
|
outcomes.push(self.execute_one_tool(tool_call).await);
|
|
}
|
|
|
|
outcomes
|
|
}
|
|
|
|
/// Execute a single tool and return the outcome with event tracking.
|
|
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
|
let start = Instant::now();
|
|
let tool_name = tool_call.name.clone();
|
|
|
|
// Record ToolCallStart event
|
|
if let Some(ref observer) = self.observer {
|
|
observer.record_event(&ObserverEvent::ToolCallStart {
|
|
tool: tool_name.clone(),
|
|
arguments: Some(truncate_args(&tool_call.arguments, 300)),
|
|
});
|
|
}
|
|
|
|
let result = self.execute_tool_internal(tool_call).await;
|
|
let duration = start.elapsed();
|
|
|
|
// Record ToolCall event
|
|
if let Some(ref observer) = self.observer {
|
|
observer.record_event(&ObserverEvent::ToolCall {
|
|
tool: tool_name.clone(),
|
|
duration,
|
|
success: result.success,
|
|
});
|
|
}
|
|
|
|
// Apply duration
|
|
ToolExecutionOutcome {
|
|
duration,
|
|
..result
|
|
}
|
|
}
|
|
|
|
/// Internal tool execution without event tracking.
|
|
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
|
if tool_call.name == "skill_activate" {
|
|
let skill_name = match tool_call.arguments.get("name").and_then(|v| v.as_str()) {
|
|
Some(name) if !name.trim().is_empty() => name,
|
|
_ => {
|
|
self.record_skill_event(
|
|
"activation_failed",
|
|
None,
|
|
serde_json::json!({
|
|
"reason": "missing_name",
|
|
"arguments": tool_call.arguments,
|
|
}),
|
|
);
|
|
return ToolExecutionOutcome::failure(
|
|
"Error: Missing required parameter: name".to_string(),
|
|
Some("Missing required parameter: name".to_string()),
|
|
);
|
|
}
|
|
};
|
|
|
|
return match self.skills.activation_payload(skill_name) {
|
|
Ok(output) => {
|
|
if let Ok(payload) = self.skills.activation_event_payload(skill_name) {
|
|
self.record_skill_event("activated", Some(skill_name), payload);
|
|
}
|
|
ToolExecutionOutcome::success(output)
|
|
}
|
|
Err(err) => {
|
|
self.record_skill_event(
|
|
"activation_failed",
|
|
Some(skill_name),
|
|
serde_json::json!({
|
|
"reason": err,
|
|
"arguments": tool_call.arguments,
|
|
}),
|
|
);
|
|
ToolExecutionOutcome::failure(
|
|
format!("Error: {}", err),
|
|
Some(err),
|
|
)
|
|
}
|
|
};
|
|
}
|
|
|
|
let tool = match self.tools.get(&tool_call.name) {
|
|
Some(t) => t,
|
|
None => {
|
|
tracing::warn!(tool = %tool_call.name, "Tool not found");
|
|
return ToolExecutionOutcome::failure(
|
|
format!("Error: Tool '{}' not found", tool_call.name),
|
|
Some(format!("Tool '{}' not found", tool_call.name)),
|
|
);
|
|
}
|
|
};
|
|
|
|
match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await {
|
|
Ok(result) => {
|
|
if result.success {
|
|
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
|
|
ToolExecutionOutcome::pending(pending_output)
|
|
} else {
|
|
ToolExecutionOutcome::success(result.output)
|
|
}
|
|
} else {
|
|
let error = result.error.unwrap_or_default();
|
|
ToolExecutionOutcome::failure(
|
|
format!("Error: {}", error),
|
|
Some(error),
|
|
)
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
|
ToolExecutionOutcome::failure(
|
|
format!("Error: {}", e),
|
|
Some(e.to_string()),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn record_skill_event(
|
|
&self,
|
|
event_type: &str,
|
|
skill_name: Option<&str>,
|
|
payload: serde_json::Value,
|
|
) {
|
|
let (Some(store), Some(session_id)) = (
|
|
self.skill_event_store.as_ref(),
|
|
self.skill_event_session_id.as_ref(),
|
|
) else {
|
|
return;
|
|
};
|
|
|
|
if let Err(err) = store.append_skill_event(Some(session_id), event_type, skill_name, &payload) {
|
|
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::observability::{MultiObserver, Observer};
|
|
use tempfile::tempdir;
|
|
|
|
struct TestObserver {
|
|
events: std::sync::Mutex<Vec<ObserverEvent>>,
|
|
}
|
|
|
|
impl TestObserver {
|
|
fn new() -> Self {
|
|
Self {
|
|
events: std::sync::Mutex::new(Vec::new()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Observer for TestObserver {
|
|
fn record_event(&self, event: &ObserverEvent) {
|
|
self.events.lock().unwrap().push(event.clone());
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
"test_observer"
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_observer_receives_tool_events() {
|
|
// Verify MultiObserver works
|
|
let mut multi = MultiObserver::new();
|
|
multi.add_observer(Box::new(TestObserver::new()));
|
|
|
|
let event = ObserverEvent::ToolCallStart {
|
|
tool: "test".to_string(),
|
|
arguments: Some("{}".to_string()),
|
|
};
|
|
multi.record_event(&event);
|
|
|
|
// Just verify the structure works
|
|
assert_eq!(multi.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_should_execute_in_parallel_single_tool() {
|
|
// Would need a proper setup with AgentLoop to test fully
|
|
// For now, just verify the logic: single tool should return false
|
|
let calls = vec![ToolCall {
|
|
id: "1".to_string(),
|
|
name: "test".to_string(),
|
|
arguments: serde_json::json!({}),
|
|
}];
|
|
|
|
// If there's only 1 tool, should return false regardless
|
|
assert_eq!(calls.len() <= 1, true);
|
|
}
|
|
|
|
#[test]
|
|
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
|
|
let chat_message = ChatMessage::assistant_with_tool_calls(
|
|
"calling tool",
|
|
vec![ToolCall {
|
|
id: "call_1".to_string(),
|
|
name: "calculator".to_string(),
|
|
arguments: serde_json::json!({ "expression": "2+2" }),
|
|
}],
|
|
);
|
|
|
|
let provider_message = chat_message_to_llm_message(&chat_message);
|
|
|
|
assert_eq!(provider_message.role, "assistant");
|
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
|
}
|
|
|
|
#[test]
|
|
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
|
|
let chat_message = ChatMessage::assistant_with_reasoning(
|
|
"final answer",
|
|
"hidden chain of thought",
|
|
);
|
|
|
|
let provider_message = chat_message_to_llm_message(&chat_message);
|
|
|
|
assert_eq!(provider_message.role, "assistant");
|
|
assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_memory_prompt_requires_proactive_memory_search() {
|
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));
|
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search"));
|
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_memory_prompt_allows_parallel_independent_tool_calls() {
|
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("同一轮同时返回多个 tool calls"));
|
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("自动并行执行"));
|
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("只有当后一个工具的参数依赖"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_truncate_tool_result_handles_utf8_char_boundaries() {
|
|
let input = "范".repeat(20_500);
|
|
|
|
let output = truncate_tool_result(&input, 20_000);
|
|
|
|
assert!(output.contains("Output truncated"));
|
|
assert!(output.is_char_boundary(output.len()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_pending_tool_output() {
|
|
let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
|
|
assert_eq!(output.as_deref(), Some("请完成授权"));
|
|
assert!(parse_pending_tool_output("normal output").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_content_blocks_skips_non_image_media_refs() {
|
|
let temp_dir = tempdir().unwrap();
|
|
let pdf_path = temp_dir.path().join("demo.pdf");
|
|
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
|
|
|
|
let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]);
|
|
|
|
assert_eq!(blocks.len(), 1);
|
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_content_blocks_keeps_supported_images() {
|
|
let temp_dir = tempdir().unwrap();
|
|
let jpg_path = temp_dir.path().join("demo.jpg");
|
|
std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap();
|
|
|
|
let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]);
|
|
|
|
assert_eq!(blocks.len(), 2);
|
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
|
assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")));
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum AgentError {
|
|
ProviderCreation(String),
|
|
LlmError(String),
|
|
Other(String),
|
|
}
|
|
|
|
impl std::fmt::Display for AgentError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
|
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
|
AgentError::Other(e) => write!(f, "{}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for AgentError {}
|