Compare commits

...

5 Commits

16 changed files with 494 additions and 104 deletions

View File

@ -1384,7 +1384,16 @@ fn parse_post_content(content: &str) -> String {
}
"code_block" => {
let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or("");
let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or("");
let code_text = if let Some(content_arr) = el.get("content").and_then(|c| c.as_array()) {
content_arr
.iter()
.filter_map(|item| item.get("text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("")
} else {
// Fallback to text field for backwards compatibility
el.get("text").and_then(|t| t.as_str()).unwrap_or("").to_string()
};
out.push(format!("\n```{}\n{}\n```\n", lang, code_text));
}
_ => {
@ -2190,7 +2199,7 @@ fn sanitize_download_file_name(file_name: &str) -> String {
mod tests {
use super::{
FeishuChannel, MsgFormat, extract_file_name_from_content_disposition,
infer_download_filename, sanitize_download_file_name,
infer_download_filename, parse_post_content, sanitize_download_file_name,
};
#[test]
@ -2279,6 +2288,42 @@ mod tests {
let file_name = extract_file_name_from_content_disposition(&headers);
assert_eq!(file_name.as_deref(), Some("archive.zip"));
}
#[test]
fn parse_post_content_handles_code_block_with_content_array() {
// Test parsing code_block with content array (standard Feishu format)
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","language":"python","content":[{"tag":"text","text":"def hello():"},{"tag":"text","text":" print('world')"}]}]]}}}"#;
let result = parse_post_content(post_json);
assert!(result.contains("```python"));
assert!(result.contains("def hello():"));
assert!(result.contains("print('world')"));
}
#[test]
fn parse_post_content_handles_code_block_with_fallback_text() {
// Backwards compatibility: some formats might use text field directly
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","language":"rust","text":"fn main() {}"}]]}}}"#;
let result = parse_post_content(post_json);
assert!(result.contains("```rust"));
assert!(result.contains("fn main() {}"));
}
#[test]
fn parse_post_content_handles_code_block_without_language() {
// Test code_block without language field
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","content":[{"tag":"text","text":"plain text"}]}]]}}}"#;
let result = parse_post_content(post_json);
assert!(result.contains("```"));
assert!(result.contains("plain text"));
}
#[test]
fn parse_post_content_handles_empty_code_block() {
// Test code_block with empty content
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","language":"go"}]]}}}"#;
let result = parse_post_content(post_json);
assert!(result.contains("```go"));
}
}
#[async_trait]

View File

@ -278,7 +278,7 @@ mod tests {
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(cmd, Command::SaveTopic { filepath: None }));
assert!(matches!(cmd, Command::SaveTopic { filepath: None, .. }));
}
#[test]
@ -294,6 +294,7 @@ mod tests {
cmd,
Command::SaveTopic {
filepath: Some(ref p),
..
} if p == "./debug/topic.md"
));
}
@ -307,7 +308,7 @@ mod tests {
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false }));
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false, .. }));
}
#[test]
@ -324,6 +325,7 @@ mod tests {
Command::SaveSession {
filepath: Some(ref p),
include_all: false,
..
} if p == "./debug/session.md"
));
}
@ -337,7 +339,7 @@ mod tests {
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true }));
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true, .. }));
}
#[test]
@ -354,6 +356,7 @@ mod tests {
Command::SaveSession {
filepath: Some(ref p),
include_all: true,
..
} if p == "./debug/session.md"
));
}

View File

@ -845,7 +845,7 @@ mod tests {
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false }));
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None }));
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None, include_subagents: false }));
}
/// 测试用的系统提示词提供者

View File

