feat: 添加图像处理预算和估算逻辑,优化消息内容构建,支持图像媒体引用
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
fc5b2a359f
commit
531e72d24f
@ -30,6 +30,7 @@ cron = { version = "0.13", features = ["serde"] }
|
||||
iana-time-zone = "0.1"
|
||||
mime_guess = "2.0"
|
||||
base64 = "0.22"
|
||||
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
|
||||
tempfile = "3"
|
||||
meval = "0.2"
|
||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
|
||||
@ -24,10 +24,31 @@ const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
|
||||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||
|
||||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
||||
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||||
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
|
||||
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
|
||||
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
|
||||
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
|
||||
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
|
||||
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
|
||||
|
||||
/// Build content blocks from text and media paths
|
||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||
fn build_content_blocks(
|
||||
text: &str,
|
||||
media_paths: &[String],
|
||||
budget: &mut ImageInlineBudget,
|
||||
) -> Vec<ContentBlock> {
|
||||
build_content_blocks_with_image_budget(text, media_paths, budget)
|
||||
}
|
||||
|
||||
fn build_content_blocks_with_image_budget(
|
||||
text: &str,
|
||||
media_paths: &[String],
|
||||
budget: &mut ImageInlineBudget,
|
||||
) -> Vec<ContentBlock> {
|
||||
let mut blocks = Vec::new();
|
||||
let mut skipped_image_notices = Vec::new();
|
||||
|
||||
// Add text block if there's text
|
||||
if !text.is_empty() {
|
||||
@ -41,12 +62,38 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock>
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
||||
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
blocks.push(ContentBlock::image_url(url));
|
||||
let Some(target_tokens) = budget.take_next_image_tokens() else {
|
||||
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
|
||||
skipped_image_notices.push(format!(
|
||||
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
|
||||
display_media_name(path)
|
||||
));
|
||||
continue;
|
||||
};
|
||||
|
||||
match encode_image_to_base64_with_budget(path, target_tokens) {
|
||||
Ok((mime_type, base64_data)) => {
|
||||
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
blocks.push(ContentBlock::image_url(url));
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
|
||||
skipped_image_notices.push(format!(
|
||||
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
|
||||
display_media_name(path)
|
||||
));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !skipped_image_notices.is_empty() {
|
||||
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
|
||||
notice.push('\n');
|
||||
notice.push_str(&skipped_image_notices.join("\n"));
|
||||
blocks.push(ContentBlock::text(notice));
|
||||
}
|
||||
|
||||
// If nothing, add empty text block
|
||||
if blocks.is_empty() {
|
||||
blocks.push(ContentBlock::text(""));
|
||||
@ -66,6 +113,88 @@ fn supported_image_mime_type(path: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
fn display_media_name(path: &str) -> String {
|
||||
std::path::Path::new(path)
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.map(ToOwned::to_owned)
|
||||
.unwrap_or_else(|| path.to_string())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ImageInlineBudget {
|
||||
remaining_tokens: usize,
|
||||
remaining_images: usize,
|
||||
}
|
||||
|
||||
impl ImageInlineBudget {
|
||||
fn new(total_tokens: usize, image_count: usize) -> Self {
|
||||
Self {
|
||||
remaining_tokens: total_tokens,
|
||||
remaining_images: image_count,
|
||||
}
|
||||
}
|
||||
|
||||
fn take_next_image_tokens(&mut self) -> Option<usize> {
|
||||
if self.remaining_tokens == 0 || self.remaining_images == 0 {
|
||||
self.remaining_images = self.remaining_images.saturating_sub(1);
|
||||
return None;
|
||||
}
|
||||
|
||||
let target = self.remaining_tokens.div_ceil(self.remaining_images);
|
||||
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
|
||||
self.remaining_images = self.remaining_images.saturating_sub(1);
|
||||
Some(target)
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
|
||||
let raw_len = serde_json::to_string(value)
|
||||
.map(|serialized| serialized.len())
|
||||
.unwrap_or_default();
|
||||
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
|
||||
as usize
|
||||
}
|
||||
|
||||
fn image_token_budget_for_request(
|
||||
runtime_config: &AgentRuntimeConfig,
|
||||
text_only_messages: &[Message],
|
||||
tools: Option<&Vec<crate::domain::tools::Tool>>,
|
||||
) -> usize {
|
||||
let completion_reserve = runtime_config
|
||||
.provider
|
||||
.max_tokens
|
||||
.map(|tokens| tokens as usize)
|
||||
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
|
||||
let input_window = runtime_config
|
||||
.context_window_tokens
|
||||
.saturating_sub(completion_reserve);
|
||||
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
|
||||
|
||||
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
|
||||
+ tools
|
||||
.map(estimate_tokens_from_serialized_json)
|
||||
.unwrap_or_default();
|
||||
|
||||
safe_input_window.saturating_sub(text_tokens)
|
||||
}
|
||||
|
||||
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
|
||||
messages
|
||||
.iter()
|
||||
.flat_map(|message| message.media_refs.iter())
|
||||
.filter(|path| supported_image_mime_type(path).is_some())
|
||||
.count()
|
||||
}
|
||||
|
||||
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
|
||||
target_tokens
|
||||
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
|
||||
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||||
.saturating_mul(3)
|
||||
/ 4
|
||||
}
|
||||
|
||||
/// Encode an image file to base64 data URL
|
||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
@ -85,6 +214,80 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
|
||||
Ok((mime, encoded))
|
||||
}
|
||||
|
||||
fn encode_image_to_base64_with_budget(
|
||||
path: &str,
|
||||
target_tokens: usize,
|
||||
) -> Result<(String, String), std::io::Error> {
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
|
||||
let (mime, encoded) = encode_image_to_base64(path)?;
|
||||
let target_base64_chars = target_tokens
|
||||
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
|
||||
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
|
||||
if encoded.len() <= target_base64_chars {
|
||||
return Ok((mime, encoded));
|
||||
}
|
||||
|
||||
let target_bytes = target_image_bytes_for_tokens(target_tokens);
|
||||
let image_bytes = std::fs::read(path)?;
|
||||
let image = image::load_from_memory(&image_bytes)
|
||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
|
||||
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
|
||||
|
||||
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
|
||||
}
|
||||
|
||||
fn compress_image_to_jpeg_budget(
|
||||
image: &image::DynamicImage,
|
||||
target_bytes: usize,
|
||||
) -> Result<Vec<u8>, std::io::Error> {
|
||||
use image::codecs::jpeg::JpegEncoder;
|
||||
use image::imageops::FilterType;
|
||||
|
||||
let mut best: Option<Vec<u8>> = None;
|
||||
let mut max_side = image
|
||||
.width()
|
||||
.max(image.height())
|
||||
.max(MIN_COMPRESSED_IMAGE_SIDE);
|
||||
|
||||
loop {
|
||||
let candidate = if image.width().max(image.height()) > max_side {
|
||||
image.resize(max_side, max_side, FilterType::Triangle)
|
||||
} else {
|
||||
image.clone()
|
||||
};
|
||||
|
||||
for quality in JPEG_QUALITY_STEPS {
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
|
||||
encoder
|
||||
.encode_image(&candidate)
|
||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
|
||||
|
||||
if encoded.len() <= target_bytes {
|
||||
return Ok(encoded);
|
||||
}
|
||||
|
||||
if best
|
||||
.as_ref()
|
||||
.map(|current| encoded.len() < current.len())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
best = Some(encoded);
|
||||
}
|
||||
}
|
||||
|
||||
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
|
||||
break;
|
||||
}
|
||||
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
|
||||
}
|
||||
|
||||
best.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
|
||||
})
|
||||
}
|
||||
|
||||
/// Truncate tool result if it exceeds the configured limit.
|
||||
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
||||
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
||||
@ -272,11 +475,11 @@ fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
|
||||
}
|
||||
|
||||
/// Convert ChatMessage to LLM Message format
|
||||
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
||||
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
|
||||
let content = if m.media_refs.is_empty() {
|
||||
vec![ContentBlock::text(&m.content)]
|
||||
} else {
|
||||
build_content_blocks(&m.content, &m.media_refs)
|
||||
build_content_blocks(&m.content, &m.media_refs, image_budget)
|
||||
};
|
||||
|
||||
Message {
|
||||
@ -289,6 +492,17 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
||||
}
|
||||
}
|
||||
|
||||
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
|
||||
Message {
|
||||
role: m.role.clone(),
|
||||
content: vec![ContentBlock::text(&m.content)],
|
||||
reasoning_content: m.reasoning_content.clone(),
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
tool_calls: m.tool_calls.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
||||
/// History is managed externally by SessionManager.
|
||||
pub struct AgentLoop {
|
||||
@ -434,14 +648,6 @@ impl AgentLoop {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(iteration, "Agent iteration started");
|
||||
|
||||
// Convert messages to LLM format
|
||||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||
messages_for_llm.push(Message::system(skill_prompt));
|
||||
}
|
||||
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
||||
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
||||
|
||||
// Build request
|
||||
let tool_defs = self.tools.get_definitions();
|
||||
let tools = if tool_defs.is_empty() {
|
||||
@ -450,6 +656,31 @@ impl AgentLoop {
|
||||
Some(tool_defs)
|
||||
};
|
||||
|
||||
let image_count = count_supported_image_media_refs(&messages);
|
||||
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 2);
|
||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||
text_only_messages.push(Message::system(skill_prompt.clone()));
|
||||
}
|
||||
text_only_messages.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
||||
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
|
||||
|
||||
let image_tokens = image_token_budget_for_request(
|
||||
&self.runtime_config,
|
||||
&text_only_messages,
|
||||
tools.as_ref(),
|
||||
);
|
||||
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
|
||||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 2);
|
||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||
messages_for_llm.push(Message::system(skill_prompt));
|
||||
}
|
||||
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
||||
messages_for_llm.extend(
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: messages_for_llm,
|
||||
temperature: None,
|
||||
@ -457,7 +688,6 @@ impl AgentLoop {
|
||||
tools,
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = match (*self.provider).chat(request).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
@ -633,11 +863,24 @@ impl AgentLoop {
|
||||
messages.push(summary_request);
|
||||
|
||||
// Convert messages to LLM format
|
||||
let image_count = count_supported_image_media_refs(&messages);
|
||||
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||
text_only_messages.push(Message::system(skill_prompt));
|
||||
}
|
||||
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
|
||||
let image_tokens =
|
||||
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
|
||||
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
|
||||
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
|
||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||
messages_for_llm.push(Message::system(skill_prompt));
|
||||
}
|
||||
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
||||
messages_for_llm.extend(
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: messages_for_llm,
|
||||
@ -896,7 +1139,8 @@ mod tests {
|
||||
}],
|
||||
);
|
||||
|
||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||
let mut image_budget = ImageInlineBudget::new(0, 0);
|
||||
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||
@ -915,7 +1159,8 @@ mod tests {
|
||||
let chat_message =
|
||||
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
|
||||
|
||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||
let mut image_budget = ImageInlineBudget::new(0, 0);
|
||||
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(
|
||||
@ -983,7 +1228,12 @@ mod tests {
|
||||
let pdf_path = temp_dir.path().join("demo.pdf");
|
||||
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
|
||||
|
||||
let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]);
|
||||
let mut budget = ImageInlineBudget::new(1_000, 0);
|
||||
let blocks = build_content_blocks(
|
||||
"hello",
|
||||
&[pdf_path.to_string_lossy().to_string()],
|
||||
&mut budget,
|
||||
);
|
||||
|
||||
assert_eq!(blocks.len(), 1);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
@ -993,9 +1243,15 @@ mod tests {
|
||||
fn test_build_content_blocks_keeps_supported_images() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let jpg_path = temp_dir.path().join("demo.jpg");
|
||||
std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap();
|
||||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||||
image.save(&jpg_path).unwrap();
|
||||
|
||||
let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]);
|
||||
let mut budget = ImageInlineBudget::new(10_000, 1);
|
||||
let blocks = build_content_blocks(
|
||||
"hello",
|
||||
&[jpg_path.to_string_lossy().to_string()],
|
||||
&mut budget,
|
||||
);
|
||||
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
@ -1003,6 +1259,50 @@ mod tests {
|
||||
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_content_blocks_compresses_images_to_budget() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("large.png");
|
||||
let image = image::DynamicImage::new_rgb8(512, 512);
|
||||
image.save(&png_path).unwrap();
|
||||
|
||||
let mut budget = ImageInlineBudget::new(512, 1);
|
||||
let blocks = build_content_blocks(
|
||||
"hello",
|
||||
&[png_path.to_string_lossy().to_string()],
|
||||
&mut budget,
|
||||
);
|
||||
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
assert!(
|
||||
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let jpg_path = temp_dir.path().join("demo.jpg");
|
||||
let image = image::DynamicImage::new_rgb8(8, 8);
|
||||
image.save(&jpg_path).unwrap();
|
||||
|
||||
let mut budget = ImageInlineBudget::new(0, 1);
|
||||
let blocks = build_content_blocks(
|
||||
"hello",
|
||||
&[jpg_path.to_string_lossy().to_string()],
|
||||
&mut budget,
|
||||
);
|
||||
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
assert!(matches!(
|
||||
&blocks[1],
|
||||
ContentBlock::Text { text }
|
||||
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@ -8,13 +8,34 @@ use crate::text::{char_count, take_prefix_chars};
|
||||
|
||||
use crate::agent::{AgentError, AgentRuntimeConfig};
|
||||
|
||||
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
||||
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||||
|
||||
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
|
||||
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||
let raw: usize = messages
|
||||
.iter()
|
||||
.map(|m| m.content.len().div_ceil(4) + 4)
|
||||
.map(|message| {
|
||||
message
|
||||
.content
|
||||
.len()
|
||||
.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||||
+ estimate_image_tokens(&message.media_refs)
|
||||
+ 4
|
||||
})
|
||||
.sum();
|
||||
(raw as f64 * 1.2) as usize
|
||||
(raw as f64 * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
|
||||
}
|
||||
|
||||
fn estimate_image_tokens(media_refs: &[String]) -> usize {
|
||||
media_refs
|
||||
.iter()
|
||||
.filter_map(|path| std::fs::metadata(path).ok())
|
||||
.map(|metadata| {
|
||||
let base64_chars = metadata.len().saturating_mul(4).div_ceil(3) as usize;
|
||||
base64_chars.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Configuration for context compression.
|
||||
@ -492,6 +513,21 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens_includes_image_media_refs() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let image_path = temp_dir.path().join("demo.jpg");
|
||||
std::fs::write(&image_path, vec![0_u8; 12_000]).unwrap();
|
||||
|
||||
let plain = vec![ChatMessage::user("hello")];
|
||||
let with_image = vec![ChatMessage::user_with_media(
|
||||
"hello",
|
||||
vec![image_path.to_string_lossy().to_string()],
|
||||
)];
|
||||
|
||||
assert!(estimate_tokens(&with_image) > estimate_tokens(&plain));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_compress() {
|
||||
let compressor = ContextCompressor::new(20);
|
||||
|
||||
@ -494,11 +494,12 @@ impl SessionManager {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::MessageBus;
|
||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||
use crate::storage::MemoryRecord;
|
||||
use axum::http::StatusCode;
|
||||
use axum::{Json, Router, routing::post};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::{
|
||||
Arc as StdArc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@ -534,6 +535,8 @@ mod tests {
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
@ -576,6 +579,8 @@ mod tests {
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
@ -1476,6 +1481,8 @@ mod tests {
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
@ -1511,6 +1518,8 @@ mod tests {
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
@ -1574,6 +1583,8 @@ mod tests {
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
@ -1616,9 +1627,15 @@ mod tests {
|
||||
fn test_default_tools_registers_get_time() {
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tools =
|
||||
ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string())
|
||||
.build();
|
||||
let tools = ToolRegistryFactory::new(
|
||||
skills,
|
||||
store.clone(),
|
||||
store.clone(),
|
||||
store,
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
.build();
|
||||
|
||||
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
||||
}
|
||||
|
||||
@ -1,7 +1,22 @@
|
||||
use picobot::config::{Config, LLMProviderConfig};
|
||||
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use picobot::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
|
||||
ProviderRuntimeConfig {
|
||||
provider_type: config.provider_type,
|
||||
name: config.name,
|
||||
base_url: config.base_url,
|
||||
api_key: config.api_key,
|
||||
extra_headers: config.extra_headers,
|
||||
llm_timeout_secs: config.llm_timeout_secs,
|
||||
model_id: config.model_id,
|
||||
temperature: config.temperature,
|
||||
max_tokens: config.max_tokens,
|
||||
model_extra: config.model_extra,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
|
||||
@ -45,7 +60,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
||||
async fn test_openai_simple_completion() {
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||
|
||||
assert!(!response.id.is_empty());
|
||||
@ -59,7 +74,7 @@ async fn test_openai_simple_completion() {
|
||||
async fn test_openai_conversation() {
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![
|
||||
@ -89,7 +104,8 @@ async fn test_config_load() {
|
||||
assert_eq!(provider_config.name, "aliyun");
|
||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||
|
||||
let provider = create_provider(provider_config).expect("Failed to create provider");
|
||||
let provider =
|
||||
create_provider(to_runtime_config(provider_config)).expect("Failed to create provider");
|
||||
assert_eq!(provider.ptype(), "openai");
|
||||
assert_eq!(provider.name(), "aliyun");
|
||||
assert_eq!(provider.model_id(), "qwen-plus");
|
||||
|
||||
@ -1,7 +1,24 @@
|
||||
use picobot::config::LLMProviderConfig;
|
||||
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
||||
use picobot::providers::{
|
||||
ChatCompletionRequest, Message, ProviderRuntimeConfig, Tool, ToolFunction, create_provider,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
|
||||
ProviderRuntimeConfig {
|
||||
provider_type: config.provider_type,
|
||||
name: config.name,
|
||||
base_url: config.base_url,
|
||||
api_key: config.api_key,
|
||||
extra_headers: config.extra_headers,
|
||||
llm_timeout_secs: config.llm_timeout_secs,
|
||||
model_id: config.model_id,
|
||||
temperature: config.temperature,
|
||||
max_tokens: config.max_tokens,
|
||||
model_extra: config.model_extra,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
|
||||
@ -56,7 +73,7 @@ fn make_weather_tool() -> Tool {
|
||||
async fn test_openai_tool_call() {
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::user("What is the weather in Tokyo?")],
|
||||
@ -84,7 +101,7 @@ async fn test_openai_tool_call() {
|
||||
async fn test_openai_tool_call_with_manual_execution() {
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||
|
||||
// First request with tool
|
||||
let request1 = ChatCompletionRequest {
|
||||
@ -120,7 +137,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
async fn test_openai_no_tool_when_not_provided() {
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::user("Say hello in one word.")],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user