From 9834bd75cf2a1c3c06c65c08fcb1f83fdb484f3e Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Mon, 6 Apr 2026 23:43:45 +0800 Subject: [PATCH] feat: add calculator tool and integrate with agent loop - Introduced a new CalculatorTool for performing various arithmetic and statistical calculations. - Enhanced the AgentLoop to support tool execution, including handling tool calls in the process flow. - Updated ChatMessage structure to include optional fields for tool call identification and names. - Modified the Session and SessionManager to manage tool registrations and pass them to agents. - Updated the OpenAIProvider to serialize tool-related message fields. - Added a ToolRegistry for managing multiple tools and their definitions. - Implemented tests for the CalculatorTool to ensure functionality and correctness. --- Cargo.toml | 1 + src/agent/agent_loop.rs | 123 +++++- src/bus/message.rs | 21 + src/gateway/session.rs | 20 +- src/gateway/ws.rs | 2 +- src/lib.rs | 1 + src/providers/openai.rs | 17 +- src/providers/traits.rs | 4 + src/tools/calculator.rs | 824 ++++++++++++++++++++++++++++++++++++++++ src/tools/mod.rs | 7 + src/tools/registry.rs | 53 +++ src/tools/traits.rs | 16 + 12 files changed, 1079 insertions(+), 10 deletions(-) create mode 100644 src/tools/calculator.rs create mode 100644 src/tools/mod.rs create mode 100644 src/tools/registry.rs create mode 100644 src/tools/traits.rs diff --git a/Cargo.toml b/Cargo.toml index d6633e5..9b5e6a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ prost = "0.14" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } tracing-appender = "0.2" +anyhow = "1.0" diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 4555cc3..d6166fc 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,10 +1,13 @@ use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; -use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message}; +use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; +use crate::tools::ToolRegistry; +use std::sync::Arc; pub struct AgentLoop { provider: Box, history: Vec, + tools: Arc, } impl AgentLoop { @@ -15,9 +18,25 @@ impl AgentLoop { Ok(Self { provider, history: Vec::new(), + tools: Arc::new(ToolRegistry::new()), }) } + pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { + let provider = create_provider(provider_config) + .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + + Ok(Self { + provider, + history: Vec::new(), + tools, + }) + } + + pub fn tools(&self) -> &Arc { + &self.tools + } + pub async fn process(&mut self, user_message: ChatMessage) -> Result { self.history.push(user_message.clone()); @@ -26,16 +45,24 @@ impl AgentLoop { .map(|m| Message { role: m.role.clone(), content: m.content.clone(), + tool_call_id: m.tool_call_id.clone(), + name: m.tool_name.clone(), }) .collect(); tracing::debug!(history_len = self.history.len(), "Sending request to LLM"); + let tools = if self.tools.has_tools() { + Some(self.tools.get_definitions()) + } else { + None + }; + let request = ChatCompletionRequest { messages, temperature: None, max_tokens: None, - tools: None, + tools, }; let response = (*self.provider).chat(request).await @@ -44,7 +71,26 @@ impl AgentLoop { AgentError::LlmError(e.to_string()) })?; - tracing::debug!(response_len = response.content.len(), "LLM response received"); + tracing::debug!(response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received"); + + if !response.tool_calls.is_empty() { + tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools"); + let assistant_message = ChatMessage::assistant(response.content.clone()); + self.history.push(assistant_message.clone()); + + let tool_results = self.execute_tools(&response.tool_calls).await; + + for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { + let tool_message = ChatMessage::tool( + tool_call.id.clone(), + tool_call.name.clone(), + result.clone(), + ); + self.history.push(tool_message); + } + + return self.continue_with_tool_results(response.content).await; + } let assistant_message = ChatMessage::assistant(response.content); self.history.push(assistant_message.clone()); @@ -52,6 +98,77 @@ impl AgentLoop { Ok(assistant_message) } + async fn continue_with_tool_results(&mut self, _original_content: String) -> Result { + let messages: Vec = self.history + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + tool_call_id: m.tool_call_id.clone(), + name: m.tool_name.clone(), + }) + .collect(); + + let tools = if self.tools.has_tools() { + Some(self.tools.get_definitions()) + } else { + None + }; + + let request = ChatCompletionRequest { + messages, + temperature: None, + max_tokens: None, + tools, + }; + + let response = (*self.provider).chat(request).await + .map_err(|e| { + tracing::error!(error = %e, "LLM continuation request failed"); + AgentError::LlmError(e.to_string()) + })?; + + let assistant_message = ChatMessage::assistant(response.content); + self.history.push(assistant_message.clone()); + + Ok(assistant_message) + } + + async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec { + let mut results = Vec::with_capacity(tool_calls.len()); + + for tool_call in tool_calls { + let result = self.execute_tool(tool_call).await; + results.push(result); + } + + results + } + + async fn execute_tool(&self, tool_call: &ToolCall) -> String { + let tool = match self.tools.get(&tool_call.name) { + Some(t) => t, + None => { + tracing::warn!(tool = %tool_call.name, "Tool not found"); + return format!("Error: Tool '{}' not found", tool_call.name); + } + }; + + match tool.execute(tool_call.arguments.clone()).await { + Ok(result) => { + if result.success { + result.output + } else { + format!("Error: {}", result.error.unwrap_or_default()) + } + } + Err(e) => { + tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed"); + format!("Error: {}", e) + } + } + } + pub fn clear_history(&mut self) { let len = self.history.len(); self.history.clear(); diff --git a/src/bus/message.rs b/src/bus/message.rs index 8700c5f..71f3294 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -6,6 +6,10 @@ pub struct ChatMessage { pub role: String, pub content: String, pub timestamp: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option, } impl ChatMessage { @@ -15,6 +19,8 @@ impl ChatMessage { role: "user".to_string(), content: content.into(), timestamp: current_timestamp(), + tool_call_id: None, + tool_name: None, } } @@ -24,6 +30,8 @@ impl ChatMessage { role: "assistant".to_string(), content: content.into(), timestamp: current_timestamp(), + tool_call_id: None, + tool_name: None, } } @@ -33,6 +41,19 @@ impl ChatMessage { role: "system".to_string(), content: content.into(), timestamp: current_timestamp(), + tool_call_id: None, + tool_name: None, + } + } + + pub fn tool(tool_call_id: impl Into, tool_name: impl Into, content: impl Into) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + role: "tool".to_string(), + content: content.into(), + timestamp: current_timestamp(), + tool_call_id: Some(tool_call_id.into()), + tool_name: Some(tool_name.into()), } } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 19a7ffd..82c5e11 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -7,6 +7,7 @@ use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError}; use crate::protocol::WsOutbound; +use crate::tools::{CalculatorTool, ToolRegistry}; /// Session 按 channel 隔离,每个 channel 一个 Session pub struct Session { @@ -16,6 +17,7 @@ pub struct Session { chat_agents: HashMap>>, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, + tools: Arc, } impl Session { @@ -23,6 +25,7 @@ impl Session { channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, + tools: Arc, ) -> Result { Ok(Self { id: Uuid::new_v4(), @@ -30,6 +33,7 @@ impl Session { chat_agents: HashMap::new(), user_tx, provider_config, + tools, }) } @@ -40,7 +44,7 @@ impl Session { return Ok(agent.clone()); } tracing::debug!(chat_id = %chat_id, "Creating new agent for chat"); - let agent = AgentLoop::new(self.provider_config.clone())?; + let agent = AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())?; let arc = Arc::new(Mutex::new(agent)); self.chat_agents.insert(chat_id.to_string(), arc.clone()); Ok(arc) @@ -76,6 +80,7 @@ impl Session { pub struct SessionManager { inner: Arc>, provider_config: LLMProviderConfig, + tools: Arc, } struct SessionManagerInner { @@ -84,6 +89,12 @@ struct SessionManagerInner { session_ttl: Duration, } +fn default_tools() -> ToolRegistry { + let mut registry = ToolRegistry::new(); + registry.register(CalculatorTool::new()); + registry +} + impl SessionManager { pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self { Self { @@ -93,9 +104,14 @@ impl SessionManager { session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), provider_config, + tools: Arc::new(default_tools()), } } + pub fn tools(&self) -> Arc { + self.tools.clone() + } + /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { let mut inner = self.inner.lock().await; @@ -119,7 +135,7 @@ impl SessionManager { // 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS) let (user_tx, _rx) = mpsc::channel::(100); - let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx).await?; + let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone()).await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(channel_name.to_string(), arc.clone()); diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index bdab4d1..a10d0b6 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -29,7 +29,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { let channel_name = format!("cli-{}", uuid::Uuid::new_v4()); // 创建 CLI session - let session = match Session::new(channel_name.clone(), provider_config, sender).await { + let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await { Ok(s) => Arc::new(Mutex::new(s)), Err(e) => { tracing::error!(error = %e, "Failed to create session"); diff --git a/src/lib.rs b/src/lib.rs index 5e8f880..99bb45e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,3 +8,4 @@ pub mod client; pub mod protocol; pub mod channels; pub mod logging; +pub mod tools; diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 2367681..1854cb3 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -104,10 +104,19 @@ impl LLMProvider for OpenAIProvider { let mut body = json!({ "model": self.model_id, "messages": request.messages.iter().map(|m| { - json!({ - "role": m.role, - "content": m.content - }) + if m.role == "tool" { + json!({ + "role": m.role, + "content": m.content, + "tool_call_id": m.tool_call_id, + "name": m.name, + }) + } else { + json!({ + "role": m.role, + "content": m.content + }) + } }).collect::>(), "temperature": request.temperature.or(self.temperature).unwrap_or(0.7), "max_tokens": request.max_tokens.or(self.max_tokens), diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 843aabb..7aa4a51 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -5,6 +5,10 @@ use serde::{Deserialize, Serialize}; pub struct Message { pub role: String, pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs new file mode 100644 index 0000000..de29b73 --- /dev/null +++ b/src/tools/calculator.rs @@ -0,0 +1,824 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; + +pub struct CalculatorTool; + +impl CalculatorTool { + pub fn new() -> Self { + Self + } +} + +impl Default for CalculatorTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Tool for CalculatorTool { + fn name(&self) -> &str { + "calculator" + } + + fn description(&self) -> &str { + "Perform arithmetic and statistical calculations. Supports 25 functions: \ + add, subtract, divide, multiply, pow, sqrt, abs, modulo, round, \ + log, ln, exp, factorial, sum, average, median, mode, min, max, \ + range, variance, stdev, percentile, count, percentage_change, clamp. \ + Use this tool whenever you need to compute a numeric result instead of guessing." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "function": { + "type": "string", + "description": "Calculation to perform. \ + Arithmetic: add(values), subtract(values), divide(values), multiply(values), pow(a,b), sqrt(x), abs(x), modulo(a,b), round(x,decimals). \ + Logarithmic/exponential: log(x,base?), ln(x), exp(x), factorial(x). \ + Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \ + Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \ + Utility: percentage_change(a,b), clamp(x,min_val,max_val).", + "enum": [ + "add", "subtract", "divide", "multiply", "pow", "sqrt", + "abs", "modulo", "round", "log", "ln", "exp", "factorial", + "sum", "average", "median", "mode", "min", "max", "range", + "variance", "stdev", "percentile", "count", + "percentage_change", "clamp" + ] + }, + "values": { + "type": "array", + "items": { "type": "number" }, + "description": "Array of numeric values. Required for: add, subtract, divide, multiply, sum, average, median, mode, min, max, range, variance, stdev, percentile, count." + }, + "a": { + "type": "number", + "description": "First operand. Required for: pow, modulo, percentage_change." + }, + "b": { + "type": "number", + "description": "Second operand. Required for: pow, modulo, percentage_change." + }, + "x": { + "type": "number", + "description": "Input number. Required for: sqrt, abs, exp, ln, log, factorial." + }, + "base": { + "type": "number", + "description": "Logarithm base (default: 10). Optional for: log." + }, + "decimals": { + "type": "integer", + "description": "Number of decimal places for rounding. Required for: round." + }, + "p": { + "type": "integer", + "description": "Percentile rank (0-100). Required for: percentile." + }, + "min_val": { + "type": "number", + "description": "Minimum bound. Required for: clamp." + }, + "max_val": { + "type": "number", + "description": "Maximum bound. Required for: clamp." + } + }, + "required": ["function"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let function = match args.get("function").and_then(|v| v.as_str()) { + Some(f) => f, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: function".to_string()), + }); + } + }; + + let result = match function { + "add" => calc_add(&args), + "subtract" => calc_subtract(&args), + "divide" => calc_divide(&args), + "multiply" => calc_multiply(&args), + "pow" => calc_pow(&args), + "sqrt" => calc_sqrt(&args), + "abs" => calc_abs(&args), + "modulo" => calc_modulo(&args), + "round" => calc_round(&args), + "log" => calc_log(&args), + "ln" => calc_ln(&args), + "exp" => calc_exp(&args), + "factorial" => calc_factorial(&args), + "sum" => calc_sum(&args), + "average" => calc_average(&args), + "median" => calc_median(&args), + "mode" => calc_mode(&args), + "min" => calc_min(&args), + "max" => calc_max(&args), + "range" => calc_range(&args), + "variance" => calc_variance(&args), + "stdev" => calc_stdev(&args), + "percentile" => calc_percentile(&args), + "count" => calc_count(&args), + "percentage_change" => calc_percentage_change(&args), + "clamp" => calc_clamp(&args), + other => Err(format!("Unknown function: {other}")), + }; + + match result { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(err) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(err), + }), + } + } +} + +fn extract_f64(args: &serde_json::Value, key: &str, name: &str) -> Result { + args.get(key) + .and_then(|v| v.as_f64()) + .ok_or_else(|| format!("Missing required parameter: {name}")) +} + +fn extract_i64(args: &serde_json::Value, key: &str, name: &str) -> Result { + args.get(key) + .and_then(|v| v.as_i64()) + .ok_or_else(|| format!("Missing required parameter: {name}")) +} + +fn extract_values(args: &serde_json::Value, min_len: usize) -> Result, String> { + let values = args + .get("values") + .and_then(|v| v.as_array()) + .ok_or_else(|| "Missing required parameter: values (array of numbers)".to_string())?; + if values.len() < min_len { + return Err(format!( + "Expected at least {min_len} value(s), got {}", + values.len() + )); + } + let mut nums = Vec::with_capacity(values.len()); + for (i, v) in values.iter().enumerate() { + match v.as_f64() { + Some(n) => nums.push(n), + None => return Err(format!("values[{i}] is not a valid number")), + } + } + Ok(nums) +} + +fn format_num(n: f64) -> String { + if n == n.floor() && n.abs() < 1e15 { + #[allow(clippy::cast_possible_truncation)] + let rounded = n.round() as i128; + format!("{rounded}") + } else { + format!("{n}") + } +} + +fn calc_add(args: &serde_json::Value) -> Result { + let values = extract_values(args, 2)?; + Ok(format_num(values.iter().sum())) +} + +fn calc_subtract(args: &serde_json::Value) -> Result { + let values = extract_values(args, 2)?; + let mut iter = values.iter(); + let mut result = *iter.next().unwrap(); + for v in iter { + result -= v; + } + Ok(format_num(result)) +} + +fn calc_divide(args: &serde_json::Value) -> Result { + let values = extract_values(args, 2)?; + let mut iter = values.iter(); + let mut result = *iter.next().unwrap(); + for v in iter { + if *v == 0.0 { + return Err("Division by zero".to_string()); + } + result /= v; + } + Ok(format_num(result)) +} + +fn calc_multiply(args: &serde_json::Value) -> Result { + let values = extract_values(args, 2)?; + let mut result = 1.0; + for v in &values { + result *= v; + } + Ok(format_num(result)) +} + +fn calc_pow(args: &serde_json::Value) -> Result { + let base = extract_f64(args, "a", "a (base)")?; + let exp = extract_f64(args, "b", "b (exponent)")?; + Ok(format_num(base.powf(exp))) +} + +fn calc_sqrt(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + if x < 0.0 { + return Err("Cannot compute square root of a negative number".to_string()); + } + Ok(format_num(x.sqrt())) +} + +fn calc_abs(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + Ok(format_num(x.abs())) +} + +fn calc_modulo(args: &serde_json::Value) -> Result { + let a = extract_f64(args, "a", "a")?; + let b = extract_f64(args, "b", "b")?; + if b == 0.0 { + return Err("Modulo by zero".to_string()); + } + Ok(format_num(a % b)) +} + +fn calc_round(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + let decimals = extract_i64(args, "decimals", "decimals")?; + if decimals < 0 { + return Err("decimals must be non-negative".to_string()); + } + let multiplier = 10_f64.powi(i32::try_from(decimals).unwrap_or(i32::MAX)); + Ok(format_num((x * multiplier).round() / multiplier)) +} + +fn calc_log(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + if x <= 0.0 { + return Err("Logarithm requires a positive number".to_string()); + } + let base = args.get("base").and_then(|v| v.as_f64()).unwrap_or(10.0); + if base <= 0.0 || base == 1.0 { + return Err("Logarithm base must be positive and not equal to 1".to_string()); + } + Ok(format_num(x.log(base))) +} + +fn calc_ln(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + if x <= 0.0 { + return Err("Natural logarithm requires a positive number".to_string()); + } + Ok(format_num(x.ln())) +} + +fn calc_exp(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + Ok(format_num(x.exp())) +} + +fn calc_factorial(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + if x < 0.0 || x != x.floor() { + return Err("Factorial requires a non-negative integer".to_string()); + } + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + let n = x.round() as u128; + if n > 170 { + return Err("Factorial result exceeds f64 range (max input: 170)".to_string()); + } + let mut result: u128 = 1; + for i in 2..=n { + result *= i; + } + Ok(result.to_string()) +} + +fn calc_sum(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + Ok(format_num(values.iter().sum())) +} + +fn calc_average(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + if values.is_empty() { + return Err("Cannot compute average of an empty array".to_string()); + } + Ok(format_num(values.iter().sum::() / values.len() as f64)) +} + +fn calc_median(args: &serde_json::Value) -> Result { + let mut values = extract_values(args, 1)?; + if values.is_empty() { + return Err("Cannot compute median of an empty array".to_string()); + } + values.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let len = values.len(); + if len % 2 == 0 { + Ok(format_num(f64::midpoint( + values[len / 2 - 1], + values[len / 2], + ))) + } else { + Ok(format_num(values[len / 2])) + } +} + +fn calc_mode(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + if values.is_empty() { + return Err("Cannot compute mode of an empty array".to_string()); + } + let mut freq: std::collections::HashMap = std::collections::HashMap::new(); + for &v in &values { + let key = v.to_bits(); + *freq.entry(key).or_insert(0) += 1; + } + let max_freq = *freq.values().max().unwrap(); + let mut seen = std::collections::HashSet::new(); + let mut modes = Vec::new(); + for &v in &values { + let key = v.to_bits(); + if freq[&key] == max_freq && seen.insert(key) { + modes.push(v); + } + } + if modes.len() == 1 { + Ok(format_num(modes[0])) + } else { + let formatted: Vec = modes.iter().map(|v| format_num(*v)).collect(); + Ok(format!("Modes: {}", formatted.join(", "))) + } +} + +fn calc_min(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + let Some(min_val) = values.iter().copied().reduce(f64::min) else { + return Err("Cannot compute min of an empty array".to_string()); + }; + Ok(format_num(min_val)) +} + +fn calc_max(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + let Some(max_val) = values.iter().copied().reduce(f64::max) else { + return Err("Cannot compute max of an empty array".to_string()); + }; + Ok(format_num(max_val)) +} + +fn calc_range(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + if values.is_empty() { + return Err("Cannot compute range of an empty array".to_string()); + } + let min_val = values.iter().copied().fold(f64::INFINITY, f64::min); + let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + Ok(format_num(max_val - min_val)) +} + +fn calc_variance(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + if values.len() < 2 { + return Err("Variance requires at least 2 values".to_string()); + } + let mean = values.iter().sum::() / values.len() as f64; + let variance = values.iter().map(|v| (v - mean).powi(2)).sum::() / values.len() as f64; + Ok(format_num(variance)) +} + +fn calc_stdev(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + if values.len() < 2 { + return Err("Standard deviation requires at least 2 values".to_string()); + } + let mean = values.iter().sum::() / values.len() as f64; + let variance = values.iter().map(|v| (v - mean).powi(2)).sum::() / values.len() as f64; + Ok(format_num(variance.sqrt())) +} + +fn calc_percentile(args: &serde_json::Value) -> Result { + let mut values = extract_values(args, 1)?; + if values.is_empty() { + return Err("Cannot compute percentile of an empty array".to_string()); + } + let p = extract_i64(args, "p", "p (percentile rank 0-100)")?; + if !(0..=100).contains(&p) { + return Err("Percentile rank must be between 0 and 100".to_string()); + } + values.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + let idx_f = p as f64 / 100.0 * (values.len() - 1) as f64; + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + let index = idx_f.round().clamp(0.0, (values.len() - 1) as f64) as usize; + Ok(format_num(values[index])) +} + +fn calc_count(args: &serde_json::Value) -> Result { + let values = extract_values(args, 1)?; + Ok(values.len().to_string()) +} + +fn calc_percentage_change(args: &serde_json::Value) -> Result { + let old = extract_f64(args, "a", "a (old value)")?; + let new = extract_f64(args, "b", "b (new value)")?; + if old == 0.0 { + return Err("Cannot compute percentage change from zero".to_string()); + } + Ok(format_num((new - old) / old.abs() * 100.0)) +} + +fn calc_clamp(args: &serde_json::Value) -> Result { + let x = extract_f64(args, "x", "x")?; + let min_val = extract_f64(args, "min_val", "min_val")?; + let max_val = extract_f64(args, "max_val", "max_val")?; + if min_val > max_val { + return Err("min_val must be less than or equal to max_val".to_string()); + } + Ok(format_num(x.clamp(min_val, max_val))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_add() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "add", "values": [1.0, 2.0, 3.5]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "6.5"); + } + + #[tokio::test] + async fn test_subtract() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "subtract", "values": [10.0, 3.0, 1.5]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "5.5"); + } + + #[tokio::test] + async fn test_divide() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "divide", "values": [100.0, 4.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "25"); + } + + #[tokio::test] + async fn test_divide_by_zero() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "divide", "values": [10.0, 0.0]})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("zero")); + } + + #[tokio::test] + async fn test_multiply() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "multiply", "values": [3.0, 4.0, 5.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "60"); + } + + #[tokio::test] + async fn test_pow() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "pow", "a": 2.0, "b": 10.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "1024"); + } + + #[tokio::test] + async fn test_sqrt() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "sqrt", "x": 144.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "12"); + } + + #[tokio::test] + async fn test_sqrt_negative() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "sqrt", "x": -4.0})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn test_abs() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "abs", "x": -42.5})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "42.5"); + } + + #[tokio::test] + async fn test_modulo() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "modulo", "a": 17.0, "b": 5.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2"); + } + + #[tokio::test] + async fn test_round() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "round", "x": 2.715, "decimals": 2})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2.72"); + } + + #[tokio::test] + async fn test_log_base10() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "log", "x": 100.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2"); + } + + #[tokio::test] + async fn test_log_custom_base() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "log", "x": 8.0, "base": 2.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "3"); + } + + #[tokio::test] + async fn test_ln() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "ln", "x": 1.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "0"); + } + + #[tokio::test] + async fn test_exp() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "exp", "x": 0.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "1"); + } + + #[tokio::test] + async fn test_factorial() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "factorial", "x": 5.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "120"); + } + + #[tokio::test] + async fn test_average() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "average", "values": [10.0, 20.0, 30.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "20"); + } + + #[tokio::test] + async fn test_median_odd() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "median", "values": [3.0, 1.0, 2.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2"); + } + + #[tokio::test] + async fn test_median_even() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "median", "values": [4.0, 1.0, 3.0, 2.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2.5"); + } + + #[tokio::test] + async fn test_mode() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "mode", "values": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "3"); + } + + #[tokio::test] + async fn test_min() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "min", "values": [5.0, 2.0, 8.0, 1.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "1"); + } + + #[tokio::test] + async fn test_max() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "max", "values": [5.0, 2.0, 8.0, 1.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "8"); + } + + #[tokio::test] + async fn test_range() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "range", "values": [1.0, 5.0, 10.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "9"); + } + + #[tokio::test] + async fn test_variance() { + let tool = CalculatorTool::new(); + let result = tool + .execute( + json!({"function": "variance", "values": [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]}), + ) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "4"); + } + + #[tokio::test] + async fn test_stdev() { + let tool = CalculatorTool::new(); + let result = tool + .execute( + json!({"function": "stdev", "values": [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]}), + ) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2"); + } + + #[tokio::test] + async fn test_percentile_50() { + let tool = CalculatorTool::new(); + let result = tool + .execute( + json!({"function": "percentile", "values": [1.0, 2.0, 3.0, 4.0, 5.0], "p": 50}), + ) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "3"); + } + + #[tokio::test] + async fn test_count() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "count", "values": [1.0, 2.0, 3.0, 4.0, 5.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "5"); + } + + #[tokio::test] + async fn test_percentage_change() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "percentage_change", "a": 50.0, "b": 75.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "50"); + } + + #[tokio::test] + async fn test_clamp_within_range() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "clamp", "x": 5.0, "min_val": 1.0, "max_val": 10.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "5"); + } + + #[tokio::test] + async fn test_clamp_below_min() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "clamp", "x": -5.0, "min_val": 0.0, "max_val": 10.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "0"); + } + + #[tokio::test] + async fn test_clamp_above_max() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "clamp", "x": 15.0, "min_val": 0.0, "max_val": 10.0})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "10"); + } + + #[tokio::test] + async fn test_unknown_function() { + let tool = CalculatorTool::new(); + let result = tool.execute(json!({"function": "unknown"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Unknown function")); + } + + #[tokio::test] + async fn test_sum() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "sum", "values": [1.0, 2.0, 3.0, 4.0, 5.0]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "15"); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs new file mode 100644 index 0000000..07028d5 --- /dev/null +++ b/src/tools/mod.rs @@ -0,0 +1,7 @@ +pub mod calculator; +pub mod registry; +pub mod traits; + +pub use calculator::CalculatorTool; +pub use registry::ToolRegistry; +pub use traits::{Tool, ToolResult}; diff --git a/src/tools/registry.rs b/src/tools/registry.rs new file mode 100644 index 0000000..fb88c87 --- /dev/null +++ b/src/tools/registry.rs @@ -0,0 +1,53 @@ +use std::collections::HashMap; + +use crate::providers::{Tool, ToolFunction}; + +use super::traits::Tool as ToolTrait; + +pub struct ToolRegistry { + tools: HashMap>, +} + +impl ToolRegistry { + pub fn new() -> Self { + Self { + tools: HashMap::new(), + } + } + + pub fn register(&mut self, tool: T) { + self.tools.insert(tool.name().to_string(), Box::new(tool)); + } + + pub fn get(&self, name: &str) -> Option<&Box> { + self.tools.get(name) + } + + pub fn get_definitions(&self) -> Vec { + self.tools + .values() + .map(|tool| Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: tool.name().to_string(), + description: tool.description().to_string(), + parameters: tool.parameters_schema(), + }, + }) + .collect() + } + + pub fn has_tools(&self) -> bool { + !self.tools.is_empty() + } + + pub fn tool_names(&self) -> Vec { + self.tools.keys().cloned().collect() + } +} + +impl Default for ToolRegistry { + fn default() -> Self { + Self::new() + } +} diff --git a/src/tools/traits.rs b/src/tools/traits.rs new file mode 100644 index 0000000..10f0140 --- /dev/null +++ b/src/tools/traits.rs @@ -0,0 +1,16 @@ +use async_trait::async_trait; + +#[derive(Debug, Clone)] +pub struct ToolResult { + pub success: bool, + pub output: String, + pub error: Option, +} + +#[async_trait] +pub trait Tool: Send + Sync + 'static { + fn name(&self) -> &str; + fn description(&self) -> &str; + fn parameters_schema(&self) -> serde_json::Value; + async fn execute(&self, args: serde_json::Value) -> anyhow::Result; +}