From c971bc3639791fe7e691118401e2514252cc8e30 Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Sun, 12 Apr 2026 18:38:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=B8=8A=E4=B8=8B?= =?UTF-8?q?=E6=96=87=E5=8E=8B=E7=BC=A9=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E6=B6=88=E6=81=AF=E5=8E=86=E5=8F=B2=E7=AE=A1=E7=90=86?= =?UTF-8?q?=E5=92=8C=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +- Cargo.toml | 1 + src/agent/agent_loop.rs | 7 + src/agent/context_compressor.rs | 372 ++++++++++++++++++++++++++++++++ src/agent/mod.rs | 2 + src/config/mod.rs | 8 + src/gateway/session.rs | 21 +- src/gateway/ws.rs | 12 ++ src/tools/calculator.rs | 335 +++++++++------------------- 9 files changed, 529 insertions(+), 234 deletions(-) create mode 100644 src/agent/context_compressor.rs diff --git a/.gitignore b/.gitignore index f053f61..16751f8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ /target reference/** .env -*.env \ No newline at end of file +*.env +AGENTS.md +CLAUDE.md +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 0e73bba..8dd2225 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,4 @@ anyhow = "1.0" mime_guess = "2.0" base64 = "0.22" tempfile = "3" +meval = "0.2" diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 60be481..8207a71 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -330,6 +330,13 @@ impl AgentLoop { let tool_results = self.execute_tools(&response.tool_calls).await; for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { + // Log function call with name and arguments + let args_str = match &tool_call.arguments { + serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), + other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), + }; + tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); + // Truncate tool result if too large let truncated_output = truncate_tool_result(&result.output); diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs new file mode 100644 index 0000000..d53224c --- /dev/null +++ b/src/agent/context_compressor.rs @@ -0,0 +1,372 @@ +use crate::bus::ChatMessage; +use crate::config::LLMProviderConfig; +use crate::providers::{create_provider, ChatCompletionRequest, Message}; + +use crate::agent::AgentError; + +/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin. +pub fn estimate_tokens(messages: &[ChatMessage]) -> usize { + let raw: usize = messages + .iter() + .map(|m| m.content.len().div_ceil(4) + 4) + .sum(); + (raw as f64 * 1.2) as usize +} + +/// Configuration for context compression. +#[derive(Debug, Clone)] +pub struct ContextCompressionConfig { + /// Protect first N messages (system prompt, etc.) + pub protect_first_n: usize, + /// Protect last N messages (recent context) + pub protect_last_n: usize, + /// Maximum compression passes + pub max_passes: u32, + /// Maximum characters in summary + pub summary_max_chars: usize, + /// Characters to keep when trimming tool results + pub tool_result_trim_chars: usize, +} + +impl Default for ContextCompressionConfig { + fn default() -> Self { + Self { + protect_first_n: 1, + protect_last_n: 4, + max_passes: 3, + summary_max_chars: 4000, + tool_result_trim_chars: 2000, + } + } +} + +/// Context compressor that reduces message history when it exceeds token limits. +pub struct ContextCompressor { + config: ContextCompressionConfig, + context_window: usize, + /// Threshold ratio to trigger compression (50% of context window) + threshold_ratio: f64, +} + +impl ContextCompressor { + /// Create a new compressor with the given context window size. + pub fn new(context_window: usize) -> Self { + Self { + config: ContextCompressionConfig::default(), + context_window, + threshold_ratio: 0.5, + } + } + + /// Create with custom configuration. + pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self { + Self { + config, + context_window, + threshold_ratio: 0.5, + } + } + + /// Get the compression threshold in tokens. + fn threshold(&self) -> usize { + (self.context_window as f64 * self.threshold_ratio) as usize + } + + /// Fast-path: trim oversized tool results without LLM call. + /// Returns the number of messages modified. + fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize { + let limit = self.config.tool_result_trim_chars; + let mut modified = 0; + + for msg in messages.iter_mut() { + if msg.role == "tool" && msg.content.len() > limit { + let removed = msg.content.len() - limit; + msg.content = format!( + "{}...\n\n[Output truncated - {} characters removed]", + &msg.content[..limit.min(msg.content.len())], + removed + ); + modified += 1; + } + } + + modified + } + + /// Main entry point - compresses history if over threshold. + pub async fn compress_if_needed( + &self, + history: Vec, + provider_config: &LLMProviderConfig, + ) -> Result, AgentError> { + // Check if compression is needed + let tokens = estimate_tokens(&history); + if tokens <= self.threshold() { + #[cfg(debug_assertions)] + tracing::info!( + tokens = tokens, + threshold = self.threshold(), + msg_count = history.len(), + "Context compression not needed" + ); + return Ok(history); + } + + tracing::info!( + tokens = tokens, + threshold = self.threshold(), + msg_count = history.len(), + "Starting context compression" + ); + + // Fast trim pass first + let trimmed = self.fast_trim_tool_results(&mut history.clone()); + if trimmed > 0 { + let tokens_after = estimate_tokens(&history); + #[cfg(debug_assertions)] + tracing::debug!( + trimmed_messages = trimmed, + tokens_after = tokens_after, + "Fast trim completed" + ); + if tokens_after <= self.threshold() { + return Ok(history); + } + } + + // LLM summarization pass + let mut current_history = history; + for pass in 0..self.config.max_passes { + let tokens = estimate_tokens(¤t_history); + if tokens <= self.threshold() { + break; + } + + #[cfg(debug_assertions)] + tracing::debug!( + pass = pass + 1, + tokens = tokens, + "Compression pass" + ); + + match self.compress_once(¤t_history, provider_config).await { + Ok(Some(compressed)) => { + current_history = compressed; + } + Ok(None) => { + // No more compressible content + break; + } + Err(e) => { + tracing::warn!(error = %e, "Compression pass failed, using current history"); + break; + } + } + } + + tracing::info!( + final_tokens = estimate_tokens(¤t_history), + final_msg_count = current_history.len(), + "Context compression completed" + ); + + Ok(current_history) + } + + /// Single compression pass - summarize middle messages between user turns. + /// Returns Some(compressed) if compression happened, None if nothing to compress. + async fn compress_once( + &self, + history: &[ChatMessage], + provider_config: &LLMProviderConfig, + ) -> Result>, AgentError> { + if history.len() <= self.config.protect_first_n + self.config.protect_last_n { + return Ok(None); + } + + // Find user message indices (excluding protected first messages) + let user_indices: Vec = history + .iter() + .enumerate() + .skip(self.config.protect_first_n) + .filter(|(_, m)| m.role == "user") + .map(|(i, _)| i) + .collect(); + + // Need at least one user message and content between users to compress + if user_indices.len() < 2 { + return Ok(None); + } + + // Build segments: user -> (assistant turns) -> next user + // We'll summarize the assistant turns between consecutive user messages + let mut new_messages = history[..=user_indices[0]].to_vec(); + + for i in 0..user_indices.len() - 1 { + let user_idx = user_indices[i]; + let next_user_idx = user_indices[i + 1]; + + new_messages.push(history[user_idx].clone()); + + // Check if there's assistant content between these two user messages + let between_start = user_idx + 1; + let between_end = next_user_idx; + + if between_start < between_end { + let between = &history[between_start..between_end]; + let summary = self.summarize_segment(between, provider_config).await?; + + // Add summary as a special user message + new_messages.push(ChatMessage::user(format!( + "[Context Summary]\n\n{}", + summary + ))); + } + } + + // Add last user and everything after (protected) + let last_user_idx = user_indices[user_indices.len() - 1]; + if last_user_idx < history.len() - 1 { + // Add everything from last user onwards (protected) + for i in last_user_idx..history.len() { + new_messages.push(history[i].clone()); + } + } + + // If nothing changed, return None + if new_messages.len() == history.len() { + return Ok(None); + } + + Ok(Some(new_messages)) + } + + /// Summarize a segment of messages using LLM. + async fn summarize_segment( + &self, + messages: &[ChatMessage], + provider_config: &LLMProviderConfig, + ) -> Result { + if messages.is_empty() { + return Ok(String::new()); + } + + // Build transcript for summarization + let transcript = messages + .iter() + .map(|m| { + let role = match m.role.as_str() { + "assistant" => "Assistant", + "tool" => "Tool", + _ => m.role.as_str(), + }; + let name = m.tool_name + .as_ref() + .map(|n| format!(" ({})", n)) + .unwrap_or_default(); + format!("{}: {}{}", role, m.content, name) + }) + .collect::>() + .join("\n\n"); + + // Truncate transcript if too long + let transcript = if transcript.len() > self.config.summary_max_chars { + format!( + "{}...\n\n[Transcript truncated - {} characters removed]", + &transcript[..self.config.summary_max_chars], + transcript.len() - self.config.summary_max_chars + ) + } else { + transcript + }; + + let prompt = format!( + r#"You are a conversation compaction engine. Summarize the following conversation segment. + +PRESERVE: +- All identifiers (UUIDs, hashes, file paths, URLs) +- Actions taken (tool calls, file operations, commands) +- Key information obtained (results, data, errors) +- Decisions and user preferences +- Current task status + +OMIT: +- Verbose tool output (keep key results only) +- Repeated greetings or filler + +Be concise, aim for {} characters or less. + +--- + +{} + +"#, + self.config.summary_max_chars, transcript + ); + + // Create provider and call LLM + let provider = create_provider(provider_config.clone()) + .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + + let request = ChatCompletionRequest { + messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], + temperature: Some(0.3), + max_tokens: Some(1000), + tools: None, + }; + + match provider.chat(request).await { + Ok(response) => Ok(response.content), + Err(e) => { + // Fallback: just truncate the transcript + tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript"); + Ok(transcript[..transcript.len().min(2000)].to_string()) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_estimate_tokens() { + let messages = vec![ + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi there!"), + ChatMessage::user("How are you?"), + ]; + + let tokens = estimate_tokens(&messages); + // "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6 + // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 + // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 + // raw = 19, with 1.2x = ~23 + assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); + } + + #[test] + fn test_fast_trim() { + let config = ContextCompressionConfig { + tool_result_trim_chars: 50, + ..Default::default() + }; + let compressor = ContextCompressor::with_config(100_000, config); + + let mut messages = vec![ + ChatMessage::user("Hello"), + ChatMessage::tool("call1", "bash", &"x".repeat(200)), + ]; + + let modified = compressor.fast_trim_tool_results(&mut messages); + assert_eq!(modified, 1); + assert!(messages[1].content.len() < 100); + } + + #[test] + fn test_threshold() { + let compressor = ContextCompressor::new(128_000); + assert_eq!(compressor.threshold(), 64_000); + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 7c84e22..3b2c508 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,3 +1,5 @@ pub mod agent_loop; +pub mod context_compressor; pub use agent_loop::{AgentLoop, AgentError}; +pub use context_compressor::ContextCompressor; diff --git a/src/config/mod.rs b/src/config/mod.rs index 0f66974..7268084 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -75,12 +75,18 @@ pub struct AgentConfig { pub model: String, #[serde(default = "default_max_tool_iterations")] pub max_tool_iterations: usize, + #[serde(default = "default_token_limit")] + pub token_limit: usize, } fn default_max_tool_iterations() -> usize { 20 } +fn default_token_limit() -> usize { + 128_000 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GatewayConfig { #[serde(default = "default_gateway_host")] @@ -139,6 +145,7 @@ pub struct LLMProviderConfig { pub max_tokens: Option, pub model_extra: HashMap, pub max_tool_iterations: usize, + pub token_limit: usize, } fn get_default_config_path() -> PathBuf { @@ -199,6 +206,7 @@ impl Config { max_tokens: model.max_tokens, model_extra: model.extra.clone(), max_tool_iterations: agent.max_tool_iterations, + token_limit: agent.token_limit, }) } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 431c4a4..42ca962 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -5,7 +5,7 @@ use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; -use crate::agent::{AgentLoop, AgentError}; +use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::protocol::WsOutbound; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, @@ -22,6 +22,7 @@ pub struct Session { pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, tools: Arc, + compressor: ContextCompressor, } impl Session { @@ -36,8 +37,9 @@ impl Session { channel_name, chat_histories: HashMap::new(), user_tx, - provider_config, + provider_config: provider_config.clone(), tools, + compressor: ContextCompressor::new(provider_config.token_limit), }) } @@ -94,6 +96,16 @@ impl Session { let _ = self.user_tx.send(msg).await; } + /// 获取 provider_config 引用 + pub fn provider_config(&self) -> &LLMProviderConfig { + &self.provider_config + } + + /// 获取 compressor 引用 + pub fn compressor(&self) -> &ContextCompressor { + &self.compressor + } + /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent(&self) -> Result { AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone()) @@ -252,6 +264,11 @@ impl SessionManager { // 获取完整历史 let history = session_guard.get_or_create_history(chat_id).clone(); + // 压缩历史(如果需要) + let history = session_guard.compressor + .compress_if_needed(history, &session_guard.provider_config) + .await?; + // 创建 agent 并处理 let agent = session_guard.create_agent()?; let response = agent.process(history).await?; diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 3c6ce8a..5507d62 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -130,6 +130,18 @@ async fn handle_inbound(session: &Arc>, inbound: WsInbound) { // 获取完整历史 let history = session_guard.get_or_create_history(&chat_id).clone(); + // 压缩历史(如果需要) + let history = match session_guard.compressor() + .compress_if_needed(history, session_guard.provider_config()) + .await + { + Ok(h) => h, + Err(e) => { + tracing::warn!(chat_id = %chat_id, error = %e, "Compression failed, using original history"); + session_guard.get_or_create_history(&chat_id).clone() + } + }; + // 创建 agent 并处理 let agent = match session_guard.create_agent() { Ok(a) => a, diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index cff93f9..8f033dc 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -23,9 +23,8 @@ impl Tool for CalculatorTool { } fn description(&self) -> &str { - "Perform arithmetic and statistical calculations. Supports 25 functions: \ - add, subtract, divide, multiply, pow, sqrt, abs, modulo, round, \ - log, ln, exp, factorial, sum, average, median, mode, min, max, \ + "Perform arithmetic and statistical calculations. Supports expression evaluation (evaluate) and functions: \ + round, log, factorial, sum, average, median, mode, min, max, \ range, variance, stdev, percentile, count, percentage_change, clamp. \ Use this tool whenever you need to compute a numeric result instead of guessing." } @@ -41,35 +40,36 @@ impl Tool for CalculatorTool { "function": { "type": "string", "description": "Calculation to perform. \ - Arithmetic: add(values), subtract(values), divide(values), multiply(values), pow(a,b), sqrt(x), abs(x), modulo(a,b), round(x,decimals). \ - Logarithmic/exponential: log(x,base?), ln(x), exp(x), factorial(x). \ + Expression: evaluate(expression) - supports +, -, *, /, %, ^, parentheses, and functions like sqrt, abs, exp, ln, sin, cos, tan, round, floor, ceil, max, min, etc. \ + Rounding: round(x, decimals). \ + Logarithmic: log(x, base?) - base defaults to 10. \ + Special: factorial(x). \ Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \ Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \ Utility: percentage_change(a,b), clamp(x,min_val,max_val).", "enum": [ - "add", "subtract", "divide", "multiply", "pow", "sqrt", - "abs", "modulo", "round", "log", "ln", "exp", "factorial", + "round", "log", "factorial", "sum", "average", "median", "mode", "min", "max", "range", "variance", "stdev", "percentile", "count", - "percentage_change", "clamp" + "percentage_change", "clamp", "evaluate" ] }, "values": { "type": "array", "items": { "type": "number" }, - "description": "Array of numeric values. Required for: add, subtract, divide, multiply, sum, average, median, mode, min, max, range, variance, stdev, percentile, count." + "description": "Array of numeric values. Required for: sum, average, median, mode, min, max, range, variance, stdev, percentile, count." }, "a": { "type": "number", - "description": "First operand. Required for: pow, modulo, percentage_change." + "description": "First operand. Required for: percentage_change." }, "b": { "type": "number", - "description": "Second operand. Required for: pow, modulo, percentage_change." + "description": "Second operand. Required for: percentage_change." }, "x": { "type": "number", - "description": "Input number. Required for: sqrt, abs, exp, ln, log, factorial." + "description": "Input number. Required for: log, factorial, round, clamp." }, "base": { "type": "number", @@ -90,6 +90,10 @@ impl Tool for CalculatorTool { "max_val": { "type": "number", "description": "Maximum bound. Required for: clamp." + }, + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate. Supports: +, -, *, /, %, ^ (power), parentheses. Functions: sqrt, abs, exp, ln, sin, cos, tan, asin, acos, atan, atan2, sinh, cosh, tanh, asinh, acosh, atanh, floor, ceil, round, signum, max, min. Constants: pi, e. Variables: x, weight, etc. Example: '15*3+5^(2+1)', 'sin(pi/2)', 'max(1,2,3)'" } }, "required": ["function"] @@ -109,18 +113,8 @@ impl Tool for CalculatorTool { }; let result = match function { - "add" => calc_add(&args), - "subtract" => calc_subtract(&args), - "divide" => calc_divide(&args), - "multiply" => calc_multiply(&args), - "pow" => calc_pow(&args), - "sqrt" => calc_sqrt(&args), - "abs" => calc_abs(&args), - "modulo" => calc_modulo(&args), "round" => calc_round(&args), "log" => calc_log(&args), - "ln" => calc_ln(&args), - "exp" => calc_exp(&args), "factorial" => calc_factorial(&args), "sum" => calc_sum(&args), "average" => calc_average(&args), @@ -135,6 +129,7 @@ impl Tool for CalculatorTool { "count" => calc_count(&args), "percentage_change" => calc_percentage_change(&args), "clamp" => calc_clamp(&args), + "evaluate" => calc_evaluate(&args), other => Err(format!("Unknown function: {other}")), }; @@ -196,71 +191,6 @@ fn format_num(n: f64) -> String { } } -fn calc_add(args: &serde_json::Value) -> Result { - let values = extract_values(args, 2)?; - Ok(format_num(values.iter().sum())) -} - -fn calc_subtract(args: &serde_json::Value) -> Result { - let values = extract_values(args, 2)?; - let mut iter = values.iter(); - let mut result = *iter.next().unwrap(); - for v in iter { - result -= v; - } - Ok(format_num(result)) -} - -fn calc_divide(args: &serde_json::Value) -> Result { - let values = extract_values(args, 2)?; - let mut iter = values.iter(); - let mut result = *iter.next().unwrap(); - for v in iter { - if *v == 0.0 { - return Err("Division by zero".to_string()); - } - result /= v; - } - Ok(format_num(result)) -} - -fn calc_multiply(args: &serde_json::Value) -> Result { - let values = extract_values(args, 2)?; - let mut result = 1.0; - for v in &values { - result *= v; - } - Ok(format_num(result)) -} - -fn calc_pow(args: &serde_json::Value) -> Result { - let base = extract_f64(args, "a", "a (base)")?; - let exp = extract_f64(args, "b", "b (exponent)")?; - Ok(format_num(base.powf(exp))) -} - -fn calc_sqrt(args: &serde_json::Value) -> Result { - let x = extract_f64(args, "x", "x")?; - if x < 0.0 { - return Err("Cannot compute square root of a negative number".to_string()); - } - Ok(format_num(x.sqrt())) -} - -fn calc_abs(args: &serde_json::Value) -> Result { - let x = extract_f64(args, "x", "x")?; - Ok(format_num(x.abs())) -} - -fn calc_modulo(args: &serde_json::Value) -> Result { - let a = extract_f64(args, "a", "a")?; - let b = extract_f64(args, "b", "b")?; - if b == 0.0 { - return Err("Modulo by zero".to_string()); - } - Ok(format_num(a % b)) -} - fn calc_round(args: &serde_json::Value) -> Result { let x = extract_f64(args, "x", "x")?; let decimals = extract_i64(args, "decimals", "decimals")?; @@ -283,19 +213,6 @@ fn calc_log(args: &serde_json::Value) -> Result { Ok(format_num(x.log(base))) } -fn calc_ln(args: &serde_json::Value) -> Result { - let x = extract_f64(args, "x", "x")?; - if x <= 0.0 { - return Err("Natural logarithm requires a positive number".to_string()); - } - Ok(format_num(x.ln())) -} - -fn calc_exp(args: &serde_json::Value) -> Result { - let x = extract_f64(args, "x", "x")?; - Ok(format_num(x.exp())) -} - fn calc_factorial(args: &serde_json::Value) -> Result { let x = extract_f64(args, "x", "x")?; if x < 0.0 || x != x.floor() { @@ -457,119 +374,21 @@ fn calc_clamp(args: &serde_json::Value) -> Result { Ok(format_num(x.clamp(min_val, max_val))) } +fn calc_evaluate(args: &serde_json::Value) -> Result { + let expression = args + .get("expression") + .and_then(|v| v.as_str()) + .ok_or_else(|| "Missing required parameter: expression".to_string())?; + + meval::eval_str(expression) + .map(format_num) + .map_err(|e| format!("Expression evaluation error: {e}")) +} + #[cfg(test)] mod tests { use super::*; - #[tokio::test] - async fn test_add() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "add", "values": [1.0, 2.0, 3.5]})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "6.5"); - } - - #[tokio::test] - async fn test_subtract() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "subtract", "values": [10.0, 3.0, 1.5]})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "5.5"); - } - - #[tokio::test] - async fn test_divide() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "divide", "values": [100.0, 4.0]})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "25"); - } - - #[tokio::test] - async fn test_divide_by_zero() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "divide", "values": [10.0, 0.0]})) - .await - .unwrap(); - assert!(!result.success); - assert!(result.error.as_ref().unwrap().contains("zero")); - } - - #[tokio::test] - async fn test_multiply() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "multiply", "values": [3.0, 4.0, 5.0]})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "60"); - } - - #[tokio::test] - async fn test_pow() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "pow", "a": 2.0, "b": 10.0})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "1024"); - } - - #[tokio::test] - async fn test_sqrt() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "sqrt", "x": 144.0})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "12"); - } - - #[tokio::test] - async fn test_sqrt_negative() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "sqrt", "x": -4.0})) - .await - .unwrap(); - assert!(!result.success); - } - - #[tokio::test] - async fn test_abs() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "abs", "x": -42.5})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "42.5"); - } - - #[tokio::test] - async fn test_modulo() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "modulo", "a": 17.0, "b": 5.0})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "2"); - } - #[tokio::test] async fn test_round() { let tool = CalculatorTool::new(); @@ -603,28 +422,6 @@ mod tests { assert_eq!(result.output, "3"); } - #[tokio::test] - async fn test_ln() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "ln", "x": 1.0})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "0"); - } - - #[tokio::test] - async fn test_exp() { - let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "exp", "x": 0.0})) - .await - .unwrap(); - assert!(result.success); - assert_eq!(result.output, "1"); - } - #[tokio::test] async fn test_factorial() { let tool = CalculatorTool::new(); @@ -825,4 +622,80 @@ mod tests { assert!(result.success); assert_eq!(result.output, "15"); } + + #[tokio::test] + async fn test_evaluate_simple() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate", "expression": "15*3+5^(2+1)"})) + .await + .unwrap(); + assert!(result.success); + // 15*3 + 5^(2+1) = 45 + 5^3 = 45 + 125 = 170 + assert_eq!(result.output, "170"); + } + + #[tokio::test] + async fn test_evaluate_complex() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate", "expression": "(10-2)*(3+1)"})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "32"); + } + + #[tokio::test] + async fn test_evaluate_invalid() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate", "expression": "invalid"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn test_evaluate_missing_expression() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn test_evaluate_with_functions() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate", "expression": "sqrt(144)"})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "12"); + } + + #[tokio::test] + async fn test_evaluate_with_constants() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate", "expression": "pi * 2"})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "6.283185307179586"); + } + + #[tokio::test] + async fn test_evaluate_modulo() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "evaluate", "expression": "17 % 5"})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2"); + } }