PicoBot/src/providers/anthropic.rs

402 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
let mut current = error.source();
while let Some(source) = current {
details.push(source.to_string());
current = source.source();
}
details.join("\ncaused by: ")
}
fn serialize_content_blocks<S>(
blocks: &[serde_json::Value],
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
}
fn convert_content_blocks(
supports_images: bool,
provider_name: &str,
model_id: &str,
blocks: &[ContentBlock],
message_idx: usize,
) -> Vec<serde_json::Value> {
// 检查是否有图片且模型不支持
if !supports_images {
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_images {
let image_count = blocks
.iter()
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
.count();
tracing::warn!(
provider = %provider_name,
model = %model_id,
filtered_images = image_count,
message_idx,
"模型不支持图片;将图片转换为通知文本"
);
// 复用通知格式,将图片转换为文本通知
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
let mut notices: Vec<String> = Vec::new();
let mut image_idx = 0;
for block in blocks.iter() {
match block {
ContentBlock::Text { text } => {
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentBlock::ImageUrl { .. } => {
image_idx += 1;
notices.push(format!(
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
image_idx
));
}
}
}
// 添加通知文本块
if !notices.is_empty() {
let notice_text = format!(
"[系统提示] 以下图片未能成功入模:\n{}",
notices.join("\n")
);
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
}
return converted_blocks;
}
}
// 原有逻辑 - 模型支持图片,正常转换
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
})
.collect()
}
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
// data:image/png;base64,... -> Anthropic image block
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
.ok()
.and_then(|re| re.captures(url))
{
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
return serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data
}
});
}
// Regular URL -> Anthropic image block with url source
serde_json::json!({
"type": "image",
"source": {
"type": "url",
"url": url
}
})
}
pub struct AnthropicProvider {
client: Client,
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
}
impl AnthropicProvider {
pub fn new(
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(llm_timeout_secs))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client,
name,
api_key,
base_url,
extra_headers,
llm_timeout_secs,
model_id,
temperature,
max_tokens,
model_extra,
}
}
/// 检查模型是否支持指定内容类型
/// 默认支持所有类型text, image
fn supports_content_type(&self, content_type: &str) -> bool {
self.model_extra
.get("supported_content_types")
.and_then(|value| value.as_array())
.map(|types| {
types.iter().any(|t| t.as_str() == Some(content_type))
})
.unwrap_or(true)
}
/// 检查模型是否支持图片
fn supports_images(&self) -> bool {
self.supports_content_type("image")
}
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
self.model_extra
.iter()
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
}
#[derive(Serialize)]
struct AnthropicRequest {
model: String,
messages: Vec<AnthropicMessage>,
max_tokens: u32,
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(Serialize)]
struct AnthropicMessage {
role: String,
#[serde(serialize_with = "serialize_content_blocks")]
content: Vec<serde_json::Value>,
}
#[derive(Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Deserialize)]
struct AnthropicResponse {
id: String,
model: String,
content: Vec<AnthropicContent>,
usage: AnthropicUsage,
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent {
Text {
text: String,
},
#[allow(dead_code)]
Thinking {
thinking: String,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Deserialize)]
struct AnthropicUsage {
input_tokens: u32,
output_tokens: u32,
}
#[async_trait]
impl LLMProvider for AnthropicProvider {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let url = format!("{}/v1/messages", self.base_url);
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
let tools = request.tools.map(|tools| {
tools
.iter()
.map(|t: &Tool| AnthropicTool {
name: t.function.name.clone(),
description: t.function.description.clone(),
input_schema: t.function.parameters.clone(),
})
.collect()
});
let body = AnthropicRequest {
model: self.model_id.clone(),
messages: request
.messages
.iter()
.enumerate()
.map(|(i, m)| AnthropicMessage {
role: m.role.clone(),
content: convert_content_blocks(
self.supports_images(),
&self.name,
&self.model_id,
&m.content,
i,
),
})
.collect(),
max_tokens,
temperature: request.temperature.or(self.temperature),
tools,
extra: self.request_model_extra(),
};
let mut req_builder = self
.client
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
status = %status,
response_len = text.len(),
response_body = %text,
"Anthropic API request failed"
);
return Err(format!("API error {}: {}", status, text).into());
}
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "Anthropic response (first 100 chars shown)");
}
let anthropic_resp: AnthropicResponse = serde_json::from_str(&text).map_err(|e| {
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
error = %format_error_chain(&e),
response_len = text.len(),
response_body = %text,
"Failed to decode Anthropic response"
);
format!("decode error: {} | body: {}", e, &text)
})?;
let mut content = String::new();
let mut tool_calls = Vec::new();
for c in &anthropic_resp.content {
match c {
AnthropicContent::Text { text } => {
if !text.is_empty() {
if !content.is_empty() {
content.push('\n');
}
content.push_str(text);
}
}
AnthropicContent::Thinking { .. } => {}
AnthropicContent::ToolUse { id, name, input } => {
tool_calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
arguments: input.clone(),
});
}
}
}
Ok(ChatCompletionResponse {
id: anthropic_resp.id,
model: anthropic_resp.model,
content,
reasoning_content: None,
tool_calls,
usage: Usage {
prompt_tokens: anthropic_resp.usage.input_tokens,
completion_tokens: anthropic_resp.usage.output_tokens,
total_tokens: anthropic_resp.usage.input_tokens
+ anthropic_resp.usage.output_tokens,
},
})
}
fn ptype(&self) -> &str {
"anthropic"
}
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.model_id
}
}