Compare commits
5 Commits
32a9e2946e
...
c817b1dde1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c817b1dde1 | ||
|
|
159c1bbb7a | ||
|
|
3128abe3c6 | ||
|
|
efc8af12eb | ||
|
|
da9cec6d35 |
@ -1384,7 +1384,16 @@ fn parse_post_content(content: &str) -> String {
|
||||
}
|
||||
"code_block" => {
|
||||
let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or("");
|
||||
let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
let code_text = if let Some(content_arr) = el.get("content").and_then(|c| c.as_array()) {
|
||||
content_arr
|
||||
.iter()
|
||||
.filter_map(|item| item.get("text").and_then(|t| t.as_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
} else {
|
||||
// Fallback to text field for backwards compatibility
|
||||
el.get("text").and_then(|t| t.as_str()).unwrap_or("").to_string()
|
||||
};
|
||||
out.push(format!("\n```{}\n{}\n```\n", lang, code_text));
|
||||
}
|
||||
_ => {
|
||||
@ -2190,7 +2199,7 @@ fn sanitize_download_file_name(file_name: &str) -> String {
|
||||
mod tests {
|
||||
use super::{
|
||||
FeishuChannel, MsgFormat, extract_file_name_from_content_disposition,
|
||||
infer_download_filename, sanitize_download_file_name,
|
||||
infer_download_filename, parse_post_content, sanitize_download_file_name,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@ -2279,6 +2288,42 @@ mod tests {
|
||||
let file_name = extract_file_name_from_content_disposition(&headers);
|
||||
assert_eq!(file_name.as_deref(), Some("archive.zip"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_post_content_handles_code_block_with_content_array() {
|
||||
// Test parsing code_block with content array (standard Feishu format)
|
||||
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","language":"python","content":[{"tag":"text","text":"def hello():"},{"tag":"text","text":" print('world')"}]}]]}}}"#;
|
||||
let result = parse_post_content(post_json);
|
||||
assert!(result.contains("```python"));
|
||||
assert!(result.contains("def hello():"));
|
||||
assert!(result.contains("print('world')"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_post_content_handles_code_block_with_fallback_text() {
|
||||
// Backwards compatibility: some formats might use text field directly
|
||||
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","language":"rust","text":"fn main() {}"}]]}}}"#;
|
||||
let result = parse_post_content(post_json);
|
||||
assert!(result.contains("```rust"));
|
||||
assert!(result.contains("fn main() {}"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_post_content_handles_code_block_without_language() {
|
||||
// Test code_block without language field
|
||||
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","content":[{"tag":"text","text":"plain text"}]}]]}}}"#;
|
||||
let result = parse_post_content(post_json);
|
||||
assert!(result.contains("```"));
|
||||
assert!(result.contains("plain text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_post_content_handles_empty_code_block() {
|
||||
// Test code_block with empty content
|
||||
let post_json = r#"{"post":{"zh_cn":{"content":[[{"tag":"code_block","language":"go"}]]}}}"#;
|
||||
let result = parse_post_content(post_json);
|
||||
assert!(result.contains("```go"));
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@ -278,7 +278,7 @@ mod tests {
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(cmd, Command::SaveTopic { filepath: None }));
|
||||
assert!(matches!(cmd, Command::SaveTopic { filepath: None, .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -294,6 +294,7 @@ mod tests {
|
||||
cmd,
|
||||
Command::SaveTopic {
|
||||
filepath: Some(ref p),
|
||||
..
|
||||
} if p == "./debug/topic.md"
|
||||
));
|
||||
}
|
||||
@ -307,7 +308,7 @@ mod tests {
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false }));
|
||||
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false, .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -324,6 +325,7 @@ mod tests {
|
||||
Command::SaveSession {
|
||||
filepath: Some(ref p),
|
||||
include_all: false,
|
||||
..
|
||||
} if p == "./debug/session.md"
|
||||
));
|
||||
}
|
||||
@ -337,7 +339,7 @@ mod tests {
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true }));
|
||||
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true, .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -354,6 +356,7 @@ mod tests {
|
||||
Command::SaveSession {
|
||||
filepath: Some(ref p),
|
||||
include_all: true,
|
||||
..
|
||||
} if p == "./debug/session.md"
|
||||
));
|
||||
}
|
||||
|
||||
@ -845,7 +845,7 @@ mod tests {
|
||||
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false }));
|
||||
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false }));
|
||||
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
|
||||
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None }));
|
||||
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None, include_subagents: false }));
|
||||
}
|
||||
|
||||
/// 测试用的系统提示词提供者
|
||||
|
||||
@ -110,6 +110,20 @@ impl AgentExecutionService {
|
||||
|
||||
// 将结果消息保存到确定的话题
|
||||
if let Some(topic_id) = target_topic_id {
|
||||
if is_current_turn {
|
||||
// 如果是最新回合,使用 append_persisted_messages 保存到数据库并更新内存历史
|
||||
if let Err(err) = session.append_persisted_messages(
|
||||
request.chat_id,
|
||||
request.result.emitted_messages.clone(),
|
||||
) {
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
chat_id = %request.chat_id,
|
||||
"Failed to append messages to session history"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// 如果用户已切换话题,只保存到原始话题(不更新内存历史)
|
||||
if let Err(err) = session.append_messages_to_topic(
|
||||
request.chat_id,
|
||||
topic_id,
|
||||
@ -122,6 +136,19 @@ impl AgentExecutionService {
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if is_current_turn {
|
||||
// 如果没有话题,直接更新内存历史(append_persisted_messages 会处理持久化)
|
||||
if let Err(err) = session.append_persisted_messages(
|
||||
request.chat_id,
|
||||
request.result.emitted_messages.clone(),
|
||||
) {
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
chat_id = %request.chat_id,
|
||||
"Failed to append messages to session history"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// 只有当是最新回合时才发送 outbound 消息给用户
|
||||
// 如果用户已经切换到其他话题,只保存结果,不发送消息(避免打扰)
|
||||
|
||||
@ -235,12 +235,13 @@ impl OpenAIProvider {
|
||||
let normalized = self.normalize_tool_arguments(arguments);
|
||||
|
||||
if self.uses_json_tool_arguments() {
|
||||
// Model expects JSON object format
|
||||
// Model expects JSON object format (e.g., some code models)
|
||||
normalized
|
||||
} else {
|
||||
// Standard OpenAI format: arguments as JSON string
|
||||
// But ensure we serialize valid JSON, not null
|
||||
match normalized {
|
||||
Value::Null => Value::String("{}".to_string()),
|
||||
Value::String(raw) => {
|
||||
// If the string is already valid JSON, keep it as-is
|
||||
// Otherwise, ensure it's a proper JSON string
|
||||
|
||||
@ -12,6 +12,7 @@ use tokio::time::{Instant, sleep_until};
|
||||
|
||||
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use crate::tools::{extract_u64, extract_bool, check_null_args};
|
||||
|
||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||
@ -255,6 +256,10 @@ impl Tool for BashTool {
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Some(result) = check_null_args(&args, "bash") {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
@ -275,15 +280,10 @@ impl Tool for BashTool {
|
||||
});
|
||||
}
|
||||
|
||||
let timeout_secs = args
|
||||
.get("timeout")
|
||||
.and_then(|v| v.as_u64())
|
||||
let timeout_secs = extract_u64(&args, "timeout")
|
||||
.unwrap_or(self.timeout_secs)
|
||||
.min(MAX_TIMEOUT_SECS);
|
||||
let interactive = args
|
||||
.get("interactive")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let interactive = extract_bool(&args, "interactive").unwrap_or(false);
|
||||
|
||||
let cwd = self
|
||||
.working_dir
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::tools::extract_f64 as extract_f64_opt;
|
||||
use crate::tools::check_null_args;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
@ -101,6 +103,10 @@ impl Tool for CalculatorTool {
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Some(result) = check_null_args(&args, "calculator") {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let function = match args.get("function").and_then(|v| v.as_str()) {
|
||||
Some(f) => f,
|
||||
None => {
|
||||
@ -149,15 +155,33 @@ impl Tool for CalculatorTool {
|
||||
}
|
||||
|
||||
fn extract_f64(args: &serde_json::Value, key: &str, name: &str) -> Result<f64, String> {
|
||||
args.get(key)
|
||||
.and_then(|v| v.as_f64())
|
||||
.ok_or_else(|| format!("Missing required parameter: {name}"))
|
||||
match args.get(key) {
|
||||
None => Err(format!("Missing required parameter: {name}")),
|
||||
Some(v) => {
|
||||
if let Some(n) = v.as_f64() {
|
||||
Ok(n)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
s.parse::<f64>().map_err(|_| format!("{name} is not a valid number: {s}"))
|
||||
} else {
|
||||
Err(format!("{name} must be a number"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_i64(args: &serde_json::Value, key: &str, name: &str) -> Result<i64, String> {
|
||||
args.get(key)
|
||||
.and_then(|v| v.as_i64())
|
||||
.ok_or_else(|| format!("Missing required parameter: {name}"))
|
||||
match args.get(key) {
|
||||
None => Err(format!("Missing required parameter: {name}")),
|
||||
Some(v) => {
|
||||
if let Some(n) = v.as_i64() {
|
||||
Ok(n)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
s.parse::<i64>().map_err(|_| format!("{name} is not a valid integer: {s}"))
|
||||
} else {
|
||||
Err(format!("{name} must be an integer"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_values(args: &serde_json::Value, min_len: usize) -> Result<Vec<f64>, String> {
|
||||
@ -173,10 +197,15 @@ fn extract_values(args: &serde_json::Value, min_len: usize) -> Result<Vec<f64>,
|
||||
}
|
||||
let mut nums = Vec::with_capacity(values.len());
|
||||
for (i, v) in values.iter().enumerate() {
|
||||
match v.as_f64() {
|
||||
Some(n) => nums.push(n),
|
||||
None => return Err(format!("values[{i}] is not a valid number")),
|
||||
}
|
||||
let n = if let Some(n) = v.as_f64() {
|
||||
n
|
||||
} else if let Some(s) = v.as_str() {
|
||||
s.parse::<f64>()
|
||||
.map_err(|_| format!("values[{i}] is not a valid number: {s}"))?
|
||||
} else {
|
||||
return Err(format!("values[{i}] is not a valid number"));
|
||||
};
|
||||
nums.push(n);
|
||||
}
|
||||
Ok(nums)
|
||||
}
|
||||
@ -206,7 +235,7 @@ fn calc_log(args: &serde_json::Value) -> Result<String, String> {
|
||||
if x <= 0.0 {
|
||||
return Err("Logarithm requires a positive number".to_string());
|
||||
}
|
||||
let base = args.get("base").and_then(|v| v.as_f64()).unwrap_or(10.0);
|
||||
let base = extract_f64_opt(args, "base").unwrap_or(10.0);
|
||||
if base <= 0.0 || base == 1.0 {
|
||||
return Err("Logarithm base must be positive and not equal to 1".to_string());
|
||||
}
|
||||
@ -695,4 +724,45 @@ mod tests {
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_round_with_string_x() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "round", "x": "2.715", "decimals": "2"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "2.72");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum_with_string_values() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "sum", "values": ["1.5", "2.5", "3"]}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "7");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_string_number_returns_error() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "round", "x": "not_a_number", "decimals": 2}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("x is not a valid number"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_null_args_returns_error() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool.execute(serde_json::Value::Null).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("Missing required parameters"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use crate::tools::extract_bool;
|
||||
|
||||
pub struct FileEditTool {
|
||||
allowed_dir: Option<String>,
|
||||
@ -150,10 +151,7 @@ impl Tool for FileEditTool {
|
||||
}
|
||||
};
|
||||
|
||||
let replace_all = args
|
||||
.get("replace_all")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let replace_all = extract_bool(&args, "replace_all").unwrap_or(false);
|
||||
|
||||
let resolved = match self.resolve_path(path) {
|
||||
Ok(p) => p,
|
||||
|
||||
@ -4,6 +4,7 @@ use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use crate::tools::extract_u64;
|
||||
|
||||
const MAX_CHARS: usize = 128_000;
|
||||
const DEFAULT_LIMIT: usize = 2000;
|
||||
@ -103,15 +104,11 @@ impl Tool for FileReadTool {
|
||||
}
|
||||
};
|
||||
|
||||
let offset = args
|
||||
.get("offset")
|
||||
.and_then(|v| v.as_u64())
|
||||
let offset = extract_u64(&args, "offset")
|
||||
.map(|v| v as usize)
|
||||
.unwrap_or(1);
|
||||
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(|v| v.as_u64())
|
||||
let limit = extract_u64(&args, "limit")
|
||||
.map(|v| v as usize)
|
||||
.unwrap_or(DEFAULT_LIMIT);
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ use serde_json::json;
|
||||
|
||||
use crate::storage::{MemoryRecord, MemoryRepository};
|
||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||||
use crate::tools::extract_u64;
|
||||
|
||||
pub struct MemorySearchTool {
|
||||
memories: Arc<dyn MemoryRepository>,
|
||||
@ -86,10 +87,7 @@ impl Tool for MemorySearchTool {
|
||||
|
||||
let payload = match action {
|
||||
"list" => {
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(|value| value.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
let limit = extract_u64(&args, "limit").unwrap_or(10) as usize;
|
||||
let memories = self
|
||||
.memories
|
||||
.list_memories("user", &scope_key, namespace, limit)?;
|
||||
@ -138,10 +136,7 @@ impl Tool for MemorySearchTool {
|
||||
if queries.is_empty() {
|
||||
return Ok(error_result("Missing required parameter: queries"));
|
||||
}
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(|value| value.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
let limit = extract_u64(&args, "limit").unwrap_or(10) as usize;
|
||||
let memories = self
|
||||
.memories
|
||||
.search_memories_any("user", &scope_key, &queries, namespace, limit)?;
|
||||
|
||||
166
src/tools/mod.rs
166
src/tools/mod.rs
@ -41,3 +41,169 @@ pub use task::{
|
||||
pub use time::TimeTool;
|
||||
pub use traits::{Tool, ToolContext, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
/// Extract a string parameter from JSON args.
|
||||
pub fn extract_string(args: &serde_json::Value, key: &str) -> Option<String> {
|
||||
args.get(key).and_then(|v| {
|
||||
if let Some(s) = v.as_str() {
|
||||
Some(s.to_string())
|
||||
} else if let Some(n) = v.as_number() {
|
||||
// Handle case where LLM sends a number but we need a string
|
||||
Some(n.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract an f64 parameter from JSON args, handling both numbers and strings.
|
||||
pub fn extract_f64(args: &serde_json::Value, key: &str) -> Option<f64> {
|
||||
args.get(key).and_then(|v| {
|
||||
if let Some(n) = v.as_f64() {
|
||||
Some(n)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
s.parse::<f64>().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract an i64 parameter from JSON args, handling both numbers and strings.
|
||||
pub fn extract_i64(args: &serde_json::Value, key: &str) -> Option<i64> {
|
||||
args.get(key).and_then(|v| {
|
||||
if let Some(n) = v.as_i64() {
|
||||
Some(n)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
s.parse::<i64>().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract a u64 parameter from JSON args, handling both numbers and strings.
|
||||
pub fn extract_u64(args: &serde_json::Value, key: &str) -> Option<u64> {
|
||||
args.get(key).and_then(|v| {
|
||||
if let Some(n) = v.as_u64() {
|
||||
Some(n)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
s.parse::<u64>().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract a bool parameter from JSON args, handling both booleans and strings.
|
||||
pub fn extract_bool(args: &serde_json::Value, key: &str) -> Option<bool> {
|
||||
args.get(key).and_then(|v| {
|
||||
if let Some(b) = v.as_bool() {
|
||||
Some(b)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
match s.to_lowercase().as_str() {
|
||||
"true" | "1" | "yes" | "on" => Some(true),
|
||||
"false" | "0" | "no" | "off" => Some(false),
|
||||
_ => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract a required string parameter, returning an error message if missing.
|
||||
pub fn require_string(args: &serde_json::Value, key: &str) -> Result<String, String> {
|
||||
extract_string(args, key)
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| format!("Missing required parameter: {}", key))
|
||||
}
|
||||
|
||||
/// Extract a required f64 parameter, returning an error message if missing.
|
||||
pub fn require_f64(args: &serde_json::Value, key: &str) -> Result<f64, String> {
|
||||
extract_f64(args, key)
|
||||
.ok_or_else(|| format!("Missing required parameter: {}", key))
|
||||
}
|
||||
|
||||
/// Extract a required i64 parameter, returning an error message if missing.
|
||||
pub fn require_i64(args: &serde_json::Value, key: &str) -> Result<i64, String> {
|
||||
extract_i64(args, key)
|
||||
.ok_or_else(|| format!("Missing required parameter: {}", key))
|
||||
}
|
||||
|
||||
/// Extract a required u64 parameter, returning an error message if missing.
|
||||
pub fn require_u64(args: &serde_json::Value, key: &str) -> Result<u64, String> {
|
||||
extract_u64(args, key)
|
||||
.ok_or_else(|| format!("Missing required parameter: {}", key))
|
||||
}
|
||||
|
||||
/// Extract a required bool parameter, returning an error message if missing.
|
||||
pub fn require_bool(args: &serde_json::Value, key: &str) -> Result<bool, String> {
|
||||
extract_bool(args, key)
|
||||
.ok_or_else(|| format!("Missing required parameter: {}", key))
|
||||
}
|
||||
|
||||
/// Extract a string array parameter, handling both actual arrays and stringified JSON arrays.
|
||||
pub fn extract_string_array(args: &serde_json::Value, key: &str) -> Option<Vec<String>> {
|
||||
args.get(key).and_then(|v| {
|
||||
if let Some(arr) = v.as_array() {
|
||||
Some(
|
||||
arr.iter()
|
||||
.filter_map(|item| item.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect(),
|
||||
)
|
||||
} else if let Some(s) = v.as_str() {
|
||||
// Try to parse as JSON array string
|
||||
serde_json::from_str::<Vec<serde_json::Value>>(s)
|
||||
.ok()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| item.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect()
|
||||
})
|
||||
.filter(|result: &Vec<String>| !result.is_empty())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract a required string array parameter.
|
||||
pub fn require_string_array(args: &serde_json::Value, key: &str) -> Result<Vec<String>, String> {
|
||||
extract_string_array(args, key)
|
||||
.filter(|arr| !arr.is_empty())
|
||||
.ok_or_else(|| format!("Missing required parameter: {}", key))
|
||||
}
|
||||
|
||||
/// Check if args is null and return an error result if so.
|
||||
/// Returns the provided error message if args is null.
|
||||
pub fn check_null_args(args: &serde_json::Value, tool_name: &str) -> Option<ToolResult> {
|
||||
if args.is_null() {
|
||||
return Some(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Missing required parameters: {} expects a JSON object with required fields. Check the tool schema and provide the necessary arguments.",
|
||||
tool_name
|
||||
)),
|
||||
});
|
||||
}
|
||||
if !args.is_object() {
|
||||
return Some(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid parameters: {} expects a JSON object, got {}",
|
||||
tool_name,
|
||||
if args.is_array() { "an array" } else if args.is_string() { "a string" } else { "an unexpected type" }
|
||||
)),
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ use crate::storage::{
|
||||
SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobUpsert,
|
||||
};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use crate::tools::{extract_bool, extract_i64};
|
||||
|
||||
pub struct SchedulerManageTool {
|
||||
jobs: Arc<dyn SchedulerJobRepository>,
|
||||
@ -126,10 +127,7 @@ impl Tool for SchedulerManageTool {
|
||||
|
||||
let output = match action {
|
||||
"list" => {
|
||||
let enabled_only = args
|
||||
.get("enabled_only")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(false);
|
||||
let enabled_only = extract_bool(&args, "enabled_only").unwrap_or(false);
|
||||
let jobs = self.jobs.list_scheduler_jobs(enabled_only)?;
|
||||
json!(jobs.iter().map(record_to_json).collect::<Vec<_>>())
|
||||
}
|
||||
@ -243,15 +241,8 @@ fn build_upsert(
|
||||
startup_delay_secs,
|
||||
target,
|
||||
payload,
|
||||
enabled: args
|
||||
.get("enabled")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(true),
|
||||
state: if args
|
||||
.get("enabled")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
enabled: extract_bool(args, "enabled").unwrap_or(true),
|
||||
state: if extract_bool(args, "enabled").unwrap_or(true) {
|
||||
SchedulerJobState::Scheduled
|
||||
} else {
|
||||
SchedulerJobState::Paused
|
||||
@ -259,7 +250,7 @@ fn build_upsert(
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
run_count: 0,
|
||||
max_runs: args.get("max_runs").and_then(|value| value.as_i64()),
|
||||
max_runs: extract_i64(args, "max_runs"),
|
||||
last_fired_at: None,
|
||||
next_fire_at: None,
|
||||
paused_at: None,
|
||||
@ -302,7 +293,7 @@ fn build_update_upsert(
|
||||
upsert.payload = payload.clone();
|
||||
}
|
||||
|
||||
if let Some(enabled) = args.get("enabled").and_then(|value| value.as_bool()) {
|
||||
if let Some(enabled) = extract_bool(args, "enabled") {
|
||||
upsert.enabled = enabled;
|
||||
upsert.state = if enabled {
|
||||
SchedulerJobState::Scheduled
|
||||
@ -319,7 +310,7 @@ fn build_update_upsert(
|
||||
}
|
||||
|
||||
if args.get("max_runs").is_some() {
|
||||
upsert.max_runs = args.get("max_runs").and_then(|value| value.as_i64());
|
||||
upsert.max_runs = extract_i64(args, "max_runs");
|
||||
}
|
||||
|
||||
if upsert.kind == "agent_task" || upsert.kind == "silent_agent_task" {
|
||||
|
||||
@ -164,21 +164,39 @@ fn validate_context(context: &ToolContext) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
fn parse_attachments(value: &serde_json::Value) -> anyhow::Result<Vec<MediaItem>> {
|
||||
let attachment_paths = value
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow!("attachments must be an array of local file paths"))?;
|
||||
// 支持两种格式:实际数组 或 字符串化的 JSON 数组
|
||||
let paths = if let Some(arr) = value.as_array() {
|
||||
arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect::<Vec<_>>()
|
||||
} else if let Some(s) = value.as_str() {
|
||||
// 尝试解析字符串化的 JSON 数组
|
||||
serde_json::from_str::<Vec<serde_json::Value>>(s)
|
||||
.ok()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let mut attachments = Vec::with_capacity(attachment_paths.len());
|
||||
for path_value in attachment_paths {
|
||||
let raw_path = path_value
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("attachments entries must be strings"))?
|
||||
.trim();
|
||||
if raw_path.is_empty() {
|
||||
return Err(anyhow!("attachment paths must not be empty"));
|
||||
if paths.is_empty() {
|
||||
return Err(anyhow!("attachments must be an array of local file paths"));
|
||||
}
|
||||
|
||||
let metadata = std::fs::metadata(raw_path)
|
||||
let mut attachments = Vec::with_capacity(paths.len());
|
||||
for raw_path in paths {
|
||||
let metadata = std::fs::metadata(&raw_path)
|
||||
.map_err(|err| anyhow!("failed to access attachment '{}': {}", raw_path, err))?;
|
||||
if !metadata.is_file() {
|
||||
return Err(anyhow!("attachment path is not a file: {}", raw_path));
|
||||
@ -187,9 +205,9 @@ fn parse_attachments(value: &serde_json::Value) -> anyhow::Result<Vec<MediaItem>
|
||||
return Err(anyhow!("attachment file is empty: {}", raw_path));
|
||||
}
|
||||
|
||||
let media_type = infer_media_type(raw_path);
|
||||
let media_type = infer_media_type(&raw_path);
|
||||
let mut item = MediaItem::new(raw_path.to_string(), media_type);
|
||||
item.mime_type = mime_guess::from_path(raw_path)
|
||||
item.mime_type = mime_guess::from_path(&raw_path)
|
||||
.first_raw()
|
||||
.map(ToOwned::to_owned);
|
||||
attachments.push(item);
|
||||
@ -319,4 +337,20 @@ mod tests {
|
||||
assert_eq!(attachments.len(), 1);
|
||||
assert_eq!(attachments[0].media_type, "image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_attachments_handles_stringified_json_array() {
|
||||
let file = NamedTempFile::new().unwrap();
|
||||
std::fs::write(file.path(), b"demo").unwrap();
|
||||
let txt_path = file.path().with_extension("txt");
|
||||
std::fs::rename(file.path(), &txt_path).unwrap();
|
||||
|
||||
// Test with stringified JSON array (like LLM might send)
|
||||
let path_str = txt_path.to_string_lossy().to_string().replace("\\", "\\\\");
|
||||
let json_string = format!("[\"{}\"]", path_str);
|
||||
let attachments = parse_attachments(&json!(json_string)).unwrap();
|
||||
|
||||
assert_eq!(attachments.len(), 1);
|
||||
assert_eq!(attachments[0].media_type, "file");
|
||||
}
|
||||
}
|
||||
@ -6,6 +6,7 @@ use serde_json::json;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SkillEventRepository;
|
||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||||
use crate::tools::check_null_args;
|
||||
|
||||
pub struct SkillActivateTool {
|
||||
skills: Arc<SkillRuntime>,
|
||||
@ -68,6 +69,10 @@ impl Tool for SkillActivateTool {
|
||||
context: &ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
if let Some(result) = check_null_args(&args, "skill_activate") {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let skill_name = match args.get("name").and_then(|value| value.as_str()) {
|
||||
Some(name) if !name.trim().is_empty() => name,
|
||||
_ => {
|
||||
@ -152,4 +157,24 @@ mod tests {
|
||||
assert_eq!(events[0].event_type, "activation_failed");
|
||||
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_skill_activate_handles_null_args() {
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
|
||||
let tool = SkillActivateTool::new(skills, store.clone());
|
||||
let context = ToolContext {
|
||||
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||
..ToolContext::default()
|
||||
};
|
||||
|
||||
let result = tool
|
||||
.execute_with_context(&context, serde_json::Value::Null)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Missing required parameters"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
|
||||
use crate::skills::{SkillRuntime, SkillScope};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use crate::tools::{extract_bool, extract_string_array};
|
||||
|
||||
pub struct SkillManageTool {
|
||||
skills: Arc<SkillRuntime>,
|
||||
@ -86,7 +87,7 @@ impl Tool for SkillManageTool {
|
||||
}
|
||||
};
|
||||
|
||||
let reload = args.get("reload").and_then(|v| v.as_bool()).unwrap_or(true);
|
||||
let reload = extract_bool(&args, "reload").unwrap_or(true);
|
||||
let scope = match args.get("scope").and_then(|v| v.as_str()) {
|
||||
Some(value) => match SkillScope::parse(value) {
|
||||
Some(scope) => scope,
|
||||
@ -280,28 +281,10 @@ fn error_result(message: &str) -> ToolResult {
|
||||
}
|
||||
|
||||
fn parse_disable_names(args: &serde_json::Value) -> Result<Vec<String>, String> {
|
||||
let names = args
|
||||
.get("names")
|
||||
.ok_or_else(|| "disable requires names".to_string())?
|
||||
.as_array()
|
||||
.ok_or_else(|| "names must be an array of strings".to_string())?;
|
||||
|
||||
let mut parsed = Vec::new();
|
||||
for item in names {
|
||||
let name = item
|
||||
.as_str()
|
||||
.ok_or_else(|| "names must be an array of strings".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if name.is_empty() {
|
||||
return Err("names must not contain empty values".to_string());
|
||||
}
|
||||
parsed.push(name);
|
||||
}
|
||||
if parsed.is_empty() {
|
||||
return Err("names must not be empty".to_string());
|
||||
}
|
||||
Ok(parsed)
|
||||
// 支持两种格式:实际数组 或 字符串化的 JSON 数组
|
||||
extract_string_array(args, "names")
|
||||
.filter(|arr| !arr.is_empty())
|
||||
.ok_or_else(|| "disable requires names (array of strings)".to_string())
|
||||
}
|
||||
|
||||
fn skill_change_payload(change: crate::skills::SkillAvailabilityChange) -> serde_json::Value {
|
||||
|
||||
@ -2,6 +2,8 @@ use async_trait::async_trait;
|
||||
use chrono::{DateTime, Days, Duration, Months, Utc};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use crate::tools::extract_u64;
|
||||
use crate::tools::check_null_args;
|
||||
use super::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct TimeTool {
|
||||
@ -16,6 +18,10 @@ impl TimeTool {
|
||||
}
|
||||
|
||||
fn execute_at(&self, now_utc: DateTime<Utc>, args: Value) -> ToolResult {
|
||||
if let Some(result) = check_null_args(&args, "get_time") {
|
||||
return result;
|
||||
}
|
||||
|
||||
match execute_time_request(now_utc, &self.default_timezone, args) {
|
||||
Ok(output) => ToolResult {
|
||||
success: true,
|
||||
@ -156,7 +162,7 @@ fn execute_time_request(
|
||||
|
||||
fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
|
||||
let direction = args.get("direction").and_then(Value::as_str);
|
||||
let amount = args.get("amount");
|
||||
let amount = extract_u64(args, "amount");
|
||||
let unit = args.get("unit").and_then(Value::as_str);
|
||||
|
||||
if direction.is_none() && amount.is_none() && unit.is_none() {
|
||||
@ -166,9 +172,11 @@ fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
|
||||
let direction = direction.ok_or_else(|| {
|
||||
"Missing required parameter: direction when requesting a relative time".to_string()
|
||||
})?;
|
||||
let amount = amount.and_then(Value::as_u64).ok_or_else(|| {
|
||||
let amount = amount.ok_or_else(|| {
|
||||
"Missing required parameter: amount when requesting a relative time".to_string()
|
||||
})?;
|
||||
let amount = u32::try_from(amount)
|
||||
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
|
||||
let amount = u32::try_from(amount)
|
||||
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
|
||||
let unit = unit.ok_or_else(|| {
|
||||
@ -451,4 +459,51 @@ mod tests {
|
||||
assert_eq!(result_time.hour(), 12);
|
||||
assert_eq!(result_time.minute(), 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_amount_as_string_is_parsed() {
|
||||
let tool = TimeTool::new("Asia/Shanghai");
|
||||
let result = tool.execute_at(
|
||||
fixed_now(),
|
||||
json!({"direction": "future", "amount": "7", "unit": "days"}),
|
||||
);
|
||||
|
||||
assert!(result.success, "Expected success but got error: {:?}", result.error);
|
||||
let payload: Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(payload["result_time"], "2026-05-04T12:30:00+08:00");
|
||||
assert_eq!(payload["offset"]["amount"], 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_amount_string_returns_error() {
|
||||
let tool = TimeTool::new("Asia/Shanghai");
|
||||
let result = tool.execute_at(
|
||||
fixed_now(),
|
||||
json!({"direction": "future", "amount": "not_a_number", "unit": "days"}),
|
||||
);
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Missing required parameter: amount"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_args_returns_current_time() {
|
||||
let tool = TimeTool::new("Asia/Shanghai");
|
||||
let result = tool.execute_at(fixed_now(), serde_json::Value::Null);
|
||||
|
||||
// Null args should return error (current time requires no params, but null is invalid)
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Missing required parameters"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_object_returns_current_time() {
|
||||
let tool = TimeTool::new("Asia/Shanghai");
|
||||
let result = tool.execute_at(fixed_now(), json!({}));
|
||||
|
||||
// Empty object should return current time
|
||||
assert!(result.success);
|
||||
let payload: Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(payload["kind"], "current");
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user