629 lines
20 KiB
Rust
629 lines
20 KiB
Rust
use async_trait::async_trait;
|
|
use reqwest::Client;
|
|
use serde::Deserialize;
|
|
use serde_json::{json, Value};
|
|
use std::collections::HashMap;
|
|
use std::time::Duration;
|
|
|
|
use crate::bus::message::ContentBlock;
|
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
use super::traits::Usage;
|
|
|
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[
|
|
"tool_call_arguments_json",
|
|
"mock_response_content",
|
|
];
|
|
|
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
|
let mut details = vec![error.to_string()];
|
|
let mut current = error.source();
|
|
|
|
while let Some(source) = current {
|
|
details.push(source.to_string());
|
|
current = source.source();
|
|
}
|
|
|
|
details.join("\ncaused by: ")
|
|
}
|
|
|
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
|
if blocks.len() == 1 {
|
|
if let ContentBlock::Text { text } = &blocks[0] {
|
|
return Value::String(text.clone());
|
|
}
|
|
}
|
|
Value::Array(blocks.iter().map(|b| match b {
|
|
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
|
ContentBlock::ImageUrl { image_url } => {
|
|
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
|
}
|
|
}).collect())
|
|
}
|
|
|
|
pub struct OpenAIProvider {
|
|
client: Client,
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
llm_timeout_secs: u64,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(untagged)]
|
|
enum OAIFunctionArguments {
|
|
Json(Value),
|
|
String(String),
|
|
}
|
|
|
|
impl OpenAIProvider {
|
|
pub fn new(
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
llm_timeout_secs: u64,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
) -> Self {
|
|
let client = Client::builder()
|
|
.timeout(Duration::from_secs(llm_timeout_secs))
|
|
.build()
|
|
.unwrap_or_else(|_| Client::new());
|
|
|
|
Self {
|
|
client,
|
|
name,
|
|
api_key,
|
|
base_url,
|
|
extra_headers,
|
|
llm_timeout_secs,
|
|
model_id,
|
|
temperature,
|
|
max_tokens,
|
|
model_extra,
|
|
}
|
|
}
|
|
|
|
fn uses_json_tool_arguments(&self) -> bool {
|
|
self.model_extra
|
|
.get("tool_call_arguments_json")
|
|
.and_then(|value| value.as_bool())
|
|
.unwrap_or(false)
|
|
}
|
|
|
|
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
|
match arguments {
|
|
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
|
_ => arguments.clone(),
|
|
}
|
|
}
|
|
|
|
fn serialize_tool_arguments(&self, arguments: &Value) -> Value {
|
|
let normalized = self.normalize_tool_arguments(arguments);
|
|
|
|
if self.uses_json_tool_arguments() {
|
|
normalized
|
|
} else {
|
|
match normalized {
|
|
Value::String(raw) => Value::String(raw),
|
|
value => Value::String(
|
|
serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string()),
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
|
|
self.model_extra.iter().filter(|(key, _)| {
|
|
!INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str())
|
|
})
|
|
}
|
|
|
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
|
let mut body = json!({
|
|
"model": self.model_id,
|
|
"messages": request.messages.iter().map(|m| {
|
|
if m.role == "tool" {
|
|
json!({
|
|
"role": m.role,
|
|
"content": convert_content_blocks(&m.content),
|
|
"tool_call_id": m.tool_call_id,
|
|
"name": m.name,
|
|
})
|
|
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
|
let mut message = json!({
|
|
"role": m.role,
|
|
"content": convert_content_blocks(&m.content),
|
|
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
|
calls.iter().map(|call| json!({
|
|
"id": call.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": call.name,
|
|
"arguments": self.serialize_tool_arguments(&call.arguments)
|
|
}
|
|
})).collect::<Vec<_>>()
|
|
})
|
|
});
|
|
|
|
if let Some(reasoning_content) = &m.reasoning_content {
|
|
message["reasoning_content"] = Value::String(reasoning_content.clone());
|
|
}
|
|
|
|
message
|
|
} else {
|
|
let mut message = json!({
|
|
"role": m.role,
|
|
"content": convert_content_blocks(&m.content)
|
|
});
|
|
|
|
if m.role == "assistant" {
|
|
if let Some(reasoning_content) = &m.reasoning_content {
|
|
message["reasoning_content"] = Value::String(reasoning_content.clone());
|
|
}
|
|
}
|
|
|
|
message
|
|
}
|
|
}).collect::<Vec<_>>(),
|
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
|
});
|
|
|
|
for (key, value) in self.request_model_extra() {
|
|
body[key] = value.clone();
|
|
}
|
|
|
|
if let Some(tools) = &request.tools {
|
|
body["tools"] = json!(tools);
|
|
}
|
|
|
|
body
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIResponse {
|
|
id: String,
|
|
model: String,
|
|
choices: Vec<OpenAIChoice>,
|
|
#[serde(default)]
|
|
usage: OpenAIUsage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIChoice {
|
|
message: OpenAIMessage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIMessage {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
reasoning_content: Option<String>,
|
|
#[allow(dead_code)]
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Vec<OpenAIToolCall>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIToolCall {
|
|
id: String,
|
|
#[serde(rename = "function")]
|
|
function: OAIFunction,
|
|
#[allow(dead_code)]
|
|
#[serde(default)]
|
|
index: Option<u32>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OAIFunction {
|
|
name: String,
|
|
arguments: OAIFunctionArguments,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAIUsage {
|
|
#[serde(default)]
|
|
prompt_tokens: u32,
|
|
#[serde(default)]
|
|
completion_tokens: u32,
|
|
#[serde(default)]
|
|
total_tokens: u32,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for OpenAIProvider {
|
|
async fn chat(
|
|
&self,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
|
|
let body = self.build_request_body(&request);
|
|
|
|
// Debug: Log LLM request summary (only in debug builds)
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
// Log messages summary
|
|
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
|
|
tracing::debug!(msg_count = msg_count, "LLM request messages count");
|
|
|
|
// Log first 20 bytes of base64 images (don't log full base64)
|
|
if let Some(msgs) = body["messages"].as_array() {
|
|
for (i, msg) in msgs.iter().enumerate() {
|
|
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
|
for (j, item) in content.iter().enumerate() {
|
|
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
|
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
|
let prefix: String = url_str.chars().take(20).collect();
|
|
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut req_builder = self
|
|
.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("Content-Type", "application/json");
|
|
|
|
for (key, value) in &self.extra_headers {
|
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
}
|
|
|
|
let resp = req_builder.json(&body).send().await?;
|
|
|
|
let status = resp.status();
|
|
let text = resp.text().await?;
|
|
|
|
// Debug: Log LLM response (only in debug builds)
|
|
if !status.is_success() {
|
|
tracing::error!(
|
|
provider = %self.name,
|
|
model = %self.model_id,
|
|
url = %url,
|
|
status = %status,
|
|
response_len = text.len(),
|
|
response_body = %text,
|
|
"OpenAI-compatible API request failed"
|
|
);
|
|
return Err(format!("API error {}: {}", status, text).into());
|
|
}
|
|
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
let resp_preview: String = text.chars().take(100).collect();
|
|
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)");
|
|
}
|
|
|
|
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
|
|
tracing::error!(
|
|
provider = %self.name,
|
|
model = %self.model_id,
|
|
url = %url,
|
|
error = %format_error_chain(&e),
|
|
response_len = text.len(),
|
|
response_body = %text,
|
|
"Failed to decode OpenAI-compatible API response"
|
|
);
|
|
format!("decode error: {} | body: {}", e, &text)
|
|
})?;
|
|
|
|
let content = openai_resp.choices[0]
|
|
.message
|
|
.content
|
|
.as_ref()
|
|
.unwrap_or(&String::new())
|
|
.clone();
|
|
|
|
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
|
.message
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| ToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments: match &tc.function.arguments {
|
|
OAIFunctionArguments::Json(arguments) => arguments.clone(),
|
|
OAIFunctionArguments::String(arguments) => {
|
|
serde_json::from_str(arguments).unwrap_or(serde_json::Value::Null)
|
|
}
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
Ok(ChatCompletionResponse {
|
|
id: openai_resp.id,
|
|
model: openai_resp.model,
|
|
content,
|
|
reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(),
|
|
tool_calls,
|
|
usage: Usage {
|
|
prompt_tokens: openai_resp.usage.prompt_tokens,
|
|
completion_tokens: openai_resp.usage.completion_tokens,
|
|
total_tokens: openai_resp.usage.total_tokens,
|
|
},
|
|
})
|
|
}
|
|
|
|
fn ptype(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn model_id(&self) -> &str {
|
|
&self.model_id
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::providers::Message;
|
|
|
|
#[test]
|
|
fn test_build_request_body_includes_assistant_tool_calls() {
|
|
let provider = OpenAIProvider::new(
|
|
"test".to_string(),
|
|
"key".to_string(),
|
|
"https://example.com/v1".to_string(),
|
|
HashMap::new(),
|
|
120,
|
|
"gpt-test".to_string(),
|
|
None,
|
|
None,
|
|
HashMap::new(),
|
|
);
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "assistant".to_string(),
|
|
content: vec![ContentBlock::text("calling tool")],
|
|
reasoning_content: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
tool_calls: Some(vec![ToolCall {
|
|
id: "call_1".to_string(),
|
|
name: "calculator".to_string(),
|
|
arguments: json!({"expression": "1+1"}),
|
|
}]),
|
|
}],
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools: None,
|
|
};
|
|
|
|
let body = provider.build_request_body(&request);
|
|
let messages = body["messages"].as_array().unwrap();
|
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
|
|
|
assert_eq!(tool_calls.len(), 1);
|
|
assert_eq!(tool_calls[0]["id"], "call_1");
|
|
assert_eq!(tool_calls[0]["type"], "function");
|
|
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_request_body_uses_json_tool_arguments_when_enabled() {
|
|
let provider = OpenAIProvider::new(
|
|
"test".to_string(),
|
|
"key".to_string(),
|
|
"https://example.com/v1".to_string(),
|
|
HashMap::new(),
|
|
120,
|
|
"gpt-test".to_string(),
|
|
None,
|
|
None,
|
|
HashMap::from([(
|
|
"tool_call_arguments_json".to_string(),
|
|
Value::Bool(true),
|
|
)]),
|
|
);
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "assistant".to_string(),
|
|
content: vec![ContentBlock::text("calling tool")],
|
|
reasoning_content: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
tool_calls: Some(vec![ToolCall {
|
|
id: "call_1".to_string(),
|
|
name: "calculator".to_string(),
|
|
arguments: json!({"expression": "1+1"}),
|
|
}]),
|
|
}],
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools: None,
|
|
};
|
|
|
|
let body = provider.build_request_body(&request);
|
|
let messages = body["messages"].as_array().unwrap();
|
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
|
|
|
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
|
|
assert!(body.get("tool_call_arguments_json").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_request_body_preserves_raw_json_string_arguments() {
|
|
let provider = OpenAIProvider::new(
|
|
"test".to_string(),
|
|
"key".to_string(),
|
|
"https://example.com/v1".to_string(),
|
|
HashMap::new(),
|
|
120,
|
|
"gpt-test".to_string(),
|
|
None,
|
|
None,
|
|
HashMap::new(),
|
|
);
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "assistant".to_string(),
|
|
content: vec![ContentBlock::text("calling tool")],
|
|
reasoning_content: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
tool_calls: Some(vec![ToolCall {
|
|
id: "call_1".to_string(),
|
|
name: "calculator".to_string(),
|
|
arguments: Value::String("{\"expression\":\"1+1\"}".to_string()),
|
|
}]),
|
|
}],
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools: None,
|
|
};
|
|
|
|
let body = provider.build_request_body(&request);
|
|
let messages = body["messages"].as_array().unwrap();
|
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
|
|
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_request_body_omits_internal_model_extra_keys() {
|
|
let provider = OpenAIProvider::new(
|
|
"test".to_string(),
|
|
"key".to_string(),
|
|
"https://example.com/v1".to_string(),
|
|
HashMap::new(),
|
|
120,
|
|
"gpt-test".to_string(),
|
|
None,
|
|
None,
|
|
HashMap::from([
|
|
("tool_call_arguments_json".to_string(), Value::Bool(true)),
|
|
("mock_response_content".to_string(), Value::String("stub".to_string())),
|
|
("parallel_tool_calls".to_string(), Value::Bool(true)),
|
|
]),
|
|
);
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message::user("hello")],
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools: None,
|
|
};
|
|
|
|
let body = provider.build_request_body(&request);
|
|
|
|
assert!(body.get("tool_call_arguments_json").is_none());
|
|
assert!(body.get("mock_response_content").is_none());
|
|
assert_eq!(body["parallel_tool_calls"], Value::Bool(true));
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_request_body_includes_assistant_reasoning_content() {
|
|
let provider = OpenAIProvider::new(
|
|
"test".to_string(),
|
|
"key".to_string(),
|
|
"https://example.com/v1".to_string(),
|
|
HashMap::new(),
|
|
120,
|
|
"gpt-test".to_string(),
|
|
None,
|
|
None,
|
|
HashMap::new(),
|
|
);
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "assistant".to_string(),
|
|
content: vec![ContentBlock::text("final answer")],
|
|
reasoning_content: Some("step by step".to_string()),
|
|
tool_call_id: None,
|
|
name: None,
|
|
tool_calls: None,
|
|
}],
|
|
temperature: None,
|
|
max_tokens: None,
|
|
tools: None,
|
|
};
|
|
|
|
let body = provider.build_request_body(&request);
|
|
let messages = body["messages"].as_array().unwrap();
|
|
|
|
assert_eq!(messages[0]["reasoning_content"], "step by step");
|
|
}
|
|
|
|
#[test]
|
|
fn test_openai_response_parses_reasoning_content() {
|
|
let response: OpenAIResponse = serde_json::from_value(json!({
|
|
"id": "resp_1",
|
|
"model": "gpt-test",
|
|
"choices": [{
|
|
"message": {
|
|
"content": "final answer",
|
|
"reasoning_content": "hidden reasoning",
|
|
"tool_calls": []
|
|
}
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 15
|
|
}
|
|
}))
|
|
.unwrap();
|
|
|
|
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_openai_response_parses_json_tool_arguments() {
|
|
let response: OpenAIResponse = serde_json::from_value(json!({
|
|
"id": "resp_1",
|
|
"model": "gpt-test",
|
|
"choices": [{
|
|
"message": {
|
|
"content": "",
|
|
"tool_calls": [{
|
|
"id": "call_1",
|
|
"function": {
|
|
"name": "scheduler_manage",
|
|
"arguments": {"action": "list"}
|
|
}
|
|
}]
|
|
}
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 1,
|
|
"completion_tokens": 1,
|
|
"total_tokens": 2
|
|
}
|
|
}))
|
|
.unwrap();
|
|
|
|
match &response.choices[0].message.tool_calls[0].function.arguments {
|
|
OAIFunctionArguments::Json(arguments) => {
|
|
assert_eq!(arguments, &json!({"action": "list"}));
|
|
}
|
|
OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"),
|
|
}
|
|
}
|
|
}
|