Revert "feat(agent): add parallel tool execution with concurrency-safe batching"

This reverts commit 0c0d0c14436a230ca34c0cd9647c8fc14f62d7cc.
This commit is contained in:
xiaoxixi 2026-04-12 09:54:38 +08:00
parent 0c0d0c1443
commit 862eb1115a
3 changed files with 109 additions and 274 deletions

View File

@ -51,22 +51,26 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
pub struct AgentLoop { pub struct AgentLoop {
provider: Box<dyn LLMProvider>, provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
max_iterations: u32,
} }
impl AgentLoop { impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> { pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
Self::with_tools(provider_config, Arc::new(ToolRegistry::new())) let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
tools: Arc::new(ToolRegistry::new()),
})
} }
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> { pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let provider = create_provider(provider_config.clone()) let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?; .map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self { Ok(Self {
provider, provider,
tools, tools,
max_iterations: provider_config.max_iterations,
}) })
} }
@ -76,16 +80,30 @@ impl AgentLoop {
/// Process a message using the provided conversation history. /// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager. /// History management is handled externally by SessionManager.
/// Returns (final_response, complete_message_history) where the history includes pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
/// all tool calls and results for proper session continuity. let messages_for_llm: Vec<Message> = messages
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<(ChatMessage, Vec<ChatMessage>), AgentError> { .iter()
let mut messages = messages; .map(|m| {
let mut final_content: String = String::new(); let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
#[cfg(debug_assertions)]
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
build_content_blocks(&m.content, &m.media_refs)
};
#[cfg(debug_assertions)]
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
Message {
role: m.role.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
}
})
.collect();
for iteration in 0..self.max_iterations { #[cfg(debug_assertions)]
tracing::debug!(iteration, history_len = messages.len(), "Starting iteration"); tracing::debug!(history_len = messages.len(), "Sending request to LLM");
let messages_for_llm = self.build_messages_for_llm(&messages);
let tools = if self.tools.has_tools() { let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions()) Some(self.tools.get_definitions())
@ -106,6 +124,7 @@ impl AgentLoop {
AgentError::LlmError(e.to_string()) AgentError::LlmError(e.to_string())
})?; })?;
#[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
response_len = response.content.len(), response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(), tool_calls_len = response.tool_calls.len(),
@ -113,10 +132,11 @@ impl AgentLoop {
); );
if !response.tool_calls.is_empty() { if !response.tool_calls.is_empty() {
tracing::info!(count = response.tool_calls.len(), iteration, tools = ?response.tool_calls.iter().map(|tc| &tc.name).collect::<Vec<_>>(), "Tool calls detected, executing tools"); tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
let mut updated_messages = messages.clone();
let assistant_message = ChatMessage::assistant(response.content.clone()); let assistant_message = ChatMessage::assistant(response.content.clone());
messages.push(assistant_message); updated_messages.push(assistant_message.clone());
let tool_results = self.execute_tools(&response.tool_calls).await; let tool_results = self.execute_tools(&response.tool_calls).await;
@ -126,41 +146,25 @@ impl AgentLoop {
tool_call.name.clone(), tool_call.name.clone(),
result.clone(), result.clone(),
); );
messages.push(tool_message); updated_messages.push(tool_message);
} }
tracing::debug!(iteration, "Tool execution completed, continuing to next iteration"); return self.continue_with_tool_results(updated_messages).await;
continue;
} }
tracing::debug!(iteration, "No tool calls in response, agent loop ending"); let assistant_message = ChatMessage::assistant(response.content);
final_content = response.content; Ok(assistant_message)
break;
} }
if final_content.is_empty() { async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
tracing::warn!(iterations = self.max_iterations, "Max iterations reached without final response"); let messages_for_llm: Vec<Message> = messages
final_content = format!("Error: Max iterations ({}) reached without final response", self.max_iterations);
}
let final_message = ChatMessage::assistant(final_content);
// Return both the final message and the complete history for session persistence
Ok((final_message, messages))
}
fn build_messages_for_llm(&self, messages: &[ChatMessage]) -> Vec<Message> {
messages
.iter() .iter()
.map(|m| { .map(|m| {
let content = if m.media_refs.is_empty() { let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)] vec![ContentBlock::text(&m.content)]
} else { } else {
#[cfg(debug_assertions)]
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
build_content_blocks(&m.content, &m.media_refs) build_content_blocks(&m.content, &m.media_refs)
}; };
#[cfg(debug_assertions)]
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
Message { Message {
role: m.role.clone(), role: m.role.clone(),
content, content,
@ -168,62 +172,42 @@ impl AgentLoop {
name: m.tool_name.clone(), name: m.tool_name.clone(),
} }
}) })
.collect() .collect();
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM continuation request failed");
AgentError::LlmError(e.to_string())
})?;
let assistant_message = ChatMessage::assistant(response.content);
Ok(assistant_message)
} }
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> { async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
let batches = self.partition_tool_batches(tool_calls);
let mut results = Vec::with_capacity(tool_calls.len()); let mut results = Vec::with_capacity(tool_calls.len());
for batch in batches { for tool_call in tool_calls {
if batch.len() == 1 { let result = self.execute_tool(tool_call).await;
// Single tool — run directly (no spawn overhead) results.push(result);
results.push(self.execute_tool(&batch[0]).await);
} else {
// Multiple tools — run in parallel via join_all
use futures_util::future::join_all;
let futures = batch.iter().map(|tc| self.execute_tool(tc));
let batch_results = join_all(futures).await;
results.extend(batch_results);
}
} }
results results
} }
/// Partition tool calls into batches based on concurrency safety.
///
/// `concurrency_safe` tools are grouped together; each `exclusive` tool
/// runs in its own batch. This matches the approach used in Nanobot's
/// `_partition_tool_batches` and Zeroclaw's `parallel_tools` config.
fn partition_tool_batches(&self, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
let mut current: Vec<ToolCall> = Vec::new();
for tc in tool_calls {
let concurrency_safe = self
.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false);
if concurrency_safe {
current.push(tc.clone());
} else {
if !current.is_empty() {
batches.push(std::mem::take(&mut current));
}
batches.push(vec![tc.clone()]);
}
}
if !current.is_empty() {
batches.push(current);
}
batches
}
async fn execute_tool(&self, tool_call: &ToolCall) -> String { async fn execute_tool(&self, tool_call: &ToolCall) -> String {
let tool = match self.tools.get(&tool_call.name) { let tool = match self.tools.get(&tool_call.name) {
Some(t) => t, Some(t) => t,
@ -267,140 +251,3 @@ impl std::fmt::Display for AgentError {
} }
impl std::error::Error for AgentError {} impl std::error::Error for AgentError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::ToolCall;
use crate::tools::ToolRegistry;
use crate::tools::CalculatorTool;
use crate::tools::BashTool;
use crate::tools::FileReadTool;
use std::sync::Arc;
use serde_json::json;
fn make_tc(name: &str, args: serde_json::Value) -> ToolCall {
ToolCall {
id: format!("tc_{}", name),
name: name.to_string(),
arguments: args,
}
}
/// Verify that partition_tool_batches groups concurrency-safe tools together
/// and isolates exclusive tools, matching the nanobot/zeroclaw approach.
#[test]
fn test_partition_batches_mixes_safe_and_exclusive() {
let registry = Arc::new({
let mut r = ToolRegistry::new();
r.register(CalculatorTool::new()); // concurrency_safe = true
r.register(BashTool::new()); // concurrency_safe = false (exclusive)
r.register(FileReadTool::new()); // concurrency_safe = true
r
});
// agent_loop needs a provider to construct; test the partitioning logic directly
let tcs = vec![
make_tc("calculator", json!({})),
make_tc("bash", json!({"command": "ls"})),
make_tc("file_read", json!({"path": "/tmp/foo"})),
make_tc("calculator", json!({})),
];
// Expected:
// batch 1: calculator (safe, first run)
// batch 2: bash (exclusive, alone)
// batch 3: file_read, calculator (both safe, run together)
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 3);
assert_eq!(batches[0].len(), 1);
assert_eq!(batches[0][0].name, "calculator");
assert_eq!(batches[1].len(), 1);
assert_eq!(batches[1][0].name, "bash");
assert_eq!(batches[2].len(), 2);
assert_eq!(batches[2][0].name, "file_read");
assert_eq!(batches[2][1].name, "calculator");
}
/// All-safe tool calls should produce a single batch (parallel execution).
#[test]
fn test_partition_batches_all_safe_single_batch() {
let registry = Arc::new({
let mut r = ToolRegistry::new();
r.register(CalculatorTool::new());
r.register(FileReadTool::new());
r
});
let tcs = vec![
make_tc("calculator", json!({})),
make_tc("file_read", json!({"path": "/tmp/foo"})),
];
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].len(), 2);
}
/// All-exclusive tool calls should each get their own batch (sequential execution).
#[test]
fn test_partition_batches_all_exclusive_separate_batches() {
let registry = Arc::new({
let mut r = ToolRegistry::new();
r.register(BashTool::new());
r
});
let tcs = vec![
make_tc("bash", json!({"command": "ls"})),
make_tc("bash", json!({"command": "pwd"})),
];
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 1);
assert_eq!(batches[1].len(), 1);
}
/// Unknown tools (not in registry) default to non-concurrency-safe (single batch).
#[test]
fn test_partition_batches_unknown_tool_gets_own_batch() {
let registry = Arc::new(ToolRegistry::new());
let tcs = vec![
make_tc("calculator", json!({})),
make_tc("unknown_tool", json!({})),
];
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 2);
}
/// Expose partition logic for testing without needing a full AgentLoop.
fn partition_for_test(registry: &Arc<ToolRegistry>, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
let mut current: Vec<ToolCall> = Vec::new();
for tc in tool_calls {
let concurrency_safe = registry
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false);
if concurrency_safe {
current.push(tc.clone());
} else {
if !current.is_empty() {
batches.push(std::mem::take(&mut current));
}
batches.push(vec![tc.clone()]);
}
}
if !current.is_empty() {
batches.push(current);
}
batches
}
}

