feat: 优化 StreamingAccumulator 中工具调用的 ID 和名称更新逻辑,确保空值不会覆盖有效数据,并添加相关单元测试
This commit is contained in:
parent
1c6ee160e5
commit
9adaa93ecc
@ -51,12 +51,18 @@ impl StreamingAccumulator {
|
|||||||
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
|
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
|
||||||
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
|
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
|
||||||
|
|
||||||
|
// 只在 id 非空时才更新,防止流式响应中后续 chunk 的空 id 覆盖之前的值
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
|
if !id.is_empty() {
|
||||||
entry.id = id.to_string();
|
entry.id = id.to_string();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
// 只在 name 非空时才更新,防止流式响应中后续 chunk 的 None 覆盖之前的值
|
||||||
if let Some(name) = name {
|
if let Some(name) = name {
|
||||||
|
if !name.is_empty() {
|
||||||
entry.name = name.to_string();
|
entry.name = name.to_string();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if let Some(args) = arguments {
|
if let Some(args) = arguments {
|
||||||
entry.arguments.push_str(args);
|
entry.arguments.push_str(args);
|
||||||
}
|
}
|
||||||
@ -1006,4 +1012,48 @@ mod tests {
|
|||||||
OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"),
|
OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_streaming_accumulator_preserves_tool_call_id_with_empty_subsequent_chunks() {
|
||||||
|
// 模拟阿里云等云服务商的流式响应行为:
|
||||||
|
// 第一个 chunk 包含 id 和 name,后续 chunk 的 id 为空字符串、name 为 None
|
||||||
|
let mut accumulator = StreamingAccumulator::new();
|
||||||
|
|
||||||
|
// 第一个 chunk:包含完整的 id 和 name
|
||||||
|
accumulator.add_tool_call(0, Some("call_abc123"), Some("memory_search"), Some("{\"action\":\""));
|
||||||
|
// 第二个 chunk:只有参数增量
|
||||||
|
accumulator.add_tool_call(0, None, None, Some("list"));
|
||||||
|
// 第三个 chunk:参数继续
|
||||||
|
accumulator.add_tool_call(0, None, None, Some("\""));
|
||||||
|
// 第四个 chunk:id 为空字符串(某些云服务商的行为)
|
||||||
|
accumulator.add_tool_call(0, Some(""), None, Some(", \"limit\": 20"));
|
||||||
|
// 最后一个 chunk:name 为 None
|
||||||
|
accumulator.add_tool_call(0, None, None, Some("}"));
|
||||||
|
|
||||||
|
let response = accumulator.build_response("test-model".to_string());
|
||||||
|
|
||||||
|
// 验证工具调用被正确保留,id 没有被空字符串覆盖
|
||||||
|
assert_eq!(response.tool_calls.len(), 1);
|
||||||
|
assert_eq!(response.tool_calls[0].id, "call_abc123");
|
||||||
|
assert_eq!(response.tool_calls[0].name, "memory_search");
|
||||||
|
assert_eq!(response.tool_calls[0].arguments, json!({"action":"list", "limit": 20}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_streaming_accumulator_handles_multiple_tool_calls() {
|
||||||
|
let mut accumulator = StreamingAccumulator::new();
|
||||||
|
|
||||||
|
// 第一个工具调用
|
||||||
|
accumulator.add_tool_call(0, Some("call_1"), Some("calculator"), Some("{\"expr\": \"1+1\"}"));
|
||||||
|
// 第二个工具调用(id 和 name 只在第一个 chunk 出现)
|
||||||
|
accumulator.add_tool_call(1, Some("call_2"), Some("get_time"), Some("{}"));
|
||||||
|
|
||||||
|
let response = accumulator.build_response("test-model".to_string());
|
||||||
|
|
||||||
|
assert_eq!(response.tool_calls.len(), 2);
|
||||||
|
assert_eq!(response.tool_calls[0].id, "call_1");
|
||||||
|
assert_eq!(response.tool_calls[0].name, "calculator");
|
||||||
|
assert_eq!(response.tool_calls[1].id, "call_2");
|
||||||
|
assert_eq!(response.tool_calls[1].name, "get_time");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user