feat: 更新内容处理逻辑,支持图片转换为通知文本并添加模型支持检查

This commit is contained in:
ooodc 2026-05-23 14:08:06 +08:00
parent a74c801945
commit 44d9171b86
4 changed files with 302 additions and 30 deletions

View File

@ -134,26 +134,6 @@ impl InitWizard {
} }
} }
async fn prompt_yes_no(&mut self, label: &str, default: bool) -> Result<bool, InitError> {
let default_str = if default { "yes" } else { "no" };
println!("{} [yes/no, default: {}]: ", label, default_str);
self.write.flush().await?;
let mut line = String::new();
let bytes_read = self.read.read_line(&mut line).await?;
if bytes_read == 0 {
return Err(InitError::InputError("EOF reached".to_string()));
}
let input = line.trim().to_lowercase();
if input.is_empty() {
Ok(default)
} else {
Ok(input == "yes" || input == "y" || input == "1")
}
}
async fn prompt_select( async fn prompt_select(
&mut self, &mut self,
label: &str, label: &str,

View File

@ -8,6 +8,8 @@ use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::domain::messages::ContentBlock; use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()]; let mut details = vec![error.to_string()];
let mut current = error.source(); let mut current = error.source();
@ -30,7 +32,65 @@ where
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string())) serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
} }
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> { fn convert_content_blocks(
supports_images: bool,
provider_name: &str,
model_id: &str,
blocks: &[ContentBlock],
message_idx: usize,
) -> Vec<serde_json::Value> {
// 检查是否有图片且模型不支持
if !supports_images {
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_images {
let image_count = blocks
.iter()
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
.count();
tracing::warn!(
provider = %provider_name,
model = %model_id,
filtered_images = image_count,
message_idx,
"模型不支持图片;将图片转换为通知文本"
);
// 复用通知格式,将图片转换为文本通知
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
let mut notices: Vec<String> = Vec::new();
let mut image_idx = 0;
for block in blocks.iter() {
match block {
ContentBlock::Text { text } => {
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentBlock::ImageUrl { .. } => {
image_idx += 1;
notices.push(format!(
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
image_idx
));
}
}
}
// 添加通知文本块
if !notices.is_empty() {
let notice_text = format!(
"[系统提示] 以下图片未能成功入模:\n{}",
notices.join("\n")
);
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
}
return converted_blocks;
}
}
// 原有逻辑 - 模型支持图片,正常转换
blocks blocks
.iter() .iter()
.map(|b| match b { .map(|b| match b {
@ -112,6 +172,32 @@ impl AnthropicProvider {
model_extra, model_extra,
} }
} }
/// 检查模型是否支持指定内容类型
/// 默认支持所有类型text, image
fn supports_content_type(&self, content_type: &str) -> bool {
self.model_extra
.get("supported_content_types")
.and_then(|value| value.as_array())
.map(|types| {
types.iter().any(|t| t.as_str() == Some(content_type))
})
.unwrap_or(true)
}
/// 检查模型是否支持图片
fn supports_images(&self) -> bool {
self.supports_content_type("image")
}
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
self.model_extra
.iter()
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -197,15 +283,22 @@ impl LLMProvider for AnthropicProvider {
messages: request messages: request
.messages .messages
.iter() .iter()
.map(|m| AnthropicMessage { .enumerate()
.map(|(i, m)| AnthropicMessage {
role: m.role.clone(), role: m.role.clone(),
content: convert_content_blocks(&m.content), content: convert_content_blocks(
self.supports_images(),
&self.name,
&self.model_id,
&m.content,
i,
),
}) })
.collect(), .collect(),
max_tokens, max_tokens,
temperature: request.temperature.or(self.temperature), temperature: request.temperature.or(self.temperature),
tools, tools,
extra: self.model_extra.clone(), extra: self.request_model_extra(),
}; };
let mut req_builder = self let mut req_builder = self

View File

@ -10,7 +10,7 @@ use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::domain::messages::ContentBlock; use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
/// 流式响应中的工具调用增量 /// 流式响应中的工具调用增量
#[derive(Debug, Default)] #[derive(Debug, Default)]
@ -139,7 +139,75 @@ fn format_transport_error_context(
) )
} }
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { fn convert_content_blocks(
supports_images: bool,
provider_name: &str,
model_id: &str,
blocks: &[ContentBlock],
message_idx: usize,
) -> Value {
// 检查是否有图片且模型不支持
if !supports_images {
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_images {
let image_count = blocks.iter()
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
.count();
tracing::warn!(
provider = %provider_name,
model = %model_id,
filtered_images = image_count,
message_idx,
"模型不支持图片;将图片转换为通知文本"
);
// 复用通知格式,将图片转换为文本通知
let mut converted_blocks: Vec<Value> = Vec::new();
let mut notices: Vec<String> = Vec::new();
let mut image_idx = 0;
for block in blocks.iter() {
match block {
ContentBlock::Text { text } => {
converted_blocks.push(json!({ "type": "text", "text": text }));
}
ContentBlock::ImageUrl { .. } => {
image_idx += 1;
notices.push(format!(
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
image_idx
));
}
}
}
// 添加通知文本块
if !notices.is_empty() {
let notice_text = format!(
"[系统提示] 以下图片未能成功入模:\n{}",
notices.join("\n")
);
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
}
// 如果只有一个文本块且没有通知,返回字符串形式
if converted_blocks.len() == 1 {
if let Some(block) = converted_blocks.first() {
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
return Value::String(text.to_string());
}
}
}
}
return Value::Array(converted_blocks);
}
}
// 原有逻辑 - 模型支持图片,正常转换
if blocks.len() == 1 { if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] { if let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone()); return Value::String(text.clone());
@ -224,6 +292,23 @@ impl OpenAIProvider {
.unwrap_or(true) .unwrap_or(true)
} }
/// 检查模型是否支持指定内容类型
/// 默认支持所有类型text, image
fn supports_content_type(&self, content_type: &str) -> bool {
self.model_extra
.get("supported_content_types")
.and_then(|value| value.as_array())
.map(|types| {
types.iter().any(|t| t.as_str() == Some(content_type))
})
.unwrap_or(true)
}
/// 检查模型是否支持图片
fn supports_images(&self) -> bool {
self.supports_content_type("image")
}
fn normalize_tool_arguments(&self, arguments: &Value) -> Value { fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
match arguments { match arguments {
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
@ -480,20 +565,21 @@ impl OpenAIProvider {
} }
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let supports_images = self.supports_images();
let mut body = json!({ let mut body = json!({
"model": self.model_id, "model": self.model_id,
"messages": request.messages.iter().map(|m| { "messages": request.messages.iter().enumerate().map(|(i, m)| {
if m.role == "tool" { if m.role == "tool" {
json!({ json!({
"role": m.role, "role": m.role,
"content": convert_content_blocks(&m.content), "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
"tool_call_id": m.tool_call_id, "tool_call_id": m.tool_call_id,
"name": m.name, "name": m.name,
}) })
} else if m.role == "assistant" && m.tool_calls.is_some() { } else if m.role == "assistant" && m.tool_calls.is_some() {
let mut message = json!({ let mut message = json!({
"role": m.role, "role": m.role,
"content": convert_content_blocks(&m.content), "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
"tool_calls": m.tool_calls.as_ref().map(|calls| { "tool_calls": m.tool_calls.as_ref().map(|calls| {
calls.iter().map(|call| json!({ calls.iter().map(|call| json!({
"id": call.id, "id": call.id,
@ -514,7 +600,7 @@ impl OpenAIProvider {
} else { } else {
let mut message = json!({ let mut message = json!({
"role": m.role, "role": m.role,
"content": convert_content_blocks(&m.content) "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
}); });
if m.role == "assistant" { if m.role == "assistant" {
@ -1076,4 +1162,115 @@ mod tests {
assert_eq!(response.tool_calls[1].id, "call_2"); assert_eq!(response.tool_calls[1].id, "call_2");
assert_eq!(response.tool_calls[1].name, "get_time"); assert_eq!(response.tool_calls[1].name, "get_time");
} }
#[test]
fn test_supports_images_default_true() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
assert!(provider.supports_images());
}
#[test]
fn test_supports_images_disabled_via_config() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([(
"supported_content_types".to_string(),
Value::Array(vec![Value::String("text".to_string())]),
)]),
);
assert!(!provider.supports_images());
}
#[test]
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
let blocks = vec![
ContentBlock::text("hello"),
ContentBlock::image_url("data:image/png;base64,abc123"),
ContentBlock::text("world"),
];
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
// 应该是数组形式
let arr = result.as_array().unwrap();
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
// 检查通知内容
let notice_block = arr[2].as_object().unwrap();
assert_eq!(notice_block["type"], "text");
let notice_text = notice_block["text"].as_str().unwrap();
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
assert!(notice_text.contains("第 1 张图片"));
assert!(notice_text.contains("当前模型不支持图片输入"));
}
#[test]
fn test_convert_content_blocks_keeps_images_when_enabled() {
let blocks = vec![
ContentBlock::text("hello"),
ContentBlock::image_url("data:image/png;base64,abc123"),
];
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
// 应该是数组形式,包含文本和图片
let arr = result.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["type"], "text");
assert_eq!(arr[1]["type"], "image_url");
}
#[test]
fn test_build_request_body_omits_supported_content_types_from_api() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([
(
"supported_content_types".to_string(),
Value::Array(vec![Value::String("text".to_string())]),
),
("custom_param".to_string(), Value::String("value".to_string())),
]),
);
let request = ChatCompletionRequest {
messages: vec![Message::user("hello")],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
// supported_content_types 不应该发送到 API
assert!(body.get("supported_content_types").is_none());
// custom_param 应该保留
assert_eq!(body["custom_param"], Value::String("value".to_string()));
}
} }

View File

@ -341,7 +341,9 @@ impl SkillSource {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SkillCatalog { pub struct SkillCatalog {
skills: Vec<Skill>, skills: Vec<Skill>,
#[allow(dead_code)]
max_index_chars: usize, max_index_chars: usize,
#[allow(dead_code)]
max_listed_skills: usize, max_listed_skills: usize,
} }