feat: 添加图片上下文限制配置,支持最大图片数量和消息轮次限制

This commit is contained in:
ooodc 2026-05-24 18:06:22 +08:00
parent 4605c2dad3
commit b571d7b7b3
9 changed files with 355 additions and 10 deletions

View File

@ -187,6 +187,133 @@ fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
.count()
}
/// 过滤超出轮次限制的图片,并限制总图片数量
///
/// # 参数
/// - `messages`: 原始消息列表
/// - `max_age_rounds`: 图片超过多少消息轮次后不再发送(从最新消息向前数,包括所有 role
/// - `max_images`: 最多发送多少张图片(优先保留最近的图片)
///
/// # 返回
/// 过滤后的消息列表,图片被转换为文本提示 "[图片已过期]"
fn filter_images_by_age_and_count(
messages: &[ChatMessage],
max_age_rounds: usize,
max_images: usize,
) -> Vec<ChatMessage> {
if messages.is_empty() || (max_age_rounds == 0 && max_images == 0) {
return messages.to_vec();
}
let total_images = count_supported_image_media_refs(messages);
if total_images == 0 {
return messages.to_vec();
}
// 从最新消息向前遍历,优先保留最新的图片
// 消息列表顺序:[old, ..., new],所以末尾是最新的
let msg_count = messages.len();
// 先从后向前遍历,计算每条消息应该保留多少张图片
// 使用 Vec<usize> 存储每条消息应该保留的图片数量
let mut images_to_keep_per_msg: Vec<usize> = vec![0; msg_count];
let mut images_kept = 0usize;
// 从后向前遍历(从最新到最旧)
for (idx, message) in messages.iter().enumerate().rev() {
// 计算距离最新消息的消息数(末尾消息的 age = 0
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
// 检查是否超出轮次限制
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
if exceeds_age_limit {
continue; // 超出轮次限制的图片不保留
}
// 计算这条消息中的图片数量
let image_count_in_msg = message.media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
if image_count_in_msg == 0 {
continue;
}
// 计算可以保留多少张
let can_keep = std::cmp::min(image_count_in_msg, max_images.saturating_sub(images_kept));
if can_keep > 0 {
images_to_keep_per_msg[idx] = can_keep;
images_kept += can_keep;
}
}
// 然后从前向后遍历,构建过滤后的消息列表
let mut filtered = Vec::with_capacity(msg_count);
for (idx, message) in messages.iter().enumerate() {
// 计算距离最新消息的消息数(末尾消息的 age = 0
let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1);
// 检查是否超出轮次限制
let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds;
let keep_count = images_to_keep_per_msg[idx];
// 过滤图片:保留非图片媒体和指定数量的图片
let mut images_kept_in_msg = 0usize;
let filtered_media_refs: Vec<String> = message.media_refs.iter()
.filter_map(|path| {
if supported_image_mime_type(path).is_some() {
if images_kept_in_msg < keep_count {
images_kept_in_msg += 1;
Some(path.clone())
} else {
None
}
} else {
Some(path.clone()) // 保留非图片媒体
}
})
.collect();
// 如果图片被过滤,添加文本提示
let original_image_count = message.media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
let filtered_image_count = filtered_media_refs.iter()
.filter(|p| supported_image_mime_type(p).is_some())
.count();
let content = if original_image_count > filtered_image_count {
let notice = if exceeds_age_limit {
format!("{} [图片已过期:超出 {} 条消息范围]", message.content, max_age_rounds)
} else {
format!("{} [图片已过期:超出最大图片数量限制]", message.content)
};
notice
} else {
message.content.clone()
};
filtered.push(ChatMessage {
id: message.id.clone(),
role: message.role.clone(),
content,
media_refs: filtered_media_refs,
timestamp: message.timestamp,
system_context: message.system_context.clone(),
reasoning_content: message.reasoning_content.clone(),
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_state: message.tool_state.clone(),
tool_calls: message.tool_calls.clone(),
});
}
filtered
}
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
@ -699,7 +826,13 @@ impl AgentLoop {
Some(tool_defs)
};
let image_count = count_supported_image_media_refs(&messages);
// 过滤超出轮次和数量限制的图片
let filtered_messages = filter_images_by_age_and_count(
&messages,
self.runtime_config.max_image_age_rounds,
self.runtime_config.max_images_in_context,
);
let image_count = count_supported_image_media_refs(&filtered_messages);
// 构建系统提示词(统一注入 Agent 和 Skill 提示词)
let system_prompt = system_prompt_context.and_then(|ctx| {
@ -708,11 +841,11 @@ impl AgentLoop {
.and_then(|provider| provider.build(ctx))
});
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 2);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
if let Some(ref prompt) = system_prompt {
text_only_messages.push(Message::system(prompt.content.clone()));
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens = image_token_budget_for_request(
&self.runtime_config,
@ -720,13 +853,13 @@ impl AgentLoop {
tools.as_ref(),
);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 2);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 2);
// 使用相同的系统提示词(已构建)
if let Some(ref prompt) = system_prompt {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
messages_for_llm.extend(
messages
filtered_messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
@ -912,9 +1045,16 @@ impl AgentLoop {
);
messages.push(summary_request);
// 过滤超出轮次和数量限制的图片
let filtered_messages = filter_images_by_age_and_count(
&messages,
self.runtime_config.max_image_age_rounds,
self.runtime_config.max_images_in_context,
);
// Convert messages to LLM format (使用系统提示词提供者)
let image_count = count_supported_image_media_refs(&messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 1);
let image_count = count_supported_image_media_refs(&filtered_messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
@ -922,11 +1062,11 @@ impl AgentLoop {
}
}
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens =
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(filtered_messages.len() + 1);
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
@ -935,7 +1075,7 @@ impl AgentLoop {
}
}
messages_for_llm.extend(
messages
filtered_messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
@ -1208,6 +1348,8 @@ mod tests {
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
@ -1426,6 +1568,137 @@ mod tests {
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
));
}
#[test]
fn test_filter_images_by_age_and_count_no_images() {
let messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant("hi"),
ChatMessage::user("how are you"),
];
let filtered = filter_images_by_age_and_count(&messages, 10, 5);
assert_eq!(filtered.len(), 3);
assert_eq!(filtered[0].content, "hello");
assert_eq!(filtered[1].content, "hi");
assert_eq!(filtered[2].content, "how are you");
}
#[test]
fn test_filter_images_by_age_limit() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("test.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
// 创建 12 条消息,第 0 条(最旧)有图片
let messages: Vec<ChatMessage> = (0..12)
.map(|i| {
if i == 0 {
ChatMessage::user_with_media(
"first message with image",
vec![jpg_path.to_string_lossy().to_string()],
)
} else if i % 2 == 0 {
ChatMessage::user(format!("user message {}", i))
} else {
ChatMessage::assistant(format!("assistant message {}", i))
}
})
.collect();
// max_age_rounds = 10第一条消息索引 0距离最新消息索引 11有 11 条消息
// 所以第一条消息的图片应该被过滤
let filtered = filter_images_by_age_and_count(&messages, 10, 100);
// 检查第一条消息的图片被过滤
assert_eq!(filtered[0].media_refs.len(), 0);
assert!(filtered[0].content.contains("图片已过期"));
}
#[test]
fn test_filter_images_by_count_limit() {
let temp_dir = tempdir().unwrap();
// 创建 3 张不同的图片
let jpg_paths: Vec<String> = (0..3)
.map(|i| {
let path = temp_dir.path().join(format!("test{}.jpg", i));
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&path).unwrap();
path.to_string_lossy().to_string()
})
.collect();
// 创建 3 条消息,每条都有图片
let messages: Vec<ChatMessage> = (0..3)
.map(|i| {
ChatMessage::user_with_media(
format!("message {}", i),
vec![jpg_paths[i].clone()],
)
})
.collect();
// max_images = 1只保留最新的图片索引 2
let filtered = filter_images_by_age_and_count(&messages, 100, 1);
// 检查最新的消息保留图片,旧消息的图片被过滤
assert_eq!(filtered[2].media_refs.len(), 1); // 最新保留
assert_eq!(filtered[1].media_refs.len(), 0); // 被过滤
assert!(filtered[1].content.contains("超出最大图片数量限制"));
assert_eq!(filtered[0].media_refs.len(), 0); // 被过滤
assert!(filtered[0].content.contains("超出最大图片数量限制"));
}
#[test]
fn test_filter_images_combined_limits() {
let temp_dir = tempdir().unwrap();
// 创建 5 张图片
let jpg_paths: Vec<String> = (0..5)
.map(|i| {
let path = temp_dir.path().join(format!("test{}.jpg", i));
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&path).unwrap();
path.to_string_lossy().to_string()
})
.collect();
// 创建 20 条消息,在特定位置添加图片
let messages: Vec<ChatMessage> = (0..20)
.map(|i| {
// 在索引 0, 5, 10, 15, 19 添加图片
if i == 0 || i == 5 || i == 10 || i == 15 || i == 19 {
let image_idx = i / 5;
ChatMessage::user_with_media(
format!("message {} with image", i),
vec![jpg_paths[image_idx.clamp(0, 4)].clone()],
)
} else if i % 2 == 0 {
ChatMessage::user(format!("user message {}", i))
} else {
ChatMessage::assistant(format!("assistant message {}", i))
}
})
.collect();
// max_age_rounds = 10保留最近 10 条消息内的图片)
// max_images = 3最多 3 张图片)
// 最新消息索引 19 (age=0),索引 15 (age=4),索引 10 (age=9),索引 5 (age=14),索引 0 (age=19)
// 索引 5 和 0 超出 age 限制
// 索引 19, 15, 10 的图片应该保留(共 3 张,不超过 max_images
let filtered = filter_images_by_age_and_count(&messages, 10, 3);
// 检查结果
assert!(filtered[19].media_refs.len() > 0, "最新消息应保留图片");
assert!(filtered[15].media_refs.len() > 0, "age=4 的消息应保留图片");
assert!(filtered[10].media_refs.len() > 0, "age=9 的消息应保留图片");
assert_eq!(filtered[5].media_refs.len(), 0, "age=14 的消息图片应被过滤");
assert!(filtered[5].content.contains("超出 10 条消息范围"));
assert_eq!(filtered[0].media_refs.len(), 0, "age=19 的消息图片应被过滤");
assert!(filtered[0].content.contains("超出 10 条消息范围"));
}
}
#[derive(Debug)]