View File

@ -45,7 +45,7 @@ fn default_media_dir() -> String {
} }
fn default_reaction_emoji() -> String { fn default_reaction_emoji() -> String {
"Typing".to_string() "THUMBSUP".to_string()
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
@ -73,12 +73,6 @@ pub struct ModelConfig {
pub struct AgentConfig { pub struct AgentConfig {
pub provider: String, pub provider: String,
pub model: String, pub model: String,
#[serde(default = "default_max_iterations")]
pub max_iterations: u32,
}
fn default_max_iterations() -> u32 {
15
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
@ -138,7 +132,6 @@ pub struct LLMProviderConfig {
pub temperature: Option<f32>, pub temperature: Option<f32>,
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>, pub model_extra: HashMap<String, serde_json::Value>,
pub max_iterations: u32,
} }
fn get_default_config_path() -> PathBuf { fn get_default_config_path() -> PathBuf {
@ -198,7 +191,6 @@ impl Config {
temperature: model.temperature, temperature: model.temperature,
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
model_extra: model.extra.clone(), model_extra: model.extra.clone(),
max_iterations: agent.max_iterations,
}) })
} }
} }

View File

@ -92,10 +92,6 @@ impl Tool for CalculatorTool {
}) })
} }
fn read_only(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> { async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let function = match args.get("function").and_then(|v| v.as_str()) { let function = match args.get("function").and_then(|v| v.as_str()) {
Some(f) => f, Some(f) => f,