403 lines
13 KiB
Rust
403 lines
13 KiB
Rust
use crate::bus::message::ContentBlock;
|
|
use crate::bus::ChatMessage;
|
|
use crate::config::LLMProviderConfig;
|
|
use crate::observability::{
|
|
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
|
};
|
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
|
use crate::tools::ToolRegistry;
|
|
use std::io::Read;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
/// Build content blocks from text and media paths
|
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
|
let mut blocks = Vec::new();
|
|
|
|
// Add text block if there's text
|
|
if !text.is_empty() {
|
|
blocks.push(ContentBlock::text(text));
|
|
}
|
|
|
|
// Add image blocks for media paths
|
|
for path in media_paths {
|
|
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
|
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
blocks.push(ContentBlock::image_url(url));
|
|
}
|
|
}
|
|
|
|
// If nothing, add empty text block
|
|
if blocks.is_empty() {
|
|
blocks.push(ContentBlock::text(""));
|
|
}
|
|
|
|
blocks
|
|
}
|
|
|
|
/// Encode an image file to base64 data URL
|
|
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
|
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
|
|
|
let mut file = std::fs::File::open(path)?;
|
|
let mut buffer = Vec::new();
|
|
file.read_to_end(&mut buffer)?;
|
|
|
|
let mime = mime_guess::from_path(path)
|
|
.first_or_octet_stream()
|
|
.to_string();
|
|
|
|
let encoded = STANDARD.encode(&buffer);
|
|
Ok((mime, encoded))
|
|
}
|
|
|
|
/// Convert ChatMessage to LLM Message format
|
|
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|
let content = if m.media_refs.is_empty() {
|
|
vec![ContentBlock::text(&m.content)]
|
|
} else {
|
|
build_content_blocks(&m.content, &m.media_refs)
|
|
};
|
|
|
|
Message {
|
|
role: m.role.clone(),
|
|
content,
|
|
tool_call_id: m.tool_call_id.clone(),
|
|
name: m.tool_name.clone(),
|
|
}
|
|
}
|
|
|
|
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
|
/// History is managed externally by SessionManager.
|
|
pub struct AgentLoop {
|
|
provider: Box<dyn LLMProvider>,
|
|
tools: Arc<ToolRegistry>,
|
|
observer: Option<Arc<dyn Observer>>,
|
|
max_iterations: usize,
|
|
}
|
|
|
|
impl AgentLoop {
|
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
|
let max_iterations = provider_config.max_tool_iterations;
|
|
let provider = create_provider(provider_config)
|
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
provider,
|
|
tools: Arc::new(ToolRegistry::new()),
|
|
observer: None,
|
|
max_iterations,
|
|
})
|
|
}
|
|
|
|
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
|
let max_iterations = provider_config.max_tool_iterations;
|
|
let provider = create_provider(provider_config)
|
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
provider,
|
|
tools,
|
|
observer: None,
|
|
max_iterations,
|
|
})
|
|
}
|
|
|
|
/// Set an observer for tracking events.
|
|
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
|
self.observer = Some(observer);
|
|
self
|
|
}
|
|
|
|
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
|
&self.tools
|
|
}
|
|
|
|
/// Process a message using the provided conversation history.
|
|
/// History management is handled externally by SessionManager.
|
|
///
|
|
/// This method supports multi-round tool calling: after executing tools,
|
|
/// it loops back to the LLM with the tool results until either:
|
|
/// - The LLM returns no more tool calls (final response)
|
|
/// - Maximum iterations are reached
|
|
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
|
|
|
for iteration in 0..self.max_iterations {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(iteration, "Agent iteration started");
|
|
|
|
// Convert messages to LLM format
|
|
let messages_for_llm: Vec<Message> = messages
|
|
.iter()
|
|
.map(chat_message_to_llm_message)
|
|
.collect();
|
|
|
|
// Build request
|
|
let tools = if self.tools.has_tools() {
|
|
Some(self.tools.get_definitions())
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: messages_for_llm,
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools,
|
|
};
|
|
|
|
// Call LLM
|
|
let response = (*self.provider).chat(request).await
|
|
.map_err(|e| {
|
|
tracing::error!(error = %e, "LLM request failed");
|
|
AgentError::LlmError(e.to_string())
|
|
})?;
|
|
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(
|
|
iteration,
|
|
response_len = response.content.len(),
|
|
tool_calls_len = response.tool_calls.len(),
|
|
"LLM response received"
|
|
);
|
|
|
|
// If no tool calls, this is the final response
|
|
if response.tool_calls.is_empty() {
|
|
let assistant_message = ChatMessage::assistant(response.content);
|
|
return Ok(assistant_message);
|
|
}
|
|
|
|
// Execute tool calls
|
|
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
|
|
|
// Add assistant message with tool calls
|
|
let assistant_message = ChatMessage::assistant(response.content.clone());
|
|
messages.push(assistant_message.clone());
|
|
|
|
// Execute tools and add results to messages
|
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
|
|
|
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
|
let tool_message = ChatMessage::tool(
|
|
tool_call.id.clone(),
|
|
tool_call.name.clone(),
|
|
result.output.clone(),
|
|
);
|
|
messages.push(tool_message);
|
|
}
|
|
|
|
// Loop continues to next iteration with updated messages
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
|
}
|
|
|
|
// Max iterations reached
|
|
let final_message = ChatMessage::assistant(
|
|
format!("I reached the maximum number of tool call iterations ({}) without completing the task. Please try breaking the task into smaller steps.", self.max_iterations)
|
|
);
|
|
Ok(final_message)
|
|
}
|
|
|
|
/// Determine whether to execute tools in parallel or sequentially.
|
|
///
|
|
/// Returns true if:
|
|
/// - There are multiple tool calls
|
|
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
|
|
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
|
|
if tool_calls.len() <= 1 {
|
|
return false;
|
|
}
|
|
|
|
// tool_search must run sequentially to avoid MCP activation race conditions
|
|
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
|
|
return false;
|
|
}
|
|
|
|
// All tools must be concurrency-safe to run in parallel
|
|
tool_calls.iter().all(|tc| {
|
|
self.tools
|
|
.get(&tc.name)
|
|
.map(|t| t.concurrency_safe())
|
|
.unwrap_or(false)
|
|
})
|
|
}
|
|
|
|
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
|
|
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
if self.should_execute_in_parallel(tool_calls) {
|
|
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
|
|
self.execute_tools_parallel(tool_calls).await
|
|
} else {
|
|
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
|
|
self.execute_tools_sequential(tool_calls).await
|
|
}
|
|
}
|
|
|
|
/// Execute tools in parallel using join_all.
|
|
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
let futures: Vec<_> = tool_calls
|
|
.iter()
|
|
.map(|tc| self.execute_one_tool(tc))
|
|
.collect();
|
|
|
|
futures_util::future::join_all(futures).await
|
|
}
|
|
|
|
/// Execute tools sequentially.
|
|
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
let mut outcomes = Vec::with_capacity(tool_calls.len());
|
|
|
|
for tool_call in tool_calls {
|
|
outcomes.push(self.execute_one_tool(tool_call).await);
|
|
}
|
|
|
|
outcomes
|
|
}
|
|
|
|
/// Execute a single tool and return the outcome with event tracking.
|
|
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
|
let start = Instant::now();
|
|
let tool_name = tool_call.name.clone();
|
|
|
|
// Record ToolCallStart event
|
|
if let Some(ref observer) = self.observer {
|
|
observer.record_event(&ObserverEvent::ToolCallStart {
|
|
tool: tool_name.clone(),
|
|
arguments: Some(truncate_args(&tool_call.arguments, 300)),
|
|
});
|
|
}
|
|
|
|
let result = self.execute_tool_internal(tool_call).await;
|
|
let duration = start.elapsed();
|
|
|
|
// Record ToolCall event
|
|
if let Some(ref observer) = self.observer {
|
|
observer.record_event(&ObserverEvent::ToolCall {
|
|
tool: tool_name.clone(),
|
|
duration,
|
|
success: result.success,
|
|
});
|
|
}
|
|
|
|
// Apply duration
|
|
ToolExecutionOutcome {
|
|
duration,
|
|
..result
|
|
}
|
|
}
|
|
|
|
/// Internal tool execution without event tracking.
|
|
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
|
let tool = match self.tools.get(&tool_call.name) {
|
|
Some(t) => t,
|
|
None => {
|
|
tracing::warn!(tool = %tool_call.name, "Tool not found");
|
|
return ToolExecutionOutcome::failure(
|
|
format!("Error: Tool '{}' not found", tool_call.name),
|
|
Some(format!("Tool '{}' not found", tool_call.name)),
|
|
);
|
|
}
|
|
};
|
|
|
|
match tool.execute(tool_call.arguments.clone()).await {
|
|
Ok(result) => {
|
|
if result.success {
|
|
ToolExecutionOutcome::success(result.output)
|
|
} else {
|
|
let error = result.error.unwrap_or_default();
|
|
ToolExecutionOutcome::failure(
|
|
format!("Error: {}", error),
|
|
Some(error),
|
|
)
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
|
ToolExecutionOutcome::failure(
|
|
format!("Error: {}", e),
|
|
Some(e.to_string()),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::observability::{MultiObserver, Observer};
|
|
|
|
struct TestObserver {
|
|
events: std::sync::Mutex<Vec<ObserverEvent>>,
|
|
}
|
|
|
|
impl TestObserver {
|
|
fn new() -> Self {
|
|
Self {
|
|
events: std::sync::Mutex::new(Vec::new()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Observer for TestObserver {
|
|
fn record_event(&self, event: &ObserverEvent) {
|
|
self.events.lock().unwrap().push(event.clone());
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
"test_observer"
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_observer_receives_tool_events() {
|
|
// Verify MultiObserver works
|
|
let mut multi = MultiObserver::new();
|
|
multi.add_observer(Box::new(TestObserver::new()));
|
|
|
|
let event = ObserverEvent::ToolCallStart {
|
|
tool: "test".to_string(),
|
|
arguments: Some("{}".to_string()),
|
|
};
|
|
multi.record_event(&event);
|
|
|
|
// Just verify the structure works
|
|
assert_eq!(multi.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_should_execute_in_parallel_single_tool() {
|
|
// Would need a proper setup with AgentLoop to test fully
|
|
// For now, just verify the logic: single tool should return false
|
|
let calls = vec![ToolCall {
|
|
id: "1".to_string(),
|
|
name: "test".to_string(),
|
|
arguments: serde_json::json!({}),
|
|
}];
|
|
|
|
// If there's only 1 tool, should return false regardless
|
|
assert_eq!(calls.len() <= 1, true);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum AgentError {
|
|
ProviderCreation(String),
|
|
LlmError(String),
|
|
Other(String),
|
|
}
|
|
|
|
impl std::fmt::Display for AgentError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
|
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
|
AgentError::Other(e) => write!(f, "{}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for AgentError {}
|