feat: 优化 MemorySearchTool 的查询参数处理,支持多种格式的查询输入

This commit is contained in:
oudecheng 2026-05-21 11:35:17 +08:00
parent 9adaa93ecc
commit 32a9e2946e
2 changed files with 70 additions and 25 deletions

View File

@ -323,13 +323,17 @@ impl OpenAIProvider {
let line = buffer[..newline_pos].to_string(); let line = buffer[..newline_pos].to_string();
buffer = buffer[newline_pos + 1..].to_string(); buffer = buffer[newline_pos + 1..].to_string();
let line = line.trim(); let line_trimmed = line.trim();
if line.is_empty() || line.starts_with(':') {
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
continue; continue;
} }
// SSE 格式: data: {...} // SSE 格式: data: {...} 或 data:{...}(某些 API 如 139 云没有空格)
if let Some(data) = line.strip_prefix("data: ") { let data_opt = line_trimmed.strip_prefix("data: ")
.or_else(|| line_trimmed.strip_prefix("data:"));
if let Some(data) = data_opt {
if data == "[DONE]" { if data == "[DONE]" {
// 流结束 // 流结束
done_received = true; done_received = true;
@ -347,6 +351,7 @@ impl OpenAIProvider {
// 提取 choices // 提取 choices
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
for choice in choices { for choice in choices {
// 尝试从 delta 提取(标准 OpenAI 流式格式)
if let Some(delta) = choice.get("delta") { if let Some(delta) = choice.get("delta") {
// 提取内容增量 // 提取内容增量
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
@ -360,7 +365,6 @@ impl OpenAIProvider {
// 提取工具调用增量 // 提取工具调用增量
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta");
for tool_call in tool_calls { for tool_call in tool_calls {
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize; let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
@ -372,18 +376,19 @@ impl OpenAIProvider {
.and_then(|f| f.get("arguments")) .and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str()); .and_then(|a| a.as_str());
tracing::debug!(
index = index,
id = ?id,
name = ?name,
arguments = ?arguments,
"Tool call delta received"
);
accumulator.add_tool_call(index, id, name, arguments); accumulator.add_tool_call(index, id, name, arguments);
} }
} }
} }
// 尝试从 message 提取(某些非标准 API 格式)
else if let Some(message) = choice.get("message") {
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
}
} }
} }
} }
@ -405,12 +410,16 @@ impl OpenAIProvider {
// 处理缓冲区中剩余的内容 // 处理缓冲区中剩余的内容
for line in buffer.lines() { for line in buffer.lines() {
let line = line.trim(); let line_trimmed = line.trim();
if line.is_empty() || line.starts_with(':') { if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
continue; continue;
} }
if let Some(data) = line.strip_prefix("data: ") { // 同样支持 data: {...} 和 data:{...} 两种格式
let data_opt = line_trimmed.strip_prefix("data: ")
.or_else(|| line_trimmed.strip_prefix("data:"));
if let Some(data) = data_opt {
if data == "[DONE]" { if data == "[DONE]" {
break; break;
} }
@ -422,6 +431,7 @@ impl OpenAIProvider {
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
for choice in choices { for choice in choices {
// 尝试从 delta 提取
if let Some(delta) = choice.get("delta") { if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content); accumulator.add_content(content);
@ -443,6 +453,15 @@ impl OpenAIProvider {
} }
} }
} }
// 尝试从 message 提取(某些非标准 API 格式)
else if let Some(message) = choice.get("message") {
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
}
} }
} }
} }

View File

@ -99,15 +99,41 @@ impl Tool for MemorySearchTool {
}) })
} }
"search" => { "search" => {
let queries = match args.get("queries").and_then(|value| value.as_array()) { let queries = match args.get("queries") {
Some(queries) => queries Some(value) => {
.iter() // 支持两种格式:实际数组 或 字符串化的数组
.filter_map(|value| value.as_str()) if let Some(arr) = value.as_array() {
.map(str::trim) arr
.filter(|value| !value.is_empty()) .iter()
.map(ToOwned::to_owned) .filter_map(|v| v.as_str())
.collect::<Vec<_>>(), .map(str::trim)
None => return Ok(error_result("Missing required parameter: queries")), .filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
} else if let Some(s) = value.as_str() {
// 尝试解析字符串化的 JSON 数组
match serde_json::from_str::<Vec<serde_json::Value>>(s) {
Ok(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
Err(_) => {
// 如果不是 JSON 数组,尝试按逗号分割
s.split(',')
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
}
}
} else {
vec![]
}
}
None => vec![]
}; };
if queries.is_empty() { if queries.is_empty() {
return Ok(error_result("Missing required parameter: queries")); return Ok(error_result("Missing required parameter: queries"));