如果配置中没有设置 temperature 或 max_tokens,不再传递这些参数给模型, 让模型使用自己的默认值,而不是硬编码 0.7。 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1285 lines
45 KiB
Rust
1285 lines
45 KiB
Rust
use async_trait::async_trait;
|
||
use futures_util::StreamExt;
|
||
use reqwest::Client;
|
||
use serde::Deserialize;
|
||
use serde_json::{Value, json};
|
||
use std::collections::HashMap;
|
||
use std::time::Duration;
|
||
|
||
use super::traits::Usage;
|
||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||
use crate::domain::messages::ContentBlock;
|
||
|
||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
|
||
|
||
/// 流式响应中的工具调用增量
|
||
#[derive(Debug, Default)]
|
||
struct StreamingToolCall {
|
||
id: String,
|
||
name: String,
|
||
arguments: String,
|
||
}
|
||
|
||
/// 流式响应累积器
|
||
#[derive(Debug, Default)]
|
||
struct StreamingAccumulator {
|
||
content: String,
|
||
reasoning_content: Option<String>,
|
||
tool_calls: HashMap<usize, StreamingToolCall>,
|
||
response_id: String,
|
||
}
|
||
|
||
impl StreamingAccumulator {
|
||
fn new() -> Self {
|
||
Self::default()
|
||
}
|
||
|
||
/// 添加内容增量
|
||
fn add_content(&mut self, delta: &str) {
|
||
self.content.push_str(delta);
|
||
}
|
||
|
||
/// 添加推理内容增量
|
||
fn add_reasoning_content(&mut self, delta: &str) {
|
||
if self.reasoning_content.is_none() {
|
||
self.reasoning_content = Some(String::new());
|
||
}
|
||
self.reasoning_content.as_mut().unwrap().push_str(delta);
|
||
}
|
||
|
||
/// 添加工具调用增量
|
||
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
|
||
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
|
||
|
||
// 只在 id 非空时才更新,防止流式响应中后续 chunk 的空 id 覆盖之前的值
|
||
if let Some(id) = id {
|
||
if !id.is_empty() {
|
||
entry.id = id.to_string();
|
||
}
|
||
}
|
||
// 只在 name 非空时才更新,防止流式响应中后续 chunk 的 None 覆盖之前的值
|
||
if let Some(name) = name {
|
||
if !name.is_empty() {
|
||
entry.name = name.to_string();
|
||
}
|
||
}
|
||
if let Some(args) = arguments {
|
||
entry.arguments.push_str(args);
|
||
}
|
||
}
|
||
|
||
/// 设置响应 ID
|
||
fn set_response_id(&mut self, id: String) {
|
||
if self.response_id.is_empty() {
|
||
self.response_id = id;
|
||
}
|
||
}
|
||
|
||
/// 构建最终的 ChatCompletionResponse
|
||
fn build_response(self, model: String) -> ChatCompletionResponse {
|
||
let tool_calls: Vec<ToolCall> = self.tool_calls
|
||
.into_iter()
|
||
.filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty())
|
||
.map(|(_, call)| {
|
||
let arguments = serde_json::from_str(&call.arguments)
|
||
.unwrap_or_else(|_| serde_json::Value::Null);
|
||
ToolCall {
|
||
id: call.id,
|
||
name: call.name,
|
||
arguments,
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
ChatCompletionResponse {
|
||
id: if self.response_id.is_empty() {
|
||
format!("stream-{}", uuid::Uuid::new_v4())
|
||
} else {
|
||
self.response_id
|
||
},
|
||
model,
|
||
content: self.content,
|
||
reasoning_content: self.reasoning_content,
|
||
tool_calls,
|
||
usage: Usage {
|
||
prompt_tokens: 0,
|
||
completion_tokens: 0,
|
||
total_tokens: 0,
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||
let mut details = vec![error.to_string()];
|
||
let mut current = error.source();
|
||
|
||
while let Some(source) = current {
|
||
details.push(source.to_string());
|
||
current = source.source();
|
||
}
|
||
|
||
details.join("\ncaused by: ")
|
||
}
|
||
|
||
fn format_transport_error_context(
|
||
provider_name: &str,
|
||
model_id: &str,
|
||
url: &str,
|
||
timeout_secs: u64,
|
||
error: &(dyn std::error::Error + 'static),
|
||
) -> String {
|
||
format!(
|
||
"transport error: provider={} model={} url={} timeout_secs={} details={}",
|
||
provider_name,
|
||
model_id,
|
||
url,
|
||
timeout_secs,
|
||
format_error_chain(error)
|
||
)
|
||
}
|
||
|
||
fn convert_content_blocks(
|
||
supports_images: bool,
|
||
provider_name: &str,
|
||
model_id: &str,
|
||
blocks: &[ContentBlock],
|
||
message_idx: usize,
|
||
) -> Value {
|
||
// 检查是否有图片且模型不支持
|
||
if !supports_images {
|
||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||
|
||
if has_images {
|
||
let image_count = blocks.iter()
|
||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||
.count();
|
||
|
||
tracing::warn!(
|
||
provider = %provider_name,
|
||
model = %model_id,
|
||
filtered_images = image_count,
|
||
message_idx,
|
||
"模型不支持图片;将图片转换为通知文本"
|
||
);
|
||
|
||
// 复用通知格式,将图片转换为文本通知
|
||
let mut converted_blocks: Vec<Value> = Vec::new();
|
||
let mut notices: Vec<String> = Vec::new();
|
||
let mut image_idx = 0;
|
||
|
||
for block in blocks.iter() {
|
||
match block {
|
||
ContentBlock::Text { text } => {
|
||
converted_blocks.push(json!({ "type": "text", "text": text }));
|
||
}
|
||
ContentBlock::ImageUrl { .. } => {
|
||
image_idx += 1;
|
||
notices.push(format!(
|
||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||
image_idx
|
||
));
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加通知文本块
|
||
if !notices.is_empty() {
|
||
let notice_text = format!(
|
||
"[系统提示] 以下图片未能成功入模:\n{}",
|
||
notices.join("\n")
|
||
);
|
||
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
|
||
}
|
||
|
||
// 如果只有一个文本块且没有通知,返回字符串形式
|
||
if converted_blocks.len() == 1 {
|
||
if let Some(block) = converted_blocks.first() {
|
||
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
|
||
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
||
return Value::String(text.to_string());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return Value::Array(converted_blocks);
|
||
}
|
||
}
|
||
|
||
// 原有逻辑 - 模型支持图片,正常转换
|
||
if blocks.len() == 1 {
|
||
if let ContentBlock::Text { text } = &blocks[0] {
|
||
return Value::String(text.clone());
|
||
}
|
||
}
|
||
Value::Array(
|
||
blocks
|
||
.iter()
|
||
.map(|b| match b {
|
||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||
ContentBlock::ImageUrl { image_url } => {
|
||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||
}
|
||
})
|
||
.collect(),
|
||
)
|
||
}
|
||
|
||
pub struct OpenAIProvider {
|
||
client: Client,
|
||
name: String,
|
||
api_key: String,
|
||
base_url: String,
|
||
extra_headers: HashMap<String, String>,
|
||
llm_timeout_secs: u64,
|
||
model_id: String,
|
||
temperature: Option<f32>,
|
||
max_tokens: Option<u32>,
|
||
model_extra: HashMap<String, serde_json::Value>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
#[serde(untagged)]
|
||
enum OAIFunctionArguments {
|
||
Json(Value),
|
||
String(String),
|
||
}
|
||
|
||
impl OpenAIProvider {
|
||
pub fn new(
|
||
name: String,
|
||
api_key: String,
|
||
base_url: String,
|
||
extra_headers: HashMap<String, String>,
|
||
llm_timeout_secs: u64,
|
||
model_id: String,
|
||
temperature: Option<f32>,
|
||
max_tokens: Option<u32>,
|
||
model_extra: HashMap<String, serde_json::Value>,
|
||
) -> Self {
|
||
let client = Client::builder()
|
||
.timeout(Duration::from_secs(llm_timeout_secs))
|
||
.build()
|
||
.unwrap_or_else(|_| Client::new());
|
||
|
||
Self {
|
||
client,
|
||
name,
|
||
api_key,
|
||
base_url,
|
||
extra_headers,
|
||
llm_timeout_secs,
|
||
model_id,
|
||
temperature,
|
||
max_tokens,
|
||
model_extra,
|
||
}
|
||
}
|
||
|
||
fn uses_json_tool_arguments(&self) -> bool {
|
||
self.model_extra
|
||
.get("tool_call_arguments_json")
|
||
.and_then(|value| value.as_bool())
|
||
.unwrap_or(false)
|
||
}
|
||
|
||
/// 检查是否启用流式输出,默认启用
|
||
fn is_streaming_enabled(&self) -> bool {
|
||
self.model_extra
|
||
.get("enable_streaming")
|
||
.and_then(|value| value.as_bool())
|
||
.unwrap_or(true)
|
||
}
|
||
|
||
/// 检查模型是否支持指定内容类型
|
||
/// 默认支持所有类型(text, image)
|
||
fn supports_content_type(&self, content_type: &str) -> bool {
|
||
self.model_extra
|
||
.get("supported_content_types")
|
||
.and_then(|value| value.as_array())
|
||
.map(|types| {
|
||
types.iter().any(|t| t.as_str() == Some(content_type))
|
||
})
|
||
.unwrap_or(true)
|
||
}
|
||
|
||
/// 检查模型是否支持图片
|
||
fn supports_images(&self) -> bool {
|
||
self.supports_content_type("image")
|
||
}
|
||
|
||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||
match arguments {
|
||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||
_ => arguments.clone(),
|
||
}
|
||
}
|
||
|
||
fn serialize_tool_arguments(&self, arguments: &Value) -> Value {
|
||
let normalized = self.normalize_tool_arguments(arguments);
|
||
|
||
if self.uses_json_tool_arguments() {
|
||
// Model expects JSON object format (e.g., some code models)
|
||
normalized
|
||
} else {
|
||
// Standard OpenAI format: arguments as JSON string
|
||
// But ensure we serialize valid JSON, not null
|
||
match normalized {
|
||
Value::Null => Value::String("{}".to_string()),
|
||
Value::String(raw) => {
|
||
// If the string is already valid JSON, keep it as-is
|
||
// Otherwise, ensure it's a proper JSON string
|
||
if serde_json::from_str::<Value>(&raw).is_ok() {
|
||
Value::String(raw)
|
||
} else {
|
||
// Invalid JSON string - wrap it as a proper JSON string
|
||
Value::String(serde_json::to_string(&raw).unwrap_or_else(|_| "null".to_string()))
|
||
}
|
||
}
|
||
value => Value::String(
|
||
serde_json::to_string(&value).unwrap_or_else(|_| "{}".to_string()),
|
||
),
|
||
}
|
||
}
|
||
}
|
||
|
||
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
|
||
self.model_extra.iter().filter(|(key, _)| {
|
||
!INTERNAL_MODEL_EXTRA_KEYS
|
||
.iter()
|
||
.any(|internal| internal == &key.as_str())
|
||
})
|
||
}
|
||
|
||
/// 内部流式聊天实现
|
||
async fn chat_streaming(
|
||
&self,
|
||
request: &ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
|
||
|
||
let url = format!("{}/chat/completions", self.base_url);
|
||
|
||
let mut body = self.build_request_body(request);
|
||
// 启用流式输出
|
||
body["stream"] = json!(true);
|
||
|
||
let mut req_builder = self
|
||
.client
|
||
.post(&url)
|
||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||
.header("Content-Type", "application/json")
|
||
.header("Accept", "text/event-stream");
|
||
|
||
for (key, value) in &self.extra_headers {
|
||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||
}
|
||
|
||
let resp = req_builder.json(&body).send().await.map_err(|err| {
|
||
format_transport_error_context(
|
||
&self.name,
|
||
&self.model_id,
|
||
&url,
|
||
self.llm_timeout_secs,
|
||
&err,
|
||
)
|
||
})?;
|
||
|
||
let status = resp.status();
|
||
if !status.is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("API error {}: {}", status, text).into());
|
||
}
|
||
|
||
let mut accumulator = StreamingAccumulator::new();
|
||
|
||
// 读取 SSE 流
|
||
let mut stream = resp.bytes_stream();
|
||
let mut buffer = String::new();
|
||
let mut done_received = false;
|
||
|
||
while let Some(chunk_result) = stream.next().await {
|
||
let chunk = chunk_result?;
|
||
let text = String::from_utf8_lossy(&chunk);
|
||
buffer.push_str(&text);
|
||
|
||
// 处理缓冲区中的完整行
|
||
while let Some(newline_pos) = buffer.find('\n') {
|
||
let line = buffer[..newline_pos].to_string();
|
||
buffer = buffer[newline_pos + 1..].to_string();
|
||
|
||
let line_trimmed = line.trim();
|
||
|
||
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
|
||
continue;
|
||
}
|
||
|
||
// SSE 格式: data: {...} 或 data:{...}(某些 API 如 139 云没有空格)
|
||
let data_opt = line_trimmed.strip_prefix("data: ")
|
||
.or_else(|| line_trimmed.strip_prefix("data:"));
|
||
|
||
if let Some(data) = data_opt {
|
||
if data == "[DONE]" {
|
||
// 流结束
|
||
done_received = true;
|
||
break;
|
||
}
|
||
|
||
// 解析 JSON
|
||
match serde_json::from_str::<Value>(data) {
|
||
Ok(json) => {
|
||
// 提取响应 ID
|
||
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
|
||
accumulator.set_response_id(id.to_string());
|
||
}
|
||
|
||
// 提取 choices
|
||
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
||
for choice in choices {
|
||
// 尝试从 delta 提取(标准 OpenAI 流式格式)
|
||
if let Some(delta) = choice.get("delta") {
|
||
// 提取内容增量
|
||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||
accumulator.add_content(content);
|
||
}
|
||
|
||
// 提取推理内容增量
|
||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||
accumulator.add_reasoning_content(reasoning);
|
||
}
|
||
|
||
// 提取工具调用增量
|
||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||
for tool_call in tool_calls {
|
||
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
||
|
||
let id = tool_call.get("id").and_then(|v| v.as_str());
|
||
let name = tool_call.get("function")
|
||
.and_then(|f| f.get("name"))
|
||
.and_then(|n| n.as_str());
|
||
let arguments = tool_call.get("function")
|
||
.and_then(|f| f.get("arguments"))
|
||
.and_then(|a| a.as_str());
|
||
|
||
accumulator.add_tool_call(index, id, name, arguments);
|
||
}
|
||
}
|
||
}
|
||
// 尝试从 message 提取(某些非标准 API 格式)
|
||
else if let Some(message) = choice.get("message") {
|
||
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
|
||
accumulator.add_content(content);
|
||
}
|
||
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
|
||
accumulator.add_reasoning_content(reasoning);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::debug!(
|
||
error = %e,
|
||
data = %data,
|
||
"Failed to parse SSE data"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if done_received {
|
||
break;
|
||
}
|
||
}
|
||
|
||
// 处理缓冲区中剩余的内容
|
||
for line in buffer.lines() {
|
||
let line_trimmed = line.trim();
|
||
if line_trimmed.is_empty() || line_trimmed.starts_with(':') {
|
||
continue;
|
||
}
|
||
|
||
// 同样支持 data: {...} 和 data:{...} 两种格式
|
||
let data_opt = line_trimmed.strip_prefix("data: ")
|
||
.or_else(|| line_trimmed.strip_prefix("data:"));
|
||
|
||
if let Some(data) = data_opt {
|
||
if data == "[DONE]" {
|
||
break;
|
||
}
|
||
|
||
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
||
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
|
||
accumulator.set_response_id(id.to_string());
|
||
}
|
||
|
||
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
||
for choice in choices {
|
||
// 尝试从 delta 提取
|
||
if let Some(delta) = choice.get("delta") {
|
||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||
accumulator.add_content(content);
|
||
}
|
||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||
accumulator.add_reasoning_content(reasoning);
|
||
}
|
||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||
for tool_call in tool_calls {
|
||
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
||
let id = tool_call.get("id").and_then(|v| v.as_str());
|
||
let name = tool_call.get("function")
|
||
.and_then(|f| f.get("name"))
|
||
.and_then(|n| n.as_str());
|
||
let arguments = tool_call.get("function")
|
||
.and_then(|f| f.get("arguments"))
|
||
.and_then(|a| a.as_str());
|
||
accumulator.add_tool_call(index, id, name, arguments);
|
||
}
|
||
}
|
||
}
|
||
// 尝试从 message 提取(某些非标准 API 格式)
|
||
else if let Some(message) = choice.get("message") {
|
||
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
|
||
accumulator.add_content(content);
|
||
}
|
||
if let Some(reasoning) = message.get("reasoning_content").and_then(|r| r.as_str()) {
|
||
accumulator.add_reasoning_content(reasoning);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
let response = accumulator.build_response(self.model_id.clone());
|
||
tracing::debug!(
|
||
content_len = response.content.len(),
|
||
tool_calls_count = response.tool_calls.len(),
|
||
has_reasoning = response.reasoning_content.is_some(),
|
||
"Streaming response built"
|
||
);
|
||
Ok(response)
|
||
}
|
||
|
||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||
let supports_images = self.supports_images();
|
||
let mut body = json!({
|
||
"model": self.model_id,
|
||
"messages": request.messages.iter().enumerate().map(|(i, m)| {
|
||
if m.role == "tool" {
|
||
json!({
|
||
"role": m.role,
|
||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||
"tool_call_id": m.tool_call_id,
|
||
"name": m.name,
|
||
})
|
||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||
let mut message = json!({
|
||
"role": m.role,
|
||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||
calls.iter().map(|call| json!({
|
||
"id": call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": call.name,
|
||
"arguments": self.serialize_tool_arguments(&call.arguments)
|
||
}
|
||
})).collect::<Vec<_>>()
|
||
})
|
||
});
|
||
|
||
if let Some(reasoning_content) = &m.reasoning_content {
|
||
message["reasoning_content"] = Value::String(reasoning_content.clone());
|
||
}
|
||
|
||
message
|
||
} else {
|
||
let mut message = json!({
|
||
"role": m.role,
|
||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
|
||
});
|
||
|
||
if m.role == "assistant" {
|
||
if let Some(reasoning_content) = &m.reasoning_content {
|
||
message["reasoning_content"] = Value::String(reasoning_content.clone());
|
||
}
|
||
}
|
||
|
||
message
|
||
}
|
||
}).collect::<Vec<_>>(),
|
||
});
|
||
|
||
// 只有配置了才添加 temperature,否则让模型使用默认值
|
||
if let Some(temp) = request.temperature.or(self.temperature) {
|
||
body["temperature"] = json!(temp);
|
||
}
|
||
|
||
// 只有配置了才添加 max_tokens
|
||
if let Some(tokens) = request.max_tokens.or(self.max_tokens) {
|
||
body["max_tokens"] = json!(tokens);
|
||
}
|
||
|
||
for (key, value) in self.request_model_extra() {
|
||
body[key] = value.clone();
|
||
}
|
||
|
||
if let Some(tools) = &request.tools {
|
||
body["tools"] = json!(tools);
|
||
}
|
||
|
||
body
|
||
}
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct OpenAIResponse {
|
||
id: String,
|
||
model: String,
|
||
choices: Vec<OpenAIChoice>,
|
||
#[serde(default)]
|
||
usage: OpenAIUsage,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct OpenAIChoice {
|
||
message: OpenAIMessage,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct OpenAIMessage {
|
||
#[serde(default)]
|
||
content: Option<String>,
|
||
#[serde(default)]
|
||
reasoning_content: Option<String>,
|
||
#[allow(dead_code)]
|
||
#[serde(default)]
|
||
name: Option<String>,
|
||
#[serde(default)]
|
||
tool_calls: Vec<OpenAIToolCall>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct OpenAIToolCall {
|
||
id: String,
|
||
#[serde(rename = "function")]
|
||
function: OAIFunction,
|
||
#[allow(dead_code)]
|
||
#[serde(default)]
|
||
index: Option<u32>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct OAIFunction {
|
||
name: String,
|
||
arguments: OAIFunctionArguments,
|
||
}
|
||
|
||
#[derive(Deserialize, Default)]
|
||
struct OpenAIUsage {
|
||
#[serde(default)]
|
||
prompt_tokens: u32,
|
||
#[serde(default)]
|
||
completion_tokens: u32,
|
||
#[serde(default)]
|
||
total_tokens: u32,
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LLMProvider for OpenAIProvider {
|
||
async fn chat(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||
// 检查是否启用流式输出
|
||
if self.is_streaming_enabled() {
|
||
// 优先尝试流式输出
|
||
match self.chat_streaming(&request).await {
|
||
Ok(response) => return Ok(response),
|
||
Err(e) => {
|
||
tracing::debug!(
|
||
provider = %self.name,
|
||
model = %self.model_id,
|
||
error = %e,
|
||
"Streaming failed, falling back to non-streaming"
|
||
);
|
||
// 流式失败,回退到非流式
|
||
}
|
||
}
|
||
} else {
|
||
tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming");
|
||
}
|
||
|
||
// 非流式回退实现
|
||
let url = format!("{}/chat/completions", self.base_url);
|
||
|
||
let body = self.build_request_body(&request);
|
||
|
||
// Debug: Log LLM request summary (only in debug builds)
|
||
#[cfg(debug_assertions)]
|
||
{
|
||
// Log messages summary
|
||
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
|
||
tracing::debug!(msg_count = msg_count, "LLM request messages count");
|
||
|
||
// Log first 20 bytes of base64 images (don't log full base64)
|
||
if let Some(msgs) = body["messages"].as_array() {
|
||
for (i, msg) in msgs.iter().enumerate() {
|
||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||
for (j, item) in content.iter().enumerate() {
|
||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||
if let Some(url_str) = item
|
||
.get("image_url")
|
||
.and_then(|u| u.get("url"))
|
||
.and_then(|v| v.as_str())
|
||
{
|
||
let prefix: String = url_str.chars().take(20).collect();
|
||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
let mut req_builder = self
|
||
.client
|
||
.post(&url)
|
||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||
.header("Content-Type", "application/json");
|
||
|
||
for (key, value) in &self.extra_headers {
|
||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||
}
|
||
|
||
let resp = req_builder.json(&body).send().await.map_err(|err| {
|
||
let error_context = format_transport_error_context(
|
||
&self.name,
|
||
&self.model_id,
|
||
&url,
|
||
self.llm_timeout_secs,
|
||
&err,
|
||
);
|
||
tracing::error!(
|
||
provider = %self.name,
|
||
model = %self.model_id,
|
||
url = %url,
|
||
base_url = %self.base_url,
|
||
timeout_secs = self.llm_timeout_secs,
|
||
error = %error_context,
|
||
"OpenAI-compatible API transport request failed"
|
||
);
|
||
error_context
|
||
})?;
|
||
|
||
let status = resp.status();
|
||
let text = resp.text().await?;
|
||
|
||
// Debug: Log LLM response (only in debug builds)
|
||
if !status.is_success() {
|
||
tracing::error!(
|
||
provider = %self.name,
|
||
model = %self.model_id,
|
||
url = %url,
|
||
status = %status,
|
||
response_len = text.len(),
|
||
response_body = %text,
|
||
"OpenAI-compatible API request failed"
|
||
);
|
||
return Err(format!("API error {}: {}", status, text).into());
|
||
}
|
||
|
||
#[cfg(debug_assertions)]
|
||
{
|
||
let resp_preview: String = text.chars().take(100).collect();
|
||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)");
|
||
}
|
||
|
||
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
|
||
tracing::error!(
|
||
provider = %self.name,
|
||
model = %self.model_id,
|
||
url = %url,
|
||
error = %format_error_chain(&e),
|
||
response_len = text.len(),
|
||
response_body = %text,
|
||
"Failed to decode OpenAI-compatible API response"
|
||
);
|
||
format!("decode error: {} | body: {}", e, &text)
|
||
})?;
|
||
|
||
let content = openai_resp.choices[0]
|
||
.message
|
||
.content
|
||
.as_ref()
|
||
.unwrap_or(&String::new())
|
||
.clone();
|
||
|
||
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
||
.message
|
||
.tool_calls
|
||
.iter()
|
||
.map(|tc| ToolCall {
|
||
id: tc.id.clone(),
|
||
name: tc.function.name.clone(),
|
||
arguments: match &tc.function.arguments {
|
||
OAIFunctionArguments::Json(arguments) => arguments.clone(),
|
||
OAIFunctionArguments::String(arguments) => {
|
||
serde_json::from_str(arguments).unwrap_or(serde_json::Value::Null)
|
||
}
|
||
},
|
||
})
|
||
.collect();
|
||
|
||
Ok(ChatCompletionResponse {
|
||
id: openai_resp.id,
|
||
model: openai_resp.model,
|
||
content,
|
||
reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(),
|
||
tool_calls,
|
||
usage: Usage {
|
||
prompt_tokens: openai_resp.usage.prompt_tokens,
|
||
completion_tokens: openai_resp.usage.completion_tokens,
|
||
total_tokens: openai_resp.usage.total_tokens,
|
||
},
|
||
})
|
||
}
|
||
|
||
fn ptype(&self) -> &str {
|
||
"openai"
|
||
}
|
||
|
||
fn name(&self) -> &str {
|
||
&self.name
|
||
}
|
||
|
||
fn model_id(&self) -> &str {
|
||
&self.model_id
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::providers::Message;
|
||
|
||
#[test]
|
||
fn test_build_request_body_includes_assistant_tool_calls() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::new(),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![Message {
|
||
role: "assistant".to_string(),
|
||
content: vec![ContentBlock::text("calling tool")],
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: Some(vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: json!({"expression": "1+1"}),
|
||
}]),
|
||
}],
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None,
|
||
};
|
||
|
||
let body = provider.build_request_body(&request);
|
||
let messages = body["messages"].as_array().unwrap();
|
||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||
|
||
assert_eq!(tool_calls.len(), 1);
|
||
assert_eq!(tool_calls[0]["id"], "call_1");
|
||
assert_eq!(tool_calls[0]["type"], "function");
|
||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||
assert_eq!(
|
||
tool_calls[0]["function"]["arguments"],
|
||
"{\"expression\":\"1+1\"}"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_request_body_uses_json_tool_arguments_when_enabled() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![Message {
|
||
role: "assistant".to_string(),
|
||
content: vec![ContentBlock::text("calling tool")],
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: Some(vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: json!({"expression": "1+1"}),
|
||
}]),
|
||
}],
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None,
|
||
};
|
||
|
||
let body = provider.build_request_body(&request);
|
||
let messages = body["messages"].as_array().unwrap();
|
||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||
|
||
assert_eq!(
|
||
tool_calls[0]["function"]["arguments"],
|
||
json!({"expression": "1+1"})
|
||
);
|
||
assert!(body.get("tool_call_arguments_json").is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_request_body_preserves_raw_json_string_arguments() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::new(),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![Message {
|
||
role: "assistant".to_string(),
|
||
content: vec![ContentBlock::text("calling tool")],
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: Some(vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: Value::String("{\"expression\":\"1+1\"}".to_string()),
|
||
}]),
|
||
}],
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None,
|
||
};
|
||
|
||
let body = provider.build_request_body(&request);
|
||
let messages = body["messages"].as_array().unwrap();
|
||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||
|
||
assert_eq!(
|
||
tool_calls[0]["function"]["arguments"],
|
||
"{\"expression\":\"1+1\"}"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_request_body_omits_internal_model_extra_keys() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::from([
|
||
("tool_call_arguments_json".to_string(), Value::Bool(true)),
|
||
(
|
||
"mock_response_content".to_string(),
|
||
Value::String("stub".to_string()),
|
||
),
|
||
("parallel_tool_calls".to_string(), Value::Bool(true)),
|
||
]),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![Message::user("hello")],
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None,
|
||
};
|
||
|
||
let body = provider.build_request_body(&request);
|
||
|
||
assert!(body.get("tool_call_arguments_json").is_none());
|
||
assert!(body.get("mock_response_content").is_none());
|
||
assert_eq!(body["parallel_tool_calls"], Value::Bool(true));
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_request_body_includes_assistant_reasoning_content() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::new(),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![Message {
|
||
role: "assistant".to_string(),
|
||
content: vec![ContentBlock::text("final answer")],
|
||
reasoning_content: Some("step by step".to_string()),
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
}],
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None,
|
||
};
|
||
|
||
let body = provider.build_request_body(&request);
|
||
let messages = body["messages"].as_array().unwrap();
|
||
|
||
assert_eq!(messages[0]["reasoning_content"], "step by step");
|
||
}
|
||
|
||
#[test]
|
||
fn test_openai_response_parses_reasoning_content() {
|
||
let response: OpenAIResponse = serde_json::from_value(json!({
|
||
"id": "resp_1",
|
||
"model": "gpt-test",
|
||
"choices": [{
|
||
"message": {
|
||
"content": "final answer",
|
||
"reasoning_content": "hidden reasoning",
|
||
"tool_calls": []
|
||
}
|
||
}],
|
||
"usage": {
|
||
"prompt_tokens": 10,
|
||
"completion_tokens": 5,
|
||
"total_tokens": 15
|
||
}
|
||
}))
|
||
.unwrap();
|
||
|
||
assert_eq!(
|
||
response.choices[0].message.reasoning_content.as_deref(),
|
||
Some("hidden reasoning")
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_openai_response_parses_json_tool_arguments() {
|
||
let response: OpenAIResponse = serde_json::from_value(json!({
|
||
"id": "resp_1",
|
||
"model": "gpt-test",
|
||
"choices": [{
|
||
"message": {
|
||
"content": "",
|
||
"tool_calls": [{
|
||
"id": "call_1",
|
||
"function": {
|
||
"name": "scheduler_manage",
|
||
"arguments": {"action": "list"}
|
||
}
|
||
}]
|
||
}
|
||
}],
|
||
"usage": {
|
||
"prompt_tokens": 1,
|
||
"completion_tokens": 1,
|
||
"total_tokens": 2
|
||
}
|
||
}))
|
||
.unwrap();
|
||
|
||
match &response.choices[0].message.tool_calls[0].function.arguments {
|
||
OAIFunctionArguments::Json(arguments) => {
|
||
assert_eq!(arguments, &json!({"action": "list"}));
|
||
}
|
||
OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_streaming_accumulator_preserves_tool_call_id_with_empty_subsequent_chunks() {
|
||
// 模拟阿里云等云服务商的流式响应行为:
|
||
// 第一个 chunk 包含 id 和 name,后续 chunk 的 id 为空字符串、name 为 None
|
||
let mut accumulator = StreamingAccumulator::new();
|
||
|
||
// 第一个 chunk:包含完整的 id 和 name
|
||
accumulator.add_tool_call(0, Some("call_abc123"), Some("memory_search"), Some("{\"action\":\""));
|
||
// 第二个 chunk:只有参数增量
|
||
accumulator.add_tool_call(0, None, None, Some("list"));
|
||
// 第三个 chunk:参数继续
|
||
accumulator.add_tool_call(0, None, None, Some("\""));
|
||
// 第四个 chunk:id 为空字符串(某些云服务商的行为)
|
||
accumulator.add_tool_call(0, Some(""), None, Some(", \"limit\": 20"));
|
||
// 最后一个 chunk:name 为 None
|
||
accumulator.add_tool_call(0, None, None, Some("}"));
|
||
|
||
let response = accumulator.build_response("test-model".to_string());
|
||
|
||
// 验证工具调用被正确保留,id 没有被空字符串覆盖
|
||
assert_eq!(response.tool_calls.len(), 1);
|
||
assert_eq!(response.tool_calls[0].id, "call_abc123");
|
||
assert_eq!(response.tool_calls[0].name, "memory_search");
|
||
assert_eq!(response.tool_calls[0].arguments, json!({"action":"list", "limit": 20}));
|
||
}
|
||
|
||
#[test]
|
||
fn test_streaming_accumulator_handles_multiple_tool_calls() {
|
||
let mut accumulator = StreamingAccumulator::new();
|
||
|
||
// 第一个工具调用
|
||
accumulator.add_tool_call(0, Some("call_1"), Some("calculator"), Some("{\"expr\": \"1+1\"}"));
|
||
// 第二个工具调用(id 和 name 只在第一个 chunk 出现)
|
||
accumulator.add_tool_call(1, Some("call_2"), Some("get_time"), Some("{}"));
|
||
|
||
let response = accumulator.build_response("test-model".to_string());
|
||
|
||
assert_eq!(response.tool_calls.len(), 2);
|
||
assert_eq!(response.tool_calls[0].id, "call_1");
|
||
assert_eq!(response.tool_calls[0].name, "calculator");
|
||
assert_eq!(response.tool_calls[1].id, "call_2");
|
||
assert_eq!(response.tool_calls[1].name, "get_time");
|
||
}
|
||
|
||
#[test]
|
||
fn test_supports_images_default_true() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::new(),
|
||
);
|
||
|
||
assert!(provider.supports_images());
|
||
}
|
||
|
||
#[test]
|
||
fn test_supports_images_disabled_via_config() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::from([(
|
||
"supported_content_types".to_string(),
|
||
Value::Array(vec![Value::String("text".to_string())]),
|
||
)]),
|
||
);
|
||
|
||
assert!(!provider.supports_images());
|
||
}
|
||
|
||
#[test]
|
||
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
|
||
let blocks = vec![
|
||
ContentBlock::text("hello"),
|
||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||
ContentBlock::text("world"),
|
||
];
|
||
|
||
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
|
||
|
||
// 应该是数组形式
|
||
let arr = result.as_array().unwrap();
|
||
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
|
||
|
||
// 检查通知内容
|
||
let notice_block = arr[2].as_object().unwrap();
|
||
assert_eq!(notice_block["type"], "text");
|
||
let notice_text = notice_block["text"].as_str().unwrap();
|
||
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
|
||
assert!(notice_text.contains("第 1 张图片"));
|
||
assert!(notice_text.contains("当前模型不支持图片输入"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_convert_content_blocks_keeps_images_when_enabled() {
|
||
let blocks = vec![
|
||
ContentBlock::text("hello"),
|
||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||
];
|
||
|
||
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
|
||
|
||
// 应该是数组形式,包含文本和图片
|
||
let arr = result.as_array().unwrap();
|
||
assert_eq!(arr.len(), 2);
|
||
assert_eq!(arr[0]["type"], "text");
|
||
assert_eq!(arr[1]["type"], "image_url");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_request_body_omits_supported_content_types_from_api() {
|
||
let provider = OpenAIProvider::new(
|
||
"test".to_string(),
|
||
"key".to_string(),
|
||
"https://example.com/v1".to_string(),
|
||
HashMap::new(),
|
||
120,
|
||
"gpt-test".to_string(),
|
||
None,
|
||
None,
|
||
HashMap::from([
|
||
(
|
||
"supported_content_types".to_string(),
|
||
Value::Array(vec![Value::String("text".to_string())]),
|
||
),
|
||
("custom_param".to_string(), Value::String("value".to_string())),
|
||
]),
|
||
);
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![Message::user("hello")],
|
||
temperature: None,
|
||
max_tokens: None,
|
||
tools: None,
|
||
};
|
||
|
||
let body = provider.build_request_body(&request);
|
||
|
||
// supported_content_types 不应该发送到 API
|
||
assert!(body.get("supported_content_types").is_none());
|
||
// custom_param 应该保留
|
||
assert_eq!(body["custom_param"], Value::String("value".to_string()));
|
||
}
|
||
}
|