feat: 添加上下文压缩功能,优化消息历史管理和工具调用日志记录
This commit is contained in:
parent
dcf04279a7
commit
c971bc3639
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,4 +1,7 @@
|
||||
/target
|
||||
reference/**
|
||||
.env
|
||||
*.env
|
||||
*.env
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
Cargo.lock
|
||||
|
||||
@ -26,3 +26,4 @@ anyhow = "1.0"
|
||||
mime_guess = "2.0"
|
||||
base64 = "0.22"
|
||||
tempfile = "3"
|
||||
meval = "0.2"
|
||||
|
||||
@ -330,6 +330,13 @@ impl AgentLoop {
|
||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||
|
||||
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
||||
// Log function call with name and arguments
|
||||
let args_str = match &tool_call.arguments {
|
||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||
};
|
||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||
|
||||
// Truncate tool result if too large
|
||||
let truncated_output = truncate_tool_result(&result.output);
|
||||
|
||||
|
||||
372
src/agent/context_compressor.rs
Normal file
372
src/agent/context_compressor.rs
Normal file
@ -0,0 +1,372 @@
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
|
||||
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
|
||||
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||
let raw: usize = messages
|
||||
.iter()
|
||||
.map(|m| m.content.len().div_ceil(4) + 4)
|
||||
.sum();
|
||||
(raw as f64 * 1.2) as usize
|
||||
}
|
||||
|
||||
/// Configuration for context compression.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextCompressionConfig {
|
||||
/// Protect first N messages (system prompt, etc.)
|
||||
pub protect_first_n: usize,
|
||||
/// Protect last N messages (recent context)
|
||||
pub protect_last_n: usize,
|
||||
/// Maximum compression passes
|
||||
pub max_passes: u32,
|
||||
/// Maximum characters in summary
|
||||
pub summary_max_chars: usize,
|
||||
/// Characters to keep when trimming tool results
|
||||
pub tool_result_trim_chars: usize,
|
||||
}
|
||||
|
||||
impl Default for ContextCompressionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
protect_first_n: 1,
|
||||
protect_last_n: 4,
|
||||
max_passes: 3,
|
||||
summary_max_chars: 4000,
|
||||
tool_result_trim_chars: 2000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context compressor that reduces message history when it exceeds token limits.
|
||||
pub struct ContextCompressor {
|
||||
config: ContextCompressionConfig,
|
||||
context_window: usize,
|
||||
/// Threshold ratio to trigger compression (50% of context window)
|
||||
threshold_ratio: f64,
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
/// Create a new compressor with the given context window size.
|
||||
pub fn new(context_window: usize) -> Self {
|
||||
Self {
|
||||
config: ContextCompressionConfig::default(),
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom configuration.
|
||||
pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the compression threshold in tokens.
|
||||
fn threshold(&self) -> usize {
|
||||
(self.context_window as f64 * self.threshold_ratio) as usize
|
||||
}
|
||||
|
||||
/// Fast-path: trim oversized tool results without LLM call.
|
||||
/// Returns the number of messages modified.
|
||||
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize {
|
||||
let limit = self.config.tool_result_trim_chars;
|
||||
let mut modified = 0;
|
||||
|
||||
for msg in messages.iter_mut() {
|
||||
if msg.role == "tool" && msg.content.len() > limit {
|
||||
let removed = msg.content.len() - limit;
|
||||
msg.content = format!(
|
||||
"{}...\n\n[Output truncated - {} characters removed]",
|
||||
&msg.content[..limit.min(msg.content.len())],
|
||||
removed
|
||||
);
|
||||
modified += 1;
|
||||
}
|
||||
}
|
||||
|
||||
modified
|
||||
}
|
||||
|
||||
/// Main entry point - compresses history if over threshold.
|
||||
pub async fn compress_if_needed(
|
||||
&self,
|
||||
history: Vec<ChatMessage>,
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Vec<ChatMessage>, AgentError> {
|
||||
// Check if compression is needed
|
||||
let tokens = estimate_tokens(&history);
|
||||
if tokens <= self.threshold() {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::info!(
|
||||
tokens = tokens,
|
||||
threshold = self.threshold(),
|
||||
msg_count = history.len(),
|
||||
"Context compression not needed"
|
||||
);
|
||||
return Ok(history);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
tokens = tokens,
|
||||
threshold = self.threshold(),
|
||||
msg_count = history.len(),
|
||||
"Starting context compression"
|
||||
);
|
||||
|
||||
// Fast trim pass first
|
||||
let trimmed = self.fast_trim_tool_results(&mut history.clone());
|
||||
if trimmed > 0 {
|
||||
let tokens_after = estimate_tokens(&history);
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
trimmed_messages = trimmed,
|
||||
tokens_after = tokens_after,
|
||||
"Fast trim completed"
|
||||
);
|
||||
if tokens_after <= self.threshold() {
|
||||
return Ok(history);
|
||||
}
|
||||
}
|
||||
|
||||
// LLM summarization pass
|
||||
let mut current_history = history;
|
||||
for pass in 0..self.config.max_passes {
|
||||
let tokens = estimate_tokens(¤t_history);
|
||||
if tokens <= self.threshold() {
|
||||
break;
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
pass = pass + 1,
|
||||
tokens = tokens,
|
||||
"Compression pass"
|
||||
);
|
||||
|
||||
match self.compress_once(¤t_history, provider_config).await {
|
||||
Ok(Some(compressed)) => {
|
||||
current_history = compressed;
|
||||
}
|
||||
Ok(None) => {
|
||||
// No more compressible content
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Compression pass failed, using current history");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
final_tokens = estimate_tokens(¤t_history),
|
||||
final_msg_count = current_history.len(),
|
||||
"Context compression completed"
|
||||
);
|
||||
|
||||
Ok(current_history)
|
||||
}
|
||||
|
||||
/// Single compression pass - summarize middle messages between user turns.
|
||||
/// Returns Some(compressed) if compression happened, None if nothing to compress.
|
||||
async fn compress_once(
|
||||
&self,
|
||||
history: &[ChatMessage],
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
|
||||
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Find user message indices (excluding protected first messages)
|
||||
let user_indices: Vec<usize> = history
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(self.config.protect_first_n)
|
||||
.filter(|(_, m)| m.role == "user")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// Need at least one user message and content between users to compress
|
||||
if user_indices.len() < 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build segments: user -> (assistant turns) -> next user
|
||||
// We'll summarize the assistant turns between consecutive user messages
|
||||
let mut new_messages = history[..=user_indices[0]].to_vec();
|
||||
|
||||
for i in 0..user_indices.len() - 1 {
|
||||
let user_idx = user_indices[i];
|
||||
let next_user_idx = user_indices[i + 1];
|
||||
|
||||
new_messages.push(history[user_idx].clone());
|
||||
|
||||
// Check if there's assistant content between these two user messages
|
||||
let between_start = user_idx + 1;
|
||||
let between_end = next_user_idx;
|
||||
|
||||
if between_start < between_end {
|
||||
let between = &history[between_start..between_end];
|
||||
let summary = self.summarize_segment(between, provider_config).await?;
|
||||
|
||||
// Add summary as a special user message
|
||||
new_messages.push(ChatMessage::user(format!(
|
||||
"[Context Summary]\n\n{}",
|
||||
summary
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add last user and everything after (protected)
|
||||
let last_user_idx = user_indices[user_indices.len() - 1];
|
||||
if last_user_idx < history.len() - 1 {
|
||||
// Add everything from last user onwards (protected)
|
||||
for i in last_user_idx..history.len() {
|
||||
new_messages.push(history[i].clone());
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing changed, return None
|
||||
if new_messages.len() == history.len() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(new_messages))
|
||||
}
|
||||
|
||||
/// Summarize a segment of messages using LLM.
|
||||
async fn summarize_segment(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<String, AgentError> {
|
||||
if messages.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Build transcript for summarization
|
||||
let transcript = messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let role = match m.role.as_str() {
|
||||
"assistant" => "Assistant",
|
||||
"tool" => "Tool",
|
||||
_ => m.role.as_str(),
|
||||
};
|
||||
let name = m.tool_name
|
||||
.as_ref()
|
||||
.map(|n| format!(" ({})", n))
|
||||
.unwrap_or_default();
|
||||
format!("{}: {}{}", role, m.content, name)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
|
||||
// Truncate transcript if too long
|
||||
let transcript = if transcript.len() > self.config.summary_max_chars {
|
||||
format!(
|
||||
"{}...\n\n[Transcript truncated - {} characters removed]",
|
||||
&transcript[..self.config.summary_max_chars],
|
||||
transcript.len() - self.config.summary_max_chars
|
||||
)
|
||||
} else {
|
||||
transcript
|
||||
};
|
||||
|
||||
let prompt = format!(
|
||||
r#"You are a conversation compaction engine. Summarize the following conversation segment.
|
||||
|
||||
PRESERVE:
|
||||
- All identifiers (UUIDs, hashes, file paths, URLs)
|
||||
- Actions taken (tool calls, file operations, commands)
|
||||
- Key information obtained (results, data, errors)
|
||||
- Decisions and user preferences
|
||||
- Current task status
|
||||
|
||||
OMIT:
|
||||
- Verbose tool output (keep key results only)
|
||||
- Repeated greetings or filler
|
||||
|
||||
Be concise, aim for {} characters or less.
|
||||
|
||||
---
|
||||
|
||||
{}
|
||||
|
||||
"#,
|
||||
self.config.summary_max_chars, transcript
|
||||
);
|
||||
|
||||
// Create provider and call LLM
|
||||
let provider = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
||||
temperature: Some(0.3),
|
||||
max_tokens: Some(1000),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
match provider.chat(request).await {
|
||||
Ok(response) => Ok(response.content),
|
||||
Err(e) => {
|
||||
// Fallback: just truncate the transcript
|
||||
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
||||
Ok(transcript[..transcript.len().min(2000)].to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens() {
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
ChatMessage::user("How are you?"),
|
||||
];
|
||||
|
||||
let tokens = estimate_tokens(&messages);
|
||||
// "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6
|
||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||
// raw = 19, with 1.2x = ~23
|
||||
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fast_trim() {
|
||||
let config = ContextCompressionConfig {
|
||||
tool_result_trim_chars: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(100_000, config);
|
||||
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
|
||||
];
|
||||
|
||||
let modified = compressor.fast_trim_tool_results(&mut messages);
|
||||
assert_eq!(modified, 1);
|
||||
assert!(messages[1].content.len() < 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_threshold() {
|
||||
let compressor = ContextCompressor::new(128_000);
|
||||
assert_eq!(compressor.threshold(), 64_000);
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,5 @@
|
||||
pub mod agent_loop;
|
||||
pub mod context_compressor;
|
||||
|
||||
pub use agent_loop::{AgentLoop, AgentError};
|
||||
pub use context_compressor::ContextCompressor;
|
||||
|
||||
@ -75,12 +75,18 @@ pub struct AgentConfig {
|
||||
pub model: String,
|
||||
#[serde(default = "default_max_tool_iterations")]
|
||||
pub max_tool_iterations: usize,
|
||||
#[serde(default = "default_token_limit")]
|
||||
pub token_limit: usize,
|
||||
}
|
||||
|
||||
fn default_max_tool_iterations() -> usize {
|
||||
20
|
||||
}
|
||||
|
||||
fn default_token_limit() -> usize {
|
||||
128_000
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GatewayConfig {
|
||||
#[serde(default = "default_gateway_host")]
|
||||
@ -139,6 +145,7 @@ pub struct LLMProviderConfig {
|
||||
pub max_tokens: Option<u32>,
|
||||
pub model_extra: HashMap<String, serde_json::Value>,
|
||||
pub max_tool_iterations: usize,
|
||||
pub token_limit: usize,
|
||||
}
|
||||
|
||||
fn get_default_config_path() -> PathBuf {
|
||||
@ -199,6 +206,7 @@ impl Config {
|
||||
max_tokens: model.max_tokens,
|
||||
model_extra: model.extra.clone(),
|
||||
max_tool_iterations: agent.max_tool_iterations,
|
||||
token_limit: agent.token_limit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError};
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
@ -22,6 +22,7 @@ pub struct Session {
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
compressor: ContextCompressor,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@ -36,8 +37,9 @@ impl Session {
|
||||
channel_name,
|
||||
chat_histories: HashMap::new(),
|
||||
user_tx,
|
||||
provider_config,
|
||||
provider_config: provider_config.clone(),
|
||||
tools,
|
||||
compressor: ContextCompressor::new(provider_config.token_limit),
|
||||
})
|
||||
}
|
||||
|
||||
@ -94,6 +96,16 @@ impl Session {
|
||||
let _ = self.user_tx.send(msg).await;
|
||||
}
|
||||
|
||||
/// 获取 provider_config 引用
|
||||
pub fn provider_config(&self) -> &LLMProviderConfig {
|
||||
&self.provider_config
|
||||
}
|
||||
|
||||
/// 获取 compressor 引用
|
||||
pub fn compressor(&self) -> &ContextCompressor {
|
||||
&self.compressor
|
||||
}
|
||||
|
||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||||
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
|
||||
@ -252,6 +264,11 @@ impl SessionManager {
|
||||
// 获取完整历史
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
|
||||
// 压缩历史(如果需要)
|
||||
let history = session_guard.compressor
|
||||
.compress_if_needed(history, &session_guard.provider_config)
|
||||
.await?;
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = session_guard.create_agent()?;
|
||||
let response = agent.process(history).await?;
|
||||
|
||||
@ -130,6 +130,18 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
||||
// 获取完整历史
|
||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
|
||||
// 压缩历史(如果需要)
|
||||
let history = match session_guard.compressor()
|
||||
.compress_if_needed(history, session_guard.provider_config())
|
||||
.await
|
||||
{
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
tracing::warn!(chat_id = %chat_id, error = %e, "Compression failed, using original history");
|
||||
session_guard.get_or_create_history(&chat_id).clone()
|
||||
}
|
||||
};
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = match session_guard.create_agent() {
|
||||
Ok(a) => a,
|
||||
|
||||
@ -23,9 +23,8 @@ impl Tool for CalculatorTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Perform arithmetic and statistical calculations. Supports 25 functions: \
|
||||
add, subtract, divide, multiply, pow, sqrt, abs, modulo, round, \
|
||||
log, ln, exp, factorial, sum, average, median, mode, min, max, \
|
||||
"Perform arithmetic and statistical calculations. Supports expression evaluation (evaluate) and functions: \
|
||||
round, log, factorial, sum, average, median, mode, min, max, \
|
||||
range, variance, stdev, percentile, count, percentage_change, clamp. \
|
||||
Use this tool whenever you need to compute a numeric result instead of guessing."
|
||||
}
|
||||
@ -41,35 +40,36 @@ impl Tool for CalculatorTool {
|
||||
"function": {
|
||||
"type": "string",
|
||||
"description": "Calculation to perform. \
|
||||
Arithmetic: add(values), subtract(values), divide(values), multiply(values), pow(a,b), sqrt(x), abs(x), modulo(a,b), round(x,decimals). \
|
||||
Logarithmic/exponential: log(x,base?), ln(x), exp(x), factorial(x). \
|
||||
Expression: evaluate(expression) - supports +, -, *, /, %, ^, parentheses, and functions like sqrt, abs, exp, ln, sin, cos, tan, round, floor, ceil, max, min, etc. \
|
||||
Rounding: round(x, decimals). \
|
||||
Logarithmic: log(x, base?) - base defaults to 10. \
|
||||
Special: factorial(x). \
|
||||
Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \
|
||||
Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \
|
||||
Utility: percentage_change(a,b), clamp(x,min_val,max_val).",
|
||||
"enum": [
|
||||
"add", "subtract", "divide", "multiply", "pow", "sqrt",
|
||||
"abs", "modulo", "round", "log", "ln", "exp", "factorial",
|
||||
"round", "log", "factorial",
|
||||
"sum", "average", "median", "mode", "min", "max", "range",
|
||||
"variance", "stdev", "percentile", "count",
|
||||
"percentage_change", "clamp"
|
||||
"percentage_change", "clamp", "evaluate"
|
||||
]
|
||||
},
|
||||
"values": {
|
||||
"type": "array",
|
||||
"items": { "type": "number" },
|
||||
"description": "Array of numeric values. Required for: add, subtract, divide, multiply, sum, average, median, mode, min, max, range, variance, stdev, percentile, count."
|
||||
"description": "Array of numeric values. Required for: sum, average, median, mode, min, max, range, variance, stdev, percentile, count."
|
||||
},
|
||||
"a": {
|
||||
"type": "number",
|
||||
"description": "First operand. Required for: pow, modulo, percentage_change."
|
||||
"description": "First operand. Required for: percentage_change."
|
||||
},
|
||||
"b": {
|
||||
"type": "number",
|
||||
"description": "Second operand. Required for: pow, modulo, percentage_change."
|
||||
"description": "Second operand. Required for: percentage_change."
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "Input number. Required for: sqrt, abs, exp, ln, log, factorial."
|
||||
"description": "Input number. Required for: log, factorial, round, clamp."
|
||||
},
|
||||
"base": {
|
||||
"type": "number",
|
||||
@ -90,6 +90,10 @@ impl Tool for CalculatorTool {
|
||||
"max_val": {
|
||||
"type": "number",
|
||||
"description": "Maximum bound. Required for: clamp."
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate. Supports: +, -, *, /, %, ^ (power), parentheses. Functions: sqrt, abs, exp, ln, sin, cos, tan, asin, acos, atan, atan2, sinh, cosh, tanh, asinh, acosh, atanh, floor, ceil, round, signum, max, min. Constants: pi, e. Variables: x, weight, etc. Example: '15*3+5^(2+1)', 'sin(pi/2)', 'max(1,2,3)'"
|
||||
}
|
||||
},
|
||||
"required": ["function"]
|
||||
@ -109,18 +113,8 @@ impl Tool for CalculatorTool {
|
||||
};
|
||||
|
||||
let result = match function {
|
||||
"add" => calc_add(&args),
|
||||
"subtract" => calc_subtract(&args),
|
||||
"divide" => calc_divide(&args),
|
||||
"multiply" => calc_multiply(&args),
|
||||
"pow" => calc_pow(&args),
|
||||
"sqrt" => calc_sqrt(&args),
|
||||
"abs" => calc_abs(&args),
|
||||
"modulo" => calc_modulo(&args),
|
||||
"round" => calc_round(&args),
|
||||
"log" => calc_log(&args),
|
||||
"ln" => calc_ln(&args),
|
||||
"exp" => calc_exp(&args),
|
||||
"factorial" => calc_factorial(&args),
|
||||
"sum" => calc_sum(&args),
|
||||
"average" => calc_average(&args),
|
||||
@ -135,6 +129,7 @@ impl Tool for CalculatorTool {
|
||||
"count" => calc_count(&args),
|
||||
"percentage_change" => calc_percentage_change(&args),
|
||||
"clamp" => calc_clamp(&args),
|
||||
"evaluate" => calc_evaluate(&args),
|
||||
other => Err(format!("Unknown function: {other}")),
|
||||
};
|
||||
|
||||
@ -196,71 +191,6 @@ fn format_num(n: f64) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn calc_add(args: &serde_json::Value) -> Result<String, String> {
|
||||
let values = extract_values(args, 2)?;
|
||||
Ok(format_num(values.iter().sum()))
|
||||
}
|
||||
|
||||
fn calc_subtract(args: &serde_json::Value) -> Result<String, String> {
|
||||
let values = extract_values(args, 2)?;
|
||||
let mut iter = values.iter();
|
||||
let mut result = *iter.next().unwrap();
|
||||
for v in iter {
|
||||
result -= v;
|
||||
}
|
||||
Ok(format_num(result))
|
||||
}
|
||||
|
||||
fn calc_divide(args: &serde_json::Value) -> Result<String, String> {
|
||||
let values = extract_values(args, 2)?;
|
||||
let mut iter = values.iter();
|
||||
let mut result = *iter.next().unwrap();
|
||||
for v in iter {
|
||||
if *v == 0.0 {
|
||||
return Err("Division by zero".to_string());
|
||||
}
|
||||
result /= v;
|
||||
}
|
||||
Ok(format_num(result))
|
||||
}
|
||||
|
||||
fn calc_multiply(args: &serde_json::Value) -> Result<String, String> {
|
||||
let values = extract_values(args, 2)?;
|
||||
let mut result = 1.0;
|
||||
for v in &values {
|
||||
result *= v;
|
||||
}
|
||||
Ok(format_num(result))
|
||||
}
|
||||
|
||||
fn calc_pow(args: &serde_json::Value) -> Result<String, String> {
|
||||
let base = extract_f64(args, "a", "a (base)")?;
|
||||
let exp = extract_f64(args, "b", "b (exponent)")?;
|
||||
Ok(format_num(base.powf(exp)))
|
||||
}
|
||||
|
||||
fn calc_sqrt(args: &serde_json::Value) -> Result<String, String> {
|
||||
let x = extract_f64(args, "x", "x")?;
|
||||
if x < 0.0 {
|
||||
return Err("Cannot compute square root of a negative number".to_string());
|
||||
}
|
||||
Ok(format_num(x.sqrt()))
|
||||
}
|
||||
|
||||
fn calc_abs(args: &serde_json::Value) -> Result<String, String> {
|
||||
let x = extract_f64(args, "x", "x")?;
|
||||
Ok(format_num(x.abs()))
|
||||
}
|
||||
|
||||
fn calc_modulo(args: &serde_json::Value) -> Result<String, String> {
|
||||
let a = extract_f64(args, "a", "a")?;
|
||||
let b = extract_f64(args, "b", "b")?;
|
||||
if b == 0.0 {
|
||||
return Err("Modulo by zero".to_string());
|
||||
}
|
||||
Ok(format_num(a % b))
|
||||
}
|
||||
|
||||
fn calc_round(args: &serde_json::Value) -> Result<String, String> {
|
||||
let x = extract_f64(args, "x", "x")?;
|
||||
let decimals = extract_i64(args, "decimals", "decimals")?;
|
||||
@ -283,19 +213,6 @@ fn calc_log(args: &serde_json::Value) -> Result<String, String> {
|
||||
Ok(format_num(x.log(base)))
|
||||
}
|
||||
|
||||
fn calc_ln(args: &serde_json::Value) -> Result<String, String> {
|
||||
let x = extract_f64(args, "x", "x")?;
|
||||
if x <= 0.0 {
|
||||
return Err("Natural logarithm requires a positive number".to_string());
|
||||
}
|
||||
Ok(format_num(x.ln()))
|
||||
}
|
||||
|
||||
fn calc_exp(args: &serde_json::Value) -> Result<String, String> {
|
||||
let x = extract_f64(args, "x", "x")?;
|
||||
Ok(format_num(x.exp()))
|
||||
}
|
||||
|
||||
fn calc_factorial(args: &serde_json::Value) -> Result<String, String> {
|
||||
let x = extract_f64(args, "x", "x")?;
|
||||
if x < 0.0 || x != x.floor() {
|
||||
@ -457,119 +374,21 @@ fn calc_clamp(args: &serde_json::Value) -> Result<String, String> {
|
||||
Ok(format_num(x.clamp(min_val, max_val)))
|
||||
}
|
||||
|
||||
fn calc_evaluate(args: &serde_json::Value) -> Result<String, String> {
|
||||
let expression = args
|
||||
.get("expression")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "Missing required parameter: expression".to_string())?;
|
||||
|
||||
meval::eval_str(expression)
|
||||
.map(format_num)
|
||||
.map_err(|e| format!("Expression evaluation error: {e}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "add", "values": [1.0, 2.0, 3.5]}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "6.5");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_subtract() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "subtract", "values": [10.0, 3.0, 1.5]}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "5.5");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_divide() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "divide", "values": [100.0, 4.0]}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "25");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_divide_by_zero() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "divide", "values": [10.0, 0.0]}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("zero"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiply() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "multiply", "values": [3.0, 4.0, 5.0]}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "60");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pow() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "pow", "a": 2.0, "b": 10.0}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "1024");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sqrt() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "sqrt", "x": 144.0}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "12");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sqrt_negative() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "sqrt", "x": -4.0}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abs() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "abs", "x": -42.5}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "42.5");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_modulo() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "modulo", "a": 17.0, "b": 5.0}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_round() {
|
||||
let tool = CalculatorTool::new();
|
||||
@ -603,28 +422,6 @@ mod tests {
|
||||
assert_eq!(result.output, "3");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ln() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "ln", "x": 1.0}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exp() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "exp", "x": 0.0}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_factorial() {
|
||||
let tool = CalculatorTool::new();
|
||||
@ -825,4 +622,80 @@ mod tests {
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "15");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_simple() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate", "expression": "15*3+5^(2+1)"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
// 15*3 + 5^(2+1) = 45 + 5^3 = 45 + 125 = 170
|
||||
assert_eq!(result.output, "170");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_complex() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate", "expression": "(10-2)*(3+1)"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "32");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_invalid() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate", "expression": "invalid"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_missing_expression() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_with_functions() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate", "expression": "sqrt(144)"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "12");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_with_constants() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate", "expression": "pi * 2"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "6.283185307179586");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_modulo() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate", "expression": "17 % 5"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "2");
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user