View File

@ -9,6 +9,9 @@ pub struct AgentRuntimeConfig {
pub max_tool_iterations: usize,
pub tool_result_max_chars: usize,
pub context_tool_result_trim_chars: usize,
/// 图片上下文限制配置
pub max_images_in_context: usize,
pub max_image_age_rounds: usize,
}
impl From<LLMProviderConfig> for AgentRuntimeConfig {
@ -34,6 +37,8 @@ impl From<LLMProviderConfig> for AgentRuntimeConfig {
max_tool_iterations: config.max_tool_iterations,
tool_result_max_chars: config.tool_result_max_chars,
context_tool_result_trim_chars: config.context_tool_result_trim_chars,
max_images_in_context: config.max_images_in_context,
max_image_age_rounds: config.max_image_age_rounds,
}
}
}

View File

@ -77,6 +77,7 @@ impl InitWizard {
tools: crate::config::ToolsConfig::default(),
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
mcp_servers: HashMap::new(),
image_context: crate::config::ImageContextConfig::default(),
}
}
@ -828,6 +829,7 @@ impl InitWizard {
tools: existing.tools.clone(),
memory_maintenance: existing.memory_maintenance.clone(),
mcp_servers: existing.mcp_servers.clone(),
image_context: existing.image_context.clone(),
}
}

