From 9adaa93eccc6d06e1b5d2c7ae70c6a97aedd66fb Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Thu, 21 May 2026 08:50:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=20StreamingAccumulat?= =?UTF-8?q?or=20=E4=B8=AD=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E7=9A=84=20I?= =?UTF-8?q?D=20=E5=92=8C=E5=90=8D=E7=A7=B0=E6=9B=B4=E6=96=B0=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E7=A1=AE=E4=BF=9D=E7=A9=BA=E5=80=BC=E4=B8=8D?= =?UTF-8?q?=E4=BC=9A=E8=A6=86=E7=9B=96=E6=9C=89=E6=95=88=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=B7=BB=E5=8A=A0=E7=9B=B8=E5=85=B3=E5=8D=95?= =?UTF-8?q?=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/providers/openai.rs | 54 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 04baac6..45e12d2 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -51,11 +51,17 @@ impl StreamingAccumulator { fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) { let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default); + // 只在 id 非空时才更新,防止流式响应中后续 chunk 的空 id 覆盖之前的值 if let Some(id) = id { - entry.id = id.to_string(); + if !id.is_empty() { + entry.id = id.to_string(); + } } + // 只在 name 非空时才更新,防止流式响应中后续 chunk 的 None 覆盖之前的值 if let Some(name) = name { - entry.name = name.to_string(); + if !name.is_empty() { + entry.name = name.to_string(); + } } if let Some(args) = arguments { entry.arguments.push_str(args); @@ -1006,4 +1012,48 @@ mod tests { OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"), } } + + #[test] + fn test_streaming_accumulator_preserves_tool_call_id_with_empty_subsequent_chunks() { + // 模拟阿里云等云服务商的流式响应行为: + // 第一个 chunk 包含 id 和 name,后续 chunk 的 id 为空字符串、name 为 None + let mut accumulator = StreamingAccumulator::new(); + + // 第一个 chunk:包含完整的 id 和 name + accumulator.add_tool_call(0, Some("call_abc123"), Some("memory_search"), Some("{\"action\":\"")); + // 第二个 chunk:只有参数增量 + accumulator.add_tool_call(0, None, None, Some("list")); + // 第三个 chunk:参数继续 + accumulator.add_tool_call(0, None, None, Some("\"")); + // 第四个 chunk:id 为空字符串(某些云服务商的行为) + accumulator.add_tool_call(0, Some(""), None, Some(", \"limit\": 20")); + // 最后一个 chunk:name 为 None + accumulator.add_tool_call(0, None, None, Some("}")); + + let response = accumulator.build_response("test-model".to_string()); + + // 验证工具调用被正确保留,id 没有被空字符串覆盖 + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].id, "call_abc123"); + assert_eq!(response.tool_calls[0].name, "memory_search"); + assert_eq!(response.tool_calls[0].arguments, json!({"action":"list", "limit": 20})); + } + + #[test] + fn test_streaming_accumulator_handles_multiple_tool_calls() { + let mut accumulator = StreamingAccumulator::new(); + + // 第一个工具调用 + accumulator.add_tool_call(0, Some("call_1"), Some("calculator"), Some("{\"expr\": \"1+1\"}")); + // 第二个工具调用(id 和 name 只在第一个 chunk 出现) + accumulator.add_tool_call(1, Some("call_2"), Some("get_time"), Some("{}")); + + let response = accumulator.build_response("test-model".to_string()); + + assert_eq!(response.tool_calls.len(), 2); + assert_eq!(response.tool_calls[0].id, "call_1"); + assert_eq!(response.tool_calls[0].name, "calculator"); + assert_eq!(response.tool_calls[1].id, "call_2"); + assert_eq!(response.tool_calls[1].name, "get_time"); + } }