feat: 增强OpenAIProvider以支持工具调用参数的JSON格式,优化请求体构建逻辑,排除内部模型额外键
This commit is contained in:
parent
60cc8e507c
commit
9e17cd35da
@ -9,6 +9,11 @@ use crate::bus::message::ContentBlock;
|
|||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
|
||||||
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[
|
||||||
|
"tool_call_arguments_json",
|
||||||
|
"mock_response_content",
|
||||||
|
];
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||||
if blocks.len() == 1 {
|
if blocks.len() == 1 {
|
||||||
if let ContentBlock::Text { text } = &blocks[0] {
|
if let ContentBlock::Text { text } = &blocks[0] {
|
||||||
@ -81,16 +86,34 @@ impl OpenAIProvider {
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn serialize_tool_arguments(&self, arguments: &Value) -> Value {
|
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
if self.uses_json_tool_arguments() {
|
match arguments {
|
||||||
arguments.clone()
|
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||||
} else {
|
_ => arguments.clone(),
|
||||||
Value::String(
|
|
||||||
serde_json::to_string(arguments).unwrap_or_else(|_| "null".to_string()),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn serialize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
|
let normalized = self.normalize_tool_arguments(arguments);
|
||||||
|
|
||||||
|
if self.uses_json_tool_arguments() {
|
||||||
|
normalized
|
||||||
|
} else {
|
||||||
|
match normalized {
|
||||||
|
Value::String(raw) => Value::String(raw),
|
||||||
|
value => Value::String(
|
||||||
|
serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string()),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
|
||||||
|
self.model_extra.iter().filter(|(key, _)| {
|
||||||
|
!INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
@ -142,7 +165,7 @@ impl OpenAIProvider {
|
|||||||
"max_tokens": request.max_tokens.or(self.max_tokens),
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
||||||
});
|
});
|
||||||
|
|
||||||
for (key, value) in &self.model_extra {
|
for (key, value) in self.request_model_extra() {
|
||||||
body[key] = value.clone();
|
body[key] = value.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,6 +431,78 @@ mod tests {
|
|||||||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
|
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
|
||||||
|
assert!(body.get("tool_call_arguments_json").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_body_preserves_raw_json_string_arguments() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
120,
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![ContentBlock::text("calling tool")],
|
||||||
|
reasoning_content: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: Some(vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: Value::String("{\"expression\":\"1+1\"}".to_string()),
|
||||||
|
}]),
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = provider.build_request_body(&request);
|
||||||
|
let messages = body["messages"].as_array().unwrap();
|
||||||
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_body_omits_internal_model_extra_keys() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
120,
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::from([
|
||||||
|
("tool_call_arguments_json".to_string(), Value::Bool(true)),
|
||||||
|
("mock_response_content".to_string(), Value::String("stub".to_string())),
|
||||||
|
("parallel_tool_calls".to_string(), Value::Bool(true)),
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message::user("hello")],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = provider.build_request_body(&request);
|
||||||
|
|
||||||
|
assert!(body.get("tool_call_arguments_json").is_none());
|
||||||
|
assert!(body.get("mock_response_content").is_none());
|
||||||
|
assert_eq!(body["parallel_tool_calls"], Value::Bool(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user