feat: 更新内容处理逻辑,支持图片转换为通知文本并添加模型支持检查
This commit is contained in:
parent
a74c801945
commit
44d9171b86
@ -134,26 +134,6 @@ impl InitWizard {
|
||||
}
|
||||
}
|
||||
|
||||
async fn prompt_yes_no(&mut self, label: &str, default: bool) -> Result<bool, InitError> {
|
||||
let default_str = if default { "yes" } else { "no" };
|
||||
println!("{} [yes/no, default: {}]: ", label, default_str);
|
||||
self.write.flush().await?;
|
||||
|
||||
let mut line = String::new();
|
||||
let bytes_read = self.read.read_line(&mut line).await?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
return Err(InitError::InputError("EOF reached".to_string()));
|
||||
}
|
||||
|
||||
let input = line.trim().to_lowercase();
|
||||
if input.is_empty() {
|
||||
Ok(default)
|
||||
} else {
|
||||
Ok(input == "yes" || input == "y" || input == "1")
|
||||
}
|
||||
}
|
||||
|
||||
async fn prompt_select(
|
||||
&mut self,
|
||||
label: &str,
|
||||
|
||||
@ -8,6 +8,8 @@ use super::traits::Usage;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use crate::domain::messages::ContentBlock;
|
||||
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
|
||||
|
||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||
let mut details = vec![error.to_string()];
|
||||
let mut current = error.source();
|
||||
@ -30,7 +32,65 @@ where
|
||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||
}
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
fn convert_content_blocks(
|
||||
supports_images: bool,
|
||||
provider_name: &str,
|
||||
model_id: &str,
|
||||
blocks: &[ContentBlock],
|
||||
message_idx: usize,
|
||||
) -> Vec<serde_json::Value> {
|
||||
// 检查是否有图片且模型不支持
|
||||
if !supports_images {
|
||||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||
|
||||
if has_images {
|
||||
let image_count = blocks
|
||||
.iter()
|
||||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||||
.count();
|
||||
|
||||
tracing::warn!(
|
||||
provider = %provider_name,
|
||||
model = %model_id,
|
||||
filtered_images = image_count,
|
||||
message_idx,
|
||||
"模型不支持图片;将图片转换为通知文本"
|
||||
);
|
||||
|
||||
// 复用通知格式,将图片转换为文本通知
|
||||
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
|
||||
let mut notices: Vec<String> = Vec::new();
|
||||
let mut image_idx = 0;
|
||||
|
||||
for block in blocks.iter() {
|
||||
match block {
|
||||
ContentBlock::Text { text } => {
|
||||
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentBlock::ImageUrl { .. } => {
|
||||
image_idx += 1;
|
||||
notices.push(format!(
|
||||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||||
image_idx
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加通知文本块
|
||||
if !notices.is_empty() {
|
||||
let notice_text = format!(
|
||||
"[系统提示] 以下图片未能成功入模:\n{}",
|
||||
notices.join("\n")
|
||||
);
|
||||
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
|
||||
}
|
||||
|
||||
return converted_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
// 原有逻辑 - 模型支持图片,正常转换
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
@ -112,6 +172,32 @@ impl AnthropicProvider {
|
||||
model_extra,
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查模型是否支持指定内容类型
|
||||
/// 默认支持所有类型(text, image)
|
||||
fn supports_content_type(&self, content_type: &str) -> bool {
|
||||
self.model_extra
|
||||
.get("supported_content_types")
|
||||
.and_then(|value| value.as_array())
|
||||
.map(|types| {
|
||||
types.iter().any(|t| t.as_str() == Some(content_type))
|
||||
})
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// 检查模型是否支持图片
|
||||
fn supports_images(&self) -> bool {
|
||||
self.supports_content_type("image")
|
||||
}
|
||||
|
||||
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
|
||||
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
|
||||
self.model_extra
|
||||
.iter()
|
||||
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -197,15 +283,22 @@ impl LLMProvider for AnthropicProvider {
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| AnthropicMessage {
|
||||
.enumerate()
|
||||
.map(|(i, m)| AnthropicMessage {
|
||||
role: m.role.clone(),
|
||||
content: convert_content_blocks(&m.content),
|
||||
content: convert_content_blocks(
|
||||
self.supports_images(),
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&m.content,
|
||||
i,
|
||||
),
|
||||
})
|
||||
.collect(),
|
||||
max_tokens,
|
||||
temperature: request.temperature.or(self.temperature),
|
||||
tools,
|
||||
extra: self.model_extra.clone(),
|
||||
extra: self.request_model_extra(),
|
||||
};
|
||||
|
||||
let mut req_builder = self
|
||||
|
||||
@ -10,7 +10,7 @@ use super::traits::Usage;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use crate::domain::messages::ContentBlock;
|
||||
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
|
||||
|
||||
/// 流式响应中的工具调用增量
|
||||
#[derive(Debug, Default)]
|
||||
@ -139,7 +139,75 @@ fn format_transport_error_context(
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
fn convert_content_blocks(
|
||||
supports_images: bool,
|
||||
provider_name: &str,
|
||||
model_id: &str,
|
||||
blocks: &[ContentBlock],
|
||||
message_idx: usize,
|
||||
) -> Value {
|
||||
// 检查是否有图片且模型不支持
|
||||
if !supports_images {
|
||||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||
|
||||
if has_images {
|
||||
let image_count = blocks.iter()
|
||||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||||
.count();
|
||||
|
||||
tracing::warn!(
|
||||
provider = %provider_name,
|
||||
model = %model_id,
|
||||
filtered_images = image_count,
|
||||
message_idx,
|
||||
"模型不支持图片;将图片转换为通知文本"
|
||||
);
|
||||
|
||||
// 复用通知格式,将图片转换为文本通知
|
||||
let mut converted_blocks: Vec<Value> = Vec::new();
|
||||
let mut notices: Vec<String> = Vec::new();
|
||||
let mut image_idx = 0;
|
||||
|
||||
for block in blocks.iter() {
|
||||
match block {
|
||||
ContentBlock::Text { text } => {
|
||||
converted_blocks.push(json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentBlock::ImageUrl { .. } => {
|
||||
image_idx += 1;
|
||||
notices.push(format!(
|
||||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||||
image_idx
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加通知文本块
|
||||
if !notices.is_empty() {
|
||||
let notice_text = format!(
|
||||
"[系统提示] 以下图片未能成功入模:\n{}",
|
||||
notices.join("\n")
|
||||
);
|
||||
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
|
||||
}
|
||||
|
||||
// 如果只有一个文本块且没有通知,返回字符串形式
|
||||
if converted_blocks.len() == 1 {
|
||||
if let Some(block) = converted_blocks.first() {
|
||||
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
|
||||
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
||||
return Value::String(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Value::Array(converted_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
// 原有逻辑 - 模型支持图片,正常转换
|
||||
if blocks.len() == 1 {
|
||||
if let ContentBlock::Text { text } = &blocks[0] {
|
||||
return Value::String(text.clone());
|
||||
@ -224,6 +292,23 @@ impl OpenAIProvider {
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// 检查模型是否支持指定内容类型
|
||||
/// 默认支持所有类型(text, image)
|
||||
fn supports_content_type(&self, content_type: &str) -> bool {
|
||||
self.model_extra
|
||||
.get("supported_content_types")
|
||||
.and_then(|value| value.as_array())
|
||||
.map(|types| {
|
||||
types.iter().any(|t| t.as_str() == Some(content_type))
|
||||
})
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// 检查模型是否支持图片
|
||||
fn supports_images(&self) -> bool {
|
||||
self.supports_content_type("image")
|
||||
}
|
||||
|
||||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||
match arguments {
|
||||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||
@ -480,20 +565,21 @@ impl OpenAIProvider {
|
||||
}
|
||||
|
||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||
let supports_images = self.supports_images();
|
||||
let mut body = json!({
|
||||
"model": self.model_id,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
"messages": request.messages.iter().enumerate().map(|(i, m)| {
|
||||
if m.role == "tool" {
|
||||
json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||||
"tool_call_id": m.tool_call_id,
|
||||
"name": m.name,
|
||||
})
|
||||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||
let mut message = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().map(|call| json!({
|
||||
"id": call.id,
|
||||
@ -514,7 +600,7 @@ impl OpenAIProvider {
|
||||
} else {
|
||||
let mut message = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content)
|
||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
|
||||
});
|
||||
|
||||
if m.role == "assistant" {
|
||||
@ -1076,4 +1162,115 @@ mod tests {
|
||||
assert_eq!(response.tool_calls[1].id, "call_2");
|
||||
assert_eq!(response.tool_calls[1].name, "get_time");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_images_default_true() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::new(),
|
||||
);
|
||||
|
||||
assert!(provider.supports_images());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_images_disabled_via_config() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::from([(
|
||||
"supported_content_types".to_string(),
|
||||
Value::Array(vec![Value::String("text".to_string())]),
|
||||
)]),
|
||||
);
|
||||
|
||||
assert!(!provider.supports_images());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
|
||||
let blocks = vec![
|
||||
ContentBlock::text("hello"),
|
||||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||||
ContentBlock::text("world"),
|
||||
];
|
||||
|
||||
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
|
||||
|
||||
// 应该是数组形式
|
||||
let arr = result.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
|
||||
|
||||
// 检查通知内容
|
||||
let notice_block = arr[2].as_object().unwrap();
|
||||
assert_eq!(notice_block["type"], "text");
|
||||
let notice_text = notice_block["text"].as_str().unwrap();
|
||||
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
|
||||
assert!(notice_text.contains("第 1 张图片"));
|
||||
assert!(notice_text.contains("当前模型不支持图片输入"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_content_blocks_keeps_images_when_enabled() {
|
||||
let blocks = vec![
|
||||
ContentBlock::text("hello"),
|
||||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||||
];
|
||||
|
||||
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
|
||||
|
||||
// 应该是数组形式,包含文本和图片
|
||||
let arr = result.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 2);
|
||||
assert_eq!(arr[0]["type"], "text");
|
||||
assert_eq!(arr[1]["type"], "image_url");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request_body_omits_supported_content_types_from_api() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::from([
|
||||
(
|
||||
"supported_content_types".to_string(),
|
||||
Value::Array(vec![Value::String("text".to_string())]),
|
||||
),
|
||||
("custom_param".to_string(), Value::String("value".to_string())),
|
||||
]),
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::user("hello")],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let body = provider.build_request_body(&request);
|
||||
|
||||
// supported_content_types 不应该发送到 API
|
||||
assert!(body.get("supported_content_types").is_none());
|
||||
// custom_param 应该保留
|
||||
assert_eq!(body["custom_param"], Value::String("value".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -341,7 +341,9 @@ impl SkillSource {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SkillCatalog {
|
||||
skills: Vec<Skill>,
|
||||
#[allow(dead_code)]
|
||||
max_index_chars: usize,
|
||||
#[allow(dead_code)]
|
||||
max_listed_skills: usize,
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user