Byte index slicing like `&text[..100.min(text.len())]` panics when the byte index falls inside a multi-byte UTF-8 character (e.g., Chinese). Changed to `text.chars().take(100).collect::<String>()` for safe character-based truncation.
245 lines
7.3 KiB
Rust
245 lines
7.3 KiB
Rust
use async_trait::async_trait;
|
|
use reqwest::Client;
|
|
use serde::Deserialize;
|
|
use serde_json::{json, Value};
|
|
use std::collections::HashMap;
|
|
|
|
use crate::bus::message::ContentBlock;
|
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
use super::traits::Usage;
|
|
|
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
|
if blocks.len() == 1 {
|
|
if let ContentBlock::Text { text } = &blocks[0] {
|
|
return Value::String(text.clone());
|
|
}
|
|
}
|
|
Value::Array(blocks.iter().map(|b| match b {
|
|
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
|
ContentBlock::ImageUrl { image_url } => {
|
|
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
|
}
|
|
}).collect())
|
|
}
|
|
|
|
pub struct OpenAIProvider {
|
|
client: Client,
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
impl OpenAIProvider {
|
|
pub fn new(
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
) -> Self {
|
|
Self {
|
|
client: Client::new(),
|
|
name,
|
|
api_key,
|
|
base_url,
|
|
extra_headers,
|
|
model_id,
|
|
temperature,
|
|
max_tokens,
|
|
model_extra,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIResponse {
|
|
id: String,
|
|
model: String,
|
|
choices: Vec<OpenAIChoice>,
|
|
#[serde(default)]
|
|
usage: OpenAIUsage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIChoice {
|
|
message: OpenAIMessage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIMessage {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Vec<OpenAIToolCall>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIToolCall {
|
|
id: String,
|
|
#[serde(rename = "function")]
|
|
function: OAIFunction,
|
|
#[serde(default)]
|
|
index: Option<u32>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OAIFunction {
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAIUsage {
|
|
#[serde(default)]
|
|
prompt_tokens: u32,
|
|
#[serde(default)]
|
|
completion_tokens: u32,
|
|
#[serde(default)]
|
|
total_tokens: u32,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for OpenAIProvider {
|
|
async fn chat(
|
|
&self,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
|
|
let mut body = json!({
|
|
"model": self.model_id,
|
|
"messages": request.messages.iter().map(|m| {
|
|
if m.role == "tool" {
|
|
json!({
|
|
"role": m.role,
|
|
"content": convert_content_blocks(&m.content),
|
|
"tool_call_id": m.tool_call_id,
|
|
"name": m.name,
|
|
})
|
|
} else {
|
|
json!({
|
|
"role": m.role,
|
|
"content": convert_content_blocks(&m.content)
|
|
})
|
|
}
|
|
}).collect::<Vec<_>>(),
|
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
|
});
|
|
|
|
// Add model extra fields
|
|
for (key, value) in &self.model_extra {
|
|
body[key] = value.clone();
|
|
}
|
|
|
|
if let Some(tools) = &request.tools {
|
|
body["tools"] = json!(tools);
|
|
}
|
|
|
|
// Debug: Log LLM request summary (only in debug builds)
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
// Log messages summary
|
|
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
|
|
tracing::debug!(msg_count = msg_count, "LLM request messages count");
|
|
|
|
// Log first 20 bytes of base64 images (don't log full base64)
|
|
if let Some(msgs) = body["messages"].as_array() {
|
|
for (i, msg) in msgs.iter().enumerate() {
|
|
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
|
for (j, item) in content.iter().enumerate() {
|
|
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
|
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
|
let prefix: String = url_str.chars().take(20).collect();
|
|
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut req_builder = self
|
|
.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("Content-Type", "application/json");
|
|
|
|
for (key, value) in &self.extra_headers {
|
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
}
|
|
|
|
let resp = req_builder.json(&body).send().await?;
|
|
|
|
let status = resp.status();
|
|
let text = resp.text().await?;
|
|
|
|
// Debug: Log LLM response (only in debug builds)
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
let resp_preview: String = text.chars().take(100).collect();
|
|
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
|
|
}
|
|
|
|
if !status.is_success() {
|
|
return Err(format!("API error {}: {}", status, text).into());
|
|
}
|
|
|
|
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
|
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
|
|
|
let content = openai_resp.choices[0]
|
|
.message
|
|
.content
|
|
.as_ref()
|
|
.unwrap_or(&String::new())
|
|
.clone();
|
|
|
|
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
|
.message
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| ToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
|
})
|
|
.collect();
|
|
|
|
Ok(ChatCompletionResponse {
|
|
id: openai_resp.id,
|
|
model: openai_resp.model,
|
|
content,
|
|
tool_calls,
|
|
usage: Usage {
|
|
prompt_tokens: openai_resp.usage.prompt_tokens,
|
|
completion_tokens: openai_resp.usage.completion_tokens,
|
|
total_tokens: openai_resp.usage.total_tokens,
|
|
},
|
|
})
|
|
}
|
|
|
|
fn ptype(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn model_id(&self) -> &str {
|
|
&self.model_id
|
|
}
|
|
}
|