From 531e72d24fdba7b24b13b723ada62e0341ac5b69 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Fri, 1 May 2026 21:22:07 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9B=BE=E5=83=8F?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=A2=84=E7=AE=97=E5=92=8C=E4=BC=B0=E7=AE=97?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=86=85=E5=AE=B9=E6=9E=84=E5=BB=BA=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=9B=BE=E5=83=8F=E5=AA=92=E4=BD=93=E5=BC=95=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- Cargo.toml | 1 + src/agent/agent_loop.rs | 342 ++++++++++++++++++++++++++++++-- src/agent/context_compressor.rs | 40 +++- src/gateway/session.rs | 25 ++- tests/test_integration.rs | 24 ++- tests/test_tool_calling.rs | 25 ++- 6 files changed, 422 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0cb61b0..0795f50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ cron = { version = "0.13", features = ["serde"] } iana-time-zone = "0.1" mime_guess = "2.0" base64 = "0.22" +image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } tempfile = "3" meval = "0.2" rusqlite = { version = "0.32", features = ["bundled"] } diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index b154408..9eb45d0 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -24,10 +24,31 @@ const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"]; +const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4; +const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2; +const CONTEXT_INPUT_SAFETY_RATIO: f64 = 0.9; +const DEFAULT_COMPLETION_TOKEN_RESERVE: usize = 2048; +const DATA_URL_OVERHEAD_TOKENS: usize = 16; +const JPEG_QUALITY_STEPS: &[u8] = &[82, 72, 60, 48, 36]; +const MIN_COMPRESSED_IMAGE_SIDE: u32 = 64; +const IMAGE_INPUT_NOTICE_PREFIX: &str = "[系统提示] 以下图片未能成功入模:"; /// Build content blocks from text and media paths -fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { +fn build_content_blocks( + text: &str, + media_paths: &[String], + budget: &mut ImageInlineBudget, +) -> Vec { + build_content_blocks_with_image_budget(text, media_paths, budget) +} + +fn build_content_blocks_with_image_budget( + text: &str, + media_paths: &[String], + budget: &mut ImageInlineBudget, +) -> Vec { let mut blocks = Vec::new(); + let mut skipped_image_notices = Vec::new(); // Add text block if there's text if !text.is_empty() { @@ -41,12 +62,38 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec continue; } - if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) { - let url = format!("data:{};base64,{}", mime_type, base64_data); - blocks.push(ContentBlock::image_url(url)); + let Some(target_tokens) = budget.take_next_image_tokens() else { + tracing::warn!(media_path = %path, "Skipping image media ref because no LLM context budget remains"); + skipped_image_notices.push(format!( + "- {}:模型上下文预算不足,当前轮无法读取这张图片,请直接告知用户图片未成功入模。", + display_media_name(path) + )); + continue; + }; + + match encode_image_to_base64_with_budget(path, target_tokens) { + Ok((mime_type, base64_data)) => { + let url = format!("data:{};base64,{}", mime_type, base64_data); + blocks.push(ContentBlock::image_url(url)); + } + Err(err) => { + tracing::warn!(media_path = %path, target_tokens = target_tokens, error = %err, "Skipping image media ref after compression failed"); + skipped_image_notices.push(format!( + "- {}:图片压缩或编码失败,当前轮无法读取这张图片,请直接告知用户图片未成功入模。", + display_media_name(path) + )); + continue; + } } } + if !skipped_image_notices.is_empty() { + let mut notice = String::from(IMAGE_INPUT_NOTICE_PREFIX); + notice.push('\n'); + notice.push_str(&skipped_image_notices.join("\n")); + blocks.push(ContentBlock::text(notice)); + } + // If nothing, add empty text block if blocks.is_empty() { blocks.push(ContentBlock::text("")); @@ -66,6 +113,88 @@ fn supported_image_mime_type(path: &str) -> Option { } } +fn display_media_name(path: &str) -> String { + std::path::Path::new(path) + .file_name() + .and_then(|name| name.to_str()) + .map(ToOwned::to_owned) + .unwrap_or_else(|| path.to_string()) +} + +#[derive(Debug, Clone)] +struct ImageInlineBudget { + remaining_tokens: usize, + remaining_images: usize, +} + +impl ImageInlineBudget { + fn new(total_tokens: usize, image_count: usize) -> Self { + Self { + remaining_tokens: total_tokens, + remaining_images: image_count, + } + } + + fn take_next_image_tokens(&mut self) -> Option { + if self.remaining_tokens == 0 || self.remaining_images == 0 { + self.remaining_images = self.remaining_images.saturating_sub(1); + return None; + } + + let target = self.remaining_tokens.div_ceil(self.remaining_images); + self.remaining_tokens = self.remaining_tokens.saturating_sub(target); + self.remaining_images = self.remaining_images.saturating_sub(1); + Some(target) + } +} + +fn estimate_tokens_from_serialized_json(value: &T) -> usize { + let raw_len = serde_json::to_string(value) + .map(|serialized| serialized.len()) + .unwrap_or_default(); + ((raw_len.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) + as usize +} + +fn image_token_budget_for_request( + runtime_config: &AgentRuntimeConfig, + text_only_messages: &[Message], + tools: Option<&Vec>, +) -> usize { + let completion_reserve = runtime_config + .provider + .max_tokens + .map(|tokens| tokens as usize) + .unwrap_or(DEFAULT_COMPLETION_TOKEN_RESERVE); + let input_window = runtime_config + .context_window_tokens + .saturating_sub(completion_reserve); + let safe_input_window = (input_window as f64 * CONTEXT_INPUT_SAFETY_RATIO) as usize; + + let text_tokens = estimate_tokens_from_serialized_json(&text_only_messages) + + tools + .map(estimate_tokens_from_serialized_json) + .unwrap_or_default(); + + safe_input_window.saturating_sub(text_tokens) +} + +fn count_supported_image_media_refs(messages: &[ChatMessage]) -> usize { + messages + .iter() + .flat_map(|message| message.media_refs.iter()) + .filter(|path| supported_image_mime_type(path).is_some()) + .count() +} + +fn target_image_bytes_for_tokens(target_tokens: usize) -> usize { + target_tokens + .saturating_sub(DATA_URL_OVERHEAD_TOKENS) + .saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN) + .saturating_mul(3) + / 4 +} + /// Encode an image file to base64 data URL fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { use base64::{Engine as _, engine::general_purpose::STANDARD}; @@ -85,6 +214,80 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error Ok((mime, encoded)) } +fn encode_image_to_base64_with_budget( + path: &str, + target_tokens: usize, +) -> Result<(String, String), std::io::Error> { + use base64::{Engine as _, engine::general_purpose::STANDARD}; + + let (mime, encoded) = encode_image_to_base64(path)?; + let target_base64_chars = target_tokens + .saturating_sub(DATA_URL_OVERHEAD_TOKENS) + .saturating_mul(TOKEN_ESTIMATE_CHARS_PER_TOKEN); + if encoded.len() <= target_base64_chars { + return Ok((mime, encoded)); + } + + let target_bytes = target_image_bytes_for_tokens(target_tokens); + let image_bytes = std::fs::read(path)?; + let image = image::load_from_memory(&image_bytes) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; + let compressed = compress_image_to_jpeg_budget(&image, target_bytes)?; + + Ok(("image/jpeg".to_string(), STANDARD.encode(compressed))) +} + +fn compress_image_to_jpeg_budget( + image: &image::DynamicImage, + target_bytes: usize, +) -> Result, std::io::Error> { + use image::codecs::jpeg::JpegEncoder; + use image::imageops::FilterType; + + let mut best: Option> = None; + let mut max_side = image + .width() + .max(image.height()) + .max(MIN_COMPRESSED_IMAGE_SIDE); + + loop { + let candidate = if image.width().max(image.height()) > max_side { + image.resize(max_side, max_side, FilterType::Triangle) + } else { + image.clone() + }; + + for quality in JPEG_QUALITY_STEPS { + let mut encoded = Vec::new(); + let mut encoder = JpegEncoder::new_with_quality(&mut encoded, *quality); + encoder + .encode_image(&candidate) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; + + if encoded.len() <= target_bytes { + return Ok(encoded); + } + + if best + .as_ref() + .map(|current| encoded.len() < current.len()) + .unwrap_or(true) + { + best = Some(encoded); + } + } + + if max_side <= MIN_COMPRESSED_IMAGE_SIDE { + break; + } + max_side = (max_side * 3 / 4).max(MIN_COMPRESSED_IMAGE_SIDE); + } + + best.ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "image compression failed") + }) +} + /// Truncate tool result if it exceeds the configured limit. /// Preserves the end of the output as it often contains the conclusion/useful result. fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String { @@ -272,11 +475,11 @@ fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value { } /// Convert ChatMessage to LLM Message format -fn chat_message_to_llm_message(m: &ChatMessage) -> Message { +fn chat_message_to_llm_message(m: &ChatMessage, image_budget: &mut ImageInlineBudget) -> Message { let content = if m.media_refs.is_empty() { vec![ContentBlock::text(&m.content)] } else { - build_content_blocks(&m.content, &m.media_refs) + build_content_blocks(&m.content, &m.media_refs, image_budget) }; Message { @@ -289,6 +492,17 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message { } } +fn chat_message_to_text_only_llm_message(m: &ChatMessage) -> Message { + Message { + role: m.role.clone(), + content: vec![ContentBlock::text(&m.content)], + reasoning_content: m.reasoning_content.clone(), + tool_call_id: m.tool_call_id.clone(), + name: m.tool_name.clone(), + tool_calls: m.tool_calls.clone(), + } +} + /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { @@ -434,14 +648,6 @@ impl AgentLoop { #[cfg(debug_assertions)] tracing::debug!(iteration, "Agent iteration started"); - // Convert messages to LLM format - let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); - if let Some(skill_prompt) = self.skills.system_index_prompt() { - messages_for_llm.push(Message::system(skill_prompt)); - } - messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT)); - messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); - // Build request let tool_defs = self.tools.get_definitions(); let tools = if tool_defs.is_empty() { @@ -450,6 +656,31 @@ impl AgentLoop { Some(tool_defs) }; + let image_count = count_supported_image_media_refs(&messages); + let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 2); + if let Some(skill_prompt) = self.skills.system_index_prompt() { + text_only_messages.push(Message::system(skill_prompt.clone())); + } + text_only_messages.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT)); + text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); + + let image_tokens = image_token_budget_for_request( + &self.runtime_config, + &text_only_messages, + tools.as_ref(), + ); + let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); + let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 2); + if let Some(skill_prompt) = self.skills.system_index_prompt() { + messages_for_llm.push(Message::system(skill_prompt)); + } + messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT)); + messages_for_llm.extend( + messages + .iter() + .map(|message| chat_message_to_llm_message(message, &mut image_budget)), + ); + let request = ChatCompletionRequest { messages: messages_for_llm, temperature: None, @@ -457,7 +688,6 @@ impl AgentLoop { tools, }; - // Call LLM let response = match (*self.provider).chat(request).await { Ok(response) => response, Err(e) => { @@ -633,11 +863,24 @@ impl AgentLoop { messages.push(summary_request); // Convert messages to LLM format + let image_count = count_supported_image_media_refs(&messages); + let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 1); + if let Some(skill_prompt) = self.skills.system_index_prompt() { + text_only_messages.push(Message::system(skill_prompt)); + } + text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); + let image_tokens = + image_token_budget_for_request(&self.runtime_config, &text_only_messages, None); + let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } - messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); + messages_for_llm.extend( + messages + .iter() + .map(|message| chat_message_to_llm_message(message, &mut image_budget)), + ); let request = ChatCompletionRequest { messages: messages_for_llm, @@ -896,7 +1139,8 @@ mod tests { }], ); - let provider_message = chat_message_to_llm_message(&chat_message); + let mut image_budget = ImageInlineBudget::new(0, 0); + let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget); assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); @@ -915,7 +1159,8 @@ mod tests { let chat_message = ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought"); - let provider_message = chat_message_to_llm_message(&chat_message); + let mut image_budget = ImageInlineBudget::new(0, 0); + let provider_message = chat_message_to_llm_message(&chat_message, &mut image_budget); assert_eq!(provider_message.role, "assistant"); assert_eq!( @@ -983,7 +1228,12 @@ mod tests { let pdf_path = temp_dir.path().join("demo.pdf"); std::fs::write(&pdf_path, b"%PDF-1.4").unwrap(); - let blocks = build_content_blocks("hello", &[pdf_path.to_string_lossy().to_string()]); + let mut budget = ImageInlineBudget::new(1_000, 0); + let blocks = build_content_blocks( + "hello", + &[pdf_path.to_string_lossy().to_string()], + &mut budget, + ); assert_eq!(blocks.len(), 1); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); @@ -993,9 +1243,15 @@ mod tests { fn test_build_content_blocks_keeps_supported_images() { let temp_dir = tempdir().unwrap(); let jpg_path = temp_dir.path().join("demo.jpg"); - std::fs::write(&jpg_path, b"fake-jpeg-data").unwrap(); + let image = image::DynamicImage::new_rgb8(8, 8); + image.save(&jpg_path).unwrap(); - let blocks = build_content_blocks("hello", &[jpg_path.to_string_lossy().to_string()]); + let mut budget = ImageInlineBudget::new(10_000, 1); + let blocks = build_content_blocks( + "hello", + &[jpg_path.to_string_lossy().to_string()], + &mut budget, + ); assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); @@ -1003,6 +1259,50 @@ mod tests { matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) ); } + + #[test] + fn test_build_content_blocks_compresses_images_to_budget() { + let temp_dir = tempdir().unwrap(); + let png_path = temp_dir.path().join("large.png"); + let image = image::DynamicImage::new_rgb8(512, 512); + image.save(&png_path).unwrap(); + + let mut budget = ImageInlineBudget::new(512, 1); + let blocks = build_content_blocks( + "hello", + &[png_path.to_string_lossy().to_string()], + &mut budget, + ); + + assert_eq!(blocks.len(), 2); + assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); + assert!( + matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) + ); + } + + #[test] + fn test_build_content_blocks_adds_user_visible_notice_when_image_cannot_be_sent() { + let temp_dir = tempdir().unwrap(); + let jpg_path = temp_dir.path().join("demo.jpg"); + let image = image::DynamicImage::new_rgb8(8, 8); + image.save(&jpg_path).unwrap(); + + let mut budget = ImageInlineBudget::new(0, 1); + let blocks = build_content_blocks( + "hello", + &[jpg_path.to_string_lossy().to_string()], + &mut budget, + ); + + assert_eq!(blocks.len(), 2); + assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); + assert!(matches!( + &blocks[1], + ContentBlock::Text { text } + if text.contains("图片未能成功入模") && text.contains("demo.jpg") + )); + } } #[derive(Debug)] diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index c70cbf6..b3c88fe 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -8,13 +8,34 @@ use crate::text::{char_count, take_prefix_chars}; use crate::agent::{AgentError, AgentRuntimeConfig}; +const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4; +const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2; + /// Token estimation using ~4 chars/token heuristic with 1.2x safety margin. pub fn estimate_tokens(messages: &[ChatMessage]) -> usize { let raw: usize = messages .iter() - .map(|m| m.content.len().div_ceil(4) + 4) + .map(|message| { + message + .content + .len() + .div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) + + estimate_image_tokens(&message.media_refs) + + 4 + }) .sum(); - (raw as f64 * 1.2) as usize + (raw as f64 * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize +} + +fn estimate_image_tokens(media_refs: &[String]) -> usize { + media_refs + .iter() + .filter_map(|path| std::fs::metadata(path).ok()) + .map(|metadata| { + let base64_chars = metadata.len().saturating_mul(4).div_ceil(3) as usize; + base64_chars.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) + }) + .sum() } /// Configuration for context compression. @@ -492,6 +513,21 @@ mod tests { ); } + #[test] + fn test_estimate_tokens_includes_image_media_refs() { + let temp_dir = tempfile::tempdir().unwrap(); + let image_path = temp_dir.path().join("demo.jpg"); + std::fs::write(&image_path, vec![0_u8; 12_000]).unwrap(); + + let plain = vec![ChatMessage::user("hello")]; + let with_image = vec![ChatMessage::user_with_media( + "hello", + vec![image_path.to_string_lossy().to_string()], + )]; + + assert!(estimate_tokens(&with_image) > estimate_tokens(&plain)); + } + #[test] fn test_should_compress() { let compressor = ContextCompressor::new(20); diff --git a/src/gateway/session.rs b/src/gateway/session.rs index e733120..27b3055 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -494,11 +494,12 @@ impl SessionManager { mod tests { use super::*; use crate::bus::MessageBus; + use crate::gateway::tool_registry_factory::ToolRegistryFactory; use crate::storage::MemoryRecord; use axum::http::StatusCode; use axum::{Json, Router, routing::post}; use serde_json::{Value, json}; - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; use std::sync::{ Arc as StdArc, atomic::{AtomicUsize, Ordering}, @@ -534,6 +535,8 @@ mod tests { ToolRegistryFactory::new( skills.clone(), store.clone(), + store.clone(), + store.clone(), HashSet::new(), "Asia/Shanghai".to_string(), ) @@ -576,6 +579,8 @@ mod tests { ToolRegistryFactory::new( skills.clone(), store.clone(), + store.clone(), + store.clone(), HashSet::new(), "Asia/Shanghai".to_string(), ) @@ -1476,6 +1481,8 @@ mod tests { ToolRegistryFactory::new( skills.clone(), store.clone(), + store.clone(), + store.clone(), HashSet::new(), "Asia/Shanghai".to_string(), ) @@ -1511,6 +1518,8 @@ mod tests { ToolRegistryFactory::new( skills.clone(), store.clone(), + store.clone(), + store.clone(), HashSet::new(), "Asia/Shanghai".to_string(), ) @@ -1574,6 +1583,8 @@ mod tests { ToolRegistryFactory::new( skills.clone(), store.clone(), + store.clone(), + store.clone(), HashSet::new(), "Asia/Shanghai".to_string(), ) @@ -1616,9 +1627,15 @@ mod tests { fn test_default_tools_registers_get_time() { let skills = Arc::new(SkillRuntime::default()); let store = Arc::new(SessionStore::in_memory().unwrap()); - let tools = - ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string()) - .build(); + let tools = ToolRegistryFactory::new( + skills, + store.clone(), + store.clone(), + store, + HashSet::new(), + "Asia/Shanghai".to_string(), + ) + .build(); assert!(tools.tool_names().iter().any(|name| name == "get_time")); } diff --git a/tests/test_integration.rs b/tests/test_integration.rs index 8b0892a..f27358f 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -1,7 +1,22 @@ use picobot::config::{Config, LLMProviderConfig}; -use picobot::providers::{ChatCompletionRequest, Message, create_provider}; +use picobot::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider}; use std::collections::HashMap; +fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig { + ProviderRuntimeConfig { + provider_type: config.provider_type, + name: config.name, + base_url: config.base_url, + api_key: config.api_key, + extra_headers: config.extra_headers, + llm_timeout_secs: config.llm_timeout_secs, + model_id: config.model_id, + temperature: config.temperature, + max_tokens: config.max_tokens, + model_extra: config.model_extra, + } +} + fn load_config() -> Option { dotenv::from_filename("tests/test.env").ok()?; @@ -45,7 +60,7 @@ fn create_request(content: &str) -> ChatCompletionRequest { async fn test_openai_simple_completion() { let config = load_config().expect("Please configure tests/test.env with valid API keys"); - let provider = create_provider(config).expect("Failed to create provider"); + let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider"); let response = provider.chat(create_request("Say 'ok'")).await.unwrap(); assert!(!response.id.is_empty()); @@ -59,7 +74,7 @@ async fn test_openai_simple_completion() { async fn test_openai_conversation() { let config = load_config().expect("Please configure tests/test.env with valid API keys"); - let provider = create_provider(config).expect("Failed to create provider"); + let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider"); let request = ChatCompletionRequest { messages: vec![ @@ -89,7 +104,8 @@ async fn test_config_load() { assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.model_id, "qwen-plus"); - let provider = create_provider(provider_config).expect("Failed to create provider"); + let provider = + create_provider(to_runtime_config(provider_config)).expect("Failed to create provider"); assert_eq!(provider.ptype(), "openai"); assert_eq!(provider.name(), "aliyun"); assert_eq!(provider.model_id(), "qwen-plus"); diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 7084cc1..bfe7b9c 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -1,7 +1,24 @@ use picobot::config::LLMProviderConfig; -use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider}; +use picobot::providers::{ + ChatCompletionRequest, Message, ProviderRuntimeConfig, Tool, ToolFunction, create_provider, +}; use std::collections::HashMap; +fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig { + ProviderRuntimeConfig { + provider_type: config.provider_type, + name: config.name, + base_url: config.base_url, + api_key: config.api_key, + extra_headers: config.extra_headers, + llm_timeout_secs: config.llm_timeout_secs, + model_id: config.model_id, + temperature: config.temperature, + max_tokens: config.max_tokens, + model_extra: config.model_extra, + } +} + fn load_openai_config() -> Option { dotenv::from_filename("tests/test.env").ok()?; @@ -56,7 +73,7 @@ fn make_weather_tool() -> Tool { async fn test_openai_tool_call() { let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); - let provider = create_provider(config).expect("Failed to create provider"); + let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider"); let request = ChatCompletionRequest { messages: vec![Message::user("What is the weather in Tokyo?")], @@ -84,7 +101,7 @@ async fn test_openai_tool_call() { async fn test_openai_tool_call_with_manual_execution() { let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); - let provider = create_provider(config).expect("Failed to create provider"); + let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider"); // First request with tool let request1 = ChatCompletionRequest { @@ -120,7 +137,7 @@ async fn test_openai_tool_call_with_manual_execution() { async fn test_openai_no_tool_when_not_provided() { let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); - let provider = create_provider(config).expect("Failed to create provider"); + let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider"); let request = ChatCompletionRequest { messages: vec![Message::user("Say hello in one word.")],