feat:1、agentloop支持多轮工具调用

2、并发工具调用
3、可观测性改进。
This commit is contained in:
xiaoxixi 2026-04-12 11:02:48 +08:00
parent 862eb1115a
commit 394b5fdd6a
8 changed files with 528 additions and 102 deletions

3
.gitignore vendored
View File

@ -1 +1,4 @@
/target
reference/**
.env
*.env

View File

@ -1,10 +1,14 @@
use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::observability::{
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
};
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::tools::ToolRegistry;
use std::io::Read;
use std::sync::Arc;
use std::time::Instant;
/// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
@ -47,192 +51,337 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
Ok((mime, encoded))
}
/// Stateless AgentLoop - history is managed externally by SessionManager
/// Convert ChatMessage to LLM Message format
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs)
};
Message {
role: m.role.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
observer: Option<Arc<dyn Observer>>,
max_iterations: usize,
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
tools: Arc::new(ToolRegistry::new()),
observer: None,
max_iterations,
})
}
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
tools,
observer: None,
max_iterations,
})
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
let messages_for_llm: Vec<Message> = messages
.iter()
.map(|m| {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
#[cfg(debug_assertions)]
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
build_content_blocks(&m.content, &m.media_refs)
};
#[cfg(debug_assertions)]
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
Message {
role: m.role.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
}
})
.collect();
///
/// This method supports multi-round tool calling: after executing tools,
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(history_len = messages.len(), "Sending request to LLM");
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
tracing::debug!(iteration, "Agent iteration started");
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
// Convert messages to LLM format
let messages_for_llm: Vec<Message> = messages
.iter()
.map(chat_message_to_llm_message)
.collect();
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
// Build request
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
#[cfg(debug_assertions)]
tracing::debug!(
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
if !response.tool_calls.is_empty() {
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
// Call LLM
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
let mut updated_messages = messages.clone();
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let assistant_message = ChatMessage::assistant(response.content);
return Ok(assistant_message);
}
// Execute tool calls
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
// Add assistant message with tool calls
let assistant_message = ChatMessage::assistant(response.content.clone());
updated_messages.push(assistant_message.clone());
messages.push(assistant_message.clone());
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
let tool_message = ChatMessage::tool(
tool_call.id.clone(),
tool_call.name.clone(),
result.clone(),
result.output.clone(),
);
updated_messages.push(tool_message);
messages.push(tool_message);
}
return self.continue_with_tool_results(updated_messages).await;
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
}
let assistant_message = ChatMessage::assistant(response.content);
Ok(assistant_message)
// Max iterations reached
let final_message = ChatMessage::assistant(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. Please try breaking the task into smaller steps.", self.max_iterations)
);
Ok(final_message)
}
async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
let messages_for_llm: Vec<Message> = messages
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:
/// - There are multiple tool calls
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
if tool_calls.len() <= 1 {
return false;
}
// tool_search must run sequentially to avoid MCP activation race conditions
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
return false;
}
// All tools must be concurrency-safe to run in parallel
tool_calls.iter().all(|tc| {
self.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false)
})
}
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
if self.should_execute_in_parallel(tool_calls) {
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
self.execute_tools_parallel(tool_calls).await
} else {
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
self.execute_tools_sequential(tool_calls).await
}
}
/// Execute tools in parallel using join_all.
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let futures: Vec<_> = tool_calls
.iter()
.map(|m| {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs)
};
Message {
role: m.role.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
}
})
.map(|tc| self.execute_one_tool(tc))
.collect();
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM continuation request failed");
AgentError::LlmError(e.to_string())
})?;
let assistant_message = ChatMessage::assistant(response.content);
Ok(assistant_message)
futures_util::future::join_all(futures).await
}
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
let mut results = Vec::with_capacity(tool_calls.len());
/// Execute tools sequentially.
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
let mut outcomes = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
let result = self.execute_tool(tool_call).await;
results.push(result);
outcomes.push(self.execute_one_tool(tool_call).await);
}
results
outcomes
}
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
/// Execute a single tool and return the outcome with event tracking.
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let start = Instant::now();
let tool_name = tool_call.name.clone();
// Record ToolCallStart event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCallStart {
tool: tool_name.clone(),
arguments: Some(truncate_args(&tool_call.arguments, 300)),
});
}
let result = self.execute_tool_internal(tool_call).await;
let duration = start.elapsed();
// Record ToolCall event
if let Some(ref observer) = self.observer {
observer.record_event(&ObserverEvent::ToolCall {
tool: tool_name.clone(),
duration,
success: result.success,
});
}
// Apply duration
ToolExecutionOutcome {
duration,
..result
}
}
/// Internal tool execution without event tracking.
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
return format!("Error: Tool '{}' not found", tool_call.name);
return ToolExecutionOutcome::failure(
format!("Error: Tool '{}' not found", tool_call.name),
Some(format!("Tool '{}' not found", tool_call.name)),
);
}
};
match tool.execute(tool_call.arguments.clone()).await {
Ok(result) => {
if result.success {
result.output
ToolExecutionOutcome::success(result.output)
} else {
format!("Error: {}", result.error.unwrap_or_default())
let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
)
}
}
Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
format!("Error: {}", e)
ToolExecutionOutcome::failure(
format!("Error: {}", e),
Some(e.to_string()),
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observability::{MultiObserver, Observer};
struct TestObserver {
events: std::sync::Mutex<Vec<ObserverEvent>>,
}
impl TestObserver {
fn new() -> Self {
Self {
events: std::sync::Mutex::new(Vec::new()),
}
}
}
impl Observer for TestObserver {
fn record_event(&self, event: &ObserverEvent) {
self.events.lock().unwrap().push(event.clone());
}
fn name(&self) -> &str {
"test_observer"
}
}
#[tokio::test]
async fn test_observer_receives_tool_events() {
// Verify MultiObserver works
let mut multi = MultiObserver::new();
multi.add_observer(Box::new(TestObserver::new()));
let event = ObserverEvent::ToolCallStart {
tool: "test".to_string(),
arguments: Some("{}".to_string()),
};
multi.record_event(&event);
// Just verify the structure works
assert_eq!(multi.len(), 1);
}
#[test]
fn test_should_execute_in_parallel_single_tool() {
// Would need a proper setup with AgentLoop to test fully
// For now, just verify the logic: single tool should return false
let calls = vec![ToolCall {
id: "1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
// If there's only 1 tool, should return false regardless
assert_eq!(calls.len() <= 1, true);
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),

View File

@ -45,7 +45,7 @@ fn default_media_dir() -> String {
}
fn default_reaction_emoji() -> String {
"THUMBSUP".to_string()
"Typing".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -73,6 +73,12 @@ pub struct ModelConfig {
pub struct AgentConfig {
pub provider: String,
pub model: String,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
}
fn default_max_tool_iterations() -> usize {
20
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -132,6 +138,7 @@ pub struct LLMProviderConfig {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
pub max_tool_iterations: usize,
}
fn get_default_config_path() -> PathBuf {
@ -191,6 +198,7 @@ impl Config {
temperature: model.temperature,
max_tokens: model.max_tokens,
model_extra: model.extra.clone(),
max_tool_iterations: agent.max_tool_iterations,
})
}
}

View File

@ -8,4 +8,5 @@ pub mod client;
pub mod protocol;
pub mod channels;
pub mod logging;
pub mod observability;
pub mod tools;

257
src/observability/mod.rs Normal file
View File

@ -0,0 +1,257 @@
//! Observability module for tracking agent and tool events.
//!
//! This module provides an Observer pattern for emitting and collecting
//! telemetry events during agent execution.
use std::time::Duration;
/// Events emitted during agent and tool execution.
#[derive(Debug, Clone)]
pub enum ObserverEvent {
/// Emitted before a tool starts executing.
ToolCallStart {
tool: String,
arguments: Option<String>,
},
/// Emitted after a tool completes execution.
ToolCall {
tool: String,
duration: Duration,
success: bool,
},
/// Emitted when the agent starts processing.
AgentStart {
provider: String,
model: String,
},
/// Emitted when the agent finishes processing.
AgentEnd {
provider: String,
model: String,
duration: Duration,
tokens_used: Option<u64>,
},
}
/// Observer trait for receiving events.
///
/// Implement this trait to receive events during agent execution.
/// Observers are shared across async tasks, so implementations must be
/// Send + Sync.
pub trait Observer: Send + Sync + 'static {
/// Record a single event.
fn record_event(&self, event: &ObserverEvent);
/// Get the observer's name for identification.
fn name(&self) -> &str;
/// Flush any buffered events (default no-op).
fn flush(&self) {}
}
/// Outcome of a single tool execution.
#[derive(Debug, Clone)]
pub struct ToolExecutionOutcome {
/// The output from the tool execution.
pub output: String,
/// Whether the tool executed successfully.
pub success: bool,
/// The error reason if the tool failed.
pub error_reason: Option<String>,
/// How long the tool took to execute.
pub duration: Duration,
}
impl ToolExecutionOutcome {
/// Create a successful outcome with zero duration.
pub fn success(output: String) -> Self {
Self {
output,
success: true,
error_reason: None,
duration: Duration::ZERO,
}
}
/// Create a successful outcome with duration.
pub fn success_with_duration(output: String, duration: Duration) -> Self {
Self {
output,
success: true,
error_reason: None,
duration,
}
}
/// Create a failed outcome with zero duration.
pub fn failure(output: String, error_reason: Option<String>) -> Self {
Self {
output,
success: false,
error_reason,
duration: Duration::ZERO,
}
}
/// Create a failed outcome with duration.
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
Self {
output,
success: false,
error_reason,
duration,
}
}
}
/// MultiObserver broadcasts events to multiple observers.
pub struct MultiObserver {
observers: Vec<Box<dyn Observer>>,
}
impl MultiObserver {
/// Create a new MultiObserver.
pub fn new() -> Self {
Self {
observers: Vec::new(),
}
}
/// Add an observer.
pub fn add_observer(&mut self, observer: Box<dyn Observer>) {
self.observers.push(observer);
}
/// Get the number of registered observers.
pub fn len(&self) -> usize {
self.observers.len()
}
/// Check if there are no observers.
pub fn is_empty(&self) -> bool {
self.observers.is_empty()
}
}
impl Default for MultiObserver {
fn default() -> Self {
Self::new()
}
}
impl Observer for MultiObserver {
fn record_event(&self, event: &ObserverEvent) {
for observer in &self.observers {
observer.record_event(event);
}
}
fn flush(&self) {
for observer in &self.observers {
observer.flush();
}
}
fn name(&self) -> &str {
"multi_observer"
}
}
/// Truncate arguments for logging to avoid oversized events.
pub fn truncate_args(args: &serde_json::Value, max_len: usize) -> String {
let args_str = args.to_string();
if args_str.len() <= max_len {
return args_str;
}
format!("{}...truncated", &args_str[..max_len])
}
#[cfg(test)]
mod tests {
use super::*;
struct TestObserver {
name: String,
events: std::sync::Mutex<Vec<ObserverEvent>>,
}
impl TestObserver {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
events: std::sync::Mutex::new(Vec::new()),
}
}
}
impl Observer for TestObserver {
fn record_event(&self, event: &ObserverEvent) {
self.events.lock().unwrap().push(event.clone());
}
fn name(&self) -> &str {
&self.name
}
}
#[test]
fn test_tool_execution_outcome_success() {
let outcome = ToolExecutionOutcome::success("output content".to_string());
assert!(outcome.success);
assert_eq!(outcome.output, "output content");
assert!(outcome.error_reason.is_none());
assert_eq!(outcome.duration, Duration::ZERO);
}
#[test]
fn test_tool_execution_outcome_success_with_duration() {
let outcome = ToolExecutionOutcome::success_with_duration(
"output content".to_string(),
Duration::from_millis(100),
);
assert!(outcome.success);
assert_eq!(outcome.duration, Duration::from_millis(100));
}
#[test]
fn test_tool_execution_outcome_failure() {
let outcome = ToolExecutionOutcome::failure(
"error output".to_string(),
Some("error reason".to_string()),
);
assert!(!outcome.success);
assert_eq!(outcome.output, "error output");
assert_eq!(outcome.error_reason, Some("error reason".to_string()));
assert_eq!(outcome.duration, Duration::ZERO);
}
#[test]
fn test_multi_observer_broadcasts() {
let mut multi = MultiObserver::new();
let obs1 = Box::new(TestObserver::new("obs1"));
let obs2 = Box::new(TestObserver::new("obs2"));
multi.add_observer(obs1);
multi.add_observer(obs2);
let event = ObserverEvent::ToolCallStart {
tool: "test_tool".to_string(),
arguments: Some("{}".to_string()),
};
multi.record_event(&event);
// Both observers should have received the event
assert_eq!(multi.len(), 2);
}
#[test]
fn test_truncate_args() {
let args = serde_json::json!({"key": "value"});
assert_eq!(truncate_args(&args, 100), args.to_string());
let long_args = serde_json::json!({"key": "a".repeat(200)});
let truncated = truncate_args(&long_args, 50);
assert!(truncated.ends_with("...truncated"));
assert!(truncated.len() < long_args.to_string().len());
}
}

View File

@ -23,6 +23,12 @@ impl ToolRegistry {
self.tools.get(name)
}
/// Get all registered tools.
/// Used for concurrent tool execution when we need to look up tools by name.
pub fn get_all(&self) -> Vec<&Box<dyn ToolTrait>> {
self.tools.values().collect()
}
pub fn get_definitions(&self) -> Vec<Tool> {
self.tools
.values()

View File

@ -23,6 +23,7 @@ fn load_config() -> Option<LLMProviderConfig> {
temperature: Some(0.0),
max_tokens: Some(100),
model_extra: HashMap::new(),
max_tool_iterations: 20,
})
}

View File

@ -23,6 +23,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
temperature: Some(0.0),
max_tokens: Some(100),
model_extra: HashMap::new(),
max_tool_iterations: 20,
})
}