diff --git a/.gitignore b/.gitignore index ea8c4bf..f053f61 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ /target +reference/** +.env +*.env \ No newline at end of file diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index f3072a0..86c43a0 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,10 +1,14 @@ use crate::bus::message::ContentBlock; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; +use crate::observability::{ + truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, +}; use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; use crate::tools::ToolRegistry; use std::io::Read; use std::sync::Arc; +use std::time::Instant; /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { @@ -47,192 +51,337 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error Ok((mime, encoded)) } -/// Stateless AgentLoop - history is managed externally by SessionManager +/// Convert ChatMessage to LLM Message format +fn chat_message_to_llm_message(m: &ChatMessage) -> Message { + let content = if m.media_refs.is_empty() { + vec![ContentBlock::text(&m.content)] + } else { + build_content_blocks(&m.content, &m.media_refs) + }; + + Message { + role: m.role.clone(), + content, + tool_call_id: m.tool_call_id.clone(), + name: m.tool_name.clone(), + } +} + +/// AgentLoop - Stateless agent that processes messages with tool calling support. +/// History is managed externally by SessionManager. pub struct AgentLoop { provider: Box, tools: Arc, + observer: Option>, + max_iterations: usize, } impl AgentLoop { pub fn new(provider_config: LLMProviderConfig) -> Result { + let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { provider, tools: Arc::new(ToolRegistry::new()), + observer: None, + max_iterations, }) } pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { + let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { provider, tools, + observer: None, + max_iterations, }) } + /// Set an observer for tracking events. + pub fn with_observer(mut self, observer: Arc) -> Self { + self.observer = Some(observer); + self + } + pub fn tools(&self) -> &Arc { &self.tools } /// Process a message using the provided conversation history. /// History management is handled externally by SessionManager. - pub async fn process(&self, messages: Vec) -> Result { - let messages_for_llm: Vec = messages - .iter() - .map(|m| { - let content = if m.media_refs.is_empty() { - vec![ContentBlock::text(&m.content)] - } else { - #[cfg(debug_assertions)] - tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media"); - build_content_blocks(&m.content, &m.media_refs) - }; - #[cfg(debug_assertions)] - tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message"); - Message { - role: m.role.clone(), - content, - tool_call_id: m.tool_call_id.clone(), - name: m.tool_name.clone(), - } - }) - .collect(); - + /// + /// This method supports multi-round tool calling: after executing tools, + /// it loops back to the LLM with the tool results until either: + /// - The LLM returns no more tool calls (final response) + /// - Maximum iterations are reached + pub async fn process(&self, mut messages: Vec) -> Result { #[cfg(debug_assertions)] - tracing::debug!(history_len = messages.len(), "Sending request to LLM"); + tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); - let tools = if self.tools.has_tools() { - Some(self.tools.get_definitions()) - } else { - None - }; + for iteration in 0..self.max_iterations { + #[cfg(debug_assertions)] + tracing::debug!(iteration, "Agent iteration started"); - let request = ChatCompletionRequest { - messages: messages_for_llm, - temperature: None, - max_tokens: None, - tools, - }; + // Convert messages to LLM format + let messages_for_llm: Vec = messages + .iter() + .map(chat_message_to_llm_message) + .collect(); - let response = (*self.provider).chat(request).await - .map_err(|e| { - tracing::error!(error = %e, "LLM request failed"); - AgentError::LlmError(e.to_string()) - })?; + // Build request + let tools = if self.tools.has_tools() { + Some(self.tools.get_definitions()) + } else { + None + }; - #[cfg(debug_assertions)] - tracing::debug!( - response_len = response.content.len(), - tool_calls_len = response.tool_calls.len(), - "LLM response received" - ); + let request = ChatCompletionRequest { + messages: messages_for_llm, + temperature: None, + max_tokens: None, + tools, + }; - if !response.tool_calls.is_empty() { - tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools"); + // Call LLM + let response = (*self.provider).chat(request).await + .map_err(|e| { + tracing::error!(error = %e, "LLM request failed"); + AgentError::LlmError(e.to_string()) + })?; - let mut updated_messages = messages.clone(); + #[cfg(debug_assertions)] + tracing::debug!( + iteration, + response_len = response.content.len(), + tool_calls_len = response.tool_calls.len(), + "LLM response received" + ); + + // If no tool calls, this is the final response + if response.tool_calls.is_empty() { + let assistant_message = ChatMessage::assistant(response.content); + return Ok(assistant_message); + } + + // Execute tool calls + tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); + + // Add assistant message with tool calls let assistant_message = ChatMessage::assistant(response.content.clone()); - updated_messages.push(assistant_message.clone()); + messages.push(assistant_message.clone()); + // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { let tool_message = ChatMessage::tool( tool_call.id.clone(), tool_call.name.clone(), - result.clone(), + result.output.clone(), ); - updated_messages.push(tool_message); + messages.push(tool_message); } - return self.continue_with_tool_results(updated_messages).await; + // Loop continues to next iteration with updated messages + #[cfg(debug_assertions)] + tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); } - let assistant_message = ChatMessage::assistant(response.content); - Ok(assistant_message) + // Max iterations reached + let final_message = ChatMessage::assistant( + format!("I reached the maximum number of tool call iterations ({}) without completing the task. Please try breaking the task into smaller steps.", self.max_iterations) + ); + Ok(final_message) } - async fn continue_with_tool_results(&self, messages: Vec) -> Result { - let messages_for_llm: Vec = messages + /// Determine whether to execute tools in parallel or sequentially. + /// + /// Returns true if: + /// - There are multiple tool calls + /// - None of the tools require sequential execution (tool_search, non-concurrency-safe) + fn should_execute_in_parallel(&self, tool_calls: &[ToolCall]) -> bool { + if tool_calls.len() <= 1 { + return false; + } + + // tool_search must run sequentially to avoid MCP activation race conditions + if tool_calls.iter().any(|tc| tc.name == "tool_search") { + return false; + } + + // All tools must be concurrency-safe to run in parallel + tool_calls.iter().all(|tc| { + self.tools + .get(&tc.name) + .map(|t| t.concurrency_safe()) + .unwrap_or(false) + }) + } + + /// Execute multiple tool calls, choosing parallel or sequential based on conditions. + async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { + if self.should_execute_in_parallel(tool_calls) { + tracing::debug!("Executing {} tools in parallel", tool_calls.len()); + self.execute_tools_parallel(tool_calls).await + } else { + tracing::debug!("Executing {} tools sequentially", tool_calls.len()); + self.execute_tools_sequential(tool_calls).await + } + } + + /// Execute tools in parallel using join_all. + async fn execute_tools_parallel(&self, tool_calls: &[ToolCall]) -> Vec { + let futures: Vec<_> = tool_calls .iter() - .map(|m| { - let content = if m.media_refs.is_empty() { - vec![ContentBlock::text(&m.content)] - } else { - build_content_blocks(&m.content, &m.media_refs) - }; - Message { - role: m.role.clone(), - content, - tool_call_id: m.tool_call_id.clone(), - name: m.tool_name.clone(), - } - }) + .map(|tc| self.execute_one_tool(tc)) .collect(); - let tools = if self.tools.has_tools() { - Some(self.tools.get_definitions()) - } else { - None - }; - - let request = ChatCompletionRequest { - messages: messages_for_llm, - temperature: None, - max_tokens: None, - tools, - }; - - let response = (*self.provider).chat(request).await - .map_err(|e| { - tracing::error!(error = %e, "LLM continuation request failed"); - AgentError::LlmError(e.to_string()) - })?; - - let assistant_message = ChatMessage::assistant(response.content); - Ok(assistant_message) + futures_util::future::join_all(futures).await } - async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { - let mut results = Vec::with_capacity(tool_calls.len()); + /// Execute tools sequentially. + async fn execute_tools_sequential(&self, tool_calls: &[ToolCall]) -> Vec { + let mut outcomes = Vec::with_capacity(tool_calls.len()); for tool_call in tool_calls { - let result = self.execute_tool(tool_call).await; - results.push(result); + outcomes.push(self.execute_one_tool(tool_call).await); } - results + outcomes } - async fn execute_tool(&self, tool_call: &ToolCall) -> String { + /// Execute a single tool and return the outcome with event tracking. + async fn execute_one_tool(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { + let start = Instant::now(); + let tool_name = tool_call.name.clone(); + + // Record ToolCallStart event + if let Some(ref observer) = self.observer { + observer.record_event(&ObserverEvent::ToolCallStart { + tool: tool_name.clone(), + arguments: Some(truncate_args(&tool_call.arguments, 300)), + }); + } + + let result = self.execute_tool_internal(tool_call).await; + let duration = start.elapsed(); + + // Record ToolCall event + if let Some(ref observer) = self.observer { + observer.record_event(&ObserverEvent::ToolCall { + tool: tool_name.clone(), + duration, + success: result.success, + }); + } + + // Apply duration + ToolExecutionOutcome { + duration, + ..result + } + } + + /// Internal tool execution without event tracking. + async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { let tool = match self.tools.get(&tool_call.name) { Some(t) => t, None => { tracing::warn!(tool = %tool_call.name, "Tool not found"); - return format!("Error: Tool '{}' not found", tool_call.name); + return ToolExecutionOutcome::failure( + format!("Error: Tool '{}' not found", tool_call.name), + Some(format!("Tool '{}' not found", tool_call.name)), + ); } }; match tool.execute(tool_call.arguments.clone()).await { Ok(result) => { if result.success { - result.output + ToolExecutionOutcome::success(result.output) } else { - format!("Error: {}", result.error.unwrap_or_default()) + let error = result.error.unwrap_or_default(); + ToolExecutionOutcome::failure( + format!("Error: {}", error), + Some(error), + ) } } Err(e) => { tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed"); - format!("Error: {}", e) + ToolExecutionOutcome::failure( + format!("Error: {}", e), + Some(e.to_string()), + ) } } } } +#[cfg(test)] +mod tests { + use super::*; + use crate::observability::{MultiObserver, Observer}; + + struct TestObserver { + events: std::sync::Mutex>, + } + + impl TestObserver { + fn new() -> Self { + Self { + events: std::sync::Mutex::new(Vec::new()), + } + } + } + + impl Observer for TestObserver { + fn record_event(&self, event: &ObserverEvent) { + self.events.lock().unwrap().push(event.clone()); + } + + fn name(&self) -> &str { + "test_observer" + } + } + + #[tokio::test] + async fn test_observer_receives_tool_events() { + // Verify MultiObserver works + let mut multi = MultiObserver::new(); + multi.add_observer(Box::new(TestObserver::new())); + + let event = ObserverEvent::ToolCallStart { + tool: "test".to_string(), + arguments: Some("{}".to_string()), + }; + multi.record_event(&event); + + // Just verify the structure works + assert_eq!(multi.len(), 1); + } + + #[test] + fn test_should_execute_in_parallel_single_tool() { + // Would need a proper setup with AgentLoop to test fully + // For now, just verify the logic: single tool should return false + let calls = vec![ToolCall { + id: "1".to_string(), + name: "test".to_string(), + arguments: serde_json::json!({}), + }]; + + // If there's only 1 tool, should return false regardless + assert_eq!(calls.len() <= 1, true); + } +} + #[derive(Debug)] pub enum AgentError { ProviderCreation(String), diff --git a/src/config/mod.rs b/src/config/mod.rs index a44e64e..0f66974 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -45,7 +45,7 @@ fn default_media_dir() -> String { } fn default_reaction_emoji() -> String { - "THUMBSUP".to_string() + "Typing".to_string() } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -73,6 +73,12 @@ pub struct ModelConfig { pub struct AgentConfig { pub provider: String, pub model: String, + #[serde(default = "default_max_tool_iterations")] + pub max_tool_iterations: usize, +} + +fn default_max_tool_iterations() -> usize { + 20 } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -132,6 +138,7 @@ pub struct LLMProviderConfig { pub temperature: Option, pub max_tokens: Option, pub model_extra: HashMap, + pub max_tool_iterations: usize, } fn get_default_config_path() -> PathBuf { @@ -191,6 +198,7 @@ impl Config { temperature: model.temperature, max_tokens: model.max_tokens, model_extra: model.extra.clone(), + max_tool_iterations: agent.max_tool_iterations, }) } } diff --git a/src/lib.rs b/src/lib.rs index 99bb45e..d7778e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,4 +8,5 @@ pub mod client; pub mod protocol; pub mod channels; pub mod logging; +pub mod observability; pub mod tools; diff --git a/src/observability/mod.rs b/src/observability/mod.rs new file mode 100644 index 0000000..d050567 --- /dev/null +++ b/src/observability/mod.rs @@ -0,0 +1,257 @@ +//! Observability module for tracking agent and tool events. +//! +//! This module provides an Observer pattern for emitting and collecting +//! telemetry events during agent execution. + +use std::time::Duration; + +/// Events emitted during agent and tool execution. +#[derive(Debug, Clone)] +pub enum ObserverEvent { + /// Emitted before a tool starts executing. + ToolCallStart { + tool: String, + arguments: Option, + }, + /// Emitted after a tool completes execution. + ToolCall { + tool: String, + duration: Duration, + success: bool, + }, + /// Emitted when the agent starts processing. + AgentStart { + provider: String, + model: String, + }, + /// Emitted when the agent finishes processing. + AgentEnd { + provider: String, + model: String, + duration: Duration, + tokens_used: Option, + }, +} + +/// Observer trait for receiving events. +/// +/// Implement this trait to receive events during agent execution. +/// Observers are shared across async tasks, so implementations must be +/// Send + Sync. +pub trait Observer: Send + Sync + 'static { + /// Record a single event. + fn record_event(&self, event: &ObserverEvent); + + /// Get the observer's name for identification. + fn name(&self) -> &str; + + /// Flush any buffered events (default no-op). + fn flush(&self) {} +} + +/// Outcome of a single tool execution. +#[derive(Debug, Clone)] +pub struct ToolExecutionOutcome { + /// The output from the tool execution. + pub output: String, + /// Whether the tool executed successfully. + pub success: bool, + /// The error reason if the tool failed. + pub error_reason: Option, + /// How long the tool took to execute. + pub duration: Duration, +} + +impl ToolExecutionOutcome { + /// Create a successful outcome with zero duration. + pub fn success(output: String) -> Self { + Self { + output, + success: true, + error_reason: None, + duration: Duration::ZERO, + } + } + + /// Create a successful outcome with duration. + pub fn success_with_duration(output: String, duration: Duration) -> Self { + Self { + output, + success: true, + error_reason: None, + duration, + } + } + + /// Create a failed outcome with zero duration. + pub fn failure(output: String, error_reason: Option) -> Self { + Self { + output, + success: false, + error_reason, + duration: Duration::ZERO, + } + } + + /// Create a failed outcome with duration. + pub fn failure_with_duration(output: String, error_reason: Option, duration: Duration) -> Self { + Self { + output, + success: false, + error_reason, + duration, + } + } +} + +/// MultiObserver broadcasts events to multiple observers. +pub struct MultiObserver { + observers: Vec>, +} + +impl MultiObserver { + /// Create a new MultiObserver. + pub fn new() -> Self { + Self { + observers: Vec::new(), + } + } + + /// Add an observer. + pub fn add_observer(&mut self, observer: Box) { + self.observers.push(observer); + } + + /// Get the number of registered observers. + pub fn len(&self) -> usize { + self.observers.len() + } + + /// Check if there are no observers. + pub fn is_empty(&self) -> bool { + self.observers.is_empty() + } +} + +impl Default for MultiObserver { + fn default() -> Self { + Self::new() + } +} + +impl Observer for MultiObserver { + fn record_event(&self, event: &ObserverEvent) { + for observer in &self.observers { + observer.record_event(event); + } + } + + fn flush(&self) { + for observer in &self.observers { + observer.flush(); + } + } + + fn name(&self) -> &str { + "multi_observer" + } +} + +/// Truncate arguments for logging to avoid oversized events. +pub fn truncate_args(args: &serde_json::Value, max_len: usize) -> String { + let args_str = args.to_string(); + if args_str.len() <= max_len { + return args_str; + } + format!("{}...truncated", &args_str[..max_len]) +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestObserver { + name: String, + events: std::sync::Mutex>, + } + + impl TestObserver { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + events: std::sync::Mutex::new(Vec::new()), + } + } + } + + impl Observer for TestObserver { + fn record_event(&self, event: &ObserverEvent) { + self.events.lock().unwrap().push(event.clone()); + } + + fn name(&self) -> &str { + &self.name + } + } + + #[test] + fn test_tool_execution_outcome_success() { + let outcome = ToolExecutionOutcome::success("output content".to_string()); + assert!(outcome.success); + assert_eq!(outcome.output, "output content"); + assert!(outcome.error_reason.is_none()); + assert_eq!(outcome.duration, Duration::ZERO); + } + + #[test] + fn test_tool_execution_outcome_success_with_duration() { + let outcome = ToolExecutionOutcome::success_with_duration( + "output content".to_string(), + Duration::from_millis(100), + ); + assert!(outcome.success); + assert_eq!(outcome.duration, Duration::from_millis(100)); + } + + #[test] + fn test_tool_execution_outcome_failure() { + let outcome = ToolExecutionOutcome::failure( + "error output".to_string(), + Some("error reason".to_string()), + ); + assert!(!outcome.success); + assert_eq!(outcome.output, "error output"); + assert_eq!(outcome.error_reason, Some("error reason".to_string())); + assert_eq!(outcome.duration, Duration::ZERO); + } + + #[test] + fn test_multi_observer_broadcasts() { + let mut multi = MultiObserver::new(); + let obs1 = Box::new(TestObserver::new("obs1")); + let obs2 = Box::new(TestObserver::new("obs2")); + multi.add_observer(obs1); + multi.add_observer(obs2); + + let event = ObserverEvent::ToolCallStart { + tool: "test_tool".to_string(), + arguments: Some("{}".to_string()), + }; + + multi.record_event(&event); + + // Both observers should have received the event + assert_eq!(multi.len(), 2); + } + + #[test] + fn test_truncate_args() { + let args = serde_json::json!({"key": "value"}); + assert_eq!(truncate_args(&args, 100), args.to_string()); + + let long_args = serde_json::json!({"key": "a".repeat(200)}); + let truncated = truncate_args(&long_args, 50); + assert!(truncated.ends_with("...truncated")); + assert!(truncated.len() < long_args.to_string().len()); + } +} diff --git a/src/tools/registry.rs b/src/tools/registry.rs index fb88c87..64d0be7 100644 --- a/src/tools/registry.rs +++ b/src/tools/registry.rs @@ -23,6 +23,12 @@ impl ToolRegistry { self.tools.get(name) } + /// Get all registered tools. + /// Used for concurrent tool execution when we need to look up tools by name. + pub fn get_all(&self) -> Vec<&Box> { + self.tools.values().collect() + } + pub fn get_definitions(&self) -> Vec { self.tools .values() diff --git a/tests/test_integration.rs b/tests/test_integration.rs index aeb9e95..5a942e5 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -23,6 +23,7 @@ fn load_config() -> Option { temperature: Some(0.0), max_tokens: Some(100), model_extra: HashMap::new(), + max_tool_iterations: 20, }) } diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index c96ba60..1421891 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -23,6 +23,7 @@ fn load_openai_config() -> Option { temperature: Some(0.0), max_tokens: Some(100), model_extra: HashMap::new(), + max_tool_iterations: 20, }) }