diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 9023548..fd8d793 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -27,6 +27,13 @@ const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; +const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", +]; + /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { let mut blocks = Vec::new(); @@ -38,6 +45,11 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec // Add image blocks for media paths for path in media_paths { + if supported_image_mime_type(path).is_none() { + tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block"); + continue; + } + if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) { let url = format!("data:{};base64,{}", mime_type, base64_data); blocks.push(ContentBlock::image_url(url)); @@ -52,18 +64,32 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec blocks } +fn supported_image_mime_type(path: &str) -> Option { + let mime = mime_guess::from_path(path).first_or_octet_stream(); + let essence = mime.essence_str(); + + if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) { + Some(essence.to_string()) + } else { + None + } +} + /// Encode an image file to base64 data URL fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; + let mime = supported_image_mime_type(path).ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("unsupported image media type for path: {}", path), + ) + })?; + let mut file = std::fs::File::open(path)?; let mut buffer = Vec::new(); file.read_to_end(&mut buffer)?; - let mime = mime_guess::from_path(path) - .first_or_octet_stream() - .to_string(); - let encoded = STANDARD.encode(&buffer); Ok((mime, encoded)) } @@ -779,6 +805,7 @@ impl AgentLoop { mod tests { use super::*; use crate::observability::{MultiObserver, Observer}; + use tempfile::tempdir; struct TestObserver { events: std::sync::Mutex>, @@ -881,6 +908,31 @@ mod tests { assert_eq!(output.as_deref(), Some("请完成授权")); assert!(parse_pending_tool_output("normal output").is_none()); } + + #[test] + fn test_build_content_blocks_skips_non_image_media_refs() { + let temp_dir = tempdir().unwrap(); + let pdf_path = temp_dir.path().join("demo.pdf"); + std::fs::write(&pdf_path, b"%PDF-1.4").unwrap(); + + let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]); + + assert_eq!(blocks.len(), 1); + assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); + } + + #[test] + fn test_build_content_blocks_keeps_supported_images() { + let temp_dir = tempdir().unwrap(); + let jpg_path = temp_dir.path().join("demo.jpg"); + std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap(); + + let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]); + + assert_eq!(blocks.len(), 2); + assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); + assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))); + } } #[derive(Debug)] diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 237da67..d3bfb0f 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -262,6 +262,17 @@ fn preview_text(content: &str, max_chars: usize) -> String { preview.replace('\n', "\\n") } +fn enrich_user_content_with_media_refs(content: &str, media_refs: &[String]) -> Result { + if media_refs.is_empty() { + return Ok(content.to_string()); + } + + let media_refs_json = serde_json::to_string(media_refs) + .map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?; + + Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}")) +} + fn combine_managed_memory_markdown(chunks: &[String]) -> String { let normalized_chunks = chunks .iter() @@ -1237,7 +1248,8 @@ impl SessionManager { if !media_refs.is_empty() { tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); } - let user_message = session_guard.create_user_message(content, media_refs); + let enriched_content = enrich_user_content_with_media_refs(content, &media_refs)?; + let user_message = session_guard.create_user_message(&enriched_content, media_refs); let user_message_id = user_message.id.clone(); session_guard.append_persisted_message(chat_id, user_message)?; @@ -1501,6 +1513,25 @@ mod tests { assert_eq!(selected.model_id, "planner-model"); } + #[test] + fn test_enrich_user_content_with_media_refs_appends_tagged_json() { + let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()]; + + let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap(); + + assert_eq!( + enriched, + "hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]" + ); + } + + #[test] + fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() { + let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap(); + + assert_eq!(enriched, "hello"); + } + #[tokio::test] async fn test_latest_user_message_guard_tracks_current_turn() { let store = Arc::new(SessionStore::in_memory().unwrap());