@ -110,15 +110,42 @@ impl AgentExecutionService {
// 将结果消息保存到确定的话题
if let Some(topic_id) = target_topic_id {
if let Err(err) = session.append_messages_to_topic(
if is_current_turn {
// 如果是最新回合,使用 append_persisted_messages 保存到数据库并更新内存历史
if let Err(err) = session.append_persisted_messages(
request.chat_id,
request.result.emitted_messages.clone(),
) {
tracing::error!(
error = %err,
chat_id = %request.chat_id,
"Failed to append messages to session history"
);
}
} else {
// 如果用户已切换话题,只保存到原始话题(不更新内存历史)
if let Err(err) = session.append_messages_to_topic(
request.chat_id,
topic_id,
&request.result.emitted_messages,
) {
tracing::error!(
error = %err,
topic_id = %topic_id,
"Failed to append messages to topic"
);
}
}
} else if is_current_turn {
// 如果没有话题直接更新内存历史append_persisted_messages 会处理持久化)
if let Err(err) = session.append_persisted_messages(
request.chat_id,
topic_id,
&request.result.emitted_messages,
request.result.emitted_messages.clone(),
) {
tracing::error!(
error = %err,
topic_id = %topic_id,
"Failed to append messages to topic"
chat_id = %request.chat_id,
"Failed to append messages to session history"
);
}
}

View File

@ -235,12 +235,13 @@ impl OpenAIProvider {
let normalized = self.normalize_tool_arguments(arguments);
if self.uses_json_tool_arguments() {
// Model expects JSON object format
// Model expects JSON object format (e.g., some code models)
normalized
} else {
// Standard OpenAI format: arguments as JSON string
// But ensure we serialize valid JSON, not null
match normalized {
Value::Null => Value::String("{}".to_string()),
Value::String(raw) => {
// If the string is already valid JSON, keep it as-is
// Otherwise, ensure it's a proper JSON string

View File

@ -12,6 +12,7 @@ use tokio::time::{Instant, sleep_until};
use crate::platform::{ShellInfo, dangerous_command_patterns};
use crate::tools::traits::{Tool, ToolResult};
use crate::tools::{extract_u64, extract_bool, check_null_args};
const MAX_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_CHARS: usize = 50_000;
@ -255,6 +256,10 @@ impl Tool for BashTool {
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if let Some(result) = check_null_args(&args, "bash") {
return Ok(result);
}
let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
@ -275,15 +280,10 @@ impl Tool for BashTool {
});
}
let timeout_secs = args
.get("timeout")
.and_then(|v| v.as_u64())
let timeout_secs = extract_u64(&args, "timeout")
.unwrap_or(self.timeout_secs)
.min(MAX_TIMEOUT_SECS);
let interactive = args
.get("interactive")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let interactive = extract_bool(&args, "interactive").unwrap_or(false);
let cwd = self
.working_dir

View File

@ -1,4 +1,6 @@
use super::traits::{Tool, ToolResult};
use crate::tools::extract_f64 as extract_f64_opt;
use crate::tools::check_null_args;
use async_trait::async_trait;
use serde_json::json;
@ -101,6 +103,10 @@ impl Tool for CalculatorTool {
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if let Some(result) = check_null_args(&args, "calculator") {
return Ok(result);
}
let function = match args.get("function").and_then(|v| v.as_str()) {
Some(f) => f,
None => {
@ -149,15 +155,33 @@ impl Tool for CalculatorTool {
}
fn extract_f64(args: &serde_json::Value, key: &str, name: &str) -> Result<f64, String> {
args.get(key)
.and_then(|v| v.as_f64())
.ok_or_else(|| format!("Missing required parameter: {name}"))
match args.get(key) {
None => Err(format!("Missing required parameter: {name}")),
Some(v) => {
if let Some(n) = v.as_f64() {
Ok(n)
} else if let Some(s) = v.as_str() {
s.parse::<f64>().map_err(|_| format!("{name} is not a valid number: {s}"))
} else {
Err(format!("{name} must be a number"))
}
}
}
}
fn extract_i64(args: &serde_json::Value, key: &str, name: &str) -> Result<i64, String> {
args.get(key)
.and_then(|v| v.as_i64())
.ok_or_else(|| format!("Missing required parameter: {name}"))
match args.get(key) {
None => Err(format!("Missing required parameter: {name}")),
Some(v) => {
if let Some(n) = v.as_i64() {
Ok(n)
} else if let Some(s) = v.as_str() {
s.parse::<i64>().map_err(|_| format!("{name} is not a valid integer: {s}"))
} else {
Err(format!("{name} must be an integer"))
}
}
}
}
fn extract_values(args: &serde_json::Value, min_len: usize) -> Result<Vec<f64>, String> {
@ -173,10 +197,15 @@ fn extract_values(args: &serde_json::Value, min_len: usize) -> Result<Vec<f64>,
}
let mut nums = Vec::with_capacity(values.len());
for (i, v) in values.iter().enumerate() {
match v.as_f64() {
Some(n) => nums.push(n),
None => return Err(format!("values[{i}] is not a valid number")),
}
let n = if let Some(n) = v.as_f64() {
n
} else if let Some(s) = v.as_str() {
s.parse::<f64>()
.map_err(|_| format!("values[{i}] is not a valid number: {s}"))?
} else {
return Err(format!("values[{i}] is not a valid number"));
};
nums.push(n);
}
Ok(nums)
}
@ -206,7 +235,7 @@ fn calc_log(args: &serde_json::Value) -> Result<String, String> {
if x <= 0.0 {
return Err("Logarithm requires a positive number".to_string());
}
let base = args.get("base").and_then(|v| v.as_f64()).unwrap_or(10.0);
let base = extract_f64_opt(args, "base").unwrap_or(10.0);
if base <= 0.0 || base == 1.0 {
return Err("Logarithm base must be positive and not equal to 1".to_string());
}
@ -695,4 +724,45 @@ mod tests {
assert!(result.success);
assert_eq!(result.output, "2");
}
#[tokio::test]
async fn test_round_with_string_x() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "round", "x": "2.715", "decimals": "2"}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2.72");
}
#[tokio::test]
async fn test_sum_with_string_values() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "sum", "values": ["1.5", "2.5", "3"]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "7");
}
#[tokio::test]
async fn test_invalid_string_number_returns_error() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "round", "x": "not_a_number", "decimals": 2}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("x is not a valid number"));
}
#[tokio::test]
async fn test_null_args_returns_error() {
let tool = CalculatorTool::new();
let result = tool.execute(serde_json::Value::Null).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("Missing required parameters"));
}
}

