diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index ab82f87..b7f0da1 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -187,6 +187,133 @@ fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize { .count() } +/// 过滤超出轮次限制的图片,并限制总图片数量 +/// +/// # 参数 +/// - `messages`: 原始消息列表 +/// - `max_age_rounds`: 图片超过多少消息轮次后不再发送(从最新消息向前数,包括所有 role) +/// - `max_images`: 最多发送多少张图片(优先保留最近的图片) +/// +/// # 返回 +/// 过滤后的消息列表,图片被转换为文本提示 "[图片已过期]" +fn filter_images_by_age_and_count( + messages: &[ChatMessage], + max_age_rounds: usize, + max_images: usize, +) -> Vec { + if messages.is_empty() || (max_age_rounds == 0 && max_images == 0) { + return messages.to_vec(); + } + + let total_images = count_supported_image_media_refs(messages); + if total_images == 0 { + return messages.to_vec(); + } + + // 从最新消息向前遍历,优先保留最新的图片 + // 消息列表顺序:[old, ..., new],所以末尾是最新的 + let msg_count = messages.len(); + + // 先从后向前遍历,计算每条消息应该保留多少张图片 + // 使用 Vec 存储每条消息应该保留的图片数量 + let mut images_to_keep_per_msg: Vec = vec![0; msg_count]; + let mut images_kept = 0usize; + + // 从后向前遍历(从最新到最旧) + for (idx, message) in messages.iter().enumerate().rev() { + // 计算距离最新消息的消息数(末尾消息的 age = 0) + let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1); + + // 检查是否超出轮次限制 + let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds; + + if exceeds_age_limit { + continue; // 超出轮次限制的图片不保留 + } + + // 计算这条消息中的图片数量 + let image_count_in_msg = message.media_refs.iter() + .filter(|p| supported_image_mime_type(p).is_some()) + .count(); + + if image_count_in_msg == 0 { + continue; + } + + // 计算可以保留多少张 + let can_keep = std::cmp::min(image_count_in_msg, max_images.saturating_sub(images_kept)); + if can_keep > 0 { + images_to_keep_per_msg[idx] = can_keep; + images_kept += can_keep; + } + } + + // 然后从前向后遍历,构建过滤后的消息列表 + let mut filtered = Vec::with_capacity(msg_count); + + for (idx, message) in messages.iter().enumerate() { + // 计算距离最新消息的消息数(末尾消息的 age = 0) + let age_from_end = msg_count.saturating_sub(idx).saturating_sub(1); + + // 检查是否超出轮次限制 + let exceeds_age_limit = max_age_rounds > 0 && age_from_end >= max_age_rounds; + + let keep_count = images_to_keep_per_msg[idx]; + + // 过滤图片:保留非图片媒体和指定数量的图片 + let mut images_kept_in_msg = 0usize; + let filtered_media_refs: Vec = message.media_refs.iter() + .filter_map(|path| { + if supported_image_mime_type(path).is_some() { + if images_kept_in_msg < keep_count { + images_kept_in_msg += 1; + Some(path.clone()) + } else { + None + } + } else { + Some(path.clone()) // 保留非图片媒体 + } + }) + .collect(); + + // 如果图片被过滤,添加文本提示 + let original_image_count = message.media_refs.iter() + .filter(|p| supported_image_mime_type(p).is_some()) + .count(); + let filtered_image_count = filtered_media_refs.iter() + .filter(|p| supported_image_mime_type(p).is_some()) + .count(); + + let content = if original_image_count > filtered_image_count { + let notice = if exceeds_age_limit { + format!("{} [图片已过期:超出 {} 条消息范围]", message.content, max_age_rounds) + } else { + format!("{} [图片已过期:超出最大图片数量限制]", message.content) + }; + notice + } else { + message.content.clone() + }; + + filtered.push(ChatMessage { + id: message.id.clone(), + role: message.role.clone(), + content, + media_refs: filtered_media_refs, + timestamp: message.timestamp, + system_context: message.system_context.clone(), + reasoning_content: message.reasoning_content.clone(), + tool_call_id: message.tool_call_id.clone(), + tool_name: message.tool_name.clone(), + tool_state: message.tool_state.clone(), + tool_calls: message.tool_calls.clone(), + }); + } + + filtered +} + fn target_image_bytes_for_tokens(target_tokens: usize) -> usize { target_tokens .saturating_sub(DATA_URL_OVERHEAD_TOKENS) @@ -699,7 +826,13 @@ impl AgentLoop { Some(tool_defs) }; - let image_count = count_supported_image_media_refs(&messages); + // 过滤超出轮次和数量限制的图片 + let filtered_messages = filter_images_by_age_and_count( + &messages, + self.runtime_config.max_image_age_rounds, + self.runtime_config.max_images_in_context, + ); + let image_count = count_supported_image_media_refs(&filtered_messages); // 构建系统提示词(统一注入 Agent 和 Skill 提示词) let system_prompt = system_prompt_context.and_then(|ctx| { @@ -708,11 +841,11 @@ impl AgentLoop { .and_then(|provider| provider.build(ctx)) }); - let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 2); + let mut text_only_messages: Vec = Vec::with_capacity(filtered_messages.len() + 2); if let Some(ref prompt) = system_prompt { text_only_messages.push(Message::system(prompt.content.clone())); } - text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); + text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request( &self.runtime_config, @@ -720,13 +853,13 @@ impl AgentLoop { tools.as_ref(), ); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); - let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 2); + let mut messages_for_llm: Vec = Vec::with_capacity(filtered_messages.len() + 2); // 使用相同的系统提示词(已构建) if let Some(ref prompt) = system_prompt { messages_for_llm.push(Message::system(prompt.content.clone())); } messages_for_llm.extend( - messages + filtered_messages .iter() .map(|message| chat_message_to_llm_message(message, &mut image_budget)), ); @@ -912,9 +1045,16 @@ impl AgentLoop { ); messages.push(summary_request); + // 过滤超出轮次和数量限制的图片 + let filtered_messages = filter_images_by_age_and_count( + &messages, + self.runtime_config.max_image_age_rounds, + self.runtime_config.max_images_in_context, + ); + // Convert messages to LLM format (使用系统提示词提供者) - let image_count = count_supported_image_media_refs(&messages); - let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 1); + let image_count = count_supported_image_media_refs(&filtered_messages); + let mut text_only_messages: Vec = Vec::with_capacity(filtered_messages.len() + 1); if let Some(ref provider) = self.system_prompt_provider { if let Some(ctx) = system_prompt_context { if let Some(prompt) = provider.build(ctx) { @@ -922,11 +1062,11 @@ impl AgentLoop { } } } - text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); + text_only_messages.extend(filtered_messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request(&self.runtime_config, &text_only_messages, None); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); - let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); + let mut messages_for_llm: Vec = Vec::with_capacity(filtered_messages.len() + 1); if let Some(ref provider) = self.system_prompt_provider { if let Some(ctx) = system_prompt_context { if let Some(prompt) = provider.build(ctx) { @@ -935,7 +1075,7 @@ impl AgentLoop { } } messages_for_llm.extend( - messages + filtered_messages .iter() .map(|message| chat_message_to_llm_message(message, &mut image_budget)), ); @@ -1208,6 +1348,8 @@ mod tests { max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, } } @@ -1426,6 +1568,137 @@ mod tests { if text.contains("图片未能成功入模") && text.contains("demo.jpg") )); } + + #[test] + fn test_filter_images_by_age_and_count_no_images() { + let messages = vec![ + ChatMessage::user("hello"), + ChatMessage::assistant("hi"), + ChatMessage::user("how are you"), + ]; + + let filtered = filter_images_by_age_and_count(&messages, 10, 5); + assert_eq!(filtered.len(), 3); + assert_eq!(filtered[0].content, "hello"); + assert_eq!(filtered[1].content, "hi"); + assert_eq!(filtered[2].content, "how are you"); + } + + #[test] + fn test_filter_images_by_age_limit() { + let temp_dir = tempdir().unwrap(); + let jpg_path = temp_dir.path().join("test.jpg"); + let image = image::DynamicImage::new_rgb8(8, 8); + image.save(&jpg_path).unwrap(); + + // 创建 12 条消息,第 0 条(最旧)有图片 + let messages: Vec = (0..12) + .map(|i| { + if i == 0 { + ChatMessage::user_with_media( + "first message with image", + vec![jpg_path.to_string_lossy().to_string()], + ) + } else if i % 2 == 0 { + ChatMessage::user(format!("user message {}", i)) + } else { + ChatMessage::assistant(format!("assistant message {}", i)) + } + }) + .collect(); + + // max_age_rounds = 10,第一条消息(索引 0)距离最新消息(索引 11)有 11 条消息 + // 所以第一条消息的图片应该被过滤 + let filtered = filter_images_by_age_and_count(&messages, 10, 100); + + // 检查第一条消息的图片被过滤 + assert_eq!(filtered[0].media_refs.len(), 0); + assert!(filtered[0].content.contains("图片已过期")); + } + + #[test] + fn test_filter_images_by_count_limit() { + let temp_dir = tempdir().unwrap(); + + // 创建 3 张不同的图片 + let jpg_paths: Vec = (0..3) + .map(|i| { + let path = temp_dir.path().join(format!("test{}.jpg", i)); + let image = image::DynamicImage::new_rgb8(8, 8); + image.save(&path).unwrap(); + path.to_string_lossy().to_string() + }) + .collect(); + + // 创建 3 条消息,每条都有图片 + let messages: Vec = (0..3) + .map(|i| { + ChatMessage::user_with_media( + format!("message {}", i), + vec![jpg_paths[i].clone()], + ) + }) + .collect(); + + // max_images = 1,只保留最新的图片(索引 2) + let filtered = filter_images_by_age_and_count(&messages, 100, 1); + + // 检查最新的消息保留图片,旧消息的图片被过滤 + assert_eq!(filtered[2].media_refs.len(), 1); // 最新保留 + assert_eq!(filtered[1].media_refs.len(), 0); // 被过滤 + assert!(filtered[1].content.contains("超出最大图片数量限制")); + assert_eq!(filtered[0].media_refs.len(), 0); // 被过滤 + assert!(filtered[0].content.contains("超出最大图片数量限制")); + } + + #[test] + fn test_filter_images_combined_limits() { + let temp_dir = tempdir().unwrap(); + + // 创建 5 张图片 + let jpg_paths: Vec = (0..5) + .map(|i| { + let path = temp_dir.path().join(format!("test{}.jpg", i)); + let image = image::DynamicImage::new_rgb8(8, 8); + image.save(&path).unwrap(); + path.to_string_lossy().to_string() + }) + .collect(); + + // 创建 20 条消息,在特定位置添加图片 + let messages: Vec = (0..20) + .map(|i| { + // 在索引 0, 5, 10, 15, 19 添加图片 + if i == 0 || i == 5 || i == 10 || i == 15 || i == 19 { + let image_idx = i / 5; + ChatMessage::user_with_media( + format!("message {} with image", i), + vec![jpg_paths[image_idx.clamp(0, 4)].clone()], + ) + } else if i % 2 == 0 { + ChatMessage::user(format!("user message {}", i)) + } else { + ChatMessage::assistant(format!("assistant message {}", i)) + } + }) + .collect(); + + // max_age_rounds = 10(保留最近 10 条消息内的图片) + // max_images = 3(最多 3 张图片) + // 最新消息索引 19 (age=0),索引 15 (age=4),索引 10 (age=9),索引 5 (age=14),索引 0 (age=19) + // 索引 5 和 0 超出 age 限制 + // 索引 19, 15, 10 的图片应该保留(共 3 张,不超过 max_images) + let filtered = filter_images_by_age_and_count(&messages, 10, 3); + + // 检查结果 + assert!(filtered[19].media_refs.len() > 0, "最新消息应保留图片"); + assert!(filtered[15].media_refs.len() > 0, "age=4 的消息应保留图片"); + assert!(filtered[10].media_refs.len() > 0, "age=9 的消息应保留图片"); + assert_eq!(filtered[5].media_refs.len(), 0, "age=14 的消息图片应被过滤"); + assert!(filtered[5].content.contains("超出 10 条消息范围")); + assert_eq!(filtered[0].media_refs.len(), 0, "age=19 的消息图片应被过滤"); + assert!(filtered[0].content.contains("超出 10 条消息范围")); + } } #[derive(Debug)] diff --git a/src/agent/runtime_config.rs b/src/agent/runtime_config.rs index e38d1c2..836c72f 100644 --- a/src/agent/runtime_config.rs +++ b/src/agent/runtime_config.rs @@ -9,6 +9,9 @@ pub struct AgentRuntimeConfig { pub max_tool_iterations: usize, pub tool_result_max_chars: usize, pub context_tool_result_trim_chars: usize, + /// 图片上下文限制配置 + pub max_images_in_context: usize, + pub max_image_age_rounds: usize, } impl From for AgentRuntimeConfig { @@ -34,6 +37,8 @@ impl From for AgentRuntimeConfig { max_tool_iterations: config.max_tool_iterations, tool_result_max_chars: config.tool_result_max_chars, context_tool_result_trim_chars: config.context_tool_result_trim_chars, + max_images_in_context: config.max_images_in_context, + max_image_age_rounds: config.max_image_age_rounds, } } } diff --git a/src/cli/init.rs b/src/cli/init.rs index d3b190d..7f5b61d 100644 --- a/src/cli/init.rs +++ b/src/cli/init.rs @@ -77,6 +77,7 @@ impl InitWizard { tools: crate::config::ToolsConfig::default(), memory_maintenance: crate::config::MemoryMaintenanceConfig::default(), mcp_servers: HashMap::new(), + image_context: crate::config::ImageContextConfig::default(), } } @@ -828,6 +829,7 @@ impl InitWizard { tools: existing.tools.clone(), memory_maintenance: existing.memory_maintenance.clone(), mcp_servers: existing.mcp_servers.clone(), + image_context: existing.image_context.clone(), } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 77e49b1..fbd2605 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -31,6 +31,38 @@ pub struct Config { pub memory_maintenance: MemoryMaintenanceConfig, #[serde(default, rename = "mcpServers")] pub mcp_servers: HashMap, + #[serde(default)] + pub image_context: ImageContextConfig, +} + +/// 图片上下文限制配置 +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageContextConfig { + /// topic 上下文历史中最多发送给模型的图片数量 (默认 1) + #[serde(default = "default_max_images_in_context")] + pub max_images_in_context: usize, + /// 图片超过多少消息轮次后就不再提交给模型 (默认 10) + /// "轮次"定义为:消息在历史中的位置(距离最新消息的消息数) + /// 包括所有 role 类型(user、assistant、tool、system 等) + #[serde(default = "default_max_image_age_rounds")] + pub max_image_age_rounds: usize, +} + +fn default_max_images_in_context() -> usize { + 1 +} + +fn default_max_image_age_rounds() -> usize { + 10 +} + +impl Default for ImageContextConfig { + fn default() -> Self { + Self { + max_images_in_context: default_max_images_in_context(), + max_image_age_rounds: default_max_image_age_rounds(), + } + } } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -758,6 +790,9 @@ pub struct LLMProviderConfig { pub max_tool_iterations: usize, pub tool_result_max_chars: usize, pub context_tool_result_trim_chars: usize, + /// 图片上下文限制配置 + pub max_images_in_context: usize, + pub max_image_age_rounds: usize, } impl LLMProviderConfig { @@ -857,6 +892,8 @@ impl Config { max_tool_iterations: agent.max_tool_iterations, tool_result_max_chars: agent.tool_result_max_chars, context_tool_result_trim_chars: agent.context_tool_result_trim_chars, + max_images_in_context: self.image_context.max_images_in_context, + max_image_age_rounds: self.image_context.max_image_age_rounds, }) } } diff --git a/src/gateway/agent_prompt_provider.rs b/src/gateway/agent_prompt_provider.rs index d8726f0..5364db5 100644 --- a/src/gateway/agent_prompt_provider.rs +++ b/src/gateway/agent_prompt_provider.rs @@ -123,6 +123,8 @@ mod tests { max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, } } diff --git a/src/gateway/provider_config_service.rs b/src/gateway/provider_config_service.rs index 752de7f..0bee5a7 100644 --- a/src/gateway/provider_config_service.rs +++ b/src/gateway/provider_config_service.rs @@ -67,6 +67,8 @@ mod tests { max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index b82f1ac..5a4b572 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -693,6 +693,8 @@ mod tests { max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, } } @@ -953,6 +955,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -998,6 +1002,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let planner_provider = LLMProviderConfig { model_id: "planner-model".to_string(), @@ -1073,6 +1079,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -1158,6 +1166,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -1245,6 +1255,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -1331,6 +1343,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -1399,6 +1413,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -1476,6 +1492,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( @@ -1540,6 +1558,8 @@ mod tests { memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }; let session_manager = SessionManager::new( diff --git a/tests/test_integration.rs b/tests/test_integration.rs index b428dba..84c6bd2 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -44,6 +44,8 @@ fn load_config() -> Option { max_tool_iterations: 20, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }) } diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 653d1c6..071edfa 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -46,6 +46,8 @@ fn load_openai_config() -> Option { max_tool_iterations: 20, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, + max_images_in_context: 1, + max_image_age_rounds: 10, }) }