PicoBot/src/providers/anthropic.rs

347 lines
11 KiB
Rust

use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => {
convert_image_url_to_anthropic(&image_url.url)
}
}).collect()
}
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
// data:image/png;base64,... -> Anthropic image block
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
.ok()
.and_then(|re| re.captures(url))
{
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
return serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data
}
});
}
// Regular URL -> Anthropic image block with url source
serde_json::json!({
"type": "image",
"source": {
"type": "url",
"url": url
}
})
}
pub struct AnthropicProvider {
client: Client,
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
storage: Option<Arc<Storage>>,
}
impl AnthropicProvider {
pub fn new(
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
Self {
client: Client::new(),
name,
api_key,
base_url,
extra_headers,
model_id,
temperature,
max_tokens,
model_extra,
storage: None,
}
}
pub fn set_storage(&mut self, storage: Arc<Storage>) {
self.storage = Some(storage);
}
}
#[derive(Serialize)]
struct AnthropicRequest {
model: String,
messages: Vec<AnthropicMessage>,
max_tokens: u32,
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(Serialize)]
struct AnthropicMessage {
role: String,
content: Vec<serde_json::Value>,
}
#[derive(Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Deserialize)]
struct AnthropicResponse {
id: Option<String>,
model: Option<String>,
#[serde(default)]
content: Vec<AnthropicContent>,
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent {
Text {
#[serde(alias = "content")]
text: String,
},
Thinking {
#[serde(alias = "content")]
thinking: String,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
#[serde(alias = "arguments")]
input: serde_json::Value,
},
#[serde(other)]
Unknown,
}
#[derive(Deserialize)]
struct AnthropicUsage {
#[serde(default)]
input_tokens: u32,
#[serde(default)]
output_tokens: u32,
}
#[async_trait]
impl LLMProvider for AnthropicProvider {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let start = std::time::Instant::now();
let url = format!("{}/v1/messages", self.base_url);
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
let tools = request.tools.map(|tools| {
tools
.iter()
.map(|t: &Tool| AnthropicTool {
name: t.function.name.clone(),
description: t.function.description.clone(),
input_schema: t.function.parameters.clone(),
})
.collect()
});
let body = AnthropicRequest {
model: self.model_id.clone(),
messages: request
.messages
.iter()
.map(|m| {
let role = if m.role == "tool" {
// Anthropic uses "user" role for tool result messages
"user".to_string()
} else {
m.role.clone()
};
let content = if let Some(ref tc_id) = m.tool_call_id {
// Tool result: wrap as tool_result content block
let output = m.content.iter()
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
.collect::<Vec<_>>()
.join("");
vec![serde_json::json!({
"type": "tool_result",
"tool_use_id": tc_id,
"content": output,
})]
} else {
let mut blocks = convert_content_blocks(&m.content);
// Append tool_use blocks from assistant messages with tool calls
if let Some(tool_calls) = m.tool_calls.as_ref().filter(|c| !c.is_empty()) {
for tc in tool_calls {
blocks.push(serde_json::json!({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments,
}));
}
}
blocks
};
AnthropicMessage { role, content }
})
.collect(),
max_tokens,
temperature: request.temperature.or(self.temperature),
tools,
extra: self.model_extra.clone(),
};
let mut req_builder = self
.client
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await?;
let status = resp.status();
let body_text = resp.text().await?;
tracing::debug!(status = %status, resp_body = %body_text, "LLM response");
if !status.is_success() {
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
.ok()
.and_then(|v| {
v.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| body_text.clone());
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&body_text), Some(&error_msg),
start.elapsed().as_millis() as u64,
).await;
}
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
}
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
.map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
});
}
err_msg
})?;
let mut content = String::new();
let mut tool_calls = Vec::new();
for c in &anthropic_resp.content {
match c {
AnthropicContent::Text { text } => {
if !text.is_empty() {
if !content.is_empty() {
content.push('\n');
}
content.push_str(text);
}
}
AnthropicContent::Thinking { .. } => {}
AnthropicContent::Unknown => {}
AnthropicContent::ToolUse { id, name, input } => {
tool_calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
arguments: input.clone(),
});
}
}
}
let response = ChatCompletionResponse {
id: anthropic_resp.id.unwrap_or_default(),
model: anthropic_resp.model.unwrap_or_default(),
content,
tool_calls,
usage: Usage {
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
},
};
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)
}
fn ptype(&self) -> &str {
"anthropic"
}
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.model_id
}
}