From 44d9171b8649f64c453aa696d584b121475e9365 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 23 May 2026 14:08:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E5=86=85=E5=AE=B9?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E8=BD=AC=E6=8D=A2=E4=B8=BA=E9=80=9A=E7=9F=A5?= =?UTF-8?q?=E6=96=87=E6=9C=AC=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cli/init.rs | 20 ---- src/providers/anthropic.rs | 101 +++++++++++++++++- src/providers/openai.rs | 209 +++++++++++++++++++++++++++++++++++-- src/skills/mod.rs | 2 + 4 files changed, 302 insertions(+), 30 deletions(-) diff --git a/src/cli/init.rs b/src/cli/init.rs index 3616f48..bbe0f96 100644 --- a/src/cli/init.rs +++ b/src/cli/init.rs @@ -134,26 +134,6 @@ impl InitWizard { } } - async fn prompt_yes_no(&mut self, label: &str, default: bool) -> Result { - let default_str = if default { "yes" } else { "no" }; - println!("{} [yes/no, default: {}]: ", label, default_str); - self.write.flush().await?; - - let mut line = String::new(); - let bytes_read = self.read.read_line(&mut line).await?; - - if bytes_read == 0 { - return Err(InitError::InputError("EOF reached".to_string())); - } - - let input = line.trim().to_lowercase(); - if input.is_empty() { - Ok(default) - } else { - Ok(input == "yes" || input == "y" || input == "1") - } - } - async fn prompt_select( &mut self, label: &str, diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index a2632fe..d3ba937 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -8,6 +8,8 @@ use super::traits::Usage; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use crate::domain::messages::ContentBlock; +const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"]; + fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; let mut current = error.source(); @@ -30,7 +32,65 @@ where serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string())) } -fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec { +fn convert_content_blocks( + supports_images: bool, + provider_name: &str, + model_id: &str, + blocks: &[ContentBlock], + message_idx: usize, +) -> Vec { + // 检查是否有图片且模型不支持 + if !supports_images { + let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. })); + + if has_images { + let image_count = blocks + .iter() + .filter(|b| matches!(b, ContentBlock::ImageUrl { .. })) + .count(); + + tracing::warn!( + provider = %provider_name, + model = %model_id, + filtered_images = image_count, + message_idx, + "模型不支持图片;将图片转换为通知文本" + ); + + // 复用通知格式,将图片转换为文本通知 + let mut converted_blocks: Vec = Vec::new(); + let mut notices: Vec = Vec::new(); + let mut image_idx = 0; + + for block in blocks.iter() { + match block { + ContentBlock::Text { text } => { + converted_blocks.push(serde_json::json!({ "type": "text", "text": text })); + } + ContentBlock::ImageUrl { .. } => { + image_idx += 1; + notices.push(format!( + "- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。", + image_idx + )); + } + } + } + + // 添加通知文本块 + if !notices.is_empty() { + let notice_text = format!( + "[系统提示] 以下图片未能成功入模:\n{}", + notices.join("\n") + ); + converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text })); + } + + return converted_blocks; + } + } + + // 原有逻辑 - 模型支持图片,正常转换 blocks .iter() .map(|b| match b { @@ -112,6 +172,32 @@ impl AnthropicProvider { model_extra, } } + + /// 检查模型是否支持指定内容类型 + /// 默认支持所有类型(text, image) + fn supports_content_type(&self, content_type: &str) -> bool { + self.model_extra + .get("supported_content_types") + .and_then(|value| value.as_array()) + .map(|types| { + types.iter().any(|t| t.as_str() == Some(content_type)) + }) + .unwrap_or(true) + } + + /// 检查模型是否支持图片 + fn supports_images(&self) -> bool { + self.supports_content_type("image") + } + + /// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段 + fn request_model_extra(&self) -> HashMap { + self.model_extra + .iter() + .filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str())) + .map(|(k, v)| (k.clone(), v.clone())) + .collect() + } } #[derive(Serialize)] @@ -197,15 +283,22 @@ impl LLMProvider for AnthropicProvider { messages: request .messages .iter() - .map(|m| AnthropicMessage { + .enumerate() + .map(|(i, m)| AnthropicMessage { role: m.role.clone(), - content: convert_content_blocks(&m.content), + content: convert_content_blocks( + self.supports_images(), + &self.name, + &self.model_id, + &m.content, + i, + ), }) .collect(), max_tokens, temperature: request.temperature.or(self.temperature), tools, - extra: self.model_extra.clone(), + extra: self.request_model_extra(), }; let mut req_builder = self diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 35e93a7..15dde1d 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -10,7 +10,7 @@ use super::traits::Usage; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use crate::domain::messages::ContentBlock; -const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; +const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"]; /// 流式响应中的工具调用增量 #[derive(Debug, Default)] @@ -139,7 +139,75 @@ fn format_transport_error_context( ) } -fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { +fn convert_content_blocks( + supports_images: bool, + provider_name: &str, + model_id: &str, + blocks: &[ContentBlock], + message_idx: usize, +) -> Value { + // 检查是否有图片且模型不支持 + if !supports_images { + let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. })); + + if has_images { + let image_count = blocks.iter() + .filter(|b| matches!(b, ContentBlock::ImageUrl { .. })) + .count(); + + tracing::warn!( + provider = %provider_name, + model = %model_id, + filtered_images = image_count, + message_idx, + "模型不支持图片;将图片转换为通知文本" + ); + + // 复用通知格式,将图片转换为文本通知 + let mut converted_blocks: Vec = Vec::new(); + let mut notices: Vec = Vec::new(); + let mut image_idx = 0; + + for block in blocks.iter() { + match block { + ContentBlock::Text { text } => { + converted_blocks.push(json!({ "type": "text", "text": text })); + } + ContentBlock::ImageUrl { .. } => { + image_idx += 1; + notices.push(format!( + "- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。", + image_idx + )); + } + } + } + + // 添加通知文本块 + if !notices.is_empty() { + let notice_text = format!( + "[系统提示] 以下图片未能成功入模:\n{}", + notices.join("\n") + ); + converted_blocks.push(json!({ "type": "text", "text": notice_text })); + } + + // 如果只有一个文本块且没有通知,返回字符串形式 + if converted_blocks.len() == 1 { + if let Some(block) = converted_blocks.first() { + if block.get("type").and_then(|t| t.as_str()) == Some("text") { + if let Some(text) = block.get("text").and_then(|t| t.as_str()) { + return Value::String(text.to_string()); + } + } + } + } + + return Value::Array(converted_blocks); + } + } + + // 原有逻辑 - 模型支持图片,正常转换 if blocks.len() == 1 { if let ContentBlock::Text { text } = &blocks[0] { return Value::String(text.clone()); @@ -224,6 +292,23 @@ impl OpenAIProvider { .unwrap_or(true) } + /// 检查模型是否支持指定内容类型 + /// 默认支持所有类型(text, image) + fn supports_content_type(&self, content_type: &str) -> bool { + self.model_extra + .get("supported_content_types") + .and_then(|value| value.as_array()) + .map(|types| { + types.iter().any(|t| t.as_str() == Some(content_type)) + }) + .unwrap_or(true) + } + + /// 检查模型是否支持图片 + fn supports_images(&self) -> bool { + self.supports_content_type("image") + } + fn normalize_tool_arguments(&self, arguments: &Value) -> Value { match arguments { Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), @@ -480,20 +565,21 @@ impl OpenAIProvider { } fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { + let supports_images = self.supports_images(); let mut body = json!({ "model": self.model_id, - "messages": request.messages.iter().map(|m| { + "messages": request.messages.iter().enumerate().map(|(i, m)| { if m.role == "tool" { json!({ "role": m.role, - "content": convert_content_blocks(&m.content), + "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i), "tool_call_id": m.tool_call_id, "name": m.name, }) } else if m.role == "assistant" && m.tool_calls.is_some() { let mut message = json!({ "role": m.role, - "content": convert_content_blocks(&m.content), + "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i), "tool_calls": m.tool_calls.as_ref().map(|calls| { calls.iter().map(|call| json!({ "id": call.id, @@ -514,7 +600,7 @@ impl OpenAIProvider { } else { let mut message = json!({ "role": m.role, - "content": convert_content_blocks(&m.content) + "content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i) }); if m.role == "assistant" { @@ -1076,4 +1162,115 @@ mod tests { assert_eq!(response.tool_calls[1].id, "call_2"); assert_eq!(response.tool_calls[1].name, "get_time"); } + + #[test] + fn test_supports_images_default_true() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::new(), + ); + + assert!(provider.supports_images()); + } + + #[test] + fn test_supports_images_disabled_via_config() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::from([( + "supported_content_types".to_string(), + Value::Array(vec![Value::String("text".to_string())]), + )]), + ); + + assert!(!provider.supports_images()); + } + + #[test] + fn test_convert_content_blocks_converts_images_to_notice_when_disabled() { + let blocks = vec![ + ContentBlock::text("hello"), + ContentBlock::image_url("data:image/png;base64,abc123"), + ContentBlock::text("world"), + ]; + + let result = convert_content_blocks(false, "test", "test-model", &blocks, 0); + + // 应该是数组形式 + let arr = result.as_array().unwrap(); + assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块 + + // 检查通知内容 + let notice_block = arr[2].as_object().unwrap(); + assert_eq!(notice_block["type"], "text"); + let notice_text = notice_block["text"].as_str().unwrap(); + assert!(notice_text.contains("[系统提示] 以下图片未能成功入模")); + assert!(notice_text.contains("第 1 张图片")); + assert!(notice_text.contains("当前模型不支持图片输入")); + } + + #[test] + fn test_convert_content_blocks_keeps_images_when_enabled() { + let blocks = vec![ + ContentBlock::text("hello"), + ContentBlock::image_url("data:image/png;base64,abc123"), + ]; + + let result = convert_content_blocks(true, "test", "test-model", &blocks, 0); + + // 应该是数组形式,包含文本和图片 + let arr = result.as_array().unwrap(); + assert_eq!(arr.len(), 2); + assert_eq!(arr[0]["type"], "text"); + assert_eq!(arr[1]["type"], "image_url"); + } + + #[test] + fn test_build_request_body_omits_supported_content_types_from_api() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::from([ + ( + "supported_content_types".to_string(), + Value::Array(vec![Value::String("text".to_string())]), + ), + ("custom_param".to_string(), Value::String("value".to_string())), + ]), + ); + + let request = ChatCompletionRequest { + messages: vec![Message::user("hello")], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + + // supported_content_types 不应该发送到 API + assert!(body.get("supported_content_types").is_none()); + // custom_param 应该保留 + assert_eq!(body["custom_param"], Value::String("value".to_string())); + } } diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 577fc73..4feaa48 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -341,7 +341,9 @@ impl SkillSource { #[derive(Debug, Clone)] pub struct SkillCatalog { skills: Vec, + #[allow(dead_code)] max_index_chars: usize, + #[allow(dead_code)] max_listed_skills: usize, }