feat(media): 添加媒体引用处理,增强用户内容的丰富性
This commit is contained in:
parent
7fefd40dca
commit
ab7a8ad924
@ -27,6 +27,13 @@ const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||
|
||||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
];
|
||||
|
||||
/// Build content blocks from text and media paths
|
||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||
let mut blocks = Vec::new();
|
||||
@ -38,6 +45,11 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock>
|
||||
|
||||
// Add image blocks for media paths
|
||||
for path in media_paths {
|
||||
if supported_image_mime_type(path).is_none() {
|
||||
tracing::debug!(media_path = %path, "Skipping non-image media ref for LLM image block");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
||||
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
blocks.push(ContentBlock::image_url(url));
|
||||
@ -52,18 +64,32 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock>
|
||||
blocks
|
||||
}
|
||||
|
||||
fn supported_image_mime_type(path: &str) -> Option<String> {
|
||||
let mime = mime_guess::from_path(path).first_or_octet_stream();
|
||||
let essence = mime.essence_str();
|
||||
|
||||
if SUPPORTED_IMAGE_MIME_TYPES.contains(&essence) {
|
||||
Some(essence.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode an image file to base64 data URL
|
||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
|
||||
let mime = supported_image_mime_type(path).ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("unsupported image media type for path: {}", path),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer)?;
|
||||
|
||||
let mime = mime_guess::from_path(path)
|
||||
.first_or_octet_stream()
|
||||
.to_string();
|
||||
|
||||
let encoded = STANDARD.encode(&buffer);
|
||||
Ok((mime, encoded))
|
||||
}
|
||||
@ -779,6 +805,7 @@ impl AgentLoop {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::observability::{MultiObserver, Observer};
|
||||
use tempfile::tempdir;
|
||||
|
||||
struct TestObserver {
|
||||
events: std::sync::Mutex<Vec<ObserverEvent>>,
|
||||
@ -881,6 +908,31 @@ mod tests {
|
||||
assert_eq!(output.as_deref(), Some("请完成授权"));
|
||||
assert!(parse_pending_tool_output("normal output").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_content_blocks_skips_non_image_media_refs() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let pdf_path = temp_dir.path().join("demo.pdf");
|
||||
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
|
||||
|
||||
let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]);
|
||||
|
||||
assert_eq!(blocks.len(), 1);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_content_blocks_keeps_supported_images() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let jpg_path = temp_dir.path().join("demo.jpg");
|
||||
std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap();
|
||||
|
||||
let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]);
|
||||
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@ -262,6 +262,17 @@ fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
|
||||
fn enrich_user_content_with_media_refs(content: &str, media_refs: &[String]) -> Result<String, AgentError> {
|
||||
if media_refs.is_empty() {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
|
||||
let media_refs_json = serde_json::to_string(media_refs)
|
||||
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
|
||||
|
||||
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
||||
}
|
||||
|
||||
fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
||||
let normalized_chunks = chunks
|
||||
.iter()
|
||||
@ -1237,7 +1248,8 @@ impl SessionManager {
|
||||
if !media_refs.is_empty() {
|
||||
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||||
}
|
||||
let user_message = session_guard.create_user_message(content, media_refs);
|
||||
let enriched_content = enrich_user_content_with_media_refs(content, &media_refs)?;
|
||||
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
|
||||
let user_message_id = user_message.id.clone();
|
||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||
|
||||
@ -1501,6 +1513,25 @@ mod tests {
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
enriched,
|
||||
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
|
||||
|
||||
assert_eq!(enriched, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user