Compare commits
No commits in common. "394b5fdd6a27dd2b7d0dd5a08fbcc452faa46cc1" and "0c0d0c14436a230ca34c0cd9647c8fc14f62d7cc" have entirely different histories.
394b5fdd6a
...
0c0d0c1443
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1 @@
|
|||||||
/target
|
/target
|
||||||
reference/**
|
|
||||||
.env
|
|
||||||
*.env
|
|
||||||
@ -1,14 +1,10 @@
|
|||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::observability::{
|
|
||||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
|
||||||
};
|
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
/// Build content blocks from text and media paths
|
/// Build content blocks from text and media paths
|
||||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||||
@ -51,90 +47,46 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
|
|||||||
Ok((mime, encoded))
|
Ok((mime, encoded))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert ChatMessage to LLM Message format
|
/// Stateless AgentLoop - history is managed externally by SessionManager
|
||||||
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|
||||||
let content = if m.media_refs.is_empty() {
|
|
||||||
vec![ContentBlock::text(&m.content)]
|
|
||||||
} else {
|
|
||||||
build_content_blocks(&m.content, &m.media_refs)
|
|
||||||
};
|
|
||||||
|
|
||||||
Message {
|
|
||||||
role: m.role.clone(),
|
|
||||||
content,
|
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
|
||||||
name: m.tool_name.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
|
||||||
/// History is managed externally by SessionManager.
|
|
||||||
pub struct AgentLoop {
|
pub struct AgentLoop {
|
||||||
provider: Box<dyn LLMProvider>,
|
provider: Box<dyn LLMProvider>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
observer: Option<Arc<dyn Observer>>,
|
max_iterations: u32,
|
||||||
max_iterations: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
Self::with_tools(provider_config, Arc::new(ToolRegistry::new()))
|
||||||
let provider = create_provider(provider_config)
|
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
provider,
|
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
|
||||||
observer: None,
|
|
||||||
max_iterations,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let provider = create_provider(provider_config.clone())
|
||||||
let provider = create_provider(provider_config)
|
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
provider,
|
provider,
|
||||||
tools,
|
tools,
|
||||||
observer: None,
|
max_iterations: provider_config.max_iterations,
|
||||||
max_iterations,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set an observer for tracking events.
|
|
||||||
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
|
||||||
self.observer = Some(observer);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||||||
&self.tools
|
&self.tools
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process a message using the provided conversation history.
|
/// Process a message using the provided conversation history.
|
||||||
/// History management is handled externally by SessionManager.
|
/// History management is handled externally by SessionManager.
|
||||||
///
|
/// Returns (final_response, complete_message_history) where the history includes
|
||||||
/// This method supports multi-round tool calling: after executing tools,
|
/// all tool calls and results for proper session continuity.
|
||||||
/// it loops back to the LLM with the tool results until either:
|
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<(ChatMessage, Vec<ChatMessage>), AgentError> {
|
||||||
/// - The LLM returns no more tool calls (final response)
|
let mut messages = messages;
|
||||||
/// - Maximum iterations are reached
|
let mut final_content: String = String::new();
|
||||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
|
||||||
|
|
||||||
for iteration in 0..self.max_iterations {
|
for iteration in 0..self.max_iterations {
|
||||||
#[cfg(debug_assertions)]
|
tracing::debug!(iteration, history_len = messages.len(), "Starting iteration");
|
||||||
tracing::debug!(iteration, "Agent iteration started");
|
|
||||||
|
|
||||||
// Convert messages to LLM format
|
let messages_for_llm = self.build_messages_for_llm(&messages);
|
||||||
let messages_for_llm: Vec<Message> = messages
|
|
||||||
.iter()
|
|
||||||
.map(chat_message_to_llm_message)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Build request
|
|
||||||
let tools = if self.tools.has_tools() {
|
let tools = if self.tools.has_tools() {
|
||||||
Some(self.tools.get_definitions())
|
Some(self.tools.get_definitions())
|
||||||
} else {
|
} else {
|
||||||
@ -148,240 +100,155 @@ impl AgentLoop {
|
|||||||
tools,
|
tools,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Call LLM
|
|
||||||
let response = (*self.provider).chat(request).await
|
let response = (*self.provider).chat(request).await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
tracing::error!(error = %e, "LLM request failed");
|
tracing::error!(error = %e, "LLM request failed");
|
||||||
AgentError::LlmError(e.to_string())
|
AgentError::LlmError(e.to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
iteration,
|
|
||||||
response_len = response.content.len(),
|
response_len = response.content.len(),
|
||||||
tool_calls_len = response.tool_calls.len(),
|
tool_calls_len = response.tool_calls.len(),
|
||||||
"LLM response received"
|
"LLM response received"
|
||||||
);
|
);
|
||||||
|
|
||||||
// If no tool calls, this is the final response
|
if !response.tool_calls.is_empty() {
|
||||||
if response.tool_calls.is_empty() {
|
tracing::info!(count = response.tool_calls.len(), iteration, tools = ?response.tool_calls.iter().map(|tc| &tc.name).collect::<Vec<_>>(), "Tool calls detected, executing tools");
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
|
||||||
return Ok(assistant_message);
|
let assistant_message = ChatMessage::assistant(response.content.clone());
|
||||||
|
messages.push(assistant_message);
|
||||||
|
|
||||||
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
|
|
||||||
|
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
||||||
|
let tool_message = ChatMessage::tool(
|
||||||
|
tool_call.id.clone(),
|
||||||
|
tool_call.name.clone(),
|
||||||
|
result.clone(),
|
||||||
|
);
|
||||||
|
messages.push(tool_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::debug!(iteration, "Tool execution completed, continuing to next iteration");
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tool calls
|
tracing::debug!(iteration, "No tool calls in response, agent loop ending");
|
||||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
final_content = response.content;
|
||||||
|
break;
|
||||||
// Add assistant message with tool calls
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content.clone());
|
|
||||||
messages.push(assistant_message.clone());
|
|
||||||
|
|
||||||
// Execute tools and add results to messages
|
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
|
||||||
|
|
||||||
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
|
||||||
let tool_message = ChatMessage::tool(
|
|
||||||
tool_call.id.clone(),
|
|
||||||
tool_call.name.clone(),
|
|
||||||
result.output.clone(),
|
|
||||||
);
|
|
||||||
messages.push(tool_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop continues to next iteration with updated messages
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max iterations reached
|
if final_content.is_empty() {
|
||||||
let final_message = ChatMessage::assistant(
|
tracing::warn!(iterations = self.max_iterations, "Max iterations reached without final response");
|
||||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. Please try breaking the task into smaller steps.", self.max_iterations)
|
final_content = format!("Error: Max iterations ({}) reached without final response", self.max_iterations);
|
||||||
);
|
}
|
||||||
Ok(final_message)
|
|
||||||
|
let final_message = ChatMessage::assistant(final_content);
|
||||||
|
// Return both the final message and the complete history for session persistence
|
||||||
|
Ok((final_message, messages))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Determine whether to execute tools in parallel or sequentially.
|
fn build_messages_for_llm(&self, messages: &[ChatMessage]) -> Vec<Message> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
let content = if m.media_refs.is_empty() {
|
||||||
|
vec![ContentBlock::text(&m.content)]
|
||||||
|
} else {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
|
||||||
|
build_content_blocks(&m.content, &m.media_refs)
|
||||||
|
};
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
|
||||||
|
Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content,
|
||||||
|
tool_call_id: m.tool_call_id.clone(),
|
||||||
|
name: m.tool_name.clone(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
|
||||||
|
let batches = self.partition_tool_batches(tool_calls);
|
||||||
|
let mut results = Vec::with_capacity(tool_calls.len());
|
||||||
|
|
||||||
|
for batch in batches {
|
||||||
|
if batch.len() == 1 {
|
||||||
|
// Single tool — run directly (no spawn overhead)
|
||||||
|
results.push(self.execute_tool(&batch[0]).await);
|
||||||
|
} else {
|
||||||
|
// Multiple tools — run in parallel via join_all
|
||||||
|
use futures_util::future::join_all;
|
||||||
|
let futures = batch.iter().map(|tc| self.execute_tool(tc));
|
||||||
|
let batch_results = join_all(futures).await;
|
||||||
|
results.extend(batch_results);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Partition tool calls into batches based on concurrency safety.
|
||||||
///
|
///
|
||||||
/// Returns true if:
|
/// `concurrency_safe` tools are grouped together; each `exclusive` tool
|
||||||
/// - There are multiple tool calls
|
/// runs in its own batch. This matches the approach used in Nanobot's
|
||||||
/// - None of the tools require sequential execution (tool_search, non-concurrency-safe)
|
/// `_partition_tool_batches` and Zeroclaw's `parallel_tools` config.
|
||||||
fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool {
|
fn partition_tool_batches(&self, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
|
||||||
if tool_calls.len() <= 1 {
|
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
|
||||||
return false;
|
let mut current: Vec<ToolCall> = Vec::new();
|
||||||
}
|
|
||||||
|
|
||||||
// tool_search must run sequentially to avoid MCP activation race conditions
|
for tc in tool_calls {
|
||||||
if tool_calls.iter().any(|tc| tc.name == "tool_search") {
|
let concurrency_safe = self
|
||||||
return false;
|
.tools
|
||||||
}
|
|
||||||
|
|
||||||
// All tools must be concurrency-safe to run in parallel
|
|
||||||
tool_calls.iter().all(|tc| {
|
|
||||||
self.tools
|
|
||||||
.get(&tc.name)
|
.get(&tc.name)
|
||||||
.map(|t| t.concurrency_safe())
|
.map(|t| t.concurrency_safe())
|
||||||
.unwrap_or(false)
|
.unwrap_or(false);
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute multiple tool calls, choosing parallel or sequential based on conditions.
|
if concurrency_safe {
|
||||||
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
current.push(tc.clone());
|
||||||
if self.should_execute_in_parallel(tool_calls) {
|
} else {
|
||||||
tracing::debug!("Executing {} tools in parallel", tool_calls.len());
|
if !current.is_empty() {
|
||||||
self.execute_tools_parallel(tool_calls).await
|
batches.push(std::mem::take(&mut current));
|
||||||
} else {
|
}
|
||||||
tracing::debug!("Executing {} tools sequentially", tool_calls.len());
|
batches.push(vec![tc.clone()]);
|
||||||
self.execute_tools_sequential(tool_calls).await
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute tools in parallel using join_all.
|
|
||||||
async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
||||||
let futures: Vec<_> = tool_calls
|
|
||||||
.iter()
|
|
||||||
.map(|tc| self.execute_one_tool(tc))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
futures_util::future::join_all(futures).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute tools sequentially.
|
|
||||||
async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec<ToolExecutionOutcome> {
|
|
||||||
let mut outcomes = Vec::with_capacity(tool_calls.len());
|
|
||||||
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
outcomes.push(self.execute_one_tool(tool_call).await);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
outcomes
|
if !current.is_empty() {
|
||||||
|
batches.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
batches
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a single tool and return the outcome with event tracking.
|
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
|
||||||
async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
|
||||||
let start = Instant::now();
|
|
||||||
let tool_name = tool_call.name.clone();
|
|
||||||
|
|
||||||
// Record ToolCallStart event
|
|
||||||
if let Some(ref observer) = self.observer {
|
|
||||||
observer.record_event(&ObserverEvent::ToolCallStart {
|
|
||||||
tool: tool_name.clone(),
|
|
||||||
arguments: Some(truncate_args(&tool_call.arguments, 300)),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = self.execute_tool_internal(tool_call).await;
|
|
||||||
let duration = start.elapsed();
|
|
||||||
|
|
||||||
// Record ToolCall event
|
|
||||||
if let Some(ref observer) = self.observer {
|
|
||||||
observer.record_event(&ObserverEvent::ToolCall {
|
|
||||||
tool: tool_name.clone(),
|
|
||||||
duration,
|
|
||||||
success: result.success,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply duration
|
|
||||||
ToolExecutionOutcome {
|
|
||||||
duration,
|
|
||||||
..result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Internal tool execution without event tracking.
|
|
||||||
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
|
|
||||||
let tool = match self.tools.get(&tool_call.name) {
|
let tool = match self.tools.get(&tool_call.name) {
|
||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
None => {
|
None => {
|
||||||
tracing::warn!(tool = %tool_call.name, "Tool not found");
|
tracing::warn!(tool = %tool_call.name, "Tool not found");
|
||||||
return ToolExecutionOutcome::failure(
|
return format!("Error: Tool '{}' not found", tool_call.name);
|
||||||
format!("Error: Tool '{}' not found", tool_call.name),
|
|
||||||
Some(format!("Tool '{}' not found", tool_call.name)),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match tool.execute(tool_call.arguments.clone()).await {
|
match tool.execute(tool_call.arguments.clone()).await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
if result.success {
|
if result.success {
|
||||||
ToolExecutionOutcome::success(result.output)
|
result.output
|
||||||
} else {
|
} else {
|
||||||
let error = result.error.unwrap_or_default();
|
format!("Error: {}", result.error.unwrap_or_default())
|
||||||
ToolExecutionOutcome::failure(
|
|
||||||
format!("Error: {}", error),
|
|
||||||
Some(error),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
||||||
ToolExecutionOutcome::failure(
|
format!("Error: {}", e)
|
||||||
format!("Error: {}", e),
|
|
||||||
Some(e.to_string()),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::observability::{MultiObserver, Observer};
|
|
||||||
|
|
||||||
struct TestObserver {
|
|
||||||
events: std::sync::Mutex<Vec<ObserverEvent>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestObserver {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
events: std::sync::Mutex::new(Vec::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Observer for TestObserver {
|
|
||||||
fn record_event(&self, event: &ObserverEvent) {
|
|
||||||
self.events.lock().unwrap().push(event.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"test_observer"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_observer_receives_tool_events() {
|
|
||||||
// Verify MultiObserver works
|
|
||||||
let mut multi = MultiObserver::new();
|
|
||||||
multi.add_observer(Box::new(TestObserver::new()));
|
|
||||||
|
|
||||||
let event = ObserverEvent::ToolCallStart {
|
|
||||||
tool: "test".to_string(),
|
|
||||||
arguments: Some("{}".to_string()),
|
|
||||||
};
|
|
||||||
multi.record_event(&event);
|
|
||||||
|
|
||||||
// Just verify the structure works
|
|
||||||
assert_eq!(multi.len(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_execute_in_parallel_single_tool() {
|
|
||||||
// Would need a proper setup with AgentLoop to test fully
|
|
||||||
// For now, just verify the logic: single tool should return false
|
|
||||||
let calls = vec![ToolCall {
|
|
||||||
id: "1".to_string(),
|
|
||||||
name: "test".to_string(),
|
|
||||||
arguments: serde_json::json!({}),
|
|
||||||
}];
|
|
||||||
|
|
||||||
// If there's only 1 tool, should return false regardless
|
|
||||||
assert_eq!(calls.len() <= 1, true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AgentError {
|
pub enum AgentError {
|
||||||
ProviderCreation(String),
|
ProviderCreation(String),
|
||||||
@ -400,3 +267,140 @@ impl std::fmt::Display for AgentError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for AgentError {}
|
impl std::error::Error for AgentError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
use crate::tools::ToolRegistry;
|
||||||
|
use crate::tools::CalculatorTool;
|
||||||
|
use crate::tools::BashTool;
|
||||||
|
use crate::tools::FileReadTool;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn make_tc(name: &str, args: serde_json::Value) -> ToolCall {
|
||||||
|
ToolCall {
|
||||||
|
id: format!("tc_{}", name),
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verify that partition_tool_batches groups concurrency-safe tools together
|
||||||
|
/// and isolates exclusive tools, matching the nanobot/zeroclaw approach.
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_mixes_safe_and_exclusive() {
|
||||||
|
let registry = Arc::new({
|
||||||
|
let mut r = ToolRegistry::new();
|
||||||
|
r.register(CalculatorTool::new()); // concurrency_safe = true
|
||||||
|
r.register(BashTool::new()); // concurrency_safe = false (exclusive)
|
||||||
|
r.register(FileReadTool::new()); // concurrency_safe = true
|
||||||
|
r
|
||||||
|
});
|
||||||
|
|
||||||
|
// agent_loop needs a provider to construct; test the partitioning logic directly
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
make_tc("bash", json!({"command": "ls"})),
|
||||||
|
make_tc("file_read", json!({"path": "/tmp/foo"})),
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Expected:
|
||||||
|
// batch 1: calculator (safe, first run)
|
||||||
|
// batch 2: bash (exclusive, alone)
|
||||||
|
// batch 3: file_read, calculator (both safe, run together)
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 3);
|
||||||
|
assert_eq!(batches[0].len(), 1);
|
||||||
|
assert_eq!(batches[0][0].name, "calculator");
|
||||||
|
assert_eq!(batches[1].len(), 1);
|
||||||
|
assert_eq!(batches[1][0].name, "bash");
|
||||||
|
assert_eq!(batches[2].len(), 2);
|
||||||
|
assert_eq!(batches[2][0].name, "file_read");
|
||||||
|
assert_eq!(batches[2][1].name, "calculator");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// All-safe tool calls should produce a single batch (parallel execution).
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_all_safe_single_batch() {
|
||||||
|
let registry = Arc::new({
|
||||||
|
let mut r = ToolRegistry::new();
|
||||||
|
r.register(CalculatorTool::new());
|
||||||
|
r.register(FileReadTool::new());
|
||||||
|
r
|
||||||
|
});
|
||||||
|
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
make_tc("file_read", json!({"path": "/tmp/foo"})),
|
||||||
|
];
|
||||||
|
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 1);
|
||||||
|
assert_eq!(batches[0].len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// All-exclusive tool calls should each get their own batch (sequential execution).
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_all_exclusive_separate_batches() {
|
||||||
|
let registry = Arc::new({
|
||||||
|
let mut r = ToolRegistry::new();
|
||||||
|
r.register(BashTool::new());
|
||||||
|
r
|
||||||
|
});
|
||||||
|
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("bash", json!({"command": "ls"})),
|
||||||
|
make_tc("bash", json!({"command": "pwd"})),
|
||||||
|
];
|
||||||
|
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 2);
|
||||||
|
assert_eq!(batches[0].len(), 1);
|
||||||
|
assert_eq!(batches[1].len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unknown tools (not in registry) default to non-concurrency-safe (single batch).
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_unknown_tool_gets_own_batch() {
|
||||||
|
let registry = Arc::new(ToolRegistry::new());
|
||||||
|
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
make_tc("unknown_tool", json!({})),
|
||||||
|
];
|
||||||
|
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expose partition logic for testing without needing a full AgentLoop.
|
||||||
|
fn partition_for_test(registry: &Arc<ToolRegistry>, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
|
||||||
|
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
|
||||||
|
let mut current: Vec<ToolCall> = Vec::new();
|
||||||
|
|
||||||
|
for tc in tool_calls {
|
||||||
|
let concurrency_safe = registry
|
||||||
|
.get(&tc.name)
|
||||||
|
.map(|t| t.concurrency_safe())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if concurrency_safe {
|
||||||
|
current.push(tc.clone());
|
||||||
|
} else {
|
||||||
|
if !current.is_empty() {
|
||||||
|
batches.push(std::mem::take(&mut current));
|
||||||
|
}
|
||||||
|
batches.push(vec![tc.clone()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
batches.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
batches
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -73,12 +73,12 @@ pub struct ModelConfig {
|
|||||||
pub struct AgentConfig {
|
pub struct AgentConfig {
|
||||||
pub provider: String,
|
pub provider: String,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
#[serde(default = "default_max_tool_iterations")]
|
#[serde(default = "default_max_iterations")]
|
||||||
pub max_tool_iterations: usize,
|
pub max_iterations: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_tool_iterations() -> usize {
|
fn default_max_iterations() -> u32 {
|
||||||
20
|
15
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -138,7 +138,7 @@ pub struct LLMProviderConfig {
|
|||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
pub model_extra: HashMap<String, serde_json::Value>,
|
pub model_extra: HashMap<String, serde_json::Value>,
|
||||||
pub max_tool_iterations: usize,
|
pub max_iterations: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_default_config_path() -> PathBuf {
|
fn get_default_config_path() -> PathBuf {
|
||||||
@ -198,7 +198,7 @@ impl Config {
|
|||||||
temperature: model.temperature,
|
temperature: model.temperature,
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
model_extra: model.extra.clone(),
|
model_extra: model.extra.clone(),
|
||||||
max_tool_iterations: agent.max_tool_iterations,
|
max_iterations: agent.max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,5 +8,4 @@ pub mod client;
|
|||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
pub mod channels;
|
pub mod channels;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod observability;
|
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|||||||
@ -1,257 +0,0 @@
|
|||||||
//! Observability module for tracking agent and tool events.
|
|
||||||
//!
|
|
||||||
//! This module provides an Observer pattern for emitting and collecting
|
|
||||||
//! telemetry events during agent execution.
|
|
||||||
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
/// Events emitted during agent and tool execution.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum ObserverEvent {
|
|
||||||
/// Emitted before a tool starts executing.
|
|
||||||
ToolCallStart {
|
|
||||||
tool: String,
|
|
||||||
arguments: Option<String>,
|
|
||||||
},
|
|
||||||
/// Emitted after a tool completes execution.
|
|
||||||
ToolCall {
|
|
||||||
tool: String,
|
|
||||||
duration: Duration,
|
|
||||||
success: bool,
|
|
||||||
},
|
|
||||||
/// Emitted when the agent starts processing.
|
|
||||||
AgentStart {
|
|
||||||
provider: String,
|
|
||||||
model: String,
|
|
||||||
},
|
|
||||||
/// Emitted when the agent finishes processing.
|
|
||||||
AgentEnd {
|
|
||||||
provider: String,
|
|
||||||
model: String,
|
|
||||||
duration: Duration,
|
|
||||||
tokens_used: Option<u64>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Observer trait for receiving events.
|
|
||||||
///
|
|
||||||
/// Implement this trait to receive events during agent execution.
|
|
||||||
/// Observers are shared across async tasks, so implementations must be
|
|
||||||
/// Send + Sync.
|
|
||||||
pub trait Observer: Send + Sync + 'static {
|
|
||||||
/// Record a single event.
|
|
||||||
fn record_event(&self, event: &ObserverEvent);
|
|
||||||
|
|
||||||
/// Get the observer's name for identification.
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
/// Flush any buffered events (default no-op).
|
|
||||||
fn flush(&self) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Outcome of a single tool execution.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ToolExecutionOutcome {
|
|
||||||
/// The output from the tool execution.
|
|
||||||
pub output: String,
|
|
||||||
/// Whether the tool executed successfully.
|
|
||||||
pub success: bool,
|
|
||||||
/// The error reason if the tool failed.
|
|
||||||
pub error_reason: Option<String>,
|
|
||||||
/// How long the tool took to execute.
|
|
||||||
pub duration: Duration,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ToolExecutionOutcome {
|
|
||||||
/// Create a successful outcome with zero duration.
|
|
||||||
pub fn success(output: String) -> Self {
|
|
||||||
Self {
|
|
||||||
output,
|
|
||||||
success: true,
|
|
||||||
error_reason: None,
|
|
||||||
duration: Duration::ZERO,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a successful outcome with duration.
|
|
||||||
pub fn success_with_duration(output: String, duration: Duration) -> Self {
|
|
||||||
Self {
|
|
||||||
output,
|
|
||||||
success: true,
|
|
||||||
error_reason: None,
|
|
||||||
duration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a failed outcome with zero duration.
|
|
||||||
pub fn failure(output: String, error_reason: Option<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
output,
|
|
||||||
success: false,
|
|
||||||
error_reason,
|
|
||||||
duration: Duration::ZERO,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a failed outcome with duration.
|
|
||||||
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
|
||||||
Self {
|
|
||||||
output,
|
|
||||||
success: false,
|
|
||||||
error_reason,
|
|
||||||
duration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// MultiObserver broadcasts events to multiple observers.
|
|
||||||
pub struct MultiObserver {
|
|
||||||
observers: Vec<Box<dyn Observer>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MultiObserver {
|
|
||||||
/// Create a new MultiObserver.
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
observers: Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add an observer.
|
|
||||||
pub fn add_observer(&mut self, observer: Box<dyn Observer>) {
|
|
||||||
self.observers.push(observer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the number of registered observers.
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.observers.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if there are no observers.
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.observers.is_empty()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for MultiObserver {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Observer for MultiObserver {
|
|
||||||
fn record_event(&self, event: &ObserverEvent) {
|
|
||||||
for observer in &self.observers {
|
|
||||||
observer.record_event(event);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&self) {
|
|
||||||
for observer in &self.observers {
|
|
||||||
observer.flush();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"multi_observer"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Truncate arguments for logging to avoid oversized events.
|
|
||||||
pub fn truncate_args(args: &serde_json::Value, max_len: usize) -> String {
|
|
||||||
let args_str = args.to_string();
|
|
||||||
if args_str.len() <= max_len {
|
|
||||||
return args_str;
|
|
||||||
}
|
|
||||||
format!("{}...truncated", &args_str[..max_len])
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
struct TestObserver {
|
|
||||||
name: String,
|
|
||||||
events: std::sync::Mutex<Vec<ObserverEvent>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestObserver {
|
|
||||||
fn new(name: &str) -> Self {
|
|
||||||
Self {
|
|
||||||
name: name.to_string(),
|
|
||||||
events: std::sync::Mutex::new(Vec::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Observer for TestObserver {
|
|
||||||
fn record_event(&self, event: &ObserverEvent) {
|
|
||||||
self.events.lock().unwrap().push(event.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tool_execution_outcome_success() {
|
|
||||||
let outcome = ToolExecutionOutcome::success("output content".to_string());
|
|
||||||
assert!(outcome.success);
|
|
||||||
assert_eq!(outcome.output, "output content");
|
|
||||||
assert!(outcome.error_reason.is_none());
|
|
||||||
assert_eq!(outcome.duration, Duration::ZERO);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tool_execution_outcome_success_with_duration() {
|
|
||||||
let outcome = ToolExecutionOutcome::success_with_duration(
|
|
||||||
"output content".to_string(),
|
|
||||||
Duration::from_millis(100),
|
|
||||||
);
|
|
||||||
assert!(outcome.success);
|
|
||||||
assert_eq!(outcome.duration, Duration::from_millis(100));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tool_execution_outcome_failure() {
|
|
||||||
let outcome = ToolExecutionOutcome::failure(
|
|
||||||
"error output".to_string(),
|
|
||||||
Some("error reason".to_string()),
|
|
||||||
);
|
|
||||||
assert!(!outcome.success);
|
|
||||||
assert_eq!(outcome.output, "error output");
|
|
||||||
assert_eq!(outcome.error_reason, Some("error reason".to_string()));
|
|
||||||
assert_eq!(outcome.duration, Duration::ZERO);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_multi_observer_broadcasts() {
|
|
||||||
let mut multi = MultiObserver::new();
|
|
||||||
let obs1 = Box::new(TestObserver::new("obs1"));
|
|
||||||
let obs2 = Box::new(TestObserver::new("obs2"));
|
|
||||||
multi.add_observer(obs1);
|
|
||||||
multi.add_observer(obs2);
|
|
||||||
|
|
||||||
let event = ObserverEvent::ToolCallStart {
|
|
||||||
tool: "test_tool".to_string(),
|
|
||||||
arguments: Some("{}".to_string()),
|
|
||||||
};
|
|
||||||
|
|
||||||
multi.record_event(&event);
|
|
||||||
|
|
||||||
// Both observers should have received the event
|
|
||||||
assert_eq!(multi.len(), 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_truncate_args() {
|
|
||||||
let args = serde_json::json!({"key": "value"});
|
|
||||||
assert_eq!(truncate_args(&args, 100), args.to_string());
|
|
||||||
|
|
||||||
let long_args = serde_json::json!({"key": "a".repeat(200)});
|
|
||||||
let truncated = truncate_args(&long_args, 50);
|
|
||||||
assert!(truncated.ends_with("...truncated"));
|
|
||||||
assert!(truncated.len() < long_args.to_string().len());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -92,6 +92,10 @@ impl Tool for CalculatorTool {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
let function = match args.get("function").and_then(|v| v.as_str()) {
|
let function = match args.get("function").and_then(|v| v.as_str()) {
|
||||||
Some(f) => f,
|
Some(f) => f,
|
||||||
|
|||||||
@ -23,12 +23,6 @@ impl ToolRegistry {
|
|||||||
self.tools.get(name)
|
self.tools.get(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all registered tools.
|
|
||||||
/// Used for concurrent tool execution when we need to look up tools by name.
|
|
||||||
pub fn get_all(&self) -> Vec<&Box<dyn ToolTrait>> {
|
|
||||||
self.tools.values().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_definitions(&self) -> Vec<Tool> {
|
pub fn get_definitions(&self) -> Vec<Tool> {
|
||||||
self.tools
|
self.tools
|
||||||
.values()
|
.values()
|
||||||
|
|||||||
@ -23,7 +23,6 @@ fn load_config() -> Option<LLMProviderConfig> {
|
|||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,6 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
|||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user