PicoBot/src/providers/openai.rs

351 lines
11 KiB
Rust

use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone());
}
}
Value::Array(blocks.iter().map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
}).collect())
}
pub struct OpenAIProvider {
client: Client,
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
storage: Option<Arc<Storage>>,
}
impl OpenAIProvider {
pub fn new(
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
Self {
client: Client::new(),
name,
api_key,
base_url,
extra_headers,
model_id,
temperature,
max_tokens,
model_extra,
storage: None,
}
}
pub fn set_storage(&mut self, storage: Arc<Storage>) {
self.storage = Some(storage);
}
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let mut body = json!({
"model": self.model_id,
"messages": request.messages.iter().map(|m| {
if m.role == "tool" {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content),
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else if m.role == "assistant" && m.tool_calls.as_ref().map_or(false, |c| !c.is_empty()) {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content),
"tool_calls": m.tool_calls.as_ref().map(|calls| {
calls.iter().map(|call| json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string())
}
})).collect::<Vec<_>>()
})
})
} else {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content)
})
}
}).collect::<Vec<_>>(),
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
"max_tokens": request.max_tokens.or(self.max_tokens),
});
for (key, value) in &self.model_extra {
body[key] = value.clone();
}
if let Some(tools) = &request.tools {
body["tools"] = json!(tools);
}
body
}
}
#[derive(Deserialize)]
struct OpenAIResponse {
id: String,
model: String,
choices: Vec<OpenAIChoice>,
#[serde(default)]
usage: OpenAIUsage,
}
#[derive(Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[derive(Deserialize)]
struct OpenAIMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
tool_calls: Vec<OpenAIToolCall>,
}
#[derive(Deserialize)]
struct OpenAIToolCall {
id: String,
#[serde(rename = "function")]
function: OAIFunction,
#[serde(default)]
index: Option<u32>,
}
#[derive(Deserialize)]
struct OAIFunction {
name: String,
arguments: String,
}
#[derive(Deserialize, Default)]
struct OpenAIUsage {
#[serde(default)]
prompt_tokens: u32,
#[serde(default)]
completion_tokens: u32,
#[serde(default)]
total_tokens: u32,
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let start = std::time::Instant::now();
let url = format!("{}/chat/completions", self.base_url);
let body = self.build_request_body(&request);
// Debug: Log LLM request summary (only in debug builds)
#[cfg(debug_assertions)]
{
// Log messages summary
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
tracing::debug!(msg_count = msg_count, "LLM request messages count");
// Log first 20 bytes of base64 images (don't log full base64)
if let Some(msgs) = body["messages"].as_array() {
for (i, msg) in msgs.iter().enumerate() {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
}
}
}
}
}
}
let mut req_builder = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await?;
let status = resp.status();
let text = resp.text().await?;
tracing::debug!(status = %status, resp_body = %text, "LLM response");
if !status.is_success() {
let error = format!("API error {}: {}", status, text);
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), Some(&error),
start.elapsed().as_millis() as u64,
).await;
}
return Err(error.into());
}
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await;
});
}
err_msg
})?;
let content = openai_resp.choices[0]
.message
.content
.as_ref()
.unwrap_or(&String::new())
.clone();
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
.message
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
})
.collect();
let response = ChatCompletionResponse {
id: openai_resp.id,
model: openai_resp.model,
content,
tool_calls,
usage: Usage {
prompt_tokens: openai_resp.usage.prompt_tokens,
completion_tokens: openai_resp.usage.completion_tokens,
total_tokens: openai_resp.usage.total_tokens,
},
};
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)
}
fn ptype(&self) -> &str {
"openai"
}
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.model_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::Message;
#[test]
fn test_build_request_body_includes_assistant_tool_calls() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("calling tool")],
tool_call_id: None,
name: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1+1"}),
}]),
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
}
}