diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 04baac6..45e12d2 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -51,11 +51,17 @@ impl StreamingAccumulator { fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) { let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default); + // 只在 id 非空时才更新,防止流式响应中后续 chunk 的空 id 覆盖之前的值 if let Some(id) = id { - entry.id = id.to_string(); + if !id.is_empty() { + entry.id = id.to_string(); + } } + // 只在 name 非空时才更新,防止流式响应中后续 chunk 的 None 覆盖之前的值 if let Some(name) = name { - entry.name = name.to_string(); + if !name.is_empty() { + entry.name = name.to_string(); + } } if let Some(args) = arguments { entry.arguments.push_str(args); @@ -1006,4 +1012,48 @@ mod tests { OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"), } } + + #[test] + fn test_streaming_accumulator_preserves_tool_call_id_with_empty_subsequent_chunks() { + // 模拟阿里云等云服务商的流式响应行为: + // 第一个 chunk 包含 id 和 name,后续 chunk 的 id 为空字符串、name 为 None + let mut accumulator = StreamingAccumulator::new(); + + // 第一个 chunk:包含完整的 id 和 name + accumulator.add_tool_call(0, Some("call_abc123"), Some("memory_search"), Some("{\"action\":\"")); + // 第二个 chunk:只有参数增量 + accumulator.add_tool_call(0, None, None, Some("list")); + // 第三个 chunk:参数继续 + accumulator.add_tool_call(0, None, None, Some("\"")); + // 第四个 chunk:id 为空字符串(某些云服务商的行为) + accumulator.add_tool_call(0, Some(""), None, Some(", \"limit\": 20")); + // 最后一个 chunk:name 为 None + accumulator.add_tool_call(0, None, None, Some("}")); + + let response = accumulator.build_response("test-model".to_string()); + + // 验证工具调用被正确保留,id 没有被空字符串覆盖 + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].id, "call_abc123"); + assert_eq!(response.tool_calls[0].name, "memory_search"); + assert_eq!(response.tool_calls[0].arguments, json!({"action":"list", "limit": 20})); + } + + #[test] + fn test_streaming_accumulator_handles_multiple_tool_calls() { + let mut accumulator = StreamingAccumulator::new(); + + // 第一个工具调用 + accumulator.add_tool_call(0, Some("call_1"), Some("calculator"), Some("{\"expr\": \"1+1\"}")); + // 第二个工具调用(id 和 name 只在第一个 chunk 出现) + accumulator.add_tool_call(1, Some("call_2"), Some("get_time"), Some("{}")); + + let response = accumulator.build_response("test-model".to_string()); + + assert_eq!(response.tool_calls.len(), 2); + assert_eq!(response.tool_calls[0].id, "call_1"); + assert_eq!(response.tool_calls[0].name, "calculator"); + assert_eq!(response.tool_calls[1].id, "call_2"); + assert_eq!(response.tool_calls[1].name, "get_time"); + } }