PicoBot/src/providers/openai.rs
oudecheng b3dc207ad1 fix: 移除 temperature 和 max_tokens 的硬编码默认值
如果配置中没有设置 temperature 或 max_tokens,不再传递这些参数给模型,
让模型使用自己的默认值,而不是硬编码 0.7。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 12:58:15 +08:00

1285 lines
45 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
/// 流式响应中的工具调用增量
#[derive(Debug, Default)]
struct StreamingToolCall {
id: String,
name: String,
arguments: String,
}
/// 流式响应累积器
#[derive(Debug, Default)]
struct StreamingAccumulator {
content: String,
reasoning_content: Option<String>,
tool_calls: HashMap<usize, StreamingToolCall>,
response_id: String,
}
impl StreamingAccumulator {
fn new() -> Self {
Self::default()
}
/// 添加内容增量
fn add_content(&mut self, delta: &str) {
self.content.push_str(delta);
}
/// 添加推理内容增量
fn add_reasoning_content(&mut self, delta: &str) {
if self.reasoning_content.is_none() {
self.reasoning_content = Some(String::new());
}
self.reasoning_content.as_mut().unwrap().push_str(delta);
}
/// 添加工具调用增量
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
// 只在 id 非空时才更新,防止流式响应中后续 chunk 的空 id 覆盖之前的值
if let Some(id) = id {
if !id.is_empty() {
entry.id = id.to_string();
}
}
// 只在 name 非空时才更新,防止流式响应中后续 chunk 的 None 覆盖之前的值
if let Some(name) = name {
if !name.is_empty() {
entry.name = name.to_string();
}
}
if let Some(args) = arguments {
entry.arguments.push_str(args);
}
}
/// 设置响应 ID
fn set_response_id(&mut self, id: String) {
if self.response_id.is_empty() {
self.response_id = id;
}
}
/// 构建最终的 ChatCompletionResponse
fn build_response(self, model: String) -> ChatCompletionResponse {
let tool_calls: Vec<ToolCall> = self.tool_calls
.into_iter()
.filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty())
.map(|(_, call)| {
let arguments = serde_json::from_str(&call.arguments)
.unwrap_or_else(|_| serde_json::Value::Null);
ToolCall {
id: call.id,
name: call.name,
arguments,
}
})
.collect();
ChatCompletionResponse {
id: if self.response_id.is_empty() {
format!("stream-{}", uuid::Uuid::new_v4())
} else {
self.response_id
},
model,
content: self.content,
reasoning_content: self.reasoning_content,
tool_calls,
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
}
}
}
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
let mut current = error.source();
while let Some(source) = current {
details.push(source.to_string());
current = source.source();
}
details.join("\ncaused by: ")
}
fn format_transport_error_context(
provider_name: &str,
model_id: &str,
url: &str,
timeout_secs: u64,
error: &(dyn std::error::Error + 'static),
) -> String {
format!(
"transport error: provider={} model={} url={} timeout_secs={} details={}",
provider_name,
model_id,
url,
timeout_secs,
format_error_chain(error)
)
}
fn convert_content_blocks(
supports_images: bool,
provider_name: &str,
model_id: &str,
blocks: &[ContentBlock],
message_idx: usize,
) -> Value {
// 检查是否有图片且模型不支持
if !supports_images {
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_images {
let image_count = blocks.iter()
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
.count();
tracing::warn!(
provider = %provider_name,
model = %model_id,
filtered_images = image_count,
message_idx,
"模型不支持图片;将图片转换为通知文本"
);
// 复用通知格式,将图片转换为文本通知
let mut converted_blocks: Vec<Value> = Vec::new();
let mut notices: Vec<String> = Vec::new();
let mut image_idx = 0;
for block in blocks.iter() {
match block {
ContentBlock::Text { text } => {
converted_blocks.push(json!({ "type": "text", "text": text }));
}
ContentBlock::ImageUrl { .. } => {
image_idx += 1;
notices.push(format!(
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
image_idx
));
}
}
}
// 添加通知文本块
if !notices.is_empty() {
let notice_text = format!(
"[系统提示] 以下图片未能成功入模:\n{}",
notices.join("\n")
);
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
}
// 如果只有一个文本块且没有通知,返回字符串形式
if converted_blocks.len() == 1 {
if let Some(block) = converted_blocks.first() {
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
return Value::String(text.to_string());
}
}
}
}
return Value::Array(converted_blocks);
}
}
// 原有逻辑 - 模型支持图片,正常转换
if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone());
}
}
Value::Array(
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
})
.collect(),
)
}
pub struct OpenAIProvider {
client: Client,
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum OAIFunctionArguments {
Json(Value),
String(String),
}
impl OpenAIProvider {
pub fn new(
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(llm_timeout_secs))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client,
name,
api_key,
base_url,
extra_headers,
llm_timeout_secs,
model_id,
temperature,
max_tokens,
model_extra,
}
}
fn uses_json_tool_arguments(&self) -> bool {
self.model_extra
.get("tool_call_arguments_json")
.and_then(|value| value.as_bool())
.unwrap_or(false)
}
/// 检查是否启用流式输出,默认启用
fn is_streaming_enabled(&self) -> bool {
self.model_extra
.get("enable_streaming")
.and_then(|value| value.as_bool())
.unwrap_or(true)
}
/// 检查模型是否支持指定内容类型
/// 默认支持所有类型text, image
fn supports_content_type(&self, content_type: &str) -> bool {
self.model_extra
.get("supported_content_types")
.and_then(|value| value.as_array())
.map(|types| {
types.iter().any(|t| t.as_str() == Some(content_type))
})
.unwrap_or(true)
}
/// 检查模型是否支持图片
fn supports_images(&self) -> bool {
self.supports_content_type("image")
}
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
match arguments {
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
_ => arguments.clone(),
}
}
fn serialize_tool_arguments(&self, arguments: &Value) -> Value {
let normalized = self.normalize_tool_arguments(arguments);
if self.uses_json_tool_arguments() {
// Model expects JSON object format (e.g., some code models)
normalized
} else {
// Standard OpenAI format: arguments as JSON string
// But ensure we serialize valid JSON, not null
match normalized {
Value::Null => Value::String("{}".to_string()),
Value::String(raw) => {
// If the string is already valid JSON, keep it as-is
// Otherwise, ensure it's a proper JSON string
if serde_json::from_str::<Value>(&raw).is_ok() {
Value::String(raw)
} else {
// Invalid JSON string - wrap it as a proper JSON string
Value::String(serde_json::to_string(&raw).unwrap_or_else(|_| "null".to_string()))
}
}
value => Value::String(
serde_json::to_string(&value).unwrap_or_else(|_| "{}".to_string()),
),
}
}
}
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
self.model_extra.iter().filter(|(key, _)| {
!INTERNAL_MODEL_EXTRA_KEYS
.iter()
.any(|internal| internal == &key.as_str())
})
}
/// 内部流式聊天实现
async fn chat_streaming(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
let url = format!("{}/chat/completions", self.base_url);
let mut body = self.build_request_body(request);
// 启用流式输出
body["stream"] = json!(true);
let mut req_builder = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await.map_err(|err| {
format_transport_error_context(
&self.name,
&self.model_id,
&url,
self.llm_timeout_secs,
&err,
)
})?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(format!("API error {}: {}", status, text).into());
}
let mut accumulator = StreamingAccumulator::new();
// 读取 SSE 流
let mut stream = resp.bytes_stream();
let mut buffer = String::new();
let mut done_received = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
let text = String::from_utf8_lossy(&chunk);
buffer.push_str(&text);
// 处理缓冲区中的完整行
while let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].to_string();
buffer = buffer[newline_pos + 1..].to_string();
let line_trimmed = line.trim();
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
continue;
}
// SSE 格式: data: {...} 或 data:{...}(某些 API 如 139 云没有空格)
let data_opt = line_trimmed.strip_prefix("data: ")
.or_else(|| line_trimmed.strip_prefix("data:"));
if let Some(data) = data_opt {
if data == "[DONE]" {
// 流结束
done_received = true;
break;
}
// 解析 JSON
match serde_json::from_str::<Value>(data) {
Ok(json) => {
// 提取响应 ID
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
accumulator.set_response_id(id.to_string());
}
// 提取 choices
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
// 尝试从 delta 提取(标准 OpenAI 流式格式)
if let Some(delta) = choice.get("delta") {
// 提取内容增量
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
// 提取推理内容增量
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
// 提取工具调用增量
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
for tool_call in tool_calls {
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
let id = tool_call.get("id").and_then(|v| v.as_str());
let name = tool_call.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str());
let arguments = tool_call.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str());
accumulator.add_tool_call(index, id, name, arguments);
}
}
}
// 尝试从 message 提取(某些非标准 API 格式)
else if let Some(message) = choice.get("message") {
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
}
}
}
}
Err(e) => {
tracing::debug!(
error = %e,
data = %data,
"Failed to parse SSE data"
);
}
}
}
}
if done_received {
break;
}
}
// 处理缓冲区中剩余的内容
for line in buffer.lines() {
let line_trimmed = line.trim();
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
continue;
}
// 同样支持 data: {...} 和 data:{...} 两种格式
let data_opt = line_trimmed.strip_prefix("data: ")
.or_else(|| line_trimmed.strip_prefix("data:"));
if let Some(data) = data_opt {
if data == "[DONE]" {
break;
}
if let Ok(json) = serde_json::from_str::<Value>(data) {
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
accumulator.set_response_id(id.to_string());
}
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
// 尝试从 delta 提取
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
for tool_call in tool_calls {
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
let id = tool_call.get("id").and_then(|v| v.as_str());
let name = tool_call.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str());
let arguments = tool_call.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str());
accumulator.add_tool_call(index, id, name, arguments);
}
}
}
// 尝试从 message 提取(某些非标准 API 格式)
else if let Some(message) = choice.get("message") {
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
}
}
}
}
}
}
let response = accumulator.build_response(self.model_id.clone());
tracing::debug!(
content_len = response.content.len(),
tool_calls_count = response.tool_calls.len(),
has_reasoning = response.reasoning_content.is_some(),
"Streaming response built"
);
Ok(response)
}
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let supports_images = self.supports_images();
let mut body = json!({
"model": self.model_id,
"messages": request.messages.iter().enumerate().map(|(i, m)| {
if m.role == "tool" {
json!({
"role": m.role,
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else if m.role == "assistant" && m.tool_calls.is_some() {
let mut message = json!({
"role": m.role,
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
"tool_calls": m.tool_calls.as_ref().map(|calls| {
calls.iter().map(|call| json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": self.serialize_tool_arguments(&call.arguments)
}
})).collect::<Vec<_>>()
})
});
if let Some(reasoning_content) = &m.reasoning_content {
message["reasoning_content"] = Value::String(reasoning_content.clone());
}
message
} else {
let mut message = json!({
"role": m.role,
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
});
if m.role == "assistant" {
if let Some(reasoning_content) = &m.reasoning_content {
message["reasoning_content"] = Value::String(reasoning_content.clone());
}
}
message
}
}).collect::<Vec<_>>(),
});
// 只有配置了才添加 temperature否则让模型使用默认值
if let Some(temp) = request.temperature.or(self.temperature) {
body["temperature"] = json!(temp);
}
// 只有配置了才添加 max_tokens
if let Some(tokens) = request.max_tokens.or(self.max_tokens) {
body["max_tokens"] = json!(tokens);
}
for (key, value) in self.request_model_extra() {
body[key] = value.clone();
}
if let Some(tools) = &request.tools {
body["tools"] = json!(tools);
}
body
}
}
#[derive(Deserialize)]
struct OpenAIResponse {
id: String,
model: String,
choices: Vec<OpenAIChoice>,
#[serde(default)]
usage: OpenAIUsage,
}
#[derive(Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[derive(Deserialize)]
struct OpenAIMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
#[allow(dead_code)]
#[serde(default)]
name: Option<String>,
#[serde(default)]
tool_calls: Vec<OpenAIToolCall>,
}
#[derive(Deserialize)]
struct OpenAIToolCall {
id: String,
#[serde(rename = "function")]
function: OAIFunction,
#[allow(dead_code)]
#[serde(default)]
index: Option<u32>,
}
#[derive(Deserialize)]
struct OAIFunction {
name: String,
arguments: OAIFunctionArguments,
}
#[derive(Deserialize, Default)]
struct OpenAIUsage {
#[serde(default)]
prompt_tokens: u32,
#[serde(default)]
completion_tokens: u32,
#[serde(default)]
total_tokens: u32,
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
// 检查是否启用流式输出
if self.is_streaming_enabled() {
// 优先尝试流式输出
match self.chat_streaming(&request).await {
Ok(response) => return Ok(response),
Err(e) => {
tracing::debug!(
provider = %self.name,
model = %self.model_id,
error = %e,
"Streaming failed, falling back to non-streaming"
);
// 流式失败,回退到非流式
}
}
} else {
tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming");
}
// 非流式回退实现
let url = format!("{}/chat/completions", self.base_url);
let body = self.build_request_body(&request);
// Debug: Log LLM request summary (only in debug builds)
#[cfg(debug_assertions)]
{
// Log messages summary
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
tracing::debug!(msg_count = msg_count, "LLM request messages count");
// Log first 20 bytes of base64 images (don't log full base64)
if let Some(msgs) = body["messages"].as_array() {
for (i, msg) in msgs.iter().enumerate() {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
if let Some(url_str) = item
.get("image_url")
.and_then(|u| u.get("url"))
.and_then(|v| v.as_str())
{
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
}
}
}
}
}
}
let mut req_builder = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await.map_err(|err| {
let error_context = format_transport_error_context(
&self.name,
&self.model_id,
&url,
self.llm_timeout_secs,
&err,
);
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
base_url = %self.base_url,
timeout_secs = self.llm_timeout_secs,
error = %error_context,
"OpenAI-compatible API transport request failed"
);
error_context
})?;
let status = resp.status();
let text = resp.text().await?;
// Debug: Log LLM response (only in debug builds)
if !status.is_success() {
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
status = %status,
response_len = text.len(),
response_body = %text,
"OpenAI-compatible API request failed"
);
return Err(format!("API error {}: {}", status, text).into());
}
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)");
}
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
error = %format_error_chain(&e),
response_len = text.len(),
response_body = %text,
"Failed to decode OpenAI-compatible API response"
);
format!("decode error: {} | body: {}", e, &text)
})?;
let content = openai_resp.choices[0]
.message
.content
.as_ref()
.unwrap_or(&String::new())
.clone();
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
.message
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: match &tc.function.arguments {
OAIFunctionArguments::Json(arguments) => arguments.clone(),
OAIFunctionArguments::String(arguments) => {
serde_json::from_str(arguments).unwrap_or(serde_json::Value::Null)
}
},
})
.collect();
Ok(ChatCompletionResponse {
id: openai_resp.id,
model: openai_resp.model,
content,
reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(),
tool_calls,
usage: Usage {
prompt_tokens: openai_resp.usage.prompt_tokens,
completion_tokens: openai_resp.usage.completion_tokens,
total_tokens: openai_resp.usage.total_tokens,
},
})
}
fn ptype(&self) -> &str {
"openai"
}
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.model_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::Message;
#[test]
fn test_build_request_body_includes_assistant_tool_calls() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("calling tool")],
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1+1"}),
}]),
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
}
#[test]
fn test_build_request_body_uses_json_tool_arguments_when_enabled() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("calling tool")],
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1+1"}),
}]),
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(
tool_calls[0]["function"]["arguments"],
json!({"expression": "1+1"})
);
assert!(body.get("tool_call_arguments_json").is_none());
}
#[test]
fn test_build_request_body_preserves_raw_json_string_arguments() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("calling tool")],
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: Value::String("{\"expression\":\"1+1\"}".to_string()),
}]),
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
}
#[test]
fn test_build_request_body_omits_internal_model_extra_keys() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([
("tool_call_arguments_json".to_string(), Value::Bool(true)),
(
"mock_response_content".to_string(),
Value::String("stub".to_string()),
),
("parallel_tool_calls".to_string(), Value::Bool(true)),
]),
);
let request = ChatCompletionRequest {
messages: vec![Message::user("hello")],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
assert!(body.get("tool_call_arguments_json").is_none());
assert!(body.get("mock_response_content").is_none());
assert_eq!(body["parallel_tool_calls"], Value::Bool(true));
}
#[test]
fn test_build_request_body_includes_assistant_reasoning_content() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("final answer")],
reasoning_content: Some("step by step".to_string()),
tool_call_id: None,
name: None,
tool_calls: None,
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
assert_eq!(messages[0]["reasoning_content"], "step by step");
}
#[test]
fn test_openai_response_parses_reasoning_content() {
let response: OpenAIResponse = serde_json::from_value(json!({
"id": "resp_1",
"model": "gpt-test",
"choices": [{
"message": {
"content": "final answer",
"reasoning_content": "hidden reasoning",
"tool_calls": []
}
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
}))
.unwrap();
assert_eq!(
response.choices[0].message.reasoning_content.as_deref(),
Some("hidden reasoning")
);
}
#[test]
fn test_openai_response_parses_json_tool_arguments() {
let response: OpenAIResponse = serde_json::from_value(json!({
"id": "resp_1",
"model": "gpt-test",
"choices": [{
"message": {
"content": "",
"tool_calls": [{
"id": "call_1",
"function": {
"name": "scheduler_manage",
"arguments": {"action": "list"}
}
}]
}
}],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2
}
}))
.unwrap();
match &response.choices[0].message.tool_calls[0].function.arguments {
OAIFunctionArguments::Json(arguments) => {
assert_eq!(arguments, &json!({"action": "list"}));
}
OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"),
}
}
#[test]
fn test_streaming_accumulator_preserves_tool_call_id_with_empty_subsequent_chunks() {
// 模拟阿里云等云服务商的流式响应行为:
// 第一个 chunk 包含 id 和 name后续 chunk 的 id 为空字符串、name 为 None
let mut accumulator = StreamingAccumulator::new();
// 第一个 chunk包含完整的 id 和 name
accumulator.add_tool_call(0, Some("call_abc123"), Some("memory_search"), Some("{\"action\":\""));
// 第二个 chunk只有参数增量
accumulator.add_tool_call(0, None, None, Some("list"));
// 第三个 chunk参数继续
accumulator.add_tool_call(0, None, None, Some("\""));
// 第四个 chunkid 为空字符串(某些云服务商的行为)
accumulator.add_tool_call(0, Some(""), None, Some(", \"limit\": 20"));
// 最后一个 chunkname 为 None
accumulator.add_tool_call(0, None, None, Some("}"));
let response = accumulator.build_response("test-model".to_string());
// 验证工具调用被正确保留id 没有被空字符串覆盖
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.tool_calls[0].id, "call_abc123");
assert_eq!(response.tool_calls[0].name, "memory_search");
assert_eq!(response.tool_calls[0].arguments, json!({"action":"list", "limit": 20}));
}
#[test]
fn test_streaming_accumulator_handles_multiple_tool_calls() {
let mut accumulator = StreamingAccumulator::new();
// 第一个工具调用
accumulator.add_tool_call(0, Some("call_1"), Some("calculator"), Some("{\"expr\": \"1+1\"}"));
// 第二个工具调用id 和 name 只在第一个 chunk 出现)
accumulator.add_tool_call(1, Some("call_2"), Some("get_time"), Some("{}"));
let response = accumulator.build_response("test-model".to_string());
assert_eq!(response.tool_calls.len(), 2);
assert_eq!(response.tool_calls[0].id, "call_1");
assert_eq!(response.tool_calls[0].name, "calculator");
assert_eq!(response.tool_calls[1].id, "call_2");
assert_eq!(response.tool_calls[1].name, "get_time");
}
#[test]
fn test_supports_images_default_true() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
assert!(provider.supports_images());
}
#[test]
fn test_supports_images_disabled_via_config() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([(
"supported_content_types".to_string(),
Value::Array(vec![Value::String("text".to_string())]),
)]),
);
assert!(!provider.supports_images());
}
#[test]
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
let blocks = vec![
ContentBlock::text("hello"),
ContentBlock::image_url("data:image/png;base64,abc123"),
ContentBlock::text("world"),
];
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
// 应该是数组形式
let arr = result.as_array().unwrap();
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
// 检查通知内容
let notice_block = arr[2].as_object().unwrap();
assert_eq!(notice_block["type"], "text");
let notice_text = notice_block["text"].as_str().unwrap();
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
assert!(notice_text.contains("第 1 张图片"));
assert!(notice_text.contains("当前模型不支持图片输入"));
}
#[test]
fn test_convert_content_blocks_keeps_images_when_enabled() {
let blocks = vec![
ContentBlock::text("hello"),
ContentBlock::image_url("data:image/png;base64,abc123"),
];
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
// 应该是数组形式,包含文本和图片
let arr = result.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["type"], "text");
assert_eq!(arr[1]["type"], "image_url");
}
#[test]
fn test_build_request_body_omits_supported_content_types_from_api() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([
(
"supported_content_types".to_string(),
Value::Array(vec![Value::String("text".to_string())]),
),
("custom_param".to_string(), Value::String("value".to_string())),
]),
);
let request = ChatCompletionRequest {
messages: vec![Message::user("hello")],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
// supported_content_types 不应该发送到 API
assert!(body.get("supported_content_types").is_none());
// custom_param 应该保留
assert_eq!(body["custom_param"], Value::String("value".to_string()));
}
}