feat: 更新内容处理逻辑,支持图片转换为通知文本并添加模型支持检查
This commit is contained in:
parent
a74c801945
commit
44d9171b86
@ -134,26 +134,6 @@ impl InitWizard {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn prompt_yes_no(&mut self, label: &str, default: bool) -> Result<bool, InitError> {
|
|
||||||
let default_str = if default { "yes" } else { "no" };
|
|
||||||
println!("{} [yes/no, default: {}]: ", label, default_str);
|
|
||||||
self.write.flush().await?;
|
|
||||||
|
|
||||||
let mut line = String::new();
|
|
||||||
let bytes_read = self.read.read_line(&mut line).await?;
|
|
||||||
|
|
||||||
if bytes_read == 0 {
|
|
||||||
return Err(InitError::InputError("EOF reached".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let input = line.trim().to_lowercase();
|
|
||||||
if input.is_empty() {
|
|
||||||
Ok(default)
|
|
||||||
} else {
|
|
||||||
Ok(input == "yes" || input == "y" || input == "1")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn prompt_select(
|
async fn prompt_select(
|
||||||
&mut self,
|
&mut self,
|
||||||
label: &str,
|
label: &str,
|
||||||
|
|||||||
@ -8,6 +8,8 @@ use super::traits::Usage;
|
|||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
use crate::domain::messages::ContentBlock;
|
use crate::domain::messages::ContentBlock;
|
||||||
|
|
||||||
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
|
||||||
|
|
||||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||||
let mut details = vec![error.to_string()];
|
let mut details = vec![error.to_string()];
|
||||||
let mut current = error.source();
|
let mut current = error.source();
|
||||||
@ -30,7 +32,65 @@ where
|
|||||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
fn convert_content_blocks(
|
||||||
|
supports_images: bool,
|
||||||
|
provider_name: &str,
|
||||||
|
model_id: &str,
|
||||||
|
blocks: &[ContentBlock],
|
||||||
|
message_idx: usize,
|
||||||
|
) -> Vec<serde_json::Value> {
|
||||||
|
// 检查是否有图片且模型不支持
|
||||||
|
if !supports_images {
|
||||||
|
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||||
|
|
||||||
|
if has_images {
|
||||||
|
let image_count = blocks
|
||||||
|
.iter()
|
||||||
|
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||||||
|
.count();
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
provider = %provider_name,
|
||||||
|
model = %model_id,
|
||||||
|
filtered_images = image_count,
|
||||||
|
message_idx,
|
||||||
|
"模型不支持图片;将图片转换为通知文本"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 复用通知格式,将图片转换为文本通知
|
||||||
|
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
|
||||||
|
let mut notices: Vec<String> = Vec::new();
|
||||||
|
let mut image_idx = 0;
|
||||||
|
|
||||||
|
for block in blocks.iter() {
|
||||||
|
match block {
|
||||||
|
ContentBlock::Text { text } => {
|
||||||
|
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
|
||||||
|
}
|
||||||
|
ContentBlock::ImageUrl { .. } => {
|
||||||
|
image_idx += 1;
|
||||||
|
notices.push(format!(
|
||||||
|
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||||||
|
image_idx
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加通知文本块
|
||||||
|
if !notices.is_empty() {
|
||||||
|
let notice_text = format!(
|
||||||
|
"[系统提示] 以下图片未能成功入模:\n{}",
|
||||||
|
notices.join("\n")
|
||||||
|
);
|
||||||
|
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
|
||||||
|
}
|
||||||
|
|
||||||
|
return converted_blocks;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有逻辑 - 模型支持图片,正常转换
|
||||||
blocks
|
blocks
|
||||||
.iter()
|
.iter()
|
||||||
.map(|b| match b {
|
.map(|b| match b {
|
||||||
@ -112,6 +172,32 @@ impl AnthropicProvider {
|
|||||||
model_extra,
|
model_extra,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 检查模型是否支持指定内容类型
|
||||||
|
/// 默认支持所有类型(text, image)
|
||||||
|
fn supports_content_type(&self, content_type: &str) -> bool {
|
||||||
|
self.model_extra
|
||||||
|
.get("supported_content_types")
|
||||||
|
.and_then(|value| value.as_array())
|
||||||
|
.map(|types| {
|
||||||
|
types.iter().any(|t| t.as_str() == Some(content_type))
|
||||||
|
})
|
||||||
|
.unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查模型是否支持图片
|
||||||
|
fn supports_images(&self) -> bool {
|
||||||
|
self.supports_content_type("image")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
|
||||||
|
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
|
||||||
|
self.model_extra
|
||||||
|
.iter()
|
||||||
|
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
|
||||||
|
.map(|(k, v)| (k.clone(), v.clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@ -197,15 +283,22 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
messages: request
|
messages: request
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| AnthropicMessage {
|
.enumerate()
|
||||||
|
.map(|(i, m)| AnthropicMessage {
|
||||||
role: m.role.clone(),
|
role: m.role.clone(),
|
||||||
content: convert_content_blocks(&m.content),
|
content: convert_content_blocks(
|
||||||
|
self.supports_images(),
|
||||||
|
&self.name,
|
||||||
|
&self.model_id,
|
||||||
|
&m.content,
|
||||||
|
i,
|
||||||
|
),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
max_tokens,
|
max_tokens,
|
||||||
temperature: request.temperature.or(self.temperature),
|
temperature: request.temperature.or(self.temperature),
|
||||||
tools,
|
tools,
|
||||||
extra: self.model_extra.clone(),
|
extra: self.request_model_extra(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut req_builder = self
|
let mut req_builder = self
|
||||||
|
|||||||
@ -10,7 +10,7 @@ use super::traits::Usage;
|
|||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
use crate::domain::messages::ContentBlock;
|
use crate::domain::messages::ContentBlock;
|
||||||
|
|
||||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
|
||||||
|
|
||||||
/// 流式响应中的工具调用增量
|
/// 流式响应中的工具调用增量
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
@ -139,7 +139,75 @@ fn format_transport_error_context(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
fn convert_content_blocks(
|
||||||
|
supports_images: bool,
|
||||||
|
provider_name: &str,
|
||||||
|
model_id: &str,
|
||||||
|
blocks: &[ContentBlock],
|
||||||
|
message_idx: usize,
|
||||||
|
) -> Value {
|
||||||
|
// 检查是否有图片且模型不支持
|
||||||
|
if !supports_images {
|
||||||
|
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||||
|
|
||||||
|
if has_images {
|
||||||
|
let image_count = blocks.iter()
|
||||||
|
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||||||
|
.count();
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
provider = %provider_name,
|
||||||
|
model = %model_id,
|
||||||
|
filtered_images = image_count,
|
||||||
|
message_idx,
|
||||||
|
"模型不支持图片;将图片转换为通知文本"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 复用通知格式,将图片转换为文本通知
|
||||||
|
let mut converted_blocks: Vec<Value> = Vec::new();
|
||||||
|
let mut notices: Vec<String> = Vec::new();
|
||||||
|
let mut image_idx = 0;
|
||||||
|
|
||||||
|
for block in blocks.iter() {
|
||||||
|
match block {
|
||||||
|
ContentBlock::Text { text } => {
|
||||||
|
converted_blocks.push(json!({ "type": "text", "text": text }));
|
||||||
|
}
|
||||||
|
ContentBlock::ImageUrl { .. } => {
|
||||||
|
image_idx += 1;
|
||||||
|
notices.push(format!(
|
||||||
|
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||||||
|
image_idx
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加通知文本块
|
||||||
|
if !notices.is_empty() {
|
||||||
|
let notice_text = format!(
|
||||||
|
"[系统提示] 以下图片未能成功入模:\n{}",
|
||||||
|
notices.join("\n")
|
||||||
|
);
|
||||||
|
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果只有一个文本块且没有通知,返回字符串形式
|
||||||
|
if converted_blocks.len() == 1 {
|
||||||
|
if let Some(block) = converted_blocks.first() {
|
||||||
|
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
|
||||||
|
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
||||||
|
return Value::String(text.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Value::Array(converted_blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有逻辑 - 模型支持图片,正常转换
|
||||||
if blocks.len() == 1 {
|
if blocks.len() == 1 {
|
||||||
if let ContentBlock::Text { text } = &blocks[0] {
|
if let ContentBlock::Text { text } = &blocks[0] {
|
||||||
return Value::String(text.clone());
|
return Value::String(text.clone());
|
||||||
@ -224,6 +292,23 @@ impl OpenAIProvider {
|
|||||||
.unwrap_or(true)
|
.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 检查模型是否支持指定内容类型
|
||||||
|
/// 默认支持所有类型(text, image)
|
||||||
|
fn supports_content_type(&self, content_type: &str) -> bool {
|
||||||
|
self.model_extra
|
||||||
|
.get("supported_content_types")
|
||||||
|
.and_then(|value| value.as_array())
|
||||||
|
.map(|types| {
|
||||||
|
types.iter().any(|t| t.as_str() == Some(content_type))
|
||||||
|
})
|
||||||
|
.unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查模型是否支持图片
|
||||||
|
fn supports_images(&self) -> bool {
|
||||||
|
self.supports_content_type("image")
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
match arguments {
|
match arguments {
|
||||||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||||
@ -480,20 +565,21 @@ impl OpenAIProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
|
let supports_images = self.supports_images();
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
"messages": request.messages.iter().map(|m| {
|
"messages": request.messages.iter().enumerate().map(|(i, m)| {
|
||||||
if m.role == "tool" {
|
if m.role == "tool" {
|
||||||
json!({
|
json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(&m.content),
|
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||||||
"tool_call_id": m.tool_call_id,
|
"tool_call_id": m.tool_call_id,
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
})
|
})
|
||||||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||||
let mut message = json!({
|
let mut message = json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(&m.content),
|
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||||||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||||
calls.iter().map(|call| json!({
|
calls.iter().map(|call| json!({
|
||||||
"id": call.id,
|
"id": call.id,
|
||||||
@ -514,7 +600,7 @@ impl OpenAIProvider {
|
|||||||
} else {
|
} else {
|
||||||
let mut message = json!({
|
let mut message = json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(&m.content)
|
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
|
||||||
});
|
});
|
||||||
|
|
||||||
if m.role == "assistant" {
|
if m.role == "assistant" {
|
||||||
@ -1076,4 +1162,115 @@ mod tests {
|
|||||||
assert_eq!(response.tool_calls[1].id, "call_2");
|
assert_eq!(response.tool_calls[1].id, "call_2");
|
||||||
assert_eq!(response.tool_calls[1].name, "get_time");
|
assert_eq!(response.tool_calls[1].name, "get_time");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_supports_images_default_true() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
120,
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(provider.supports_images());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_supports_images_disabled_via_config() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
120,
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::from([(
|
||||||
|
"supported_content_types".to_string(),
|
||||||
|
Value::Array(vec![Value::String("text".to_string())]),
|
||||||
|
)]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(!provider.supports_images());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
|
||||||
|
let blocks = vec![
|
||||||
|
ContentBlock::text("hello"),
|
||||||
|
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||||||
|
ContentBlock::text("world"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
|
||||||
|
|
||||||
|
// 应该是数组形式
|
||||||
|
let arr = result.as_array().unwrap();
|
||||||
|
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
|
||||||
|
|
||||||
|
// 检查通知内容
|
||||||
|
let notice_block = arr[2].as_object().unwrap();
|
||||||
|
assert_eq!(notice_block["type"], "text");
|
||||||
|
let notice_text = notice_block["text"].as_str().unwrap();
|
||||||
|
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
|
||||||
|
assert!(notice_text.contains("第 1 张图片"));
|
||||||
|
assert!(notice_text.contains("当前模型不支持图片输入"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_content_blocks_keeps_images_when_enabled() {
|
||||||
|
let blocks = vec![
|
||||||
|
ContentBlock::text("hello"),
|
||||||
|
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
|
||||||
|
|
||||||
|
// 应该是数组形式,包含文本和图片
|
||||||
|
let arr = result.as_array().unwrap();
|
||||||
|
assert_eq!(arr.len(), 2);
|
||||||
|
assert_eq!(arr[0]["type"], "text");
|
||||||
|
assert_eq!(arr[1]["type"], "image_url");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_body_omits_supported_content_types_from_api() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
120,
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::from([
|
||||||
|
(
|
||||||
|
"supported_content_types".to_string(),
|
||||||
|
Value::Array(vec![Value::String("text".to_string())]),
|
||||||
|
),
|
||||||
|
("custom_param".to_string(), Value::String("value".to_string())),
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message::user("hello")],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = provider.build_request_body(&request);
|
||||||
|
|
||||||
|
// supported_content_types 不应该发送到 API
|
||||||
|
assert!(body.get("supported_content_types").is_none());
|
||||||
|
// custom_param 应该保留
|
||||||
|
assert_eq!(body["custom_param"], Value::String("value".to_string()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -341,7 +341,9 @@ impl SkillSource {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SkillCatalog {
|
pub struct SkillCatalog {
|
||||||
skills: Vec<Skill>,
|
skills: Vec<Skill>,
|
||||||
|
#[allow(dead_code)]
|
||||||
max_index_chars: usize,
|
max_index_chars: usize,
|
||||||
|
#[allow(dead_code)]
|
||||||
max_listed_skills: usize,
|
max_listed_skills: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user