From 9e17cd35daf327fe9a08085c315d2f7f5d5501c5 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Mon, 27 Apr 2026 09:21:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BAOpenAIProvider?= =?UTF-8?q?=E4=BB=A5=E6=94=AF=E6=8C=81=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84JSON=E6=A0=BC=E5=BC=8F=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=AF=B7=E6=B1=82=E4=BD=93=E6=9E=84=E5=BB=BA?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E6=8E=92=E9=99=A4=E5=86=85=E9=83=A8?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E9=A2=9D=E5=A4=96=E9=94=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/providers/openai.rs | 111 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 8 deletions(-) diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 3549e91..21a2d55 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -9,6 +9,11 @@ use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::traits::Usage; +const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[ + "tool_call_arguments_json", + "mock_response_content", +]; + fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { if blocks.len() == 1 { if let ContentBlock::Text { text } = &blocks[0] { @@ -81,16 +86,34 @@ impl OpenAIProvider { .unwrap_or(false) } - fn serialize_tool_arguments(&self, arguments: &Value) -> Value { - if self.uses_json_tool_arguments() { - arguments.clone() - } else { - Value::String( - serde_json::to_string(arguments).unwrap_or_else(|_| "null".to_string()), - ) + fn normalize_tool_arguments(&self, arguments: &Value) -> Value { + match arguments { + Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), + _ => arguments.clone(), } } + fn serialize_tool_arguments(&self, arguments: &Value) -> Value { + let normalized = self.normalize_tool_arguments(arguments); + + if self.uses_json_tool_arguments() { + normalized + } else { + match normalized { + Value::String(raw) => Value::String(raw), + value => Value::String( + serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string()), + ), + } + } + } + + fn request_model_extra(&self) -> impl Iterator { + self.model_extra.iter().filter(|(key, _)| { + !INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str()) + }) + } + fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let mut body = json!({ "model": self.model_id, @@ -142,7 +165,7 @@ impl OpenAIProvider { "max_tokens": request.max_tokens.or(self.max_tokens), }); - for (key, value) in &self.model_extra { + for (key, value) in self.request_model_extra() { body[key] = value.clone(); } @@ -408,6 +431,78 @@ mod tests { let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"})); + assert!(body.get("tool_call_arguments_json").is_none()); + } + + #[test] + fn test_build_request_body_preserves_raw_json_string_arguments() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::new(), + ); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "assistant".to_string(), + content: vec![ContentBlock::text("calling tool")], + reasoning_content: None, + tool_call_id: None, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "calculator".to_string(), + arguments: Value::String("{\"expression\":\"1+1\"}".to_string()), + }]), + }], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + let messages = body["messages"].as_array().unwrap(); + let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); + + assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); + } + + #[test] + fn test_build_request_body_omits_internal_model_extra_keys() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::from([ + ("tool_call_arguments_json".to_string(), Value::Bool(true)), + ("mock_response_content".to_string(), Value::String("stub".to_string())), + ("parallel_tool_calls".to_string(), Value::Bool(true)), + ]), + ); + + let request = ChatCompletionRequest { + messages: vec![Message::user("hello")], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + + assert!(body.get("tool_call_arguments_json").is_none()); + assert!(body.get("mock_response_content").is_none()); + assert_eq!(body["parallel_tool_calls"], Value::Bool(true)); } #[test]