use async_trait::async_trait; use futures_util::StreamExt; use reqwest::Client; use serde::Deserialize; use serde_json::{Value, json}; use std::collections::HashMap; use std::time::Duration; use super::traits::Usage; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use crate::domain::messages::ContentBlock; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"]; /// 流式响应中的工具调用增量 #[derive(Debug, Default)] struct StreamingToolCall { id: String, name: String, arguments: String, } /// 流式响应累积器 #[derive(Debug, Default)] struct StreamingAccumulator { content: String, reasoning_content: Option, tool_calls: HashMap, response_id: String, } impl StreamingAccumulator { fn new() -> Self { Self::default() } /// 添加内容增量 fn add_content(&mut self, delta: &str) { self.content.push_str(delta); } /// 添加推理内容增量 fn add_reasoning_content(&mut self, delta: &str) { if self.reasoning_content.is_none() { self.reasoning_content = Some(String::new()); } self.reasoning_content.as_mut().unwrap().push_str(delta); } /// 添加工具调用增量 fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) { let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default); // 只在 id 非空时才更新,防止流式响应中后续 chunk 的空 id 覆盖之前的值 if let Some(id) = id { if !id.is_empty() { entry.id = id.to_string(); } } // 只在 name 非空时才更新,防止流式响应中后续 chunk 的 None 覆盖之前的值 if let Some(name) = name { if !name.is_empty() { entry.name = name.to_string(); } } if let Some(args) = arguments { entry.arguments.push_str(args); } } /// 设置响应 ID fn set_response_id(&mut self, id: String) { if self.response_id.is_empty() { self.response_id = id; } } /// 构建最终的 ChatCompletionResponse fn build_response(self, model: String) -> ChatCompletionResponse { let tool_calls: Vec = self.tool_calls .into_iter() .filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty()) .map(|(_, call)| { let arguments = serde_json::from_str(&call.arguments) .unwrap_or_else(|_| serde_json::Value::Null); ToolCall { id: call.id, name: call.name, arguments, } }) .collect(); ChatCompletionResponse { id: if self.response_id.is_empty() { format!("stream-{}", uuid::Uuid::new_v4()) } else { self.response_id }, model, content: self.content, reasoning_content: self.reasoning_content, tool_calls, usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, }, } } } fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; let mut current = error.source(); while let Some(source) = current { details.push(source.to_string()); current = source.source(); } details.join("\ncaused by: ") } fn format_transport_error_context( provider_name: &str, model_id: &str, url: &str, timeout_secs: u64, error: &(dyn std::error::Error + 'static), ) -> String { format!( "transport error: provider={} model={} url={} timeout_secs={} details={}", provider_name, model_id, url, timeout_secs, format_error_chain(error) ) } fn convert_content_blocks( supports_images: bool, provider_name: &str, model_id: &str, blocks: &[ContentBlock], message_idx: usize, ) -> Value { // 检查是否有图片且模型不支持 if !supports_images { let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. })); if has_images { let image_count = blocks.iter() .filter(|b| matches!(b, ContentBlock::ImageUrl { .. })) .count(); tracing::warn!( provider = %provider_name, model = %model_id, filtered_images = image_count, message_idx, "模型不支持图片;将图片转换为通知文本" ); // 复用通知格式,将图片转换为文本通知 let mut converted_blocks: Vec = Vec::new(); let mut notices: Vec = Vec::new(); let mut image_idx = 0; for block in blocks.iter() { match block { ContentBlock::Text { text } => { converted_blocks.push(json!({ "type": "text", "text": text })); } ContentBlock::ImageUrl { .. } => { image_idx += 1; notices.push(format!( "- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。", image_idx )); } } } // 添加通知文本块 if !notices.is_empty() { let notice_text = format!( "[系统提示] 以下图片未能成功入模:\n{}", notices.join("\n") ); converted_blocks.push(json!({ "type": "text", "text": notice_text })); } // 如果只有一个文本块且没有通知,返回字符串形式 if converted_blocks.len() == 1 { if let Some(block) = converted_blocks.first() { if block.get("type").and_then(|t| t.as_str()) == Some("text") { if let Some(text) = block.get("text").and_then(|t| t.as_str()) { return Value::String(text.to_string()); } } } } return Value::Array(converted_blocks); } } // 原有逻辑 - 模型支持图片,正常转换 if blocks.len() == 1 { if let ContentBlock::Text { text } = &blocks[0] { return Value::String(text.clone()); } } Value::Array( blocks .iter() .map(|b| match b { ContentBlock::Text { text } => json!({ "type": "text", "text": text }), ContentBlock::ImageUrl { image_url } => { json!({ "type": "image_url", "image_url": { "url": image_url.url } }) } }) .collect(), ) } pub struct OpenAIProvider { client: Client, name: String, api_key: String, base_url: String, extra_headers: HashMap, llm_timeout_secs: u64, model_id: String, temperature: Option, max_tokens: Option, model_extra: HashMap, } #[derive(Deserialize)] #[serde(untagged)] enum OAIFunctionArguments { Json(Value), String(String), } impl OpenAIProvider { pub fn new( name: String, api_key: String, base_url: String, extra_headers: HashMap, llm_timeout_secs: u64, model_id: String, temperature: Option, max_tokens: Option, model_extra: HashMap, ) -> Self { let client = Client::builder() .timeout(Duration::from_secs(llm_timeout_secs)) .build() .unwrap_or_else(|_| Client::new()); Self { client, name, api_key, base_url, extra_headers, llm_timeout_secs, model_id, temperature, max_tokens, model_extra, } } fn uses_json_tool_arguments(&self) -> bool { self.model_extra .get("tool_call_arguments_json") .and_then(|value| value.as_bool()) .unwrap_or(false) } /// 检查是否启用流式输出,默认启用 fn is_streaming_enabled(&self) -> bool { self.model_extra .get("enable_streaming") .and_then(|value| value.as_bool()) .unwrap_or(true) } /// 检查模型是否支持指定内容类型 /// 默认支持所有类型(text, image) fn supports_content_type(&self, content_type: &str) -> bool { self.model_extra .get("supported_content_types") .and_then(|value| value.as_array()) .map(|types| { types.iter().any(|t| t.as_str() == Some(content_type)) }) .unwrap_or(true) } /// 检查模型是否支持图片 fn supports_images(&self) -> bool { self.supports_content_type("image") } fn normalize_tool_arguments(&self, arguments: &Value) -> Value { match arguments { Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), _ => arguments.clone(), } } fn serialize_tool_arguments(&self, arguments: &Value) -> Value { let normalized = self.normalize_tool_arguments(arguments); if self.uses_json_tool_arguments() { // Model expects JSON object format (e.g., some code models) normalized } else { // Standard OpenAI format: arguments as JSON string // But ensure we serialize valid JSON, not null match normalized { Value::Null => Value::String("{}".to_string()), Value::String(raw) => { // If the string is already valid JSON, keep it as-is // Otherwise, ensure it's a proper JSON string if serde_json::from_str::(&raw).is_ok() { Value::String(raw) } else { // Invalid JSON string - wrap it as a proper JSON string Value::String(serde_json::to_string(&raw).unwrap_or_else(|_| "null".to_string())) } } value => Value::String( serde_json::to_string(&value).unwrap_or_else(|_| "{}".to_string()), ), } } } fn request_model_extra(&self) -> impl Iterator { self.model_extra.iter().filter(|(key, _)| { !INTERNAL_MODEL_EXTRA_KEYS .iter() .any(|internal| internal == &key.as_str()) }) } /// 内部流式聊天实现 async fn chat_streaming( &self, request: &ChatCompletionRequest, ) -> Result> { tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat"); let url = format!("{}/chat/completions", self.base_url); let mut body = self.build_request_body(request); // 启用流式输出 body["stream"] = json!(true); let mut req_builder = self .client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .header("Accept", "text/event-stream"); for (key, value) in &self.extra_headers { req_builder = req_builder.header(key.as_str(), value.as_str()); } let resp = req_builder.json(&body).send().await.map_err(|err| { format_transport_error_context( &self.name, &self.model_id, &url, self.llm_timeout_secs, &err, ) })?; let status = resp.status(); if !status.is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("API error {}: {}", status, text).into()); } let mut accumulator = StreamingAccumulator::new(); // 读取 SSE 流 let mut stream = resp.bytes_stream(); let mut buffer = String::new(); let mut done_received = false; while let Some(chunk_result) = stream.next().await { let chunk = chunk_result?; let text = String::from_utf8_lossy(&chunk); buffer.push_str(&text); // 处理缓冲区中的完整行 while let Some(newline_pos) = buffer.find('\n') { let line = buffer[..newline_pos].to_string(); buffer = buffer[newline_pos + 1..].to_string(); let line_trimmed = line.trim(); if line_trimmed.is_empty() || line_trimmed.starts_with(':') { continue; } // SSE 格式: data: {...} 或 data:{...}(某些 API 如 139 云没有空格) let data_opt = line_trimmed.strip_prefix("data: ") .or_else(|| line_trimmed.strip_prefix("data:")); if let Some(data) = data_opt { if data == "[DONE]" { // 流结束 done_received = true; break; } // 解析 JSON match serde_json::from_str::(data) { Ok(json) => { // 提取响应 ID if let Some(id) = json.get("id").and_then(|v| v.as_str()) { accumulator.set_response_id(id.to_string()); } // 提取 choices if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { for choice in choices { // 尝试从 delta 提取(标准 OpenAI 流式格式) if let Some(delta) = choice.get("delta") { // 提取内容增量 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { accumulator.add_content(content); } // 提取推理内容增量 if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) { accumulator.add_reasoning_content(reasoning); } // 提取工具调用增量 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { for tool_call in tool_calls { let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize; let id = tool_call.get("id").and_then(|v| v.as_str()); let name = tool_call.get("function") .and_then(|f| f.get("name")) .and_then(|n| n.as_str()); let arguments = tool_call.get("function") .and_then(|f| f.get("arguments")) .and_then(|a| a.as_str()); accumulator.add_tool_call(index, id, name, arguments); } } } // 尝试从 message 提取(某些非标准 API 格式) else if let Some(message) = choice.get("message") { if let Some(content) = message.get("content").and_then(|c| c.as_str()) { accumulator.add_content(content); } if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) { accumulator.add_reasoning_content(reasoning); } } } } } Err(e) => { tracing::debug!( error = %e, data = %data, "Failed to parse SSE data" ); } } } } if done_received { break; } } // 处理缓冲区中剩余的内容 for line in buffer.lines() { let line_trimmed = line.trim(); if line_trimmed.is_empty() || line_trimmed.starts_with(':') { continue; } // 同样支持 data: {...} 和 data:{...} 两种格式 let data_opt = line_trimmed.strip_prefix("data: ") .or_else(|| line_trimmed.strip_prefix("data:")); if let Some(data) = data_opt { if data == "[DONE]" { break; } if let Ok(json) = serde_json::from_str::(data) { if let Some(id) = json.get("id").and_then(|v| v.as_str()) { accumulator.set_response_id(id.to_string()); } if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { for choice in choices { // 尝试从 delta 提取 if let Some(delta) = choice.get("delta") { if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { accumulator.add_content(content); } if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) { accumulator.add_reasoning_content(reasoning); } if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { for tool_call in tool_calls { let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize; let id = tool_call.get("id").and_then(|v| v.as_str()); let name = tool_call.get("function") .and_then(|f| f.get("name")) .and_then(|n| n.as_str()); let arguments = tool_call.get("function") .and_then(|f| f.get("arguments")) .and_then(|a| a.as_str()); accumulator.add_tool_call(index, id, name, arguments); } } } // 尝试从 message 提取(某些非标准 API 格式) else if let Some(message) = choice.get("message") { if let Some(content) = message.get("content").and_then(|c| c.as_str()) { accumulator.add_content(content); } if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) { accumulator.add_reasoning_content(reasoning); } } } } } } } let response = accumulator.build_response(self.model_id.clone()); tracing::debug!( content_len = response.content.len(), tool_calls_count = response.tool_calls.len(), has_reasoning = response.reasoning_content.is_some(), "Streaming response built" ); Ok(response) } fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let supports_images = self.supports_images(); let mut body = json!({ "model": self.model_id, "messages": request.messages.iter().enumerate().map(|(i, m)| { if m.role == "tool" { json!({ "role": m.role, "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i), "tool_call_id": m.tool_call_id, "name": m.name, }) } else if m.role == "assistant" && m.tool_calls.is_some() { let mut message = json!({ "role": m.role, "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i), "tool_calls": m.tool_calls.as_ref().map(|calls| { calls.iter().map(|call| json!({ "id": call.id, "type": "function", "function": { "name": call.name, "arguments": self.serialize_tool_arguments(&call.arguments) } })).collect::>() }) }); if let Some(reasoning_content) = &m.reasoning_content { message["reasoning_content"] = Value::String(reasoning_content.clone()); } message } else { let mut message = json!({ "role": m.role, "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i) }); if m.role == "assistant" { if let Some(reasoning_content) = &m.reasoning_content { message["reasoning_content"] = Value::String(reasoning_content.clone()); } } message } }).collect::>(), }); // 只有配置了才添加 temperature,否则让模型使用默认值 if let Some(temp) = request.temperature.or(self.temperature) { body["temperature"] = json!(temp); } // 只有配置了才添加 max_tokens if let Some(tokens) = request.max_tokens.or(self.max_tokens) { body["max_tokens"] = json!(tokens); } for (key, value) in self.request_model_extra() { body[key] = value.clone(); } if let Some(tools) = &request.tools { body["tools"] = json!(tools); } body } } #[derive(Deserialize)] struct OpenAIResponse { id: String, model: String, choices: Vec, #[serde(default)] usage: OpenAIUsage, } #[derive(Deserialize)] struct OpenAIChoice { message: OpenAIMessage, } #[derive(Deserialize)] struct OpenAIMessage { #[serde(default)] content: Option, #[serde(default)] reasoning_content: Option, #[allow(dead_code)] #[serde(default)] name: Option, #[serde(default)] tool_calls: Vec, } #[derive(Deserialize)] struct OpenAIToolCall { id: String, #[serde(rename = "function")] function: OAIFunction, #[allow(dead_code)] #[serde(default)] index: Option, } #[derive(Deserialize)] struct OAIFunction { name: String, arguments: OAIFunctionArguments, } #[derive(Deserialize, Default)] struct OpenAIUsage { #[serde(default)] prompt_tokens: u32, #[serde(default)] completion_tokens: u32, #[serde(default)] total_tokens: u32, } #[async_trait] impl LLMProvider for OpenAIProvider { async fn chat( &self, request: ChatCompletionRequest, ) -> Result> { // 检查是否启用流式输出 if self.is_streaming_enabled() { // 优先尝试流式输出 match self.chat_streaming(&request).await { Ok(response) => return Ok(response), Err(e) => { tracing::debug!( provider = %self.name, model = %self.model_id, error = %e, "Streaming failed, falling back to non-streaming" ); // 流式失败,回退到非流式 } } } else { tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming"); } // 非流式回退实现 let url = format!("{}/chat/completions", self.base_url); let body = self.build_request_body(&request); // Debug: Log LLM request summary (only in debug builds) #[cfg(debug_assertions)] { // Log messages summary let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0); tracing::debug!(msg_count = msg_count, "LLM request messages count"); // Log first 20 bytes of base64 images (don't log full base64) if let Some(msgs) = body["messages"].as_array() { for (i, msg) in msgs.iter().enumerate() { if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { for (j, item) in content.iter().enumerate() { if item.get("type").and_then(|t| t.as_str()) == Some("image_url") { if let Some(url_str) = item .get("image_url") .and_then(|u| u.get("url")) .and_then(|v| v.as_str()) { let prefix: String = url_str.chars().take(20).collect(); tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); } } } } } } } let mut req_builder = self .client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json"); for (key, value) in &self.extra_headers { req_builder = req_builder.header(key.as_str(), value.as_str()); } let resp = req_builder.json(&body).send().await.map_err(|err| { let error_context = format_transport_error_context( &self.name, &self.model_id, &url, self.llm_timeout_secs, &err, ); tracing::error!( provider = %self.name, model = %self.model_id, url = %url, base_url = %self.base_url, timeout_secs = self.llm_timeout_secs, error = %error_context, "OpenAI-compatible API transport request failed" ); error_context })?; let status = resp.status(); let text = resp.text().await?; // Debug: Log LLM response (only in debug builds) if !status.is_success() { tracing::error!( provider = %self.name, model = %self.model_id, url = %url, status = %status, response_len = text.len(), response_body = %text, "OpenAI-compatible API request failed" ); return Err(format!("API error {}: {}", status, text).into()); } #[cfg(debug_assertions)] { let resp_preview: String = text.chars().take(100).collect(); tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)"); } let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| { tracing::error!( provider = %self.name, model = %self.model_id, url = %url, error = %format_error_chain(&e), response_len = text.len(), response_body = %text, "Failed to decode OpenAI-compatible API response" ); format!("decode error: {} | body: {}", e, &text) })?; let content = openai_resp.choices[0] .message .content .as_ref() .unwrap_or(&String::new()) .clone(); let tool_calls: Vec = openai_resp.choices[0] .message .tool_calls .iter() .map(|tc| ToolCall { id: tc.id.clone(), name: tc.function.name.clone(), arguments: match &tc.function.arguments { OAIFunctionArguments::Json(arguments) => arguments.clone(), OAIFunctionArguments::String(arguments) => { serde_json::from_str(arguments).unwrap_or(serde_json::Value::Null) } }, }) .collect(); Ok(ChatCompletionResponse { id: openai_resp.id, model: openai_resp.model, content, reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(), tool_calls, usage: Usage { prompt_tokens: openai_resp.usage.prompt_tokens, completion_tokens: openai_resp.usage.completion_tokens, total_tokens: openai_resp.usage.total_tokens, }, }) } fn ptype(&self) -> &str { "openai" } fn name(&self) -> &str { &self.name } fn model_id(&self) -> &str { &self.model_id } } #[cfg(test)] mod tests { use super::*; use crate::providers::Message; #[test] fn test_build_request_body_includes_assistant_tool_calls() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::new(), ); let request = ChatCompletionRequest { messages: vec![Message { role: "assistant".to_string(), content: vec![ContentBlock::text("calling tool")], reasoning_content: None, tool_call_id: None, name: None, tool_calls: Some(vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: json!({"expression": "1+1"}), }]), }], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); let messages = body["messages"].as_array().unwrap(); let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0]["id"], "call_1"); assert_eq!(tool_calls[0]["type"], "function"); assert_eq!(tool_calls[0]["function"]["name"], "calculator"); assert_eq!( tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}" ); } #[test] fn test_build_request_body_uses_json_tool_arguments_when_enabled() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]), ); let request = ChatCompletionRequest { messages: vec![Message { role: "assistant".to_string(), content: vec![ContentBlock::text("calling tool")], reasoning_content: None, tool_call_id: None, name: None, tool_calls: Some(vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: json!({"expression": "1+1"}), }]), }], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); let messages = body["messages"].as_array().unwrap(); let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); assert_eq!( tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}) ); assert!(body.get("tool_call_arguments_json").is_none()); } #[test] fn test_build_request_body_preserves_raw_json_string_arguments() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::new(), ); let request = ChatCompletionRequest { messages: vec![Message { role: "assistant".to_string(), content: vec![ContentBlock::text("calling tool")], reasoning_content: None, tool_call_id: None, name: None, tool_calls: Some(vec![ToolCall { id: "call_1".to_string(), name: "calculator".to_string(), arguments: Value::String("{\"expression\":\"1+1\"}".to_string()), }]), }], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); let messages = body["messages"].as_array().unwrap(); let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); assert_eq!( tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}" ); } #[test] fn test_build_request_body_omits_internal_model_extra_keys() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::from([ ("tool_call_arguments_json".to_string(), Value::Bool(true)), ( "mock_response_content".to_string(), Value::String("stub".to_string()), ), ("parallel_tool_calls".to_string(), Value::Bool(true)), ]), ); let request = ChatCompletionRequest { messages: vec![Message::user("hello")], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); assert!(body.get("tool_call_arguments_json").is_none()); assert!(body.get("mock_response_content").is_none()); assert_eq!(body["parallel_tool_calls"], Value::Bool(true)); } #[test] fn test_build_request_body_includes_assistant_reasoning_content() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::new(), ); let request = ChatCompletionRequest { messages: vec![Message { role: "assistant".to_string(), content: vec![ContentBlock::text("final answer")], reasoning_content: Some("step by step".to_string()), tool_call_id: None, name: None, tool_calls: None, }], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); let messages = body["messages"].as_array().unwrap(); assert_eq!(messages[0]["reasoning_content"], "step by step"); } #[test] fn test_openai_response_parses_reasoning_content() { let response: OpenAIResponse = serde_json::from_value(json!({ "id": "resp_1", "model": "gpt-test", "choices": [{ "message": { "content": "final answer", "reasoning_content": "hidden reasoning", "tool_calls": [] } }], "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 } })) .unwrap(); assert_eq!( response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning") ); } #[test] fn test_openai_response_parses_json_tool_arguments() { let response: OpenAIResponse = serde_json::from_value(json!({ "id": "resp_1", "model": "gpt-test", "choices": [{ "message": { "content": "", "tool_calls": [{ "id": "call_1", "function": { "name": "scheduler_manage", "arguments": {"action": "list"} } }] } }], "usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 } })) .unwrap(); match &response.choices[0].message.tool_calls[0].function.arguments { OAIFunctionArguments::Json(arguments) => { assert_eq!(arguments, &json!({"action": "list"})); } OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"), } } #[test] fn test_streaming_accumulator_preserves_tool_call_id_with_empty_subsequent_chunks() { // 模拟阿里云等云服务商的流式响应行为: // 第一个 chunk 包含 id 和 name,后续 chunk 的 id 为空字符串、name 为 None let mut accumulator = StreamingAccumulator::new(); // 第一个 chunk:包含完整的 id 和 name accumulator.add_tool_call(0, Some("call_abc123"), Some("memory_search"), Some("{\"action\":\"")); // 第二个 chunk:只有参数增量 accumulator.add_tool_call(0, None, None, Some("list")); // 第三个 chunk:参数继续 accumulator.add_tool_call(0, None, None, Some("\"")); // 第四个 chunk:id 为空字符串(某些云服务商的行为) accumulator.add_tool_call(0, Some(""), None, Some(", \"limit\": 20")); // 最后一个 chunk:name 为 None accumulator.add_tool_call(0, None, None, Some("}")); let response = accumulator.build_response("test-model".to_string()); // 验证工具调用被正确保留,id 没有被空字符串覆盖 assert_eq!(response.tool_calls.len(), 1); assert_eq!(response.tool_calls[0].id, "call_abc123"); assert_eq!(response.tool_calls[0].name, "memory_search"); assert_eq!(response.tool_calls[0].arguments, json!({"action":"list", "limit": 20})); } #[test] fn test_streaming_accumulator_handles_multiple_tool_calls() { let mut accumulator = StreamingAccumulator::new(); // 第一个工具调用 accumulator.add_tool_call(0, Some("call_1"), Some("calculator"), Some("{\"expr\": \"1+1\"}")); // 第二个工具调用(id 和 name 只在第一个 chunk 出现) accumulator.add_tool_call(1, Some("call_2"), Some("get_time"), Some("{}")); let response = accumulator.build_response("test-model".to_string()); assert_eq!(response.tool_calls.len(), 2); assert_eq!(response.tool_calls[0].id, "call_1"); assert_eq!(response.tool_calls[0].name, "calculator"); assert_eq!(response.tool_calls[1].id, "call_2"); assert_eq!(response.tool_calls[1].name, "get_time"); } #[test] fn test_supports_images_default_true() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::new(), ); assert!(provider.supports_images()); } #[test] fn test_supports_images_disabled_via_config() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::from([( "supported_content_types".to_string(), Value::Array(vec![Value::String("text".to_string())]), )]), ); assert!(!provider.supports_images()); } #[test] fn test_convert_content_blocks_converts_images_to_notice_when_disabled() { let blocks = vec![ ContentBlock::text("hello"), ContentBlock::image_url("data:image/png;base64,abc123"), ContentBlock::text("world"), ]; let result = convert_content_blocks(false, "test", "test-model", &blocks, 0); // 应该是数组形式 let arr = result.as_array().unwrap(); assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块 // 检查通知内容 let notice_block = arr[2].as_object().unwrap(); assert_eq!(notice_block["type"], "text"); let notice_text = notice_block["text"].as_str().unwrap(); assert!(notice_text.contains("[系统提示] 以下图片未能成功入模")); assert!(notice_text.contains("第 1 张图片")); assert!(notice_text.contains("当前模型不支持图片输入")); } #[test] fn test_convert_content_blocks_keeps_images_when_enabled() { let blocks = vec![ ContentBlock::text("hello"), ContentBlock::image_url("data:image/png;base64,abc123"), ]; let result = convert_content_blocks(true, "test", "test-model", &blocks, 0); // 应该是数组形式,包含文本和图片 let arr = result.as_array().unwrap(); assert_eq!(arr.len(), 2); assert_eq!(arr[0]["type"], "text"); assert_eq!(arr[1]["type"], "image_url"); } #[test] fn test_build_request_body_omits_supported_content_types_from_api() { let provider = OpenAIProvider::new( "test".to_string(), "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), 120, "gpt-test".to_string(), None, None, HashMap::from([ ( "supported_content_types".to_string(), Value::Array(vec![Value::String("text".to_string())]), ), ("custom_param".to_string(), Value::String("value".to_string())), ]), ); let request = ChatCompletionRequest { messages: vec![Message::user("hello")], temperature: None, max_tokens: None, tools: None, }; let body = provider.build_request_body(&request); // supported_content_types 不应该发送到 API assert!(body.get("supported_content_types").is_none()); // custom_param 应该保留 assert_eq!(body["custom_param"], Value::String("value".to_string())); } }