402 lines
12 KiB
Rust
402 lines
12 KiB
Rust
use async_trait::async_trait;
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::collections::HashMap;
|
||
use std::time::Duration;
|
||
|
||
use super::traits::Usage;
|
||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||
use crate::domain::messages::ContentBlock;
|
||
|
||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
|
||
|
||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||
let mut details = vec![error.to_string()];
|
||
let mut current = error.source();
|
||
|
||
while let Some(source) = current {
|
||
details.push(source.to_string());
|
||
current = source.source();
|
||
}
|
||
|
||
details.join("\ncaused by: ")
|
||
}
|
||
|
||
fn serialize_content_blocks<S>(
|
||
blocks: &[serde_json::Value],
|
||
serializer: S,
|
||
) -> Result<S::Ok, S::Error>
|
||
where
|
||
S: serde::Serializer,
|
||
{
|
||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||
}
|
||
|
||
fn convert_content_blocks(
|
||
supports_images: bool,
|
||
provider_name: &str,
|
||
model_id: &str,
|
||
blocks: &[ContentBlock],
|
||
message_idx: usize,
|
||
) -> Vec<serde_json::Value> {
|
||
// 检查是否有图片且模型不支持
|
||
if !supports_images {
|
||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||
|
||
if has_images {
|
||
let image_count = blocks
|
||
.iter()
|
||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||
.count();
|
||
|
||
tracing::warn!(
|
||
provider = %provider_name,
|
||
model = %model_id,
|
||
filtered_images = image_count,
|
||
message_idx,
|
||
"模型不支持图片;将图片转换为通知文本"
|
||
);
|
||
|
||
// 复用通知格式,将图片转换为文本通知
|
||
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
|
||
let mut notices: Vec<String> = Vec::new();
|
||
let mut image_idx = 0;
|
||
|
||
for block in blocks.iter() {
|
||
match block {
|
||
ContentBlock::Text { text } => {
|
||
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
|
||
}
|
||
ContentBlock::ImageUrl { .. } => {
|
||
image_idx += 1;
|
||
notices.push(format!(
|
||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||
image_idx
|
||
));
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加通知文本块
|
||
if !notices.is_empty() {
|
||
let notice_text = format!(
|
||
"[系统提示] 以下图片未能成功入模:\n{}",
|
||
notices.join("\n")
|
||
);
|
||
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
|
||
}
|
||
|
||
return converted_blocks;
|
||
}
|
||
}
|
||
|
||
// 原有逻辑 - 模型支持图片,正常转换
|
||
blocks
|
||
.iter()
|
||
.map(|b| match b {
|
||
ContentBlock::Text { text } => {
|
||
serde_json::json!({ "type": "text", "text": text })
|
||
}
|
||
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||
// data:image/png;base64,... -> Anthropic image block
|
||
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
|
||
.ok()
|
||
.and_then(|re| re.captures(url))
|
||
{
|
||
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
|
||
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
|
||
return serde_json::json!({
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": media_type,
|
||
"data": data
|
||
}
|
||
});
|
||
}
|
||
// Regular URL -> Anthropic image block with url source
|
||
serde_json::json!({
|
||
"type": "image",
|
||
"source": {
|
||
"type": "url",
|
||
"url": url
|
||
}
|
||
})
|
||
}
|
||
|
||
pub struct AnthropicProvider {
|
||
client: Client,
|
||
name: String,
|
||
api_key: String,
|
||
base_url: String,
|
||
extra_headers: HashMap<String, String>,
|
||
llm_timeout_secs: u64,
|
||
model_id: String,
|
||
temperature: Option<f32>,
|
||
max_tokens: Option<u32>,
|
||
model_extra: HashMap<String, serde_json::Value>,
|
||
}
|
||
|
||
impl AnthropicProvider {
|
||
pub fn new(
|
||
name: String,
|
||
api_key: String,
|
||
base_url: String,
|
||
extra_headers: HashMap<String, String>,
|
||
llm_timeout_secs: u64,
|
||
model_id: String,
|
||
temperature: Option<f32>,
|
||
max_tokens: Option<u32>,
|
||
model_extra: HashMap<String, serde_json::Value>,
|
||
) -> Self {
|
||
let client = Client::builder()
|
||
.timeout(Duration::from_secs(llm_timeout_secs))
|
||
.build()
|
||
.unwrap_or_else(|_| Client::new());
|
||
|
||
Self {
|
||
client,
|
||
name,
|
||
api_key,
|
||
base_url,
|
||
extra_headers,
|
||
llm_timeout_secs,
|
||
model_id,
|
||
temperature,
|
||
max_tokens,
|
||
model_extra,
|
||
}
|
||
}
|
||
|
||
/// 检查模型是否支持指定内容类型
|
||
/// 默认支持所有类型(text, image)
|
||
fn supports_content_type(&self, content_type: &str) -> bool {
|
||
self.model_extra
|
||
.get("supported_content_types")
|
||
.and_then(|value| value.as_array())
|
||
.map(|types| {
|
||
types.iter().any(|t| t.as_str() == Some(content_type))
|
||
})
|
||
.unwrap_or(true)
|
||
}
|
||
|
||
/// 检查模型是否支持图片
|
||
fn supports_images(&self) -> bool {
|
||
self.supports_content_type("image")
|
||
}
|
||
|
||
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
|
||
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
|
||
self.model_extra
|
||
.iter()
|
||
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
|
||
.map(|(k, v)| (k.clone(), v.clone()))
|
||
.collect()
|
||
}
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct AnthropicRequest {
|
||
model: String,
|
||
messages: Vec<AnthropicMessage>,
|
||
max_tokens: u32,
|
||
temperature: Option<f32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
tools: Option<Vec<AnthropicTool>>,
|
||
#[serde(flatten)]
|
||
extra: HashMap<String, serde_json::Value>,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct AnthropicMessage {
|
||
role: String,
|
||
#[serde(serialize_with = "serialize_content_blocks")]
|
||
content: Vec<serde_json::Value>,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct AnthropicTool {
|
||
name: String,
|
||
description: String,
|
||
input_schema: serde_json::Value,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct AnthropicResponse {
|
||
id: String,
|
||
model: String,
|
||
content: Vec<AnthropicContent>,
|
||
usage: AnthropicUsage,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
enum AnthropicContent {
|
||
Text {
|
||
text: String,
|
||
},
|
||
#[allow(dead_code)]
|
||
Thinking {
|
||
thinking: String,
|
||
},
|
||
#[serde(rename = "tool_use")]
|
||
ToolUse {
|
||
id: String,
|
||
name: String,
|
||
input: serde_json::Value,
|
||
},
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct AnthropicUsage {
|
||
input_tokens: u32,
|
||
output_tokens: u32,
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LLMProvider for AnthropicProvider {
|
||
async fn chat(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||
let url = format!("{}/v1/messages", self.base_url);
|
||
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
||
|
||
let tools = request.tools.map(|tools| {
|
||
tools
|
||
.iter()
|
||
.map(|t: &Tool| AnthropicTool {
|
||
name: t.function.name.clone(),
|
||
description: t.function.description.clone(),
|
||
input_schema: t.function.parameters.clone(),
|
||
})
|
||
.collect()
|
||
});
|
||
|
||
let body = AnthropicRequest {
|
||
model: self.model_id.clone(),
|
||
messages: request
|
||
.messages
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(i, m)| AnthropicMessage {
|
||
role: m.role.clone(),
|
||
content: convert_content_blocks(
|
||
self.supports_images(),
|
||
&self.name,
|
||
&self.model_id,
|
||
&m.content,
|
||
i,
|
||
),
|
||
})
|
||
.collect(),
|
||
max_tokens,
|
||
temperature: request.temperature.or(self.temperature),
|
||
tools,
|
||
extra: self.request_model_extra(),
|
||
};
|
||
|
||
let mut req_builder = self
|
||
.client
|
||
.post(&url)
|
||
.header("x-api-key", &self.api_key)
|
||
.header("anthropic-version", "2023-06-01")
|
||
.header("Content-Type", "application/json");
|
||
|
||
for (key, value) in &self.extra_headers {
|
||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||
}
|
||
|
||
let resp = req_builder.json(&body).send().await?;
|
||
let status = resp.status();
|
||
let text = resp.text().await?;
|
||
|
||
if !status.is_success() {
|
||
tracing::error!(
|
||
provider = %self.name,
|
||
model = %self.model_id,
|
||
url = %url,
|
||
status = %status,
|
||
response_len = text.len(),
|
||
response_body = %text,
|
||
"Anthropic API request failed"
|
||
);
|
||
return Err(format!("API error {}: {}", status, text).into());
|
||
}
|
||
|
||
#[cfg(debug_assertions)]
|
||
{
|
||
let resp_preview: String = text.chars().take(100).collect();
|
||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "Anthropic response (first 100 chars shown)");
|
||
}
|
||
|
||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&text).map_err(|e| {
|
||
tracing::error!(
|
||
provider = %self.name,
|
||
model = %self.model_id,
|
||
url = %url,
|
||
error = %format_error_chain(&e),
|
||
response_len = text.len(),
|
||
response_body = %text,
|
||
"Failed to decode Anthropic response"
|
||
);
|
||
format!("decode error: {} | body: {}", e, &text)
|
||
})?;
|
||
|
||
let mut content = String::new();
|
||
let mut tool_calls = Vec::new();
|
||
|
||
for c in &anthropic_resp.content {
|
||
match c {
|
||
AnthropicContent::Text { text } => {
|
||
if !text.is_empty() {
|
||
if !content.is_empty() {
|
||
content.push('\n');
|
||
}
|
||
content.push_str(text);
|
||
}
|
||
}
|
||
AnthropicContent::Thinking { .. } => {}
|
||
AnthropicContent::ToolUse { id, name, input } => {
|
||
tool_calls.push(ToolCall {
|
||
id: id.clone(),
|
||
name: name.clone(),
|
||
arguments: input.clone(),
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(ChatCompletionResponse {
|
||
id: anthropic_resp.id,
|
||
model: anthropic_resp.model,
|
||
content,
|
||
reasoning_content: None,
|
||
tool_calls,
|
||
usage: Usage {
|
||
prompt_tokens: anthropic_resp.usage.input_tokens,
|
||
completion_tokens: anthropic_resp.usage.output_tokens,
|
||
total_tokens: anthropic_resp.usage.input_tokens
|
||
+ anthropic_resp.usage.output_tokens,
|
||
},
|
||
})
|
||
}
|
||
|
||
fn ptype(&self) -> &str {
|
||
"anthropic"
|
||
}
|
||
|
||
fn name(&self) -> &str {
|
||
&self.name
|
||
}
|
||
|
||
fn model_id(&self) -> &str {
|
||
&self.model_id
|
||
}
|
||
}
|