diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 3549e91..21a2d55 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -9,6 +9,11 @@ use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::traits::Usage; +const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[ + "tool_call_arguments_json", + "mock_response_content", +]; + fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { if blocks.len() == 1 { if let ContentBlock::Text { text } = &blocks[0] { @@ -81,16 +86,34 @@ impl OpenAIProvider { .unwrap_or(false) } - fn serialize_tool_arguments(&self, arguments: &Value) -> Value { - if self.uses_json_tool_arguments() { - arguments.clone() - } else { - Value::String( - serde_json::to_string(arguments).unwrap_or_else(|_| "null".to_string()), - ) + fn normalize_tool_arguments(&self, arguments: &Value) -> Value { + match arguments { + Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), + _ => arguments.clone(), } } + fn serialize_tool_arguments(&self, arguments: &Value) -> Value { + let normalized = self.normalize_tool_arguments(arguments); + + if self.uses_json_tool_arguments() { + normalized + } else { + match normalized { + Value::String(raw) => Value::String(raw), + value => Value::String( + serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string()), + ), + } + } + } + + fn request_model_extra(&self) -> impl Iterator { + self.model_extra.iter().filter(|(key, _)| { + !INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str()) + }) + } + fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let mut body = json!({ "model": self.model_id, @@ -142,7 +165,7 @@ impl OpenAIProvider { "max_tokens": request.max_tokens.or(self.max_tokens), }); - for (key, value) in &self.model_extra { + for (key, value) in self.request_model_extra() { body[key] = value.clone(); } @@ -408,6 +431,78 @@ mod tests { let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"})); + assert!(body.get("tool_call_arguments_json").is_none()); + } + + #[test] + fn test_build_request_body_preserves_raw_json_string_arguments() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::new(), + ); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "assistant".to_string(), + content: vec![ContentBlock::text("calling tool")], + reasoning_content: None, + tool_call_id: None, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "calculator".to_string(), + arguments: Value::String("{\"expression\":\"1+1\"}".to_string()), + }]), + }], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + let messages = body["messages"].as_array().unwrap(); + let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); + + assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); + } + + #[test] + fn test_build_request_body_omits_internal_model_extra_keys() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::from([ + ("tool_call_arguments_json".to_string(), Value::Bool(true)), + ("mock_response_content".to_string(), Value::String("stub".to_string())), + ("parallel_tool_calls".to_string(), Value::Bool(true)), + ]), + ); + + let request = ChatCompletionRequest { + messages: vec![Message::user("hello")], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + + assert!(body.get("tool_call_arguments_json").is_none()); + assert!(body.get("mock_response_content").is_none()); + assert_eq!(body["parallel_tool_calls"], Value::Bool(true)); } #[test]