diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 5a22b15..f3072a0 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -51,22 +51,26 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error pub struct AgentLoop { provider: Box, tools: Arc, - max_iterations: u32, } impl AgentLoop { pub fn new(provider_config: LLMProviderConfig) -> Result { - Self::with_tools(provider_config, Arc::new(ToolRegistry::new())) + let provider = create_provider(provider_config) + .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + + Ok(Self { + provider, + tools: Arc::new(ToolRegistry::new()), + }) } pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { - let provider = create_provider(provider_config.clone()) + let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { provider, tools, - max_iterations: provider_config.max_iterations, }) } @@ -76,80 +80,8 @@ impl AgentLoop { /// Process a message using the provided conversation history. /// History management is handled externally by SessionManager. - /// Returns (final_response, complete_message_history) where the history includes - /// all tool calls and results for proper session continuity. - pub async fn process(&self, messages: Vec) -> Result<(ChatMessage, Vec), AgentError> { - let mut messages = messages; - let mut final_content: String = String::new(); - - for iteration in 0..self.max_iterations { - tracing::debug!(iteration, history_len = messages.len(), "Starting iteration"); - - let messages_for_llm = self.build_messages_for_llm(&messages); - - let tools = if self.tools.has_tools() { - Some(self.tools.get_definitions()) - } else { - None - }; - - let request = ChatCompletionRequest { - messages: messages_for_llm, - temperature: None, - max_tokens: None, - tools, - }; - - let response = (*self.provider).chat(request).await - .map_err(|e| { - tracing::error!(error = %e, "LLM request failed"); - AgentError::LlmError(e.to_string()) - })?; - - tracing::debug!( - response_len = response.content.len(), - tool_calls_len = response.tool_calls.len(), - "LLM response received" - ); - - if !response.tool_calls.is_empty() { - tracing::info!(count = response.tool_calls.len(), iteration, tools = ?response.tool_calls.iter().map(|tc| &tc.name).collect::>(), "Tool calls detected, executing tools"); - - let assistant_message = ChatMessage::assistant(response.content.clone()); - messages.push(assistant_message); - - let tool_results = self.execute_tools(&response.tool_calls).await; - - for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { - let tool_message = ChatMessage::tool( - tool_call.id.clone(), - tool_call.name.clone(), - result.clone(), - ); - messages.push(tool_message); - } - - tracing::debug!(iteration, "Tool execution completed, continuing to next iteration"); - continue; - } - - tracing::debug!(iteration, "No tool calls in response, agent loop ending"); - final_content = response.content; - break; - } - - if final_content.is_empty() { - tracing::warn!(iterations = self.max_iterations, "Max iterations reached without final response"); - final_content = format!("Error: Max iterations ({}) reached without final response", self.max_iterations); - } - - let final_message = ChatMessage::assistant(final_content); - // Return both the final message and the complete history for session persistence - Ok((final_message, messages)) - } - - fn build_messages_for_llm(&self, messages: &[ChatMessage]) -> Vec { - messages + pub async fn process(&self, messages: Vec) -> Result { + let messages_for_llm: Vec = messages .iter() .map(|m| { let content = if m.media_refs.is_empty() { @@ -168,62 +100,114 @@ impl AgentLoop { name: m.tool_name.clone(), } }) - .collect() + .collect(); + + #[cfg(debug_assertions)] + tracing::debug!(history_len = messages.len(), "Sending request to LLM"); + + let tools = if self.tools.has_tools() { + Some(self.tools.get_definitions()) + } else { + None + }; + + let request = ChatCompletionRequest { + messages: messages_for_llm, + temperature: None, + max_tokens: None, + tools, + }; + + let response = (*self.provider).chat(request).await + .map_err(|e| { + tracing::error!(error = %e, "LLM request failed"); + AgentError::LlmError(e.to_string()) + })?; + + #[cfg(debug_assertions)] + tracing::debug!( + response_len = response.content.len(), + tool_calls_len = response.tool_calls.len(), + "LLM response received" + ); + + if !response.tool_calls.is_empty() { + tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools"); + + let mut updated_messages = messages.clone(); + let assistant_message = ChatMessage::assistant(response.content.clone()); + updated_messages.push(assistant_message.clone()); + + let tool_results = self.execute_tools(&response.tool_calls).await; + + for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { + let tool_message = ChatMessage::tool( + tool_call.id.clone(), + tool_call.name.clone(), + result.clone(), + ); + updated_messages.push(tool_message); + } + + return self.continue_with_tool_results(updated_messages).await; + } + + let assistant_message = ChatMessage::assistant(response.content); + Ok(assistant_message) + } + + async fn continue_with_tool_results(&self, messages: Vec) -> Result { + let messages_for_llm: Vec = messages + .iter() + .map(|m| { + let content = if m.media_refs.is_empty() { + vec![ContentBlock::text(&m.content)] + } else { + build_content_blocks(&m.content, &m.media_refs) + }; + Message { + role: m.role.clone(), + content, + tool_call_id: m.tool_call_id.clone(), + name: m.tool_name.clone(), + } + }) + .collect(); + + let tools = if self.tools.has_tools() { + Some(self.tools.get_definitions()) + } else { + None + }; + + let request = ChatCompletionRequest { + messages: messages_for_llm, + temperature: None, + max_tokens: None, + tools, + }; + + let response = (*self.provider).chat(request).await + .map_err(|e| { + tracing::error!(error = %e, "LLM continuation request failed"); + AgentError::LlmError(e.to_string()) + })?; + + let assistant_message = ChatMessage::assistant(response.content); + Ok(assistant_message) } async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { - let batches = self.partition_tool_batches(tool_calls); let mut results = Vec::with_capacity(tool_calls.len()); - for batch in batches { - if batch.len() == 1 { - // Single tool — run directly (no spawn overhead) - results.push(self.execute_tool(&batch[0]).await); - } else { - // Multiple tools — run in parallel via join_all - use futures_util::future::join_all; - let futures = batch.iter().map(|tc| self.execute_tool(tc)); - let batch_results = join_all(futures).await; - results.extend(batch_results); - } + for tool_call in tool_calls { + let result = self.execute_tool(tool_call).await; + results.push(result); } results } - /// Partition tool calls into batches based on concurrency safety. - /// - /// `concurrency_safe` tools are grouped together; each `exclusive` tool - /// runs in its own batch. This matches the approach used in Nanobot's - /// `_partition_tool_batches` and Zeroclaw's `parallel_tools` config. - fn partition_tool_batches(&self, tool_calls: &[ToolCall]) -> Vec> { - let mut batches: Vec> = Vec::new(); - let mut current: Vec = Vec::new(); - - for tc in tool_calls { - let concurrency_safe = self - .tools - .get(&tc.name) - .map(|t| t.concurrency_safe()) - .unwrap_or(false); - - if concurrency_safe { - current.push(tc.clone()); - } else { - if !current.is_empty() { - batches.push(std::mem::take(&mut current)); - } - batches.push(vec![tc.clone()]); - } - } - - if !current.is_empty() { - batches.push(current); - } - - batches - } - async fn execute_tool(&self, tool_call: &ToolCall) -> String { let tool = match self.tools.get(&tool_call.name) { Some(t) => t, @@ -267,140 +251,3 @@ impl std::fmt::Display for AgentError { } impl std::error::Error for AgentError {} - -#[cfg(test)] -mod tests { - use super::*; - use crate::providers::ToolCall; - use crate::tools::ToolRegistry; - use crate::tools::CalculatorTool; - use crate::tools::BashTool; - use crate::tools::FileReadTool; - use std::sync::Arc; - use serde_json::json; - - fn make_tc(name: &str, args: serde_json::Value) -> ToolCall { - ToolCall { - id: format!("tc_{}", name), - name: name.to_string(), - arguments: args, - } - } - - /// Verify that partition_tool_batches groups concurrency-safe tools together - /// and isolates exclusive tools, matching the nanobot/zeroclaw approach. - #[test] - fn test_partition_batches_mixes_safe_and_exclusive() { - let registry = Arc::new({ - let mut r = ToolRegistry::new(); - r.register(CalculatorTool::new()); // concurrency_safe = true - r.register(BashTool::new()); // concurrency_safe = false (exclusive) - r.register(FileReadTool::new()); // concurrency_safe = true - r - }); - - // agent_loop needs a provider to construct; test the partitioning logic directly - let tcs = vec![ - make_tc("calculator", json!({})), - make_tc("bash", json!({"command": "ls"})), - make_tc("file_read", json!({"path": "/tmp/foo"})), - make_tc("calculator", json!({})), - ]; - - // Expected: - // batch 1: calculator (safe, first run) - // batch 2: bash (exclusive, alone) - // batch 3: file_read, calculator (both safe, run together) - let batches = partition_for_test(®istry, &tcs); - assert_eq!(batches.len(), 3); - assert_eq!(batches[0].len(), 1); - assert_eq!(batches[0][0].name, "calculator"); - assert_eq!(batches[1].len(), 1); - assert_eq!(batches[1][0].name, "bash"); - assert_eq!(batches[2].len(), 2); - assert_eq!(batches[2][0].name, "file_read"); - assert_eq!(batches[2][1].name, "calculator"); - } - - /// All-safe tool calls should produce a single batch (parallel execution). - #[test] - fn test_partition_batches_all_safe_single_batch() { - let registry = Arc::new({ - let mut r = ToolRegistry::new(); - r.register(CalculatorTool::new()); - r.register(FileReadTool::new()); - r - }); - - let tcs = vec![ - make_tc("calculator", json!({})), - make_tc("file_read", json!({"path": "/tmp/foo"})), - ]; - - let batches = partition_for_test(®istry, &tcs); - assert_eq!(batches.len(), 1); - assert_eq!(batches[0].len(), 2); - } - - /// All-exclusive tool calls should each get their own batch (sequential execution). - #[test] - fn test_partition_batches_all_exclusive_separate_batches() { - let registry = Arc::new({ - let mut r = ToolRegistry::new(); - r.register(BashTool::new()); - r - }); - - let tcs = vec![ - make_tc("bash", json!({"command": "ls"})), - make_tc("bash", json!({"command": "pwd"})), - ]; - - let batches = partition_for_test(®istry, &tcs); - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].len(), 1); - assert_eq!(batches[1].len(), 1); - } - - /// Unknown tools (not in registry) default to non-concurrency-safe (single batch). - #[test] - fn test_partition_batches_unknown_tool_gets_own_batch() { - let registry = Arc::new(ToolRegistry::new()); - - let tcs = vec![ - make_tc("calculator", json!({})), - make_tc("unknown_tool", json!({})), - ]; - - let batches = partition_for_test(®istry, &tcs); - assert_eq!(batches.len(), 2); - } - - /// Expose partition logic for testing without needing a full AgentLoop. - fn partition_for_test(registry: &Arc, tool_calls: &[ToolCall]) -> Vec> { - let mut batches: Vec> = Vec::new(); - let mut current: Vec = Vec::new(); - - for tc in tool_calls { - let concurrency_safe = registry - .get(&tc.name) - .map(|t| t.concurrency_safe()) - .unwrap_or(false); - - if concurrency_safe { - current.push(tc.clone()); - } else { - if !current.is_empty() { - batches.push(std::mem::take(&mut current)); - } - batches.push(vec![tc.clone()]); - } - } - - if !current.is_empty() { - batches.push(current); - } - - batches - } -} diff --git a/src/config/mod.rs b/src/config/mod.rs index 1f8083c..a44e64e 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -45,7 +45,7 @@ fn default_media_dir() -> String { } fn default_reaction_emoji() -> String { - "Typing".to_string() + "THUMBSUP".to_string() } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -73,12 +73,6 @@ pub struct ModelConfig { pub struct AgentConfig { pub provider: String, pub model: String, - #[serde(default = "default_max_iterations")] - pub max_iterations: u32, -} - -fn default_max_iterations() -> u32 { - 15 } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -138,7 +132,6 @@ pub struct LLMProviderConfig { pub temperature: Option, pub max_tokens: Option, pub model_extra: HashMap, - pub max_iterations: u32, } fn get_default_config_path() -> PathBuf { @@ -198,7 +191,6 @@ impl Config { temperature: model.temperature, max_tokens: model.max_tokens, model_extra: model.extra.clone(), - max_iterations: agent.max_iterations, }) } } diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index e07edfa..de29b73 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -92,10 +92,6 @@ impl Tool for CalculatorTool { }) } - fn read_only(&self) -> bool { - true - } - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let function = match args.get("function").and_then(|v| v.as_str()) { Some(f) => f,