347 lines
11 KiB
Rust
347 lines
11 KiB
Rust
use async_trait::async_trait;
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
|
|
use crate::bus::message::ContentBlock;
|
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
|
use super::traits::Usage;
|
|
use std::sync::Arc;
|
|
use crate::storage::Storage;
|
|
|
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
|
blocks.iter().map(|b| match b {
|
|
ContentBlock::Text { text } => {
|
|
serde_json::json!({ "type": "text", "text": text })
|
|
}
|
|
ContentBlock::ImageUrl { image_url } => {
|
|
convert_image_url_to_anthropic(&image_url.url)
|
|
}
|
|
}).collect()
|
|
}
|
|
|
|
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
|
// data:image/png;base64,... -> Anthropic image block
|
|
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
|
|
.ok()
|
|
.and_then(|re| re.captures(url))
|
|
{
|
|
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
|
|
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
|
|
return serde_json::json!({
|
|
"type": "image",
|
|
"source": {
|
|
"type": "base64",
|
|
"media_type": media_type,
|
|
"data": data
|
|
}
|
|
});
|
|
}
|
|
// Regular URL -> Anthropic image block with url source
|
|
serde_json::json!({
|
|
"type": "image",
|
|
"source": {
|
|
"type": "url",
|
|
"url": url
|
|
}
|
|
})
|
|
}
|
|
|
|
pub struct AnthropicProvider {
|
|
client: Client,
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
storage: Option<Arc<Storage>>,
|
|
}
|
|
|
|
impl AnthropicProvider {
|
|
pub fn new(
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
) -> Self {
|
|
Self {
|
|
client: Client::new(),
|
|
name,
|
|
api_key,
|
|
base_url,
|
|
extra_headers,
|
|
model_id,
|
|
temperature,
|
|
max_tokens,
|
|
model_extra,
|
|
storage: None,
|
|
}
|
|
}
|
|
|
|
pub fn set_storage(&mut self, storage: Arc<Storage>) {
|
|
self.storage = Some(storage);
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct AnthropicRequest {
|
|
model: String,
|
|
messages: Vec<AnthropicMessage>,
|
|
max_tokens: u32,
|
|
temperature: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<AnthropicTool>>,
|
|
#[serde(flatten)]
|
|
extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct AnthropicMessage {
|
|
role: String,
|
|
content: Vec<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct AnthropicTool {
|
|
name: String,
|
|
description: String,
|
|
input_schema: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct AnthropicResponse {
|
|
id: Option<String>,
|
|
model: Option<String>,
|
|
#[serde(default)]
|
|
content: Vec<AnthropicContent>,
|
|
#[serde(default)]
|
|
usage: Option<AnthropicUsage>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
enum AnthropicContent {
|
|
Text {
|
|
#[serde(alias = "content")]
|
|
text: String,
|
|
},
|
|
Thinking {
|
|
#[serde(alias = "content")]
|
|
thinking: String,
|
|
},
|
|
#[serde(rename = "tool_use")]
|
|
ToolUse {
|
|
id: String,
|
|
name: String,
|
|
#[serde(alias = "arguments")]
|
|
input: serde_json::Value,
|
|
},
|
|
#[serde(other)]
|
|
Unknown,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct AnthropicUsage {
|
|
#[serde(default)]
|
|
input_tokens: u32,
|
|
#[serde(default)]
|
|
output_tokens: u32,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for AnthropicProvider {
|
|
async fn chat(
|
|
&self,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
let start = std::time::Instant::now();
|
|
let url = format!("{}/v1/messages", self.base_url);
|
|
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
|
|
|
let tools = request.tools.map(|tools| {
|
|
tools
|
|
.iter()
|
|
.map(|t: &Tool| AnthropicTool {
|
|
name: t.function.name.clone(),
|
|
description: t.function.description.clone(),
|
|
input_schema: t.function.parameters.clone(),
|
|
})
|
|
.collect()
|
|
});
|
|
|
|
let body = AnthropicRequest {
|
|
model: self.model_id.clone(),
|
|
messages: request
|
|
.messages
|
|
.iter()
|
|
.map(|m| {
|
|
let role = if m.role == "tool" {
|
|
// Anthropic uses "user" role for tool result messages
|
|
"user".to_string()
|
|
} else {
|
|
m.role.clone()
|
|
};
|
|
let content = if let Some(ref tc_id) = m.tool_call_id {
|
|
// Tool result: wrap as tool_result content block
|
|
let output = m.content.iter()
|
|
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
|
|
.collect::<Vec<_>>()
|
|
.join("");
|
|
vec![serde_json::json!({
|
|
"type": "tool_result",
|
|
"tool_use_id": tc_id,
|
|
"content": output,
|
|
})]
|
|
} else {
|
|
let mut blocks = convert_content_blocks(&m.content);
|
|
// Append tool_use blocks from assistant messages with tool calls
|
|
if let Some(tool_calls) = m.tool_calls.as_ref().filter(|c| !c.is_empty()) {
|
|
for tc in tool_calls {
|
|
blocks.push(serde_json::json!({
|
|
"type": "tool_use",
|
|
"id": tc.id,
|
|
"name": tc.name,
|
|
"input": tc.arguments,
|
|
}));
|
|
}
|
|
}
|
|
blocks
|
|
};
|
|
AnthropicMessage { role, content }
|
|
})
|
|
.collect(),
|
|
max_tokens,
|
|
temperature: request.temperature.or(self.temperature),
|
|
tools,
|
|
extra: self.model_extra.clone(),
|
|
};
|
|
|
|
let mut req_builder = self
|
|
.client
|
|
.post(&url)
|
|
.header("x-api-key", &self.api_key)
|
|
.header("anthropic-version", "2023-06-01")
|
|
.header("Content-Type", "application/json");
|
|
|
|
for (key, value) in &self.extra_headers {
|
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
}
|
|
|
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
|
|
|
let resp = req_builder.json(&body).send().await?;
|
|
|
|
let status = resp.status();
|
|
let body_text = resp.text().await?;
|
|
tracing::debug!(status = %status, resp_body = %body_text, "LLM response");
|
|
|
|
if !status.is_success() {
|
|
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
|
|
.ok()
|
|
.and_then(|v| {
|
|
v.get("error")
|
|
.and_then(|e| e.get("message"))
|
|
.and_then(|m| m.as_str())
|
|
.map(|s| s.to_string())
|
|
})
|
|
.unwrap_or_else(|| body_text.clone());
|
|
if let Some(ref storage) = self.storage {
|
|
let _ = storage.append_llm_call(
|
|
&self.name, &self.model_id, &req_body_str,
|
|
Some(&body_text), Some(&error_msg),
|
|
start.elapsed().as_millis() as u64,
|
|
).await;
|
|
}
|
|
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
|
}
|
|
|
|
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
|
.map_err(|e| {
|
|
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
|
if let Some(ref storage) = self.storage {
|
|
let name = self.name.clone();
|
|
let model = self.model_id.clone();
|
|
let req = req_body_str.clone();
|
|
let resp_body = body_text.clone();
|
|
let dur = start.elapsed().as_millis() as u64;
|
|
let err = err_msg.clone();
|
|
let s = storage.clone();
|
|
tokio::spawn(async move {
|
|
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
|
});
|
|
}
|
|
err_msg
|
|
})?;
|
|
|
|
let mut content = String::new();
|
|
let mut tool_calls = Vec::new();
|
|
|
|
for c in &anthropic_resp.content {
|
|
match c {
|
|
AnthropicContent::Text { text } => {
|
|
if !text.is_empty() {
|
|
if !content.is_empty() {
|
|
content.push('\n');
|
|
}
|
|
content.push_str(text);
|
|
}
|
|
}
|
|
AnthropicContent::Thinking { .. } => {}
|
|
AnthropicContent::Unknown => {}
|
|
AnthropicContent::ToolUse { id, name, input } => {
|
|
tool_calls.push(ToolCall {
|
|
id: id.clone(),
|
|
name: name.clone(),
|
|
arguments: input.clone(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
let response = ChatCompletionResponse {
|
|
id: anthropic_resp.id.unwrap_or_default(),
|
|
model: anthropic_resp.model.unwrap_or_default(),
|
|
content,
|
|
tool_calls,
|
|
usage: Usage {
|
|
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
|
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
|
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
|
},
|
|
};
|
|
|
|
if let Some(ref storage) = self.storage {
|
|
let _ = storage.append_llm_call(
|
|
&self.name,
|
|
&self.model_id,
|
|
&req_body_str,
|
|
Some(&body_text),
|
|
None,
|
|
start.elapsed().as_millis() as u64,
|
|
).await;
|
|
}
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
fn ptype(&self) -> &str {
|
|
"anthropic"
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn model_id(&self) -> &str {
|
|
&self.model_id
|
|
}
|
|
}
|