From 32a9e2946e956dd054807e4e9c1f280615d5e6a1 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Thu, 21 May 2026 11:35:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=20MemorySearchTool?= =?UTF-8?q?=20=E7=9A=84=E6=9F=A5=E8=AF=A2=E5=8F=82=E6=95=B0=E5=A4=84?= =?UTF-8?q?=E7=90=86=EF=BC=8C=E6=94=AF=E6=8C=81=E5=A4=9A=E7=A7=8D=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E7=9A=84=E6=9F=A5=E8=AF=A2=E8=BE=93=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/providers/openai.rs | 51 ++++++++++++++++++++++++++------------ src/tools/memory_search.rs | 44 +++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 45e12d2..1766a82 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -323,13 +323,17 @@ impl OpenAIProvider { let line = buffer[..newline_pos].to_string(); buffer = buffer[newline_pos + 1..].to_string(); - let line = line.trim(); - if line.is_empty() || line.starts_with(':') { + let line_trimmed = line.trim(); + + if line_trimmed.is_empty() || line_trimmed.starts_with(':') { continue; } - // SSE 格式: data: {...} - if let Some(data) = line.strip_prefix("data: ") { + // SSE 格式: data: {...} 或 data:{...}(某些 API 如 139 云没有空格) + let data_opt = line_trimmed.strip_prefix("data: ") + .or_else(|| line_trimmed.strip_prefix("data:")); + + if let Some(data) = data_opt { if data == "[DONE]" { // 流结束 done_received = true; @@ -347,6 +351,7 @@ impl OpenAIProvider { // 提取 choices if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { for choice in choices { + // 尝试从 delta 提取(标准 OpenAI 流式格式) if let Some(delta) = choice.get("delta") { // 提取内容增量 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { @@ -360,7 +365,6 @@ impl OpenAIProvider { // 提取工具调用增量 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { - tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta"); for tool_call in tool_calls { let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize; @@ -372,18 +376,19 @@ impl OpenAIProvider { .and_then(|f| f.get("arguments")) .and_then(|a| a.as_str()); - tracing::debug!( - index = index, - id = ?id, - name = ?name, - arguments = ?arguments, - "Tool call delta received" - ); - accumulator.add_tool_call(index, id, name, arguments); } } } + // 尝试从 message 提取(某些非标准 API 格式) + else if let Some(message) = choice.get("message") { + if let Some(content) = message.get("content").and_then(|c| c.as_str()) { + accumulator.add_content(content); + } + if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) { + accumulator.add_reasoning_content(reasoning); + } + } } } } @@ -405,12 +410,16 @@ impl OpenAIProvider { // 处理缓冲区中剩余的内容 for line in buffer.lines() { - let line = line.trim(); - if line.is_empty() || line.starts_with(':') { + let line_trimmed = line.trim(); + if line_trimmed.is_empty() || line_trimmed.starts_with(':') { continue; } - if let Some(data) = line.strip_prefix("data: ") { + // 同样支持 data: {...} 和 data:{...} 两种格式 + let data_opt = line_trimmed.strip_prefix("data: ") + .or_else(|| line_trimmed.strip_prefix("data:")); + + if let Some(data) = data_opt { if data == "[DONE]" { break; } @@ -422,6 +431,7 @@ impl OpenAIProvider { if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { for choice in choices { + // 尝试从 delta 提取 if let Some(delta) = choice.get("delta") { if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { accumulator.add_content(content); @@ -443,6 +453,15 @@ impl OpenAIProvider { } } } + // 尝试从 message 提取(某些非标准 API 格式) + else if let Some(message) = choice.get("message") { + if let Some(content) = message.get("content").and_then(|c| c.as_str()) { + accumulator.add_content(content); + } + if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) { + accumulator.add_reasoning_content(reasoning); + } + } } } } diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs index 3a3c3bb..7b45040 100644 --- a/src/tools/memory_search.rs +++ b/src/tools/memory_search.rs @@ -99,15 +99,41 @@ impl Tool for MemorySearchTool { }) } "search" => { - let queries = match args.get("queries").and_then(|value| value.as_array()) { - Some(queries) => queries - .iter() - .filter_map(|value| value.as_str()) - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(ToOwned::to_owned) - .collect::>(), - None => return Ok(error_result("Missing required parameter: queries")), + let queries = match args.get("queries") { + Some(value) => { + // 支持两种格式:实际数组 或 字符串化的数组 + if let Some(arr) = value.as_array() { + arr + .iter() + .filter_map(|v| v.as_str()) + .map(str::trim) + .filter(|v| !v.is_empty()) + .map(ToOwned::to_owned) + .collect::>() + } else if let Some(s) = value.as_str() { + // 尝试解析字符串化的 JSON 数组 + match serde_json::from_str::>(s) { + Ok(arr) => arr + .iter() + .filter_map(|v| v.as_str()) + .map(str::trim) + .filter(|v| !v.is_empty()) + .map(ToOwned::to_owned) + .collect::>(), + Err(_) => { + // 如果不是 JSON 数组,尝试按逗号分割 + s.split(',') + .map(str::trim) + .filter(|v| !v.is_empty()) + .map(ToOwned::to_owned) + .collect::>() + } + } + } else { + vec![] + } + } + None => vec![] }; if queries.is_empty() { return Ok(error_result("Missing required parameter: queries"));