PicoBot/src/agent/agent_loop.rs

900 lines
32 KiB
Rust

use crate::agent::context_compressor::estimate_tokens;
use crate::agent::media_handler::MediaHandlerRegistry;
use crate::agent::system_prompt::build_system_prompt;
use crate::bus::message::ContentBlock;
use crate::bus::{ChatMessage, MediaRef};
use crate::config::LLMProviderConfig;
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
use crate::tools::ToolRegistry;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
/// Maximum characters in a tool result before truncation.
/// Prevents context overflow from large tool outputs.
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
/// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200;
/// Build content blocks from text and media, respecting model input capabilities
fn build_content_blocks(
text: &str,
media_refs: &[MediaRef],
input_types: &[String],
registry: &MediaHandlerRegistry,
) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
if !media_refs.is_empty() {
for mr in media_refs {
if input_types.contains(&mr.media_type) {
match registry.handle(&mr.media_type, &mr.path) {
Ok(content_blocks) => blocks.extend(content_blocks),
Err(e) => {
tracing::warn!(
path = %mr.path,
media_type = %mr.media_type,
error = %e,
"Media handler failed, falling back to text placeholder"
);
blocks.push(ContentBlock::text(format!(
"[用户发来了一个文件,但处理失败: {}, 错误: {}]",
mr.path, e
)));
}
}
} else {
tracing::debug!(
path = %mr.path,
media_type = %mr.media_type,
model_input_types = ?input_types,
"Media type not supported by model, using text placeholder"
);
blocks.push(ContentBlock::text(format!(
"[用户发来了一个文件: {}]",
mr.path
)));
}
}
}
if blocks.is_empty() {
blocks.push(ContentBlock::text(""));
}
blocks
}
/// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS.
/// Preserves the end of the output as it often contains the conclusion/useful result.
fn truncate_tool_result(output: &str) -> String {
if output.len() <= MAX_TOOL_RESULT_CHARS {
return output.to_string();
}
let truncated_start_len = output.len().saturating_sub(TRUNCATION_SUFFIX_LEN);
if truncated_start_len > MAX_TOOL_RESULT_CHARS {
// Even after removing suffix, still too long - take from beginning
format!(
"{}...\n\n[Output truncated - {} characters removed]",
&output[..output.ceil_char_boundary(MAX_TOOL_RESULT_CHARS - 100)],
output.len() - MAX_TOOL_RESULT_CHARS + 100
)
} else {
// Keep most of the end which usually contains the useful result
format!(
"...\n\n[Output truncated - {} characters removed]\n\n{}",
truncated_start_len,
&output[output.floor_char_boundary(truncated_start_len)..]
)
}
}
/// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult {
/// No warning needed.
Ok,
/// Warning: same tool + args repeated N times.
Warning(String),
}
/// Configuration for loop detector.
#[derive(Debug, Clone)]
struct LoopDetectorConfig {
/// Master switch.
enabled: bool,
/// Warn every N consecutive identical calls.
warn_every: usize,
}
impl Default for LoopDetectorConfig {
fn default() -> Self {
Self {
enabled: true,
warn_every: 5,
}
}
}
/// A single recorded tool invocation in the sliding window.
#[derive(Debug, Clone)]
struct ToolCallRecord {
name: String,
args_hash: u64,
}
/// Stateful loop detector that monitors for repetitive patterns.
struct LoopDetector {
config: LoopDetectorConfig,
window: VecDeque<ToolCallRecord>,
}
impl LoopDetector {
fn new(config: LoopDetectorConfig) -> Self {
Self {
window: VecDeque::with_capacity(config.warn_every * 2),
config,
}
}
/// Record a completed tool call and check for loop patterns.
/// Returns Warning every `warn_every` consecutive identical calls.
fn record(&mut self, name: &str, args: &serde_json::Value) -> LoopDetectionResult {
if !self.config.enabled {
return LoopDetectionResult::Ok;
}
let record = ToolCallRecord {
name: name.to_string(),
args_hash: hash_json_value(args),
};
// Maintain sliding window
if self.window.len() >= self.config.warn_every * 2 {
self.window.pop_front();
}
self.window.push_back(record);
// Count consecutive identical calls
let last = self.window.back().unwrap();
let consecutive: usize = self
.window
.iter()
.rev()
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
.count();
// Warn every warn_every times
if consecutive > 0 && consecutive.is_multiple_of(self.config.warn_every) {
LoopDetectionResult::Warning(format!(
"注意: 工具 '{}' 已连续执行 {} 次,参数相同。如果任务没有进展,请尝试其他方法。",
last.name, consecutive
))
} else {
LoopDetectionResult::Ok
}
}
}
/// Hash a JSON value deterministically (key-order independent).
fn hash_json_value(value: &serde_json::Value) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
let canonical = canonicalise_json(value);
canonical.hash(&mut hasher);
hasher.finish()
}
/// Return a clone of value with all object keys sorted recursively.
fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
sorted.sort_by_key(|(k, _)| *k);
let new_map: serde_json::Map<String, serde_json::Value> = sorted
.into_iter()
.map(|(k, v)| (k.clone(), canonicalise_json(v)))
.collect();
serde_json::Value::Object(new_map)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(canonicalise_json).collect())
}
other => other.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
observer: Option<Arc<dyn Observer>>,
max_iterations: usize,
workspace_dir: PathBuf,
model_name: String,
context_window: usize,
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
input_types: Vec<String>,
media_registry: MediaHandlerRegistry,
}
#[derive(Debug, Clone)]
pub struct AgentProcessResult {
pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>,
pub total_tokens: Option<u32>,
}
impl AgentLoop {
/// Create a new AgentLoop with a provider created from config.
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone();
let input_types = provider_config.input_types.clone();
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider: Arc::from(provider),
tools: Arc::new(ToolRegistry::new()),
observer: None,
notify_tx: None,
context_window: 0,
max_iterations,
workspace_dir,
model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
})
}
/// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone();
let input_types = provider_config.input_types.clone();
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider: Arc::from(provider),
tools,
observer: None,
notify_tx: None,
context_window: 0,
max_iterations,
workspace_dir,
model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
})
}
/// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(
provider: Arc<dyn LLMProvider>,
max_iterations: usize,
model_name: String,
workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self {
Self {
provider,
tools: Arc::new(ToolRegistry::new()),
observer: None,
notify_tx: None,
context_window: 0,
max_iterations,
workspace_dir,
model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
}
}
/// Create a new AgentLoop with an existing shared provider and given tools.
pub fn with_provider_and_tools(
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
max_iterations: usize,
model_name: String,
workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self {
Self {
provider,
tools,
observer: None,
notify_tx: None,
context_window: 0,
max_iterations,
workspace_dir,
model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
}
}
/// Set the context window size for preemptive trimming.
pub fn with_context_window(mut self, window: usize) -> Self {
self.context_window = window;
self
}
/// Set the workspace directory.
pub fn with_workspace_dir(mut self, dir: PathBuf) -> Self {
self.workspace_dir = dir;
self
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn with_notify(mut self, tx: tokio::sync::mpsc::UnboundedSender<String>) -> Self {
self.notify_tx = Some(tx);
self
}
/// Preemptive trim: truncate old tool results in-place when history is
/// approaching the context window limit. Old results (outside of `keep_recent`
/// zone) are replaced with a short placeholder; recent results are truncated
/// to `max_chars`.
fn preemptive_trim_old_tool_results(
&self,
messages: &mut [ChatMessage],
max_chars: usize,
keep_recent: usize,
) -> usize {
let end = messages.len().saturating_sub(keep_recent);
let start = 1; // protect system message at [0] if present
let mut modified = 0;
for i in start..end {
if messages[i].role != "tool" {
continue;
}
if messages[i].content.len() <= max_chars {
continue;
}
let tool_name = messages[i].tool_name.as_deref().unwrap_or("unknown");
let chars = messages[i].content.len();
messages[i].content = format!(
"[Tool output ({}) — {} chars, omitted from context]",
tool_name, chars
);
modified += 1;
}
modified
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
fn chat_message_to_llm_message(&self, m: &ChatMessage) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(
&m.content,
&m.media_refs,
&self.input_types,
&self.media_registry,
)
};
Message {
role: m.role.clone(),
content,
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
///
/// This method supports multi-round tool calling: after executing tools,
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Build and inject system prompt if not present
let has_system = messages.first().is_some_and(|m| m.role == "system");
if !has_system {
let system_prompt = build_system_prompt(
&self.workspace_dir,
&self.model_name,
&self.tools,
None,
None,
false,
);
#[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt));
}
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
let mut accumulated_tokens: u32 = 0;
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
tracing::debug!(iteration, "Agent iteration started");
// Preemptive context check: trim old tool results if token estimate
// exceeds 80% of context window to prevent mid-loop overflow.
if self.context_window > 0 {
let estimated = estimate_tokens(&messages);
let danger = (self.context_window as f64 * 0.8) as usize;
if estimated > danger {
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
if trimmed > 0 {
#[cfg(debug_assertions)]
tracing::debug!(
estimated,
danger,
trimmed_msgs = trimmed,
"Preemptive tool-result trim applied in loop"
);
}
}
}
// Convert messages to LLM format
let messages_for_llm: Vec<Message> = messages
.iter()
.map(|m| self.chat_message_to_llm_message(m))
.collect();
// Build request
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
// Call LLM
let response = (*self.provider).chat(request).await.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
accumulated_tokens += response.usage.total_tokens;
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let mut assistant_message = ChatMessage::assistant(response.content);
assistant_message.reasoning_content = response.reasoning_content;
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
total_tokens: Some(accumulated_tokens),
});
}
// Execute tool calls — log and notify immediately
{
let tools_info: Vec<String> = response
.tool_calls
.iter()
.map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
let s = format!("{}:{}", tc.name, args);
if let Some(ref tx) = self.notify_tx {
let _ = tx.send(format!("调用工具 {}", s));
}
s
})
.collect();
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");
}
// Add assistant message with tool calls
let mut assistant_message = ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
);
assistant_message.reasoning_content = response.reasoning_content;
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Log function call with name and arguments
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
// Truncate tool result if too large
let truncated_output = truncate_tool_result(&result.output);
// Record tool call and check for loops
let loop_result = loop_detector.record(&tool_call.name, &tool_call.arguments);
match loop_result {
LoopDetectionResult::Warning(msg) => {
// Add warning and proceed
tracing::warn!(
tool = %tool_call.name,
"Loop warning: {}",
msg
);
let tool_message = ChatMessage::tool(
tool_call.id.clone(),
tool_call.name.clone(),
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
);
messages.push(tool_message.clone());
emitted_messages.push(tool_message);
}
LoopDetectionResult::Ok => {
let tool_message = ChatMessage::tool(
tool_call.id.clone(),
tool_call.name.clone(),
truncated_output,
);
messages.push(tool_message.clone());
emitted_messages.push(tool_message);
}
}
}
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
}
// Max iterations reached - ask LLM for a summary based on completed work
tracing::warn!("Max iterations reached, requesting final summary from LLM");
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far.",
);
messages.push(summary_request);
// Convert messages to LLM format
let messages_for_llm: Vec<Message> = messages
.iter()
.map(|m| self.chat_message_to_llm_message(m))
.collect();
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools: None, // No tools in final summary call
};
match (*self.provider).chat(request).await {
Ok(response) => {
accumulated_tokens += response.usage.total_tokens;
let mut assistant_message = ChatMessage::assistant(response.content);
assistant_message.reasoning_content = response.reasoning_content;
emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
total_tokens: Some(accumulated_tokens),
})
}
Err(e) => {
// Fallback if summary call fails
tracing::error!(error = %e, "Failed to get summary from LLM");
let final_message = ChatMessage::assistant(format!(
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
self.max_iterations
));
emitted_messages.push(final_message.clone());
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
total_tokens: if accumulated_tokens > 0 {
Some(accumulated_tokens)
} else {
None
},
})
}
}
}
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:
/// - There are multiple tool calls
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
if tool_calls.len() <= 1 {
return false;
}
// tool_search must run sequentially to avoid MCP activation race conditions
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
return false;
}
// All tools must be concurrency-safe to run in parallel
tool_calls.iter().all(|tc| {
self.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false)
})
}
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
if self.should_execute_in_parallel(tool_calls) {
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
self.execute_tools_parallel(tool_calls).await
} else {
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
self.execute_tools_sequential(tool_calls).await
}
}
/// Execute tools in parallel using join_all.
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let futures: Vec<_> = tool_calls
.iter()
.map(|tc| self.execute_one_tool(tc))
.collect();
futures_util::future::join_all(futures).await
}
/// Execute tools sequentially.
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let mut outcomes = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
outcomes.push(self.execute_one_tool(tool_call).await);
}
outcomes
}
/// Execute a single tool and return the outcome with event tracking.
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let start = Instant::now();
let tool_name = tool_call.name.clone();
// Record ToolCallStart event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCallStart {
tool: tool_name.clone(),
arguments: Some(truncate_args(&tool_call.arguments, 300)),
});
}
let result = self.execute_tool_internal(tool_call).await;
let duration = start.elapsed();
// Record ToolCall event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCall {
tool: tool_name.clone(),
duration,
success: result.success,
});
}
// Apply duration
ToolExecutionOutcome { duration, ..result }
}
/// Internal tool execution without event tracking.
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
return ToolExecutionOutcome::failure(
format!("Error: Tool '{}' not found", tool_call.name),
Some(format!("Tool '{}' not found", tool_call.name)),
);
}
};
match tool.execute(tool_call.arguments.clone()).await {
Ok(result) => {
if result.success {
ToolExecutionOutcome::success(result.output)
} else {
let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
}
}
Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observability::{MultiObserver, Observer};
struct TestObserver {
events: std::sync::Mutex<Vec<ObserverEvent>>,
}
impl TestObserver {
fn new() -> Self {
Self {
events: std::sync::Mutex::new(Vec::new()),
}
}
}
impl Observer for TestObserver {
fn record_event(&self, event: &ObserverEvent) {
self.events.lock().unwrap().push(event.clone());
}
fn name(&self) -> &str {
"test_observer"
}
}
#[tokio::test]
async fn test_observer_receives_tool_events() {
// Verify MultiObserver works
let mut multi = MultiObserver::new();
multi.add_observer(Box::new(TestObserver::new()));
let event = ObserverEvent::ToolCallStart {
tool: "test".to_string(),
arguments: Some("{}".to_string()),
};
multi.record_event(&event);
// Just verify the structure works
assert_eq!(multi.len(), 1);
}
#[test]
fn test_should_execute_in_parallel_single_tool() {
// Would need a proper setup with AgentLoop to test fully
// For now, just verify the logic: single tool should return false
let calls = vec![ToolCall {
id: "1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
// If there's only 1 tool, should return false regardless
assert_eq!(calls.len() <= 1, true);
}
#[test]
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
use crate::providers::Message;
let chat_message = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "2+2" }),
}],
);
let content = vec![ContentBlock::text(&chat_message.content)];
let provider_message = Message {
role: chat_message.role.clone(),
content,
reasoning_content: None,
tool_call_id: chat_message.tool_call_id.clone(),
name: chat_message.tool_name.clone(),
tool_calls: chat_message.tool_calls.clone(),
};
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
#[test]
fn test_build_content_blocks_keeps_text_with_media() {
let registry = MediaHandlerRegistry::new();
let blocks = build_content_blocks(
"先看这段文字",
&[MediaRef {
path: "missing.png".to_string(),
media_type: "image".to_string(),
}],
&[],
&registry,
);
assert!(matches!(blocks.first(), Some(ContentBlock::Text { text }) if text == "先看这段文字"));
assert!(matches!(blocks.get(1), Some(ContentBlock::Text { text }) if text.contains("用户发来了一个文件")));
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
Other(String),
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
AgentError::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for AgentError {}