View File

@ -31,6 +31,38 @@ pub struct Config {
pub memory_maintenance: MemoryMaintenanceConfig,
#[serde(default, rename = "mcpServers")]
pub mcp_servers: HashMap<String, crate::mcp::McpServerConfig>,
#[serde(default)]
pub image_context: ImageContextConfig,
}
/// 图片上下文限制配置
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageContextConfig {
/// topic 上下文历史中最多发送给模型的图片数量 (默认 1)
#[serde(default = "default_max_images_in_context")]
pub max_images_in_context: usize,
/// 图片超过多少消息轮次后就不再提交给模型 (默认 10)
/// "轮次"定义为:消息在历史中的位置(距离最新消息的消息数)
/// 包括所有 role 类型user、assistant、tool、system 等)
#[serde(default = "default_max_image_age_rounds")]
pub max_image_age_rounds: usize,
}
fn default_max_images_in_context() -> usize {
1
}
fn default_max_image_age_rounds() -> usize {
10
}
impl Default for ImageContextConfig {
fn default() -> Self {
Self {
max_images_in_context: default_max_images_in_context(),
max_image_age_rounds: default_max_image_age_rounds(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -758,6 +790,9 @@ pub struct LLMProviderConfig {
pub max_tool_iterations: usize,
pub tool_result_max_chars: usize,
pub context_tool_result_trim_chars: usize,
/// 图片上下文限制配置
pub max_images_in_context: usize,
pub max_image_age_rounds: usize,
}
impl LLMProviderConfig {
@ -857,6 +892,8 @@ impl Config {
max_tool_iterations: agent.max_tool_iterations,
tool_result_max_chars: agent.tool_result_max_chars,
context_tool_result_trim_chars: agent.context_tool_result_trim_chars,
max_images_in_context: self.image_context.max_images_in_context,
max_image_age_rounds: self.image_context.max_image_age_rounds,
})
}
}

View File

@ -123,6 +123,8 @@ mod tests {
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}

View File

@ -67,6 +67,8 @@ mod tests {
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}

View File

@ -693,6 +693,8 @@ mod tests {
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
@ -953,6 +955,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -998,6 +1002,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let planner_provider = LLMProviderConfig {
model_id: "planner-model".to_string(),
@ -1073,6 +1079,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -1158,6 +1166,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -1245,6 +1255,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -1331,6 +1343,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -1399,6 +1413,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -1476,6 +1492,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
@ -1540,6 +1558,8 @@ mod tests {
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(

View File

@ -44,6 +44,8 @@ fn load_config() -> Option<LLMProviderConfig> {
max_tool_iterations: 20,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
})
}

View File

@ -46,6 +46,8 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
max_tool_iterations: 20,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
})
}