feat: 添加图像处理预算和估算逻辑,优化消息内容构建,支持图像媒体引用

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-05-01 21:22:07 +08:00
parent fc5b2a359f
commit 531e72d24f
6 changed files with 422 additions and 35 deletions

View File

@ -30,6 +30,7 @@ cron = { version = "0.13", features = ["serde"] }
iana-time-zone = "0.1"
mime_guess = "2.0"
base64 = "0.22"
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
tempfile = "3"
meval = "0.2"
rusqlite = { version = "0.32", features = ["bundled"] }

View File

@ -24,10 +24,31 @@ const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9;
const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048;
const DATA_URL_OVERHEAD_TOKENS: usize = 16;
const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36];
const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64;
const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:";
/// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
fn build_content_blocks(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
build_content_blocks_with_image_budget(text, media_paths, budget)
}
fn build_content_blocks_with_image_budget(
text: &str,
media_paths: &[String],
budget: &mut ImageInlineBudget,
) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
let mut skipped_image_notices = Vec::new();
// Add text block if there's text
if !text.is_empty() {
@ -41,10 +62,36 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock>
continue;
}
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
let Some(target_tokens) = budget.take_next_image_tokens() else {
tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains");
skipped_image_notices.push(format!(
"- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
};
match encode_image_to_base64_with_budget(path, target_tokens) {
Ok((mime_type, base64_data)) => {
let url = format!("data:{};base64,{}", mime_type, base64_data);
blocks.push(ContentBlock::image_url(url));
}
Err(err) => {
tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed");
skipped_image_notices.push(format!(
"- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。",
display_media_name(path)
));
continue;
}
}
}
if !skipped_image_notices.is_empty() {
let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX);
notice.push('\n');
notice.push_str(&skipped_image_notices.join("\n"));
blocks.push(ContentBlock::text(notice));
}
// If nothing, add empty text block
@ -66,6 +113,88 @@ fn supported_image_mime_type(path: &str) -> Option<String> {
}
}
fn display_media_name(path: &str) -> String {
std::path::Path::new(path)
.file_name()
.and_then(|name| name.to_str())
.map(ToOwned::to_owned)
.unwrap_or_else(|| path.to_string())
}
#[derive(Debug, Clone)]
struct ImageInlineBudget {
remaining_tokens: usize,
remaining_images: usize,
}
impl ImageInlineBudget {
fn new(total_tokens: usize, image_count: usize) -> Self {
Self {
remaining_tokens: total_tokens,
remaining_images: image_count,
}
}
fn take_next_image_tokens(&mut self) -> Option<usize> {
if self.remaining_tokens == 0 || self.remaining_images == 0 {
self.remaining_images = self.remaining_images.saturating_sub(1);
return None;
}
let target = self.remaining_tokens.div_ceil(self.remaining_images);
self.remaining_tokens = self.remaining_tokens.saturating_sub(target);
self.remaining_images = self.remaining_images.saturating_sub(1);
Some(target)
}
}
fn estimate_tokens_from_serialized_json<T: serde::Serialize>(value: &T) -> usize {
let raw_len = serde_json::to_string(value)
.map(|serialized| serialized.len())
.unwrap_or_default();
((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER)
as usize
}
fn image_token_budget_for_request(
runtime_config: &AgentRuntimeConfig,
text_only_messages: &[Message],
tools: Option<&Vec<crate::domain::tools::Tool>>,
) -> usize {
let completion_reserve = runtime_config
.provider
.max_tokens
.map(|tokens| tokens as usize)
.unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE);
let input_window = runtime_config
.context_window_tokens
.saturating_sub(completion_reserve);
let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize;
let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages)
+ tools
.map(estimate_tokens_from_serialized_json)
.unwrap_or_default();
safe_input_window.saturating_sub(text_tokens)
}
fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize {
messages
.iter()
.flat_map(|message| message.media_refs.iter())
.filter(|path| supported_image_mime_type(path).is_some())
.count()
}
fn target_image_bytes_for_tokens(target_tokens: usize) -> usize {
target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
.saturating_mul(3)
/ 4
}
/// Encode an image file to base64 data URL
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
@ -85,6 +214,80 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
Ok((mime, encoded))
}
fn encode_image_to_base64_with_budget(
path: &str,
target_tokens: usize,
) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let (mime, encoded) = encode_image_to_base64(path)?;
let target_base64_chars = target_tokens
.saturating_sub(DATA_URL_OVERHEAD_TOKENS)
.saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN);
if encoded.len() <= target_base64_chars {
return Ok((mime, encoded));
}
let target_bytes = target_image_bytes_for_tokens(target_tokens);
let image_bytes = std::fs::read(path)?;
let image = image::load_from_memory(&image_bytes)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?;
Ok(("image/jpeg".to_string(), STANDARD.encode(compressed)))
}
fn compress_image_to_jpeg_budget(
image: &image::DynamicImage,
target_bytes: usize,
) -> Result<Vec<u8>, std::io::Error> {
use image::codecs::jpeg::JpegEncoder;
use image::imageops::FilterType;
let mut best: Option<Vec<u8>> = None;
let mut max_side = image
.width()
.max(image.height())
.max(MIN_COMPRESSED_IMAGE_SIDE);
loop {
let candidate = if image.width().max(image.height()) > max_side {
image.resize(max_side, max_side, FilterType::Triangle)
} else {
image.clone()
};
for quality in JPEG_QUALITY_STEPS {
let mut encoded = Vec::new();
let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality);
encoder
.encode_image(&candidate)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
if encoded.len() <= target_bytes {
return Ok(encoded);
}
if best
.as_ref()
.map(|current| encoded.len() < current.len())
.unwrap_or(true)
{
best = Some(encoded);
}
}
if max_side <= MIN_COMPRESSED_IMAGE_SIDE {
break;
}
max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE);
}
best.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed")
})
}
/// Truncate tool result if it exceeds the configured limit.
/// Preserves the end of the output as it often contains the conclusion/useful result.
fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
@ -272,11 +475,11 @@ fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
}
/// Convert ChatMessage to LLM Message format
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs)
build_content_blocks(&m.content, &m.media_refs, image_budget)
};
Message {
@ -289,6 +492,17 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
}
}
fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message {
Message {
role: m.role.clone(),
content: vec![ContentBlock::text(&m.content)],
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
@ -434,14 +648,6 @@ impl AgentLoop {
#[cfg(debug_assertions)]
tracing::debug!(iteration, "Agent iteration started");
// Convert messages to LLM format
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
// Build request
let tool_defs = self.tools.get_definitions();
let tools = if tool_defs.is_empty() {
@ -450,6 +656,31 @@ impl AgentLoop {
Some(tool_defs)
};
let image_count = count_supported_image_media_refs(&messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 2);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
text_only_messages.push(Message::system(skill_prompt.clone()));
}
text_only_messages.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens = image_token_budget_for_request(
&self.runtime_config,
&text_only_messages,
tools.as_ref(),
);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 2);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
messages_for_llm.extend(
messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
@ -457,7 +688,6 @@ impl AgentLoop {
tools,
};
// Call LLM
let response = match (*self.provider).chat(request).await {
Ok(response) => response,
Err(e) => {
@ -633,11 +863,24 @@ impl AgentLoop {
messages.push(summary_request);
// Convert messages to LLM format
let image_count = count_supported_image_media_refs(&messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
text_only_messages.push(Message::system(skill_prompt));
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens =
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
messages_for_llm.extend(
messages
.iter()
.map(|message| chat_message_to_llm_message(message, &mut image_budget)),
);
let request = ChatCompletionRequest {
messages: messages_for_llm,
@ -896,7 +1139,8 @@ mod tests {
}],
);
let provider_message = chat_message_to_llm_message(&chat_message);
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
@ -915,7 +1159,8 @@ mod tests {
let chat_message =
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
let provider_message = chat_message_to_llm_message(&chat_message);
let mut image_budget = ImageInlineBudget::new(0, 0);
let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget);
assert_eq!(provider_message.role, "assistant");
assert_eq!(
@ -983,7 +1228,12 @@ mod tests {
let pdf_path = temp_dir.path().join("demo.pdf");
std::fs::write(&pdf_path, b"%PDF-1.4").unwrap();
let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]);
let mut budget = ImageInlineBudget::new(1_000, 0);
let blocks = build_content_blocks(
"hello",
&[pdf_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
@ -993,9 +1243,15 @@ mod tests {
fn test_build_content_blocks_keeps_supported_images() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap();
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]);
let mut budget = ImageInlineBudget::new(10_000, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
@ -1003,6 +1259,50 @@ mod tests {
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_compresses_images_to_budget() {
let temp_dir = tempdir().unwrap();
let png_path = temp_dir.path().join("large.png");
let image = image::DynamicImage::new_rgb8(512, 512);
image.save(&png_path).unwrap();
let mut budget = ImageInlineBudget::new(512, 1);
let blocks = build_content_blocks(
"hello",
&[png_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
#[test]
fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() {
let temp_dir = tempdir().unwrap();
let jpg_path = temp_dir.path().join("demo.jpg");
let image = image::DynamicImage::new_rgb8(8, 8);
image.save(&jpg_path).unwrap();
let mut budget = ImageInlineBudget::new(0, 1);
let blocks = build_content_blocks(
"hello",
&[jpg_path.to_string_lossy().to_string()],
&mut budget,
);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(matches!(
&blocks[1],
ContentBlock::Text { text }
if text.contains("图片未能成功入模") && text.contains("demo.jpg")
));
}
}
#[derive(Debug)]

View File

@ -8,13 +8,34 @@ use crate::text::{char_count, take_prefix_chars};
use crate::agent::{AgentError, AgentRuntimeConfig};
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
let raw: usize = messages
.iter()
.map(|m| m.content.len().div_ceil(4) + 4)
.map(|message| {
message
.content
.len()
.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
+ estimate_image_tokens(&message.media_refs)
+ 4
})
.sum();
(raw as f64 * 1.2) as usize
(raw as f64 * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
}
fn estimate_image_tokens(media_refs: &[String]) -> usize {
media_refs
.iter()
.filter_map(|path| std::fs::metadata(path).ok())
.map(|metadata| {
let base64_chars = metadata.len().saturating_mul(4).div_ceil(3) as usize;
base64_chars.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
})
.sum()
}
/// Configuration for context compression.
@ -492,6 +513,21 @@ mod tests {
);
}
#[test]
fn test_estimate_tokens_includes_image_media_refs() {
let temp_dir = tempfile::tempdir().unwrap();
let image_path = temp_dir.path().join("demo.jpg");
std::fs::write(&image_path, vec![0_u8; 12_000]).unwrap();
let plain = vec![ChatMessage::user("hello")];
let with_image = vec![ChatMessage::user_with_media(
"hello",
vec![image_path.to_string_lossy().to_string()],
)];
assert!(estimate_tokens(&with_image) > estimate_tokens(&plain));
}
#[test]
fn test_should_compress() {
let compressor = ContextCompressor::new(20);

View File

@ -494,11 +494,12 @@ impl SessionManager {
mod tests {
use super::*;
use crate::bus::MessageBus;
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::storage::MemoryRecord;
use axum::http::StatusCode;
use axum::{Json, Router, routing::post};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc as StdArc,
atomic::{AtomicUsize, Ordering},
@ -534,6 +535,8 @@ mod tests {
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
@ -576,6 +579,8 @@ mod tests {
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
@ -1476,6 +1481,8 @@ mod tests {
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
@ -1511,6 +1518,8 @@ mod tests {
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
@ -1574,6 +1583,8 @@ mod tests {
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
@ -1616,8 +1627,14 @@ mod tests {
fn test_default_tools_registers_get_time() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
let tools =
ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string())
let tools = ToolRegistryFactory::new(
skills,
store.clone(),
store.clone(),
store,
HashSet::new(),
"Asia/Shanghai".to_string(),
)
.build();
assert!(tools.tool_names().iter().any(|name| name == "get_time"));

View File

@ -1,7 +1,22 @@
use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
use picobot::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
use std::collections::HashMap;
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
ProviderRuntimeConfig {
provider_type: config.provider_type,
name: config.name,
base_url: config.base_url,
api_key: config.api_key,
extra_headers: config.extra_headers,
llm_timeout_secs: config.llm_timeout_secs,
model_id: config.model_id,
temperature: config.temperature,
max_tokens: config.max_tokens,
model_extra: config.model_extra,
}
}
fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -45,7 +60,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
async fn test_openai_simple_completion() {
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
assert!(!response.id.is_empty());
@ -59,7 +74,7 @@ async fn test_openai_simple_completion() {
async fn test_openai_conversation() {
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![
@ -89,7 +104,8 @@ async fn test_config_load() {
assert_eq!(provider_config.name, "aliyun");
assert_eq!(provider_config.model_id, "qwen-plus");
let provider = create_provider(provider_config).expect("Failed to create provider");
let provider =
create_provider(to_runtime_config(provider_config)).expect("Failed to create provider");
assert_eq!(provider.ptype(), "openai");
assert_eq!(provider.name(), "aliyun");
assert_eq!(provider.model_id(), "qwen-plus");

View File

@ -1,7 +1,24 @@
use picobot::config::LLMProviderConfig;
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
use picobot::providers::{
ChatCompletionRequest, Message, ProviderRuntimeConfig, Tool, ToolFunction, create_provider,
};
use std::collections::HashMap;
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
ProviderRuntimeConfig {
provider_type: config.provider_type,
name: config.name,
base_url: config.base_url,
api_key: config.api_key,
extra_headers: config.extra_headers,
llm_timeout_secs: config.llm_timeout_secs,
model_id: config.model_id,
temperature: config.temperature,
max_tokens: config.max_tokens,
model_extra: config.model_extra,
}
}
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -56,7 +73,7 @@ fn make_weather_tool() -> Tool {
async fn test_openai_tool_call() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message::user("What is the weather in Tokyo?")],
@ -84,7 +101,7 @@ async fn test_openai_tool_call() {
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
// First request with tool
let request1 = ChatCompletionRequest {
@ -120,7 +137,7 @@ async fn test_openai_tool_call_with_manual_execution() {
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message::user("Say hello in one word.")],