diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 2ab4628..9783a82 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -278,7 +278,7 @@ mod tests { assert!(result.is_some()); let cmd = result.unwrap(); - assert!(matches!(cmd, Command::SaveTopic { filepath: None })); + assert!(matches!(cmd, Command::SaveTopic { filepath: None, .. })); } #[test] @@ -294,6 +294,7 @@ mod tests { cmd, Command::SaveTopic { filepath: Some(ref p), + .. } if p == "./debug/topic.md" )); } @@ -307,7 +308,7 @@ mod tests { assert!(result.is_some()); let cmd = result.unwrap(); - assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false })); + assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false, .. })); } #[test] @@ -324,6 +325,7 @@ mod tests { Command::SaveSession { filepath: Some(ref p), include_all: false, + .. } if p == "./debug/session.md" )); } @@ -337,7 +339,7 @@ mod tests { assert!(result.is_some()); let cmd = result.unwrap(); - assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true })); + assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true, .. })); } #[test] @@ -354,6 +356,7 @@ mod tests { Command::SaveSession { filepath: Some(ref p), include_all: true, + .. } if p == "./debug/session.md" )); } diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index 2334e76..6180914 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -845,7 +845,7 @@ mod tests { assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false })); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false })); assert!(!handler.can_handle(&Command::CreateSession { title: None })); - assert!(!handler.can_handle(&Command::SaveTopic { filepath: None })); + assert!(!handler.can_handle(&Command::SaveTopic { filepath: None, include_subagents: false })); } /// 测试用的系统提示词提供者 diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 8810f87..2985547 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -12,6 +12,7 @@ use tokio::time::{Instant, sleep_until}; use crate::platform::{ShellInfo, dangerous_command_patterns}; use crate::tools::traits::{Tool, ToolResult}; +use crate::tools::{extract_u64, extract_bool, check_null_args}; const MAX_TIMEOUT_SECS: u64 = 600; const MAX_OUTPUT_CHARS: usize = 50_000; @@ -255,6 +256,10 @@ impl Tool for BashTool { } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if let Some(result) = check_null_args(&args, "bash") { + return Ok(result); + } + let command = match args.get("command").and_then(|v| v.as_str()) { Some(c) => c, None => { @@ -275,15 +280,10 @@ impl Tool for BashTool { }); } - let timeout_secs = args - .get("timeout") - .and_then(|v| v.as_u64()) + let timeout_secs = extract_u64(&args, "timeout") .unwrap_or(self.timeout_secs) .min(MAX_TIMEOUT_SECS); - let interactive = args - .get("interactive") - .and_then(|v| v.as_bool()) - .unwrap_or(false); + let interactive = extract_bool(&args, "interactive").unwrap_or(false); let cwd = self .working_dir diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index 2b42e8d..647f189 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -1,4 +1,6 @@ use super::traits::{Tool, ToolResult}; +use crate::tools::extract_f64 as extract_f64_opt; +use crate::tools::check_null_args; use async_trait::async_trait; use serde_json::json; @@ -101,6 +103,10 @@ impl Tool for CalculatorTool { } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if let Some(result) = check_null_args(&args, "calculator") { + return Ok(result); + } + let function = match args.get("function").and_then(|v| v.as_str()) { Some(f) => f, None => { @@ -149,15 +155,33 @@ impl Tool for CalculatorTool { } fn extract_f64(args: &serde_json::Value, key: &str, name: &str) -> Result { - args.get(key) - .and_then(|v| v.as_f64()) - .ok_or_else(|| format!("Missing required parameter: {name}")) + match args.get(key) { + None => Err(format!("Missing required parameter: {name}")), + Some(v) => { + if let Some(n) = v.as_f64() { + Ok(n) + } else if let Some(s) = v.as_str() { + s.parse::().map_err(|_| format!("{name} is not a valid number: {s}")) + } else { + Err(format!("{name} must be a number")) + } + } + } } fn extract_i64(args: &serde_json::Value, key: &str, name: &str) -> Result { - args.get(key) - .and_then(|v| v.as_i64()) - .ok_or_else(|| format!("Missing required parameter: {name}")) + match args.get(key) { + None => Err(format!("Missing required parameter: {name}")), + Some(v) => { + if let Some(n) = v.as_i64() { + Ok(n) + } else if let Some(s) = v.as_str() { + s.parse::().map_err(|_| format!("{name} is not a valid integer: {s}")) + } else { + Err(format!("{name} must be an integer")) + } + } + } } fn extract_values(args: &serde_json::Value, min_len: usize) -> Result, String> { @@ -173,10 +197,15 @@ fn extract_values(args: &serde_json::Value, min_len: usize) -> Result, } let mut nums = Vec::with_capacity(values.len()); for (i, v) in values.iter().enumerate() { - match v.as_f64() { - Some(n) => nums.push(n), - None => return Err(format!("values[{i}] is not a valid number")), - } + let n = if let Some(n) = v.as_f64() { + n + } else if let Some(s) = v.as_str() { + s.parse::() + .map_err(|_| format!("values[{i}] is not a valid number: {s}"))? + } else { + return Err(format!("values[{i}] is not a valid number")); + }; + nums.push(n); } Ok(nums) } @@ -206,7 +235,7 @@ fn calc_log(args: &serde_json::Value) -> Result { if x <= 0.0 { return Err("Logarithm requires a positive number".to_string()); } - let base = args.get("base").and_then(|v| v.as_f64()).unwrap_or(10.0); + let base = extract_f64_opt(args, "base").unwrap_or(10.0); if base <= 0.0 || base == 1.0 { return Err("Logarithm base must be positive and not equal to 1".to_string()); } @@ -695,4 +724,45 @@ mod tests { assert!(result.success); assert_eq!(result.output, "2"); } + + #[tokio::test] + async fn test_round_with_string_x() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "round", "x": "2.715", "decimals": "2"})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "2.72"); + } + + #[tokio::test] + async fn test_sum_with_string_values() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "sum", "values": ["1.5", "2.5", "3"]})) + .await + .unwrap(); + assert!(result.success); + assert_eq!(result.output, "7"); + } + + #[tokio::test] + async fn test_invalid_string_number_returns_error() { + let tool = CalculatorTool::new(); + let result = tool + .execute(json!({"function": "round", "x": "not_a_number", "decimals": 2})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("x is not a valid number")); + } + + #[tokio::test] + async fn test_null_args_returns_error() { + let tool = CalculatorTool::new(); + let result = tool.execute(serde_json::Value::Null).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Missing required parameters")); + } } diff --git a/src/tools/file_edit.rs b/src/tools/file_edit.rs index b3eefcb..b9f2e52 100644 --- a/src/tools/file_edit.rs +++ b/src/tools/file_edit.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use serde_json::json; use crate::tools::traits::{Tool, ToolResult}; +use crate::tools::extract_bool; pub struct FileEditTool { allowed_dir: Option, @@ -150,10 +151,7 @@ impl Tool for FileEditTool { } }; - let replace_all = args - .get("replace_all") - .and_then(|v| v.as_bool()) - .unwrap_or(false); + let replace_all = extract_bool(&args, "replace_all").unwrap_or(false); let resolved = match self.resolve_path(path) { Ok(p) => p, diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index a44c59f..4d6463a 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use serde_json::json; use crate::tools::traits::{Tool, ToolResult}; +use crate::tools::extract_u64; const MAX_CHARS: usize = 128_000; const DEFAULT_LIMIT: usize = 2000; @@ -103,15 +104,11 @@ impl Tool for FileReadTool { } }; - let offset = args - .get("offset") - .and_then(|v| v.as_u64()) + let offset = extract_u64(&args, "offset") .map(|v| v as usize) .unwrap_or(1); - let limit = args - .get("limit") - .and_then(|v| v.as_u64()) + let limit = extract_u64(&args, "limit") .map(|v| v as usize) .unwrap_or(DEFAULT_LIMIT); diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs index 7b45040..7cb522e 100644 --- a/src/tools/memory_search.rs +++ b/src/tools/memory_search.rs @@ -5,6 +5,7 @@ use serde_json::json; use crate::storage::{MemoryRecord, MemoryRepository}; use crate::tools::traits::{Tool, ToolContext, ToolResult}; +use crate::tools::extract_u64; pub struct MemorySearchTool { memories: Arc, @@ -86,10 +87,7 @@ impl Tool for MemorySearchTool { let payload = match action { "list" => { - let limit = args - .get("limit") - .and_then(|value| value.as_u64()) - .unwrap_or(10) as usize; + let limit = extract_u64(&args, "limit").unwrap_or(10) as usize; let memories = self .memories .list_memories("user", &scope_key, namespace, limit)?; @@ -138,10 +136,7 @@ impl Tool for MemorySearchTool { if queries.is_empty() { return Ok(error_result("Missing required parameter: queries")); } - let limit = args - .get("limit") - .and_then(|value| value.as_u64()) - .unwrap_or(10) as usize; + let limit = extract_u64(&args, "limit").unwrap_or(10) as usize; let memories = self .memories .search_memories_any("user", &scope_key, &queries, namespace, limit)?; diff --git a/src/tools/mod.rs b/src/tools/mod.rs index b0fdee4..c3879dd 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -41,3 +41,131 @@ pub use task::{ pub use time::TimeTool; pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; + +/// Extract a string parameter from JSON args. +pub fn extract_string(args: &serde_json::Value, key: &str) -> Option { + args.get(key).and_then(|v| { + if let Some(s) = v.as_str() { + Some(s.to_string()) + } else if let Some(n) = v.as_number() { + // Handle case where LLM sends a number but we need a string + Some(n.to_string()) + } else { + None + } + }) +} + +/// Extract an f64 parameter from JSON args, handling both numbers and strings. +pub fn extract_f64(args: &serde_json::Value, key: &str) -> Option { + args.get(key).and_then(|v| { + if let Some(n) = v.as_f64() { + Some(n) + } else if let Some(s) = v.as_str() { + s.parse::().ok() + } else { + None + } + }) +} + +/// Extract an i64 parameter from JSON args, handling both numbers and strings. +pub fn extract_i64(args: &serde_json::Value, key: &str) -> Option { + args.get(key).and_then(|v| { + if let Some(n) = v.as_i64() { + Some(n) + } else if let Some(s) = v.as_str() { + s.parse::().ok() + } else { + None + } + }) +} + +/// Extract a u64 parameter from JSON args, handling both numbers and strings. +pub fn extract_u64(args: &serde_json::Value, key: &str) -> Option { + args.get(key).and_then(|v| { + if let Some(n) = v.as_u64() { + Some(n) + } else if let Some(s) = v.as_str() { + s.parse::().ok() + } else { + None + } + }) +} + +/// Extract a bool parameter from JSON args, handling both booleans and strings. +pub fn extract_bool(args: &serde_json::Value, key: &str) -> Option { + args.get(key).and_then(|v| { + if let Some(b) = v.as_bool() { + Some(b) + } else if let Some(s) = v.as_str() { + match s.to_lowercase().as_str() { + "true" | "1" | "yes" | "on" => Some(true), + "false" | "0" | "no" | "off" => Some(false), + _ => None, + } + } else { + None + } + }) +} + +/// Extract a required string parameter, returning an error message if missing. +pub fn require_string(args: &serde_json::Value, key: &str) -> Result { + extract_string(args, key) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| format!("Missing required parameter: {}", key)) +} + +/// Extract a required f64 parameter, returning an error message if missing. +pub fn require_f64(args: &serde_json::Value, key: &str) -> Result { + extract_f64(args, key) + .ok_or_else(|| format!("Missing required parameter: {}", key)) +} + +/// Extract a required i64 parameter, returning an error message if missing. +pub fn require_i64(args: &serde_json::Value, key: &str) -> Result { + extract_i64(args, key) + .ok_or_else(|| format!("Missing required parameter: {}", key)) +} + +/// Extract a required u64 parameter, returning an error message if missing. +pub fn require_u64(args: &serde_json::Value, key: &str) -> Result { + extract_u64(args, key) + .ok_or_else(|| format!("Missing required parameter: {}", key)) +} + +/// Extract a required bool parameter, returning an error message if missing. +pub fn require_bool(args: &serde_json::Value, key: &str) -> Result { + extract_bool(args, key) + .ok_or_else(|| format!("Missing required parameter: {}", key)) +} + +/// Check if args is null and return an error result if so. +/// Returns the provided error message if args is null. +pub fn check_null_args(args: &serde_json::Value, tool_name: &str) -> Option { + if args.is_null() { + return Some(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Missing required parameters: {} expects a JSON object with required fields. Check the tool schema and provide the necessary arguments.", + tool_name + )), + }); + } + if !args.is_object() { + return Some(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Invalid parameters: {} expects a JSON object, got {}", + tool_name, + if args.is_array() { "an array" } else if args.is_string() { "a string" } else { "an unexpected type" } + )), + }); + } + None +} diff --git a/src/tools/scheduler_manage.rs b/src/tools/scheduler_manage.rs index eabb884..46a4fc1 100644 --- a/src/tools/scheduler_manage.rs +++ b/src/tools/scheduler_manage.rs @@ -9,6 +9,7 @@ use crate::storage::{ SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobUpsert, }; use crate::tools::traits::{Tool, ToolResult}; +use crate::tools::{extract_bool, extract_i64}; pub struct SchedulerManageTool { jobs: Arc, @@ -126,10 +127,7 @@ impl Tool for SchedulerManageTool { let output = match action { "list" => { - let enabled_only = args - .get("enabled_only") - .and_then(|value| value.as_bool()) - .unwrap_or(false); + let enabled_only = extract_bool(&args, "enabled_only").unwrap_or(false); let jobs = self.jobs.list_scheduler_jobs(enabled_only)?; json!(jobs.iter().map(record_to_json).collect::>()) } @@ -243,15 +241,8 @@ fn build_upsert( startup_delay_secs, target, payload, - enabled: args - .get("enabled") - .and_then(|value| value.as_bool()) - .unwrap_or(true), - state: if args - .get("enabled") - .and_then(|value| value.as_bool()) - .unwrap_or(true) - { + enabled: extract_bool(args, "enabled").unwrap_or(true), + state: if extract_bool(args, "enabled").unwrap_or(true) { SchedulerJobState::Scheduled } else { SchedulerJobState::Paused @@ -259,7 +250,7 @@ fn build_upsert( last_status: None, last_error: None, run_count: 0, - max_runs: args.get("max_runs").and_then(|value| value.as_i64()), + max_runs: extract_i64(args, "max_runs"), last_fired_at: None, next_fire_at: None, paused_at: None, @@ -302,7 +293,7 @@ fn build_update_upsert( upsert.payload = payload.clone(); } - if let Some(enabled) = args.get("enabled").and_then(|value| value.as_bool()) { + if let Some(enabled) = extract_bool(args, "enabled") { upsert.enabled = enabled; upsert.state = if enabled { SchedulerJobState::Scheduled @@ -319,7 +310,7 @@ fn build_update_upsert( } if args.get("max_runs").is_some() { - upsert.max_runs = args.get("max_runs").and_then(|value| value.as_i64()); + upsert.max_runs = extract_i64(args, "max_runs"); } if upsert.kind == "agent_task" || upsert.kind == "silent_agent_task" { diff --git a/src/tools/skill_activate.rs b/src/tools/skill_activate.rs index 46b95d5..52d3455 100644 --- a/src/tools/skill_activate.rs +++ b/src/tools/skill_activate.rs @@ -6,6 +6,7 @@ use serde_json::json; use crate::skills::SkillRuntime; use crate::storage::SkillEventRepository; use crate::tools::traits::{Tool, ToolContext, ToolResult}; +use crate::tools::check_null_args; pub struct SkillActivateTool { skills: Arc, @@ -68,6 +69,10 @@ impl Tool for SkillActivateTool { context: &ToolContext, args: serde_json::Value, ) -> anyhow::Result { + if let Some(result) = check_null_args(&args, "skill_activate") { + return Ok(result); + } + let skill_name = match args.get("name").and_then(|value| value.as_str()) { Some(name) if !name.trim().is_empty() => name, _ => { @@ -152,4 +157,24 @@ mod tests { assert_eq!(events[0].event_type, "activation_failed"); assert_eq!(events[0].skill_name.as_deref(), Some("demo")); } + + #[tokio::test] + async fn test_skill_activate_handles_null_args() { + let skills = Arc::new(SkillRuntime::default()); + let store = Arc::new(SessionStore::in_memory().unwrap()); + store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap(); + let tool = SkillActivateTool::new(skills, store.clone()); + let context = ToolContext { + session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), + ..ToolContext::default() + }; + + let result = tool + .execute_with_context(&context, serde_json::Value::Null) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Missing required parameters")); + } } diff --git a/src/tools/time.rs b/src/tools/time.rs index da1c7a4..9df5aef 100644 --- a/src/tools/time.rs +++ b/src/tools/time.rs @@ -2,6 +2,8 @@ use async_trait::async_trait; use chrono::{DateTime, Days, Duration, Months, Utc}; use serde_json::{Value, json}; +use crate::tools::extract_u64; +use crate::tools::check_null_args; use super::traits::{Tool, ToolResult}; pub struct TimeTool { @@ -16,6 +18,10 @@ impl TimeTool { } fn execute_at(&self, now_utc: DateTime, args: Value) -> ToolResult { + if let Some(result) = check_null_args(&args, "get_time") { + return result; + } + match execute_time_request(now_utc, &self.default_timezone, args) { Ok(output) => ToolResult { success: true, @@ -156,7 +162,7 @@ fn execute_time_request( fn parse_offset_request(args: &Value) -> Result, String> { let direction = args.get("direction").and_then(Value::as_str); - let amount = args.get("amount"); + let amount = extract_u64(args, "amount"); let unit = args.get("unit").and_then(Value::as_str); if direction.is_none() && amount.is_none() && unit.is_none() { @@ -166,9 +172,11 @@ fn parse_offset_request(args: &Value) -> Result, String> { let direction = direction.ok_or_else(|| { "Missing required parameter: direction when requesting a relative time".to_string() })?; - let amount = amount.and_then(Value::as_u64).ok_or_else(|| { + let amount = amount.ok_or_else(|| { "Missing required parameter: amount when requesting a relative time".to_string() })?; + let amount = u32::try_from(amount) + .map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?; let amount = u32::try_from(amount) .map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?; let unit = unit.ok_or_else(|| { @@ -451,4 +459,51 @@ mod tests { assert_eq!(result_time.hour(), 12); assert_eq!(result_time.minute(), 30); } + + #[test] + fn test_amount_as_string_is_parsed() { + let tool = TimeTool::new("Asia/Shanghai"); + let result = tool.execute_at( + fixed_now(), + json!({"direction": "future", "amount": "7", "unit": "days"}), + ); + + assert!(result.success, "Expected success but got error: {:?}", result.error); + let payload: Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(payload["result_time"], "2026-05-04T12:30:00+08:00"); + assert_eq!(payload["offset"]["amount"], 7); + } + + #[test] + fn test_invalid_amount_string_returns_error() { + let tool = TimeTool::new("Asia/Shanghai"); + let result = tool.execute_at( + fixed_now(), + json!({"direction": "future", "amount": "not_a_number", "unit": "days"}), + ); + + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("Missing required parameter: amount")); + } + + #[test] + fn test_null_args_returns_current_time() { + let tool = TimeTool::new("Asia/Shanghai"); + let result = tool.execute_at(fixed_now(), serde_json::Value::Null); + + // Null args should return error (current time requires no params, but null is invalid) + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("Missing required parameters")); + } + + #[test] + fn test_empty_object_returns_current_time() { + let tool = TimeTool::new("Asia/Shanghai"); + let result = tool.execute_at(fixed_now(), json!({})); + + // Empty object should return current time + assert!(result.success); + let payload: Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(payload["kind"], "current"); + } }