View File

@ -4,6 +4,7 @@ use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
use crate::tools::extract_bool;
pub struct FileEditTool {
allowed_dir: Option<String>,
@ -150,10 +151,7 @@ impl Tool for FileEditTool {
}
};
let replace_all = args
.get("replace_all")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let replace_all = extract_bool(&args, "replace_all").unwrap_or(false);
let resolved = match self.resolve_path(path) {
Ok(p) => p,

View File

@ -4,6 +4,7 @@ use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
use crate::tools::extract_u64;
const MAX_CHARS: usize = 128_000;
const DEFAULT_LIMIT: usize = 2000;
@ -103,15 +104,11 @@ impl Tool for FileReadTool {
}
};
let offset = args
.get("offset")
.and_then(|v| v.as_u64())
let offset = extract_u64(&args, "offset")
.map(|v| v as usize)
.unwrap_or(1);
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
let limit = extract_u64(&args, "limit")
.map(|v| v as usize)
.unwrap_or(DEFAULT_LIMIT);

View File

@ -5,6 +5,7 @@ use serde_json::json;
use crate::storage::{MemoryRecord, MemoryRepository};
use crate::tools::traits::{Tool, ToolContext, ToolResult};
use crate::tools::extract_u64;
pub struct MemorySearchTool {
memories: Arc<dyn MemoryRepository>,
@ -86,10 +87,7 @@ impl Tool for MemorySearchTool {
let payload = match action {
"list" => {
let limit = args
.get("limit")
.and_then(|value| value.as_u64())
.unwrap_or(10) as usize;
let limit = extract_u64(&args, "limit").unwrap_or(10) as usize;
let memories = self
.memories
.list_memories("user", &scope_key, namespace, limit)?;
@ -138,10 +136,7 @@ impl Tool for MemorySearchTool {
if queries.is_empty() {
return Ok(error_result("Missing required parameter: queries"));
}
let limit = args
.get("limit")
.and_then(|value| value.as_u64())
.unwrap_or(10) as usize;
let limit = extract_u64(&args, "limit").unwrap_or(10) as usize;
let memories = self
.memories
.search_memories_any("user", &scope_key, &queries, namespace, limit)?;

View File

@ -41,3 +41,169 @@ pub use task::{
pub use time::TimeTool;
pub use traits::{Tool, ToolContext, ToolResult};
pub use web_fetch::WebFetchTool;
/// Extract a string parameter from JSON args.
pub fn extract_string(args: &serde_json::Value, key: &str) -> Option<String> {
args.get(key).and_then(|v| {
if let Some(s) = v.as_str() {
Some(s.to_string())
} else if let Some(n) = v.as_number() {
// Handle case where LLM sends a number but we need a string
Some(n.to_string())
} else {
None
}
})
}
/// Extract an f64 parameter from JSON args, handling both numbers and strings.
pub fn extract_f64(args: &serde_json::Value, key: &str) -> Option<f64> {
args.get(key).and_then(|v| {
if let Some(n) = v.as_f64() {
Some(n)
} else if let Some(s) = v.as_str() {
s.parse::<f64>().ok()
} else {
None
}
})
}
/// Extract an i64 parameter from JSON args, handling both numbers and strings.
pub fn extract_i64(args: &serde_json::Value, key: &str) -> Option<i64> {
args.get(key).and_then(|v| {
if let Some(n) = v.as_i64() {
Some(n)
} else if let Some(s) = v.as_str() {
s.parse::<i64>().ok()
} else {
None
}
})
}
/// Extract a u64 parameter from JSON args, handling both numbers and strings.
pub fn extract_u64(args: &serde_json::Value, key: &str) -> Option<u64> {
args.get(key).and_then(|v| {
if let Some(n) = v.as_u64() {
Some(n)
} else if let Some(s) = v.as_str() {
s.parse::<u64>().ok()
} else {
None
}
})
}
/// Extract a bool parameter from JSON args, handling both booleans and strings.
pub fn extract_bool(args: &serde_json::Value, key: &str) -> Option<bool> {
args.get(key).and_then(|v| {
if let Some(b) = v.as_bool() {
Some(b)
} else if let Some(s) = v.as_str() {
match s.to_lowercase().as_str() {
"true" | "1" | "yes" | "on" => Some(true),
"false" | "0" | "no" | "off" => Some(false),
_ => None,
}
} else {
None
}
})
}
/// Extract a required string parameter, returning an error message if missing.
pub fn require_string(args: &serde_json::Value, key: &str) -> Result<String, String> {
extract_string(args, key)
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| format!("Missing required parameter: {}", key))
}
/// Extract a required f64 parameter, returning an error message if missing.
pub fn require_f64(args: &serde_json::Value, key: &str) -> Result<f64, String> {
extract_f64(args, key)
.ok_or_else(|| format!("Missing required parameter: {}", key))
}
/// Extract a required i64 parameter, returning an error message if missing.
pub fn require_i64(args: &serde_json::Value, key: &str) -> Result<i64, String> {
extract_i64(args, key)
.ok_or_else(|| format!("Missing required parameter: {}", key))
}
/// Extract a required u64 parameter, returning an error message if missing.
pub fn require_u64(args: &serde_json::Value, key: &str) -> Result<u64, String> {
extract_u64(args, key)
.ok_or_else(|| format!("Missing required parameter: {}", key))
}
/// Extract a required bool parameter, returning an error message if missing.
pub fn require_bool(args: &serde_json::Value, key: &str) -> Result<bool, String> {
extract_bool(args, key)
.ok_or_else(|| format!("Missing required parameter: {}", key))
}
/// Extract a string array parameter, handling both actual arrays and stringified JSON arrays.
pub fn extract_string_array(args: &serde_json::Value, key: &str) -> Option<Vec<String>> {
args.get(key).and_then(|v| {
if let Some(arr) = v.as_array() {
Some(
arr.iter()
.filter_map(|item| item.as_str())
.map(str::trim)
.filter(|s| !s.is_empty())
.map(ToOwned::to_owned)
.collect(),
)
} else if let Some(s) = v.as_str() {
// Try to parse as JSON array string
serde_json::from_str::<Vec<serde_json::Value>>(s)
.ok()
.map(|arr| {
arr.iter()
.filter_map(|item| item.as_str())
.map(str::trim)
.filter(|s| !s.is_empty())
.map(ToOwned::to_owned)
.collect()
})
.filter(|result: &Vec<String>| !result.is_empty())
} else {
None
}
})
}
/// Extract a required string array parameter.
pub fn require_string_array(args: &serde_json::Value, key: &str) -> Result<Vec<String>, String> {
extract_string_array(args, key)
.filter(|arr| !arr.is_empty())
.ok_or_else(|| format!("Missing required parameter: {}", key))
}
/// Check if args is null and return an error result if so.
/// Returns the provided error message if args is null.
pub fn check_null_args(args: &serde_json::Value, tool_name: &str) -> Option<ToolResult> {
if args.is_null() {
return Some(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Missing required parameters: {} expects a JSON object with required fields. Check the tool schema and provide the necessary arguments.",
tool_name
)),
});
}
if !args.is_object() {
return Some(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid parameters: {} expects a JSON object, got {}",
tool_name,
if args.is_array() { "an array" } else if args.is_string() { "a string" } else { "an unexpected type" }
)),
});
}
None
}

