feat: 优化 MemorySearchTool 的查询参数处理,支持多种格式的查询输入
This commit is contained in:
parent
9adaa93ecc
commit
32a9e2946e
@ -323,13 +323,17 @@ impl OpenAIProvider {
|
||||
let line = buffer[..newline_pos].to_string();
|
||||
buffer = buffer[newline_pos + 1..].to_string();
|
||||
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
let line_trimmed = line.trim();
|
||||
|
||||
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// SSE 格式: data: {...}
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// SSE 格式: data: {...} 或 data:{...}(某些 API 如 139 云没有空格)
|
||||
let data_opt = line_trimmed.strip_prefix("data: ")
|
||||
.or_else(|| line_trimmed.strip_prefix("data:"));
|
||||
|
||||
if let Some(data) = data_opt {
|
||||
if data == "[DONE]" {
|
||||
// 流结束
|
||||
done_received = true;
|
||||
@ -347,6 +351,7 @@ impl OpenAIProvider {
|
||||
// 提取 choices
|
||||
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
||||
for choice in choices {
|
||||
// 尝试从 delta 提取(标准 OpenAI 流式格式)
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
// 提取内容增量
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
@ -360,7 +365,6 @@ impl OpenAIProvider {
|
||||
|
||||
// 提取工具调用增量
|
||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||
tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta");
|
||||
for tool_call in tool_calls {
|
||||
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
||||
|
||||
@ -372,18 +376,19 @@ impl OpenAIProvider {
|
||||
.and_then(|f| f.get("arguments"))
|
||||
.and_then(|a| a.as_str());
|
||||
|
||||
tracing::debug!(
|
||||
index = index,
|
||||
id = ?id,
|
||||
name = ?name,
|
||||
arguments = ?arguments,
|
||||
"Tool call delta received"
|
||||
);
|
||||
|
||||
accumulator.add_tool_call(index, id, name, arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
// 尝试从 message 提取(某些非标准 API 格式)
|
||||
else if let Some(message) = choice.get("message") {
|
||||
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
|
||||
accumulator.add_content(content);
|
||||
}
|
||||
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||
accumulator.add_reasoning_content(reasoning);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -405,12 +410,16 @@ impl OpenAIProvider {
|
||||
|
||||
// 处理缓冲区中剩余的内容
|
||||
for line in buffer.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
let line_trimmed = line.trim();
|
||||
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// 同样支持 data: {...} 和 data:{...} 两种格式
|
||||
let data_opt = line_trimmed.strip_prefix("data: ")
|
||||
.or_else(|| line_trimmed.strip_prefix("data:"));
|
||||
|
||||
if let Some(data) = data_opt {
|
||||
if data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
@ -422,6 +431,7 @@ impl OpenAIProvider {
|
||||
|
||||
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
||||
for choice in choices {
|
||||
// 尝试从 delta 提取
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
accumulator.add_content(content);
|
||||
@ -443,6 +453,15 @@ impl OpenAIProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
// 尝试从 message 提取(某些非标准 API 格式)
|
||||
else if let Some(message) = choice.get("message") {
|
||||
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
|
||||
accumulator.add_content(content);
|
||||
}
|
||||
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||
accumulator.add_reasoning_content(reasoning);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,15 +99,41 @@ impl Tool for MemorySearchTool {
|
||||
})
|
||||
}
|
||||
"search" => {
|
||||
let queries = match args.get("queries").and_then(|value| value.as_array()) {
|
||||
Some(queries) => queries
|
||||
let queries = match args.get("queries") {
|
||||
Some(value) => {
|
||||
// 支持两种格式:实际数组 或 字符串化的数组
|
||||
if let Some(arr) = value.as_array() {
|
||||
arr
|
||||
.iter()
|
||||
.filter_map(|value| value.as_str())
|
||||
.filter_map(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect::<Vec<_>>()
|
||||
} else if let Some(s) = value.as_str() {
|
||||
// 尝试解析字符串化的 JSON 数组
|
||||
match serde_json::from_str::<Vec<serde_json::Value>>(s) {
|
||||
Ok(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect::<Vec<_>>(),
|
||||
None => return Ok(error_result("Missing required parameter: queries")),
|
||||
Err(_) => {
|
||||
// 如果不是 JSON 数组,尝试按逗号分割
|
||||
s.split(',')
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
None => vec![]
|
||||
};
|
||||
if queries.is_empty() {
|
||||
return Ok(error_result("Missing required parameter: queries"));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user