use async_trait::async_trait; use reqwest::Client; use serde::Deserialize; use serde_json::{json, Value}; use std::collections::HashMap; use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::traits::Usage; fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { if blocks.len() == 1 { if let ContentBlock::Text { text } = &blocks[0] { return Value::String(text.clone()); } } Value::Array(blocks.iter().map(|b| match b { ContentBlock::Text { text } => json!({ "type": "text", "text": text }), ContentBlock::ImageUrl { image_url } => { json!({ "type": "image_url", "image_url": { "url": image_url.url } }) } }).collect()) } pub struct OpenAIProvider { client: Client, name: String, api_key: String, base_url: String, extra_headers: HashMap, model_id: String, temperature: Option, max_tokens: Option, model_extra: HashMap, } impl OpenAIProvider { pub fn new( name: String, api_key: String, base_url: String, extra_headers: HashMap, model_id: String, temperature: Option, max_tokens: Option, model_extra: HashMap, ) -> Self { Self { client: Client::new(), name, api_key, base_url, extra_headers, model_id, temperature, max_tokens, model_extra, } } fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let mut body = json!({ "model": self.model_id, "messages": request.messages.iter().map(|m| { if m.role == "tool" { json!({ "role": m.role, "content": convert_content_blocks(&m.content), "tool_call_id": m.tool_call_id, "name": m.name, }) } else if m.role == "assistant" && m.tool_calls.is_some() { json!({ "role": m.role, "content": convert_content_blocks(&m.content), "tool_calls": m.tool_calls.as_ref().map(|calls| { calls.iter().map(|call| json!({ "id": call.id, "type": "function", "function": { "name": call.name, "arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string()) } })).collect::>() }) }) } else { json!({ "role": m.role, "content": convert_content_blocks(&m.content) }) } }).collect::>(), "temperature": request.temperature.or(self.temperature).unwrap_or(0.7), "max_tokens": request.max_tokens.or(self.max_tokens), }); for (key, value) in &self.model_extra { body[key] = value.clone(); } if let Some(tools) = &request.tools { body["tools"] = json!(tools); } body } } #[derive(Deserialize)] struct OpenAIResponse { id: String, model: String, choices: Vec, #[serde(default)] usage: OpenAIUsage, } #[derive(Deserialize)] struct OpenAIChoice { message: OpenAIMessage, } #[derive(Deserialize)] struct OpenAIMessage { #[serde(default)] content: Option, #[serde(default)] name: Option, #[serde(default)] tool_calls: Vec, } #[derive(Deserialize)] struct OpenAIToolCall { id: String, #[serde(rename = "function")] function: OAIFunction, #[serde(default)] index: Option, } #[derive(Deserialize)] struct OAIFunction { name: String, arguments: String, } #[derive(Deserialize, Default)] struct OpenAIUsage { #[serde(default)] prompt_tokens: u32, #[serde(default)] completion_tokens: u32, #[serde(default)] total_tokens: u32, } #[async_trait] impl LLMProvider for OpenAIProvider { async fn chat( &self, request: ChatCompletionRequest, ) -> Result> { let url = format!("{}/chat/completions", self.base_url); let body = self.build_request_body(&request); // Debug: Log LLM request summary (only in debug builds) #[cfg(debug_assertions)] { // Log messages summary let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0); tracing::debug!(msg_count = msg_count, "LLM request messages count"); // Log first 20 bytes of base64 images (don't log full base64) if let Some(msgs) = body["messages"].as_array() { for (i, msg) in msgs.iter().enumerate() { if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { for (j, item) in content.iter().enumerate() { if item.get("type").and_then(|t| t.as_str()) == Some("image_url") { if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) { let prefix: String = url_str.chars().take(20).collect(); tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); } } } } } } } let mut req_builder = self .client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json"); for (key, value) in &self.extra_headers { req_builder = req_builder.header(key.as_str(), value.as_str()); } let resp = req_builder.json(&body).send().await?; let status = resp.status(); let text = resp.text().await?; // Debug: Log LLM response (only in debug builds) #[cfg(debug_assertions)] { let resp_preview: String = text.chars().take(100).collect(); tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)"); } if !status.is_success() { return Err(format!("API error {}: {}", status, text).into()); } let openai_resp: OpenAIResponse = serde_json::from_str(&text) .map_err(|e| format!("decode error: {} | body: {}", e, &text))?; let content = openai_resp.choices[0] .message .content .as_ref() .unwrap_or(&String::new()) .clone(); let tool_calls: Vec = openai_resp.choices[0] .message .tool_calls .iter() .map(|tc| ToolCall { id: tc.id.clone(), name: tc.function.name.clone(), arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), }) .collect(); Ok(ChatCompletionResponse { id: openai_resp.id, model: openai_resp.model, content, tool_calls, usage: Usage { prompt_tokens: openai_resp.usage.prompt_tokens, completion_tokens: openai_resp.usage.completion_tokens, total_tokens: openai_resp.usage.total_tokens, }, }) } fn ptype(&self) -> &str { "openai" } fn name(&self) -> &str { &self.name } fn model_id(&self) -> &str { &self.model_id } } #[cfg(test)] mod tests { use super::*; use crate::providers::Message; #[test] fn test_build_request_body_includes_assistant_tool_calls() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), "gpt-test".to_string(), None, None, HashMap::new(), ); let request = ChatCompletionRequest { messages: vec![Message { role: "assistant".to_string(), content: vec![ContentBlock::text("calling tool")], tool_call_id: None, name: None, tool_calls: Some(vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: json!({"expression": "1+1"}), }]), }], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); let messages = body["messages"].as_array().unwrap(); let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0]["id"], "call_1"); assert_eq!(tool_calls[0]["type"], "function"); assert_eq!(tool_calls[0]["function"]["name"], "calculator"); assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); } }