View File

@ -9,6 +9,7 @@ use crate::storage::{
SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobUpsert,
};
use crate::tools::traits::{Tool, ToolResult};
use crate::tools::{extract_bool, extract_i64};
pub struct SchedulerManageTool {
jobs: Arc<dyn SchedulerJobRepository>,
@ -126,10 +127,7 @@ impl Tool for SchedulerManageTool {
let output = match action {
"list" => {
let enabled_only = args
.get("enabled_only")
.and_then(|value| value.as_bool())
.unwrap_or(false);
let enabled_only = extract_bool(&args, "enabled_only").unwrap_or(false);
let jobs = self.jobs.list_scheduler_jobs(enabled_only)?;
json!(jobs.iter().map(record_to_json).collect::<Vec<_>>())
}
@ -243,15 +241,8 @@ fn build_upsert(
startup_delay_secs,
target,
payload,
enabled: args
.get("enabled")
.and_then(|value| value.as_bool())
.unwrap_or(true),
state: if args
.get("enabled")
.and_then(|value| value.as_bool())
.unwrap_or(true)
{
enabled: extract_bool(args, "enabled").unwrap_or(true),
state: if extract_bool(args, "enabled").unwrap_or(true) {
SchedulerJobState::Scheduled
} else {
SchedulerJobState::Paused
@ -259,7 +250,7 @@ fn build_upsert(
last_status: None,
last_error: None,
run_count: 0,
max_runs: args.get("max_runs").and_then(|value| value.as_i64()),
max_runs: extract_i64(args, "max_runs"),
last_fired_at: None,
next_fire_at: None,
paused_at: None,
@ -302,7 +293,7 @@ fn build_update_upsert(
upsert.payload = payload.clone();
}
if let Some(enabled) = args.get("enabled").and_then(|value| value.as_bool()) {
if let Some(enabled) = extract_bool(args, "enabled") {
upsert.enabled = enabled;
upsert.state = if enabled {
SchedulerJobState::Scheduled
@ -319,7 +310,7 @@ fn build_update_upsert(
}
if args.get("max_runs").is_some() {
upsert.max_runs = args.get("max_runs").and_then(|value| value.as_i64());
upsert.max_runs = extract_i64(args, "max_runs");
}
if upsert.kind == "agent_task" || upsert.kind == "silent_agent_task" {

View File

@ -164,21 +164,39 @@ fn validate_context(context: &ToolContext) -> anyhow::Result<()> {
}
fn parse_attachments(value: &serde_json::Value) -> anyhow::Result<Vec<MediaItem>> {
let attachment_paths = value
.as_array()
.ok_or_else(|| anyhow!("attachments must be an array of local file paths"))?;
// 支持两种格式:实际数组 或 字符串化的 JSON 数组
let paths = if let Some(arr) = value.as_array() {
arr
.iter()
.filter_map(|v| v.as_str())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
} else if let Some(s) = value.as_str() {
// 尝试解析字符串化的 JSON 数组
serde_json::from_str::<Vec<serde_json::Value>>(s)
.ok()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
})
.unwrap_or_default()
} else {
vec![]
};
let mut attachments = Vec::with_capacity(attachment_paths.len());
for path_value in attachment_paths {
let raw_path = path_value
.as_str()
.ok_or_else(|| anyhow!("attachments entries must be strings"))?
.trim();
if raw_path.is_empty() {
return Err(anyhow!("attachment paths must not be empty"));
}
if paths.is_empty() {
return Err(anyhow!("attachments must be an array of local file paths"));
}
let metadata = std::fs::metadata(raw_path)
let mut attachments = Vec::with_capacity(paths.len());
for raw_path in paths {
let metadata = std::fs::metadata(&raw_path)
.map_err(|err| anyhow!("failed to access attachment '{}': {}", raw_path, err))?;
if !metadata.is_file() {
return Err(anyhow!("attachment path is not a file: {}", raw_path));
@ -187,9 +205,9 @@ fn parse_attachments(value: &serde_json::Value) -> anyhow::Result<Vec<MediaItem>
return Err(anyhow!("attachment file is empty: {}", raw_path));
}
let media_type = infer_media_type(raw_path);
let media_type = infer_media_type(&raw_path);
let mut item = MediaItem::new(raw_path.to_string(), media_type);
item.mime_type = mime_guess::from_path(raw_path)
item.mime_type = mime_guess::from_path(&raw_path)
.first_raw()
.map(ToOwned::to_owned);
attachments.push(item);
@ -319,4 +337,20 @@ mod tests {
assert_eq!(attachments.len(), 1);
assert_eq!(attachments[0].media_type, "image");
}
#[test]
fn parse_attachments_handles_stringified_json_array() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), b"demo").unwrap();
let txt_path = file.path().with_extension("txt");
std::fs::rename(file.path(), &txt_path).unwrap();
// Test with stringified JSON array (like LLM might send)
let path_str = txt_path.to_string_lossy().to_string().replace("\\", "\\\\");
let json_string = format!("[\"{}\"]", path_str);
let attachments = parse_attachments(&json!(json_string)).unwrap();
assert_eq!(attachments.len(), 1);
assert_eq!(attachments[0].media_type, "file");
}
}

View File

@ -6,6 +6,7 @@ use serde_json::json;
use crate::skills::SkillRuntime;
use crate::storage::SkillEventRepository;
use crate::tools::traits::{Tool, ToolContext, ToolResult};
use crate::tools::check_null_args;
pub struct SkillActivateTool {
skills: Arc<SkillRuntime>,
@ -68,6 +69,10 @@ impl Tool for SkillActivateTool {
context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
if let Some(result) = check_null_args(&args, "skill_activate") {
return Ok(result);
}
let skill_name = match args.get("name").and_then(|value| value.as_str()) {
Some(name) if !name.trim().is_empty() => name,
_ => {
@ -152,4 +157,24 @@ mod tests {
assert_eq!(events[0].event_type, "activation_failed");
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));
}
#[tokio::test]
async fn test_skill_activate_handles_null_args() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
let tool = SkillActivateTool::new(skills, store.clone());
let context = ToolContext {
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
..ToolContext::default()
};
let result = tool
.execute_with_context(&context, serde_json::Value::Null)
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Missing required parameters"));
}
}

View File

@ -4,6 +4,7 @@ use std::sync::Arc;
use crate::skills::{SkillRuntime, SkillScope};
use crate::tools::traits::{Tool, ToolResult};
use crate::tools::{extract_bool, extract_string_array};
pub struct SkillManageTool {
skills: Arc<SkillRuntime>,
@ -86,7 +87,7 @@ impl Tool for SkillManageTool {
}
};
let reload = args.get("reload").and_then(|v| v.as_bool()).unwrap_or(true);
let reload = extract_bool(&args, "reload").unwrap_or(true);
let scope = match args.get("scope").and_then(|v| v.as_str()) {
Some(value) => match SkillScope::parse(value) {
Some(scope) => scope,
@ -280,28 +281,10 @@ fn error_result(message: &str) -> ToolResult {
}
fn parse_disable_names(args: &serde_json::Value) -> Result<Vec<String>, String> {
let names = args
.get("names")
.ok_or_else(|| "disable requires names".to_string())?
.as_array()
.ok_or_else(|| "names must be an array of strings".to_string())?;
let mut parsed = Vec::new();
for item in names {
let name = item
.as_str()
.ok_or_else(|| "names must be an array of strings".to_string())?
.trim()
.to_string();
if name.is_empty() {
return Err("names must not contain empty values".to_string());
}
parsed.push(name);
}
if parsed.is_empty() {
return Err("names must not be empty".to_string());
}
Ok(parsed)
// 支持两种格式:实际数组 或 字符串化的 JSON 数组
extract_string_array(args, "names")
.filter(|arr| !arr.is_empty())
.ok_or_else(|| "disable requires names (array of strings)".to_string())
}
fn skill_change_payload(change: crate::skills::SkillAvailabilityChange) -> serde_json::Value {

View File

@ -2,6 +2,8 @@ use async_trait::async_trait;
use chrono::{DateTime, Days, Duration, Months, Utc};
use serde_json::{Value, json};
use crate::tools::extract_u64;
use crate::tools::check_null_args;
use super::traits::{Tool, ToolResult};
pub struct TimeTool {
@ -16,6 +18,10 @@ impl TimeTool {
}
fn execute_at(&self, now_utc: DateTime<Utc>, args: Value) -> ToolResult {
if let Some(result) = check_null_args(&args, "get_time") {
return result;
}
match execute_time_request(now_utc, &self.default_timezone, args) {
Ok(output) => ToolResult {
success: true,
@ -156,7 +162,7 @@ fn execute_time_request(
fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
let direction = args.get("direction").and_then(Value::as_str);
let amount = args.get("amount");
let amount = extract_u64(args, "amount");
let unit = args.get("unit").and_then(Value::as_str);
if direction.is_none() && amount.is_none() && unit.is_none() {
@ -166,9 +172,11 @@ fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
let direction = direction.ok_or_else(|| {
"Missing required parameter: direction when requesting a relative time".to_string()
})?;
let amount = amount.and_then(Value::as_u64).ok_or_else(|| {
let amount = amount.ok_or_else(|| {
"Missing required parameter: amount when requesting a relative time".to_string()
})?;
let amount = u32::try_from(amount)
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
let amount = u32::try_from(amount)
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
let unit = unit.ok_or_else(|| {
@ -451,4 +459,51 @@ mod tests {
assert_eq!(result_time.hour(), 12);
assert_eq!(result_time.minute(), 30);
}
#[test]
fn test_amount_as_string_is_parsed() {
let tool = TimeTool::new("Asia/Shanghai");
let result = tool.execute_at(
fixed_now(),
json!({"direction": "future", "amount": "7", "unit": "days"}),
);
assert!(result.success, "Expected success but got error: {:?}", result.error);
let payload: Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(payload["result_time"], "2026-05-04T12:30:00+08:00");
assert_eq!(payload["offset"]["amount"], 7);
}
#[test]
fn test_invalid_amount_string_returns_error() {
let tool = TimeTool::new("Asia/Shanghai");
let result = tool.execute_at(
fixed_now(),
json!({"direction": "future", "amount": "not_a_number", "unit": "days"}),
);
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("Missing required parameter: amount"));
}
#[test]
fn test_null_args_returns_current_time() {
let tool = TimeTool::new("Asia/Shanghai");
let result = tool.execute_at(fixed_now(), serde_json::Value::Null);
// Null args should return error (current time requires no params, but null is invalid)
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("Missing required parameters"));
}
#[test]
fn test_empty_object_returns_current_time() {
let tool = TimeTool::new("Asia/Shanghai");
let result = tool.execute_at(fixed_now(), json!({}));
// Empty object should return current time
assert!(result.success);
let payload: Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(payload["kind"], "current");
}
}