feat: 添加图像处理预算和估算逻辑,优化消息内容构建,支持图像媒体引用
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
fc5b2a359f
commit
531e72d24f
@ -30,6 +30,7 @@ cron = { version = "0.13", features = ["serde"] }
|
|||||||
iana-time-zone = "0.1"
|
iana-time-zone = "0.1"
|
||||||
mime_guess = "2.0"
|
mime_guess = "2.0"
|
||||||
base64 = "0.22"
|
base64 = "0.22"
|
||||||
|
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
meval = "0.2"
|
meval = "0.2"
|
||||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||||
|
|||||||
@ -24,10 +24,31 @@ const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
|
|||||||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||||
|
|
||||||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
|
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||||
|
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
||||||
|
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||||||
|
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
|
||||||
|
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
|
||||||
|
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
|
||||||
|
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
|
||||||
|
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
|
||||||
|
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
|
||||||
|
|
||||||
/// Build content blocks from text and media paths
|
/// Build content blocks from text and media paths
|
||||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
fn build_content_blocks(
|
||||||
|
text: &str,
|
||||||
|
media_paths: &[String],
|
||||||
|
budget: &mut ImageInlineBudget,
|
||||||
|
) -> Vec<ContentBlock> {
|
||||||
|
build_content_blocks_with_image_budget(text, media_paths, budget)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_content_blocks_with_image_budget(
|
||||||
|
text: &str,
|
||||||
|
media_paths: &[String],
|
||||||
|
budget: &mut ImageInlineBudget,
|
||||||
|
) -> Vec<ContentBlock> {
|
||||||
let mut blocks = Vec::new();
|
let mut blocks = Vec::new();
|
||||||
|
let mut skipped_image_notices = Vec::new();
|
||||||
|
|
||||||
// Add text block if there's text
|
// Add text block if there's text
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
@ -41,10 +62,36 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
let Some(target_tokens) = budget.take_next_image_tokens() else {
|
||||||
|
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
|
||||||
|
skipped_image_notices.push(format!(
|
||||||
|
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
|
||||||
|
display_media_name(path)
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
match encode_image_to_base64_with_budget(path, target_tokens) {
|
||||||
|
Ok((mime_type, base64_data)) => {
|
||||||
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||||
blocks.push(ContentBlock::image_url(url));
|
blocks.push(ContentBlock::image_url(url));
|
||||||
}
|
}
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
|
||||||
|
skipped_image_notices.push(format!(
|
||||||
|
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
|
||||||
|
display_media_name(path)
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !skipped_image_notices.is_empty() {
|
||||||
|
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
|
||||||
|
notice.push('\n');
|
||||||
|
notice.push_str(&skipped_image_notices.join("\n"));
|
||||||
|
blocks.push(ContentBlock::text(notice));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If nothing, add empty text block
|
// If nothing, add empty text block
|
||||||
@ -66,6 +113,88 @@ fn supported_image_mime_type(path: &str) -> Option<String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn display_media_name(path: &str) -> String {
|
||||||
|
std::path::Path::new(path)
|
||||||
|
.file_name()
|
||||||
|
.and_then(|name| name.to_str())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.unwrap_or_else(|| path.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct ImageInlineBudget {
|
||||||
|
remaining_tokens: usize,
|
||||||
|
remaining_images: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ImageInlineBudget {
|
||||||
|
fn new(total_tokens: usize, image_count: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
remaining_tokens: total_tokens,
|
||||||
|
remaining_images: image_count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_next_image_tokens(&mut self) -> Option<usize> {
|
||||||
|
if self.remaining_tokens == 0 || self.remaining_images == 0 {
|
||||||
|
self.remaining_images = self.remaining_images.saturating_sub(1);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let target = self.remaining_tokens.div_ceil(self.remaining_images);
|
||||||
|
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
|
||||||
|
self.remaining_images = self.remaining_images.saturating_sub(1);
|
||||||
|
Some(target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
|
||||||
|
let raw_len = serde_json::to_string(value)
|
||||||
|
.map(|serialized| serialized.len())
|
||||||
|
.unwrap_or_default();
|
||||||
|
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
|
||||||
|
as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn image_token_budget_for_request(
|
||||||
|
runtime_config: &AgentRuntimeConfig,
|
||||||
|
text_only_messages: &[Message],
|
||||||
|
tools: Option<&Vec<crate::domain::tools::Tool>>,
|
||||||
|
) -> usize {
|
||||||
|
let completion_reserve = runtime_config
|
||||||
|
.provider
|
||||||
|
.max_tokens
|
||||||
|
.map(|tokens| tokens as usize)
|
||||||
|
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
|
||||||
|
let input_window = runtime_config
|
||||||
|
.context_window_tokens
|
||||||
|
.saturating_sub(completion_reserve);
|
||||||
|
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
|
||||||
|
|
||||||
|
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
|
||||||
|
+ tools
|
||||||
|
.map(estimate_tokens_from_serialized_json)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
safe_input_window.saturating_sub(text_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.flat_map(|message| message.media_refs.iter())
|
||||||
|
.filter(|path| supported_image_mime_type(path).is_some())
|
||||||
|
.count()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
|
||||||
|
target_tokens
|
||||||
|
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
|
||||||
|
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||||||
|
.saturating_mul(3)
|
||||||
|
/ 4
|
||||||
|
}
|
||||||
|
|
||||||
/// Encode an image file to base64 data URL
|
/// Encode an image file to base64 data URL
|
||||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||||
@ -85,6 +214,80 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
|
|||||||
Ok((mime, encoded))
|
Ok((mime, encoded))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn encode_image_to_base64_with_budget(
|
||||||
|
path: &str,
|
||||||
|
target_tokens: usize,
|
||||||
|
) -> Result<(String, String), std::io::Error> {
|
||||||
|
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||||
|
|
||||||
|
let (mime, encoded) = encode_image_to_base64(path)?;
|
||||||
|
let target_base64_chars = target_tokens
|
||||||
|
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
|
||||||
|
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
|
||||||
|
if encoded.len() <= target_base64_chars {
|
||||||
|
return Ok((mime, encoded));
|
||||||
|
}
|
||||||
|
|
||||||
|
let target_bytes = target_image_bytes_for_tokens(target_tokens);
|
||||||
|
let image_bytes = std::fs::read(path)?;
|
||||||
|
let image = image::load_from_memory(&image_bytes)
|
||||||
|
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
|
||||||
|
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
|
||||||
|
|
||||||
|
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compress_image_to_jpeg_budget(
|
||||||
|
image: &image::DynamicImage,
|
||||||
|
target_bytes: usize,
|
||||||
|
) -> Result<Vec<u8>, std::io::Error> {
|
||||||
|
use image::codecs::jpeg::JpegEncoder;
|
||||||
|
use image::imageops::FilterType;
|
||||||
|
|
||||||
|
let mut best: Option<Vec<u8>> = None;
|
||||||
|
let mut max_side = image
|
||||||
|
.width()
|
||||||
|
.max(image.height())
|
||||||
|
.max(MIN_COMPRESSED_IMAGE_SIDE);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let candidate = if image.width().max(image.height()) > max_side {
|
||||||
|
image.resize(max_side, max_side, FilterType::Triangle)
|
||||||
|
} else {
|
||||||
|
image.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
for quality in JPEG_QUALITY_STEPS {
|
||||||
|
let mut encoded = Vec::new();
|
||||||
|
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
|
||||||
|
encoder
|
||||||
|
.encode_image(&candidate)
|
||||||
|
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
|
||||||
|
|
||||||
|
if encoded.len() <= target_bytes {
|
||||||
|
return Ok(encoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
if best
|
||||||
|
.as_ref()
|
||||||
|
.map(|current| encoded.len() < current.len())
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
best = Some(encoded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
|
||||||
|
}
|
||||||
|
|
||||||
|
best.ok_or_else(|| {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Truncate tool result if it exceeds the configured limit.
|
/// Truncate tool result if it exceeds the configured limit.
|
||||||
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
||||||
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
||||||
@ -272,11 +475,11 @@ fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Convert ChatMessage to LLM Message format
|
/// Convert ChatMessage to LLM Message format
|
||||||
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
|
||||||
let content = if m.media_refs.is_empty() {
|
let content = if m.media_refs.is_empty() {
|
||||||
vec![ContentBlock::text(&m.content)]
|
vec![ContentBlock::text(&m.content)]
|
||||||
} else {
|
} else {
|
||||||
build_content_blocks(&m.content, &m.media_refs)
|
build_content_blocks(&m.content, &m.media_refs, image_budget)
|
||||||
};
|
};
|
||||||
|
|
||||||
Message {
|
Message {
|
||||||
@ -289,6 +492,17 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
|
||||||
|
Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: vec![ContentBlock::text(&m.content)],
|
||||||
|
reasoning_content: m.reasoning_content.clone(),
|
||||||
|
tool_call_id: m.tool_call_id.clone(),
|
||||||
|
name: m.tool_name.clone(),
|
||||||
|
tool_calls: m.tool_calls.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
||||||
/// History is managed externally by SessionManager.
|
/// History is managed externally by SessionManager.
|
||||||
pub struct AgentLoop {
|
pub struct AgentLoop {
|
||||||
@ -434,14 +648,6 @@ impl AgentLoop {
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, "Agent iteration started");
|
tracing::debug!(iteration, "Agent iteration started");
|
||||||
|
|
||||||
// Convert messages to LLM format
|
|
||||||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
|
||||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
|
||||||
messages_for_llm.push(Message::system(skill_prompt));
|
|
||||||
}
|
|
||||||
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
|
||||||
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
|
||||||
|
|
||||||
// Build request
|
// Build request
|
||||||
let tool_defs = self.tools.get_definitions();
|
let tool_defs = self.tools.get_definitions();
|
||||||
let tools = if tool_defs.is_empty() {
|
let tools = if tool_defs.is_empty() {
|
||||||
@ -450,6 +656,31 @@ impl AgentLoop {
|
|||||||
Some(tool_defs)
|
Some(tool_defs)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let image_count = count_supported_image_media_refs(&messages);
|
||||||
|
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 2);
|
||||||
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||||
|
text_only_messages.push(Message::system(skill_prompt.clone()));
|
||||||
|
}
|
||||||
|
text_only_messages.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
||||||
|
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
|
||||||
|
|
||||||
|
let image_tokens = image_token_budget_for_request(
|
||||||
|
&self.runtime_config,
|
||||||
|
&text_only_messages,
|
||||||
|
tools.as_ref(),
|
||||||
|
);
|
||||||
|
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
|
||||||
|
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 2);
|
||||||
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||||
|
messages_for_llm.push(Message::system(skill_prompt));
|
||||||
|
}
|
||||||
|
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
||||||
|
messages_for_llm.extend(
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
|
||||||
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: messages_for_llm,
|
messages: messages_for_llm,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
@ -457,7 +688,6 @@ impl AgentLoop {
|
|||||||
tools,
|
tools,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Call LLM
|
|
||||||
let response = match (*self.provider).chat(request).await {
|
let response = match (*self.provider).chat(request).await {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -633,11 +863,24 @@ impl AgentLoop {
|
|||||||
messages.push(summary_request);
|
messages.push(summary_request);
|
||||||
|
|
||||||
// Convert messages to LLM format
|
// Convert messages to LLM format
|
||||||
|
let image_count = count_supported_image_media_refs(&messages);
|
||||||
|
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
||||||
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||||
|
text_only_messages.push(Message::system(skill_prompt));
|
||||||
|
}
|
||||||
|
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
|
||||||
|
let image_tokens =
|
||||||
|
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
|
||||||
|
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
|
||||||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
||||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||||
messages_for_llm.push(Message::system(skill_prompt));
|
messages_for_llm.push(Message::system(skill_prompt));
|
||||||
}
|
}
|
||||||
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
messages_for_llm.extend(
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
|
||||||
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: messages_for_llm,
|
messages: messages_for_llm,
|
||||||
@ -896,7 +1139,8 @@ mod tests {
|
|||||||
}],
|
}],
|
||||||
);
|
);
|
||||||
|
|
||||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
let mut image_budget = ImageInlineBudget::new(0, 0);
|
||||||
|
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
|
||||||
|
|
||||||
assert_eq!(provider_message.role, "assistant");
|
assert_eq!(provider_message.role, "assistant");
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||||
@ -915,7 +1159,8 @@ mod tests {
|
|||||||
let chat_message =
|
let chat_message =
|
||||||
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
|
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
|
||||||
|
|
||||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
let mut image_budget = ImageInlineBudget::new(0, 0);
|
||||||
|
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
|
||||||
|
|
||||||
assert_eq!(provider_message.role, "assistant");
|
assert_eq!(provider_message.role, "assistant");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -983,7 +1228,12 @@ mod tests {
|
|||||||
let pdf_path = temp_dir.path().join("demo.pdf");
|
let pdf_path = temp_dir.path().join("demo.pdf");
|
||||||
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
|
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
|
||||||
|
|
||||||
let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]);
|
let mut budget = ImageInlineBudget::new(1_000, 0);
|
||||||
|
let blocks = build_content_blocks(
|
||||||
|
"hello",
|
||||||
|
&[pdf_path.to_string_lossy().to_string()],
|
||||||
|
&mut budget,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(blocks.len(), 1);
|
assert_eq!(blocks.len(), 1);
|
||||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||||
@ -993,9 +1243,15 @@ mod tests {
|
|||||||
fn test_build_content_blocks_keeps_supported_images() {
|
fn test_build_content_blocks_keeps_supported_images() {
|
||||||
let temp_dir = tempdir().unwrap();
|
let temp_dir = tempdir().unwrap();
|
||||||
let jpg_path = temp_dir.path().join("demo.jpg");
|
let jpg_path = temp_dir.path().join("demo.jpg");
|
||||||
std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap();
|
let image = image::DynamicImage::new_rgb8(8, 8);
|
||||||
|
image.save(&jpg_path).unwrap();
|
||||||
|
|
||||||
let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]);
|
let mut budget = ImageInlineBudget::new(10_000, 1);
|
||||||
|
let blocks = build_content_blocks(
|
||||||
|
"hello",
|
||||||
|
&[jpg_path.to_string_lossy().to_string()],
|
||||||
|
&mut budget,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(blocks.len(), 2);
|
assert_eq!(blocks.len(), 2);
|
||||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||||
@ -1003,6 +1259,50 @@ mod tests {
|
|||||||
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_content_blocks_compresses_images_to_budget() {
|
||||||
|
let temp_dir = tempdir().unwrap();
|
||||||
|
let png_path = temp_dir.path().join("large.png");
|
||||||
|
let image = image::DynamicImage::new_rgb8(512, 512);
|
||||||
|
image.save(&png_path).unwrap();
|
||||||
|
|
||||||
|
let mut budget = ImageInlineBudget::new(512, 1);
|
||||||
|
let blocks = build_content_blocks(
|
||||||
|
"hello",
|
||||||
|
&[png_path.to_string_lossy().to_string()],
|
||||||
|
&mut budget,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(blocks.len(), 2);
|
||||||
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||||
|
assert!(
|
||||||
|
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
|
||||||
|
let temp_dir = tempdir().unwrap();
|
||||||
|
let jpg_path = temp_dir.path().join("demo.jpg");
|
||||||
|
let image = image::DynamicImage::new_rgb8(8, 8);
|
||||||
|
image.save(&jpg_path).unwrap();
|
||||||
|
|
||||||
|
let mut budget = ImageInlineBudget::new(0, 1);
|
||||||
|
let blocks = build_content_blocks(
|
||||||
|
"hello",
|
||||||
|
&[jpg_path.to_string_lossy().to_string()],
|
||||||
|
&mut budget,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(blocks.len(), 2);
|
||||||
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||||
|
assert!(matches!(
|
||||||
|
&blocks[1],
|
||||||
|
ContentBlock::Text { text }
|
||||||
|
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
@ -8,13 +8,34 @@ use crate::text::{char_count, take_prefix_chars};
|
|||||||
|
|
||||||
use crate::agent::{AgentError, AgentRuntimeConfig};
|
use crate::agent::{AgentError, AgentRuntimeConfig};
|
||||||
|
|
||||||
|
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
||||||
|
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||||||
|
|
||||||
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
|
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
|
||||||
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||||
let raw: usize = messages
|
let raw: usize = messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| m.content.len().div_ceil(4) + 4)
|
.map(|message| {
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.len()
|
||||||
|
.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||||||
|
+ estimate_image_tokens(&message.media_refs)
|
||||||
|
+ 4
|
||||||
|
})
|
||||||
.sum();
|
.sum();
|
||||||
(raw as f64 * 1.2) as usize
|
(raw as f64 * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_image_tokens(media_refs: &[String]) -> usize {
|
||||||
|
media_refs
|
||||||
|
.iter()
|
||||||
|
.filter_map(|path| std::fs::metadata(path).ok())
|
||||||
|
.map(|metadata| {
|
||||||
|
let base64_chars = metadata.len().saturating_mul(4).div_ceil(3) as usize;
|
||||||
|
base64_chars.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configuration for context compression.
|
/// Configuration for context compression.
|
||||||
@ -492,6 +513,21 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_tokens_includes_image_media_refs() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let image_path = temp_dir.path().join("demo.jpg");
|
||||||
|
std::fs::write(&image_path, vec![0_u8; 12_000]).unwrap();
|
||||||
|
|
||||||
|
let plain = vec![ChatMessage::user("hello")];
|
||||||
|
let with_image = vec![ChatMessage::user_with_media(
|
||||||
|
"hello",
|
||||||
|
vec![image_path.to_string_lossy().to_string()],
|
||||||
|
)];
|
||||||
|
|
||||||
|
assert!(estimate_tokens(&with_image) > estimate_tokens(&plain));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_should_compress() {
|
fn test_should_compress() {
|
||||||
let compressor = ContextCompressor::new(20);
|
let compressor = ContextCompressor::new(20);
|
||||||
|
|||||||
@ -494,11 +494,12 @@ impl SessionManager {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::bus::MessageBus;
|
use crate::bus::MessageBus;
|
||||||
|
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||||
use crate::storage::MemoryRecord;
|
use crate::storage::MemoryRecord;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::{Json, Router, routing::post};
|
use axum::{Json, Router, routing::post};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
Arc as StdArc,
|
Arc as StdArc,
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
@ -534,6 +535,8 @@ mod tests {
|
|||||||
ToolRegistryFactory::new(
|
ToolRegistryFactory::new(
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
store.clone(),
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
)
|
)
|
||||||
@ -576,6 +579,8 @@ mod tests {
|
|||||||
ToolRegistryFactory::new(
|
ToolRegistryFactory::new(
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
store.clone(),
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
)
|
)
|
||||||
@ -1476,6 +1481,8 @@ mod tests {
|
|||||||
ToolRegistryFactory::new(
|
ToolRegistryFactory::new(
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
store.clone(),
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
)
|
)
|
||||||
@ -1511,6 +1518,8 @@ mod tests {
|
|||||||
ToolRegistryFactory::new(
|
ToolRegistryFactory::new(
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
store.clone(),
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
)
|
)
|
||||||
@ -1574,6 +1583,8 @@ mod tests {
|
|||||||
ToolRegistryFactory::new(
|
ToolRegistryFactory::new(
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
store.clone(),
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
)
|
)
|
||||||
@ -1616,8 +1627,14 @@ mod tests {
|
|||||||
fn test_default_tools_registers_get_time() {
|
fn test_default_tools_registers_get_time() {
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tools =
|
let tools = ToolRegistryFactory::new(
|
||||||
ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string())
|
skills,
|
||||||
|
store.clone(),
|
||||||
|
store.clone(),
|
||||||
|
store,
|
||||||
|
HashSet::new(),
|
||||||
|
"Asia/Shanghai".to_string(),
|
||||||
|
)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
||||||
|
|||||||
@ -1,7 +1,22 @@
|
|||||||
use picobot::config::{Config, LLMProviderConfig};
|
use picobot::config::{Config, LLMProviderConfig};
|
||||||
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
use picobot::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
|
||||||
|
ProviderRuntimeConfig {
|
||||||
|
provider_type: config.provider_type,
|
||||||
|
name: config.name,
|
||||||
|
base_url: config.base_url,
|
||||||
|
api_key: config.api_key,
|
||||||
|
extra_headers: config.extra_headers,
|
||||||
|
llm_timeout_secs: config.llm_timeout_secs,
|
||||||
|
model_id: config.model_id,
|
||||||
|
temperature: config.temperature,
|
||||||
|
max_tokens: config.max_tokens,
|
||||||
|
model_extra: config.model_extra,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn load_config() -> Option<LLMProviderConfig> {
|
fn load_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
|
|
||||||
@ -45,7 +60,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
|||||||
async fn test_openai_simple_completion() {
|
async fn test_openai_simple_completion() {
|
||||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||||
|
|
||||||
assert!(!response.id.is_empty());
|
assert!(!response.id.is_empty());
|
||||||
@ -59,7 +74,7 @@ async fn test_openai_simple_completion() {
|
|||||||
async fn test_openai_conversation() {
|
async fn test_openai_conversation() {
|
||||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
@ -89,7 +104,8 @@ async fn test_config_load() {
|
|||||||
assert_eq!(provider_config.name, "aliyun");
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||||
|
|
||||||
let provider = create_provider(provider_config).expect("Failed to create provider");
|
let provider =
|
||||||
|
create_provider(to_runtime_config(provider_config)).expect("Failed to create provider");
|
||||||
assert_eq!(provider.ptype(), "openai");
|
assert_eq!(provider.ptype(), "openai");
|
||||||
assert_eq!(provider.name(), "aliyun");
|
assert_eq!(provider.name(), "aliyun");
|
||||||
assert_eq!(provider.model_id(), "qwen-plus");
|
assert_eq!(provider.model_id(), "qwen-plus");
|
||||||
|
|||||||
@ -1,7 +1,24 @@
|
|||||||
use picobot::config::LLMProviderConfig;
|
use picobot::config::LLMProviderConfig;
|
||||||
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
use picobot::providers::{
|
||||||
|
ChatCompletionRequest, Message, ProviderRuntimeConfig, Tool, ToolFunction, create_provider,
|
||||||
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
|
||||||
|
ProviderRuntimeConfig {
|
||||||
|
provider_type: config.provider_type,
|
||||||
|
name: config.name,
|
||||||
|
base_url: config.base_url,
|
||||||
|
api_key: config.api_key,
|
||||||
|
extra_headers: config.extra_headers,
|
||||||
|
llm_timeout_secs: config.llm_timeout_secs,
|
||||||
|
model_id: config.model_id,
|
||||||
|
temperature: config.temperature,
|
||||||
|
max_tokens: config.max_tokens,
|
||||||
|
model_extra: config.model_extra,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
|
|
||||||
@ -56,7 +73,7 @@ fn make_weather_tool() -> Tool {
|
|||||||
async fn test_openai_tool_call() {
|
async fn test_openai_tool_call() {
|
||||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message::user("What is the weather in Tokyo?")],
|
messages: vec![Message::user("What is the weather in Tokyo?")],
|
||||||
@ -84,7 +101,7 @@ async fn test_openai_tool_call() {
|
|||||||
async fn test_openai_tool_call_with_manual_execution() {
|
async fn test_openai_tool_call_with_manual_execution() {
|
||||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||||
|
|
||||||
// First request with tool
|
// First request with tool
|
||||||
let request1 = ChatCompletionRequest {
|
let request1 = ChatCompletionRequest {
|
||||||
@ -120,7 +137,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
async fn test_openai_no_tool_when_not_provided() {
|
async fn test_openai_no_tool_when_not_provided() {
|
||||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message::user("Say hello in one word.")],
|
messages: vec![Message::user("Say hello in one word.")],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user