feat: 添加上下文压缩功能,优化消息历史管理和工具调用日志记录

This commit is contained in:
xiaoxixi 2026-04-12 18:38:38 +08:00
parent dcf04279a7
commit c971bc3639
9 changed files with 529 additions and 234 deletions

5
.gitignore vendored
View File

@ -1,4 +1,7 @@
/target /target
reference/** reference/**
.env .env
*.env *.env
AGENTS.md
CLAUDE.md
Cargo.lock

View File

@ -26,3 +26,4 @@ anyhow = "1.0"
mime_guess = "2.0" mime_guess = "2.0"
base64 = "0.22" base64 = "0.22"
tempfile = "3" tempfile = "3"
meval = "0.2"

View File

@ -330,6 +330,13 @@ impl AgentLoop {
let tool_results = self.execute_tools(&response.tool_calls).await; let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Log function call with name and arguments
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
// Truncate tool result if too large // Truncate tool result if too large
let truncated_output = truncate_tool_result(&result.output); let truncated_output = truncate_tool_result(&result.output);

View File

@ -0,0 +1,372 @@
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, ChatCompletionRequest, Message};
use crate::agent::AgentError;
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
let raw: usize = messages
.iter()
.map(|m| m.content.len().div_ceil(4) + 4)
.sum();
(raw as f64 * 1.2) as usize
}
/// Configuration for context compression.
#[derive(Debug, Clone)]
pub struct ContextCompressionConfig {
/// Protect first N messages (system prompt, etc.)
pub protect_first_n: usize,
/// Protect last N messages (recent context)
pub protect_last_n: usize,
/// Maximum compression passes
pub max_passes: u32,
/// Maximum characters in summary
pub summary_max_chars: usize,
/// Characters to keep when trimming tool results
pub tool_result_trim_chars: usize,
}
impl Default for ContextCompressionConfig {
fn default() -> Self {
Self {
protect_first_n: 1,
protect_last_n: 4,
max_passes: 3,
summary_max_chars: 4000,
tool_result_trim_chars: 2000,
}
}
}
/// Context compressor that reduces message history when it exceeds token limits.
pub struct ContextCompressor {
config: ContextCompressionConfig,
context_window: usize,
/// Threshold ratio to trigger compression (50% of context window)
threshold_ratio: f64,
}
impl ContextCompressor {
/// Create a new compressor with the given context window size.
pub fn new(context_window: usize) -> Self {
Self {
config: ContextCompressionConfig::default(),
context_window,
threshold_ratio: 0.5,
}
}
/// Create with custom configuration.
pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self {
Self {
config,
context_window,
threshold_ratio: 0.5,
}
}
/// Get the compression threshold in tokens.
fn threshold(&self) -> usize {
(self.context_window as f64 * self.threshold_ratio) as usize
}
/// Fast-path: trim oversized tool results without LLM call.
/// Returns the number of messages modified.
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize {
let limit = self.config.tool_result_trim_chars;
let mut modified = 0;
for msg in messages.iter_mut() {
if msg.role == "tool" && msg.content.len() > limit {
let removed = msg.content.len() - limit;
msg.content = format!(
"{}...\n\n[Output truncated - {} characters removed]",
&msg.content[..limit.min(msg.content.len())],
removed
);
modified += 1;
}
}
modified
}
/// Main entry point - compresses history if over threshold.
pub async fn compress_if_needed(
&self,
history: Vec<ChatMessage>,
provider_config: &LLMProviderConfig,
) -> Result<Vec<ChatMessage>, AgentError> {
// Check if compression is needed
let tokens = estimate_tokens(&history);
if tokens <= self.threshold() {
#[cfg(debug_assertions)]
tracing::info!(
tokens = tokens,
threshold = self.threshold(),
msg_count = history.len(),
"Context compression not needed"
);
return Ok(history);
}
tracing::info!(
tokens = tokens,
threshold = self.threshold(),
msg_count = history.len(),
"Starting context compression"
);
// Fast trim pass first
let trimmed = self.fast_trim_tool_results(&mut history.clone());
if trimmed > 0 {
let tokens_after = estimate_tokens(&history);
#[cfg(debug_assertions)]
tracing::debug!(
trimmed_messages = trimmed,
tokens_after = tokens_after,
"Fast trim completed"
);
if tokens_after <= self.threshold() {
return Ok(history);
}
}
// LLM summarization pass
let mut current_history = history;
for pass in 0..self.config.max_passes {
let tokens = estimate_tokens(&current_history);
if tokens <= self.threshold() {
break;
}
#[cfg(debug_assertions)]
tracing::debug!(
pass = pass + 1,
tokens = tokens,
"Compression pass"
);
match self.compress_once(&current_history, provider_config).await {
Ok(Some(compressed)) => {
current_history = compressed;
}
Ok(None) => {
// No more compressible content
break;
}
Err(e) => {
tracing::warn!(error = %e, "Compression pass failed, using current history");
break;
}
}
}
tracing::info!(
final_tokens = estimate_tokens(&current_history),
final_msg_count = current_history.len(),
"Context compression completed"
);
Ok(current_history)
}
/// Single compression pass - summarize middle messages between user turns.
/// Returns Some(compressed) if compression happened, None if nothing to compress.
async fn compress_once(
&self,
history: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
return Ok(None);
}
// Find user message indices (excluding protected first messages)
let user_indices: Vec<usize> = history
.iter()
.enumerate()
.skip(self.config.protect_first_n)
.filter(|(_, m)| m.role == "user")
.map(|(i, _)| i)
.collect();
// Need at least one user message and content between users to compress
if user_indices.len() < 2 {
return Ok(None);
}
// Build segments: user -> (assistant turns) -> next user
// We'll summarize the assistant turns between consecutive user messages
let mut new_messages = history[..=user_indices[0]].to_vec();
for i in 0..user_indices.len() - 1 {
let user_idx = user_indices[i];
let next_user_idx = user_indices[i + 1];
new_messages.push(history[user_idx].clone());
// Check if there's assistant content between these two user messages
let between_start = user_idx + 1;
let between_end = next_user_idx;
if between_start < between_end {
let between = &history[between_start..between_end];
let summary = self.summarize_segment(between, provider_config).await?;
// Add summary as a special user message
new_messages.push(ChatMessage::user(format!(
"[Context Summary]\n\n{}",
summary
)));
}
}
// Add last user and everything after (protected)
let last_user_idx = user_indices[user_indices.len() - 1];
if last_user_idx < history.len() - 1 {
// Add everything from last user onwards (protected)
for i in last_user_idx..history.len() {
new_messages.push(history[i].clone());
}
}
// If nothing changed, return None
if new_messages.len() == history.len() {
return Ok(None);
}
Ok(Some(new_messages))
}
/// Summarize a segment of messages using LLM.
async fn summarize_segment(
&self,
messages: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<String, AgentError> {
if messages.is_empty() {
return Ok(String::new());
}
// Build transcript for summarization
let transcript = messages
.iter()
.map(|m| {
let role = match m.role.as_str() {
"assistant" => "Assistant",
"tool" => "Tool",
_ => m.role.as_str(),
};
let name = m.tool_name
.as_ref()
.map(|n| format!(" ({})", n))
.unwrap_or_default();
format!("{}: {}{}", role, m.content, name)
})
.collect::<Vec<_>>()
.join("\n\n");
// Truncate transcript if too long
let transcript = if transcript.len() > self.config.summary_max_chars {
format!(
"{}...\n\n[Transcript truncated - {} characters removed]",
&transcript[..self.config.summary_max_chars],
transcript.len() - self.config.summary_max_chars
)
} else {
transcript
};
let prompt = format!(
r#"You are a conversation compaction engine. Summarize the following conversation segment.
PRESERVE:
- All identifiers (UUIDs, hashes, file paths, URLs)
- Actions taken (tool calls, file operations, commands)
- Key information obtained (results, data, errors)
- Decisions and user preferences
- Current task status
OMIT:
- Verbose tool output (keep key results only)
- Repeated greetings or filler
Be concise, aim for {} characters or less.
---
{}
"#,
self.config.summary_max_chars, transcript
);
// Create provider and call LLM
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
let request = ChatCompletionRequest {
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
temperature: Some(0.3),
max_tokens: Some(1000),
tools: None,
};
match provider.chat(request).await {
Ok(response) => Ok(response.content),
Err(e) => {
// Fallback: just truncate the transcript
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
Ok(transcript[..transcript.len().min(2000)].to_string())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_tokens() {
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let tokens = estimate_tokens(&messages);
// "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
// raw = 19, with 1.2x = ~23
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
}
#[test]
fn test_fast_trim() {
let config = ContextCompressionConfig {
tool_result_trim_chars: 50,
..Default::default()
};
let compressor = ContextCompressor::with_config(100_000, config);
let mut messages = vec![
ChatMessage::user("Hello"),
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
];
let modified = compressor.fast_trim_tool_results(&mut messages);
assert_eq!(modified, 1);
assert!(messages[1].content.len() < 100);
}
#[test]
fn test_threshold() {
let compressor = ContextCompressor::new(128_000);
assert_eq!(compressor.threshold(), 64_000);
}
}

View File

@ -1,3 +1,5 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor;
pub use agent_loop::{AgentLoop, AgentError}; pub use agent_loop::{AgentLoop, AgentError};
pub use context_compressor::ContextCompressor;

View File

@ -75,12 +75,18 @@ pub struct AgentConfig {
pub model: String, pub model: String,
#[serde(default = "default_max_tool_iterations")] #[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize, pub max_tool_iterations: usize,
#[serde(default = "default_token_limit")]
pub token_limit: usize,
} }
fn default_max_tool_iterations() -> usize { fn default_max_tool_iterations() -> usize {
20 20
} }
fn default_token_limit() -> usize {
128_000
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig { pub struct GatewayConfig {
#[serde(default = "default_gateway_host")] #[serde(default = "default_gateway_host")]
@ -139,6 +145,7 @@ pub struct LLMProviderConfig {
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>, pub model_extra: HashMap<String, serde_json::Value>,
pub max_tool_iterations: usize, pub max_tool_iterations: usize,
pub token_limit: usize,
} }
fn get_default_config_path() -> PathBuf { fn get_default_config_path() -> PathBuf {
@ -199,6 +206,7 @@ impl Config {
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
model_extra: model.extra.clone(), model_extra: model.extra.clone(),
max_tool_iterations: agent.max_tool_iterations, max_tool_iterations: agent.max_tool_iterations,
token_limit: agent.token_limit,
}) })
} }
} }

View File

@ -5,7 +5,7 @@ use tokio::sync::{Mutex, mpsc};
use uuid::Uuid; use uuid::Uuid;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError}; use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::tools::{ use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
@ -22,6 +22,7 @@ pub struct Session {
pub user_tx: mpsc::Sender<WsOutbound>, pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
} }
impl Session { impl Session {
@ -36,8 +37,9 @@ impl Session {
channel_name, channel_name,
chat_histories: HashMap::new(), chat_histories: HashMap::new(),
user_tx, user_tx,
provider_config, provider_config: provider_config.clone(),
tools, tools,
compressor: ContextCompressor::new(provider_config.token_limit),
}) })
} }
@ -94,6 +96,16 @@ impl Session {
let _ = self.user_tx.send(msg).await; let _ = self.user_tx.send(msg).await;
} }
/// 获取 provider_config 引用
pub fn provider_config(&self) -> &LLMProviderConfig {
&self.provider_config
}
/// 获取 compressor 引用
pub fn compressor(&self) -> &ContextCompressor {
&self.compressor
}
/// 创建一个临时的 AgentLoop 实例来处理消息 /// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> { pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone()) AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
@ -252,6 +264,11 @@ impl SessionManager {
// 获取完整历史 // 获取完整历史
let history = session_guard.get_or_create_history(chat_id).clone(); let history = session_guard.get_or_create_history(chat_id).clone();
// 压缩历史(如果需要)
let history = session_guard.compressor
.compress_if_needed(history, &session_guard.provider_config)
.await?;
// 创建 agent 并处理 // 创建 agent 并处理
let agent = session_guard.create_agent()?; let agent = session_guard.create_agent()?;
let response = agent.process(history).await?; let response = agent.process(history).await?;

View File

@ -130,6 +130,18 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
// 获取完整历史 // 获取完整历史
let history = session_guard.get_or_create_history(&chat_id).clone(); let history = session_guard.get_or_create_history(&chat_id).clone();
// 压缩历史(如果需要)
let history = match session_guard.compressor()
.compress_if_needed(history, session_guard.provider_config())
.await
{
Ok(h) => h,
Err(e) => {
tracing::warn!(chat_id = %chat_id, error = %e, "Compression failed, using original history");
session_guard.get_or_create_history(&chat_id).clone()
}
};
// 创建 agent 并处理 // 创建 agent 并处理
let agent = match session_guard.create_agent() { let agent = match session_guard.create_agent() {
Ok(a) => a, Ok(a) => a,

View File

@ -23,9 +23,8 @@ impl Tool for CalculatorTool {
} }
fn description(&self) -> &str { fn description(&self) -> &str {
"Perform arithmetic and statistical calculations. Supports 25 functions: \ "Perform arithmetic and statistical calculations. Supports expression evaluation (evaluate) and functions: \
add, subtract, divide, multiply, pow, sqrt, abs, modulo, round, \ round, log, factorial, sum, average, median, mode, min, max, \
log, ln, exp, factorial, sum, average, median, mode, min, max, \
range, variance, stdev, percentile, count, percentage_change, clamp. \ range, variance, stdev, percentile, count, percentage_change, clamp. \
Use this tool whenever you need to compute a numeric result instead of guessing." Use this tool whenever you need to compute a numeric result instead of guessing."
} }
@ -41,35 +40,36 @@ impl Tool for CalculatorTool {
"function": { "function": {
"type": "string", "type": "string",
"description": "Calculation to perform. \ "description": "Calculation to perform. \
Arithmetic: add(values), subtract(values), divide(values), multiply(values), pow(a,b), sqrt(x), abs(x), modulo(a,b), round(x,decimals). \ Expression: evaluate(expression) - supports +, -, *, /, %, ^, parentheses, and functions like sqrt, abs, exp, ln, sin, cos, tan, round, floor, ceil, max, min, etc. \
Logarithmic/exponential: log(x,base?), ln(x), exp(x), factorial(x). \ Rounding: round(x, decimals). \
Logarithmic: log(x, base?) - base defaults to 10. \
Special: factorial(x). \
Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \ Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \
Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \ Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \
Utility: percentage_change(a,b), clamp(x,min_val,max_val).", Utility: percentage_change(a,b), clamp(x,min_val,max_val).",
"enum": [ "enum": [
"add", "subtract", "divide", "multiply", "pow", "sqrt", "round", "log", "factorial",
"abs", "modulo", "round", "log", "ln", "exp", "factorial",
"sum", "average", "median", "mode", "min", "max", "range", "sum", "average", "median", "mode", "min", "max", "range",
"variance", "stdev", "percentile", "count", "variance", "stdev", "percentile", "count",
"percentage_change", "clamp" "percentage_change", "clamp", "evaluate"
] ]
}, },
"values": { "values": {
"type": "array", "type": "array",
"items": { "type": "number" }, "items": { "type": "number" },
"description": "Array of numeric values. Required for: add, subtract, divide, multiply, sum, average, median, mode, min, max, range, variance, stdev, percentile, count." "description": "Array of numeric values. Required for: sum, average, median, mode, min, max, range, variance, stdev, percentile, count."
}, },
"a": { "a": {
"type": "number", "type": "number",
"description": "First operand. Required for: pow, modulo, percentage_change." "description": "First operand. Required for: percentage_change."
}, },
"b": { "b": {
"type": "number", "type": "number",
"description": "Second operand. Required for: pow, modulo, percentage_change." "description": "Second operand. Required for: percentage_change."
}, },
"x": { "x": {
"type": "number", "type": "number",
"description": "Input number. Required for: sqrt, abs, exp, ln, log, factorial." "description": "Input number. Required for: log, factorial, round, clamp."
}, },
"base": { "base": {
"type": "number", "type": "number",
@ -90,6 +90,10 @@ impl Tool for CalculatorTool {
"max_val": { "max_val": {
"type": "number", "type": "number",
"description": "Maximum bound. Required for: clamp." "description": "Maximum bound. Required for: clamp."
},
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate. Supports: +, -, *, /, %, ^ (power), parentheses. Functions: sqrt, abs, exp, ln, sin, cos, tan, asin, acos, atan, atan2, sinh, cosh, tanh, asinh, acosh, atanh, floor, ceil, round, signum, max, min. Constants: pi, e. Variables: x, weight, etc. Example: '15*3+5^(2+1)', 'sin(pi/2)', 'max(1,2,3)'"
} }
}, },
"required": ["function"] "required": ["function"]
@ -109,18 +113,8 @@ impl Tool for CalculatorTool {
}; };
let result = match function { let result = match function {
"add" => calc_add(&args),
"subtract" => calc_subtract(&args),
"divide" => calc_divide(&args),
"multiply" => calc_multiply(&args),
"pow" => calc_pow(&args),
"sqrt" => calc_sqrt(&args),
"abs" => calc_abs(&args),
"modulo" => calc_modulo(&args),
"round" => calc_round(&args), "round" => calc_round(&args),
"log" => calc_log(&args), "log" => calc_log(&args),
"ln" => calc_ln(&args),
"exp" => calc_exp(&args),
"factorial" => calc_factorial(&args), "factorial" => calc_factorial(&args),
"sum" => calc_sum(&args), "sum" => calc_sum(&args),
"average" => calc_average(&args), "average" => calc_average(&args),
@ -135,6 +129,7 @@ impl Tool for CalculatorTool {
"count" => calc_count(&args), "count" => calc_count(&args),
"percentage_change" => calc_percentage_change(&args), "percentage_change" => calc_percentage_change(&args),
"clamp" => calc_clamp(&args), "clamp" => calc_clamp(&args),
"evaluate" => calc_evaluate(&args),
other => Err(format!("Unknown function: {other}")), other => Err(format!("Unknown function: {other}")),
}; };
@ -196,71 +191,6 @@ fn format_num(n: f64) -> String {
} }
} }
fn calc_add(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
Ok(format_num(values.iter().sum()))
}
fn calc_subtract(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
let mut iter = values.iter();
let mut result = *iter.next().unwrap();
for v in iter {
result -= v;
}
Ok(format_num(result))
}
fn calc_divide(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
let mut iter = values.iter();
let mut result = *iter.next().unwrap();
for v in iter {
if *v == 0.0 {
return Err("Division by zero".to_string());
}
result /= v;
}
Ok(format_num(result))
}
fn calc_multiply(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
let mut result = 1.0;
for v in &values {
result *= v;
}
Ok(format_num(result))
}
fn calc_pow(args: &serde_json::Value) -> Result<String, String> {
let base = extract_f64(args, "a", "a (base)")?;
let exp = extract_f64(args, "b", "b (exponent)")?;
Ok(format_num(base.powf(exp)))
}
fn calc_sqrt(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
if x < 0.0 {
return Err("Cannot compute square root of a negative number".to_string());
}
Ok(format_num(x.sqrt()))
}
fn calc_abs(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
Ok(format_num(x.abs()))
}
fn calc_modulo(args: &serde_json::Value) -> Result<String, String> {
let a = extract_f64(args, "a", "a")?;
let b = extract_f64(args, "b", "b")?;
if b == 0.0 {
return Err("Modulo by zero".to_string());
}
Ok(format_num(a % b))
}
fn calc_round(args: &serde_json::Value) -> Result<String, String> { fn calc_round(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?; let x = extract_f64(args, "x", "x")?;
let decimals = extract_i64(args, "decimals", "decimals")?; let decimals = extract_i64(args, "decimals", "decimals")?;
@ -283,19 +213,6 @@ fn calc_log(args: &serde_json::Value) -> Result<String, String> {
Ok(format_num(x.log(base))) Ok(format_num(x.log(base)))
} }
fn calc_ln(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
if x <= 0.0 {
return Err("Natural logarithm requires a positive number".to_string());
}
Ok(format_num(x.ln()))
}
fn calc_exp(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
Ok(format_num(x.exp()))
}
fn calc_factorial(args: &serde_json::Value) -> Result<String, String> { fn calc_factorial(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?; let x = extract_f64(args, "x", "x")?;
if x < 0.0 || x != x.floor() { if x < 0.0 || x != x.floor() {
@ -457,119 +374,21 @@ fn calc_clamp(args: &serde_json::Value) -> Result<String, String> {
Ok(format_num(x.clamp(min_val, max_val))) Ok(format_num(x.clamp(min_val, max_val)))
} }
fn calc_evaluate(args: &serde_json::Value) -> Result<String, String> {
let expression = args
.get("expression")
.and_then(|v| v.as_str())
.ok_or_else(|| "Missing required parameter: expression".to_string())?;
meval::eval_str(expression)
.map(format_num)
.map_err(|e| format!("Expression evaluation error: {e}"))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[tokio::test]
async fn test_add() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "add", "values": [1.0, 2.0, 3.5]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "6.5");
}
#[tokio::test]
async fn test_subtract() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "subtract", "values": [10.0, 3.0, 1.5]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "5.5");
}
#[tokio::test]
async fn test_divide() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "divide", "values": [100.0, 4.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "25");
}
#[tokio::test]
async fn test_divide_by_zero() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "divide", "values": [10.0, 0.0]}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("zero"));
}
#[tokio::test]
async fn test_multiply() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "multiply", "values": [3.0, 4.0, 5.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "60");
}
#[tokio::test]
async fn test_pow() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "pow", "a": 2.0, "b": 10.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "1024");
}
#[tokio::test]
async fn test_sqrt() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "sqrt", "x": 144.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "12");
}
#[tokio::test]
async fn test_sqrt_negative() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "sqrt", "x": -4.0}))
.await
.unwrap();
assert!(!result.success);
}
#[tokio::test]
async fn test_abs() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "abs", "x": -42.5}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "42.5");
}
#[tokio::test]
async fn test_modulo() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "modulo", "a": 17.0, "b": 5.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2");
}
#[tokio::test] #[tokio::test]
async fn test_round() { async fn test_round() {
let tool = CalculatorTool::new(); let tool = CalculatorTool::new();
@ -603,28 +422,6 @@ mod tests {
assert_eq!(result.output, "3"); assert_eq!(result.output, "3");
} }
#[tokio::test]
async fn test_ln() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "ln", "x": 1.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "0");
}
#[tokio::test]
async fn test_exp() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "exp", "x": 0.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "1");
}
#[tokio::test] #[tokio::test]
async fn test_factorial() { async fn test_factorial() {
let tool = CalculatorTool::new(); let tool = CalculatorTool::new();
@ -825,4 +622,80 @@ mod tests {
assert!(result.success); assert!(result.success);
assert_eq!(result.output, "15"); assert_eq!(result.output, "15");
} }
#[tokio::test]
async fn test_evaluate_simple() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate", "expression": "15*3+5^(2+1)"}))
.await
.unwrap();
assert!(result.success);
// 15*3 + 5^(2+1) = 45 + 5^3 = 45 + 125 = 170
assert_eq!(result.output, "170");
}
#[tokio::test]
async fn test_evaluate_complex() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate", "expression": "(10-2)*(3+1)"}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "32");
}
#[tokio::test]
async fn test_evaluate_invalid() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate", "expression": "invalid"}))
.await
.unwrap();
assert!(!result.success);
}
#[tokio::test]
async fn test_evaluate_missing_expression() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate"}))
.await
.unwrap();
assert!(!result.success);
}
#[tokio::test]
async fn test_evaluate_with_functions() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate", "expression": "sqrt(144)"}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "12");
}
#[tokio::test]
async fn test_evaluate_with_constants() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate", "expression": "pi * 2"}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "6.283185307179586");
}
#[tokio::test]
async fn test_evaluate_modulo() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate", "expression": "17 % 5"}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2");
}
} }