From f0879f8d1370842a6221424adc16f62e0bbbf7ac Mon Sep 17 00:00:00 2001 From: xiaoski Date: Mon, 25 May 2026 11:59:29 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=96=87=E4=BB=B6=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E5=A4=84=E7=90=86=E6=B5=81=E7=A8=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resources/skills/about-picobot/SKILL.md | 2 +- .../about-picobot/assets/config.example.json | 6 +- .../skills/about-picobot/references/config.md | 1 + resources/templates/config.example.json | 6 +- src/agent/agent_loop.rs | 137 +++++++++++------- src/agent/media_handler.rs | 101 +++++++++++++ src/agent/mod.rs | 1 + src/bus/message.rs | 21 ++- src/bus/mod.rs | 2 +- src/channels/feishu.rs | 83 ++++++++--- src/config/mod.rs | 8 + src/session/session.rs | 13 +- tests/test_integration.rs | 2 + tests/test_tool_calling.rs | 2 + 14 files changed, 301 insertions(+), 84 deletions(-) create mode 100644 src/agent/media_handler.rs diff --git a/resources/skills/about-picobot/SKILL.md b/resources/skills/about-picobot/SKILL.md index 8b1127d..c06e1da 100644 --- a/resources/skills/about-picobot/SKILL.md +++ b/resources/skills/about-picobot/SKILL.md @@ -1,6 +1,6 @@ --- name: about-picobot -description: PicoBot 自身设计信息的索引入口。含配置、数据库、架构、常见问题等。具体内容在 references/ 目录下,config 示例在 assets/ 目录下,请用 file_read 工具查阅对应文件。 +description: PicoBot 自身设计信息的索引入口。含配置、数据库、架构、常见问题等,如需要修改自身配置或了解自身工作机制加载查询。具体内容在 references/ 目录下,config 示例在 assets/ 目录下,请用 file_read 工具查阅对应文件。 always: true --- # About PicoBot diff --git a/resources/skills/about-picobot/assets/config.example.json b/resources/skills/about-picobot/assets/config.example.json index 1167c4b..92f247e 100644 --- a/resources/skills/about-picobot/assets/config.example.json +++ b/resources/skills/about-picobot/assets/config.example.json @@ -28,12 +28,14 @@ "gpt-4o": { "model_id": "gpt-4o", "temperature": 0.7, - "max_tokens": 4096 + "max_tokens": 4096, + "input_type": ["text", "image"] }, "claude-sonnet-4-20250514": { "model_id": "claude-sonnet-4-20250514", "temperature": 0.7, - "max_tokens": 8192 + "max_tokens": 8192, + "input_type": ["text", "image"] } }, "agents": { diff --git a/resources/skills/about-picobot/references/config.md b/resources/skills/about-picobot/references/config.md index 31a1939..e9c3c6a 100644 --- a/resources/skills/about-picobot/references/config.md +++ b/resources/skills/about-picobot/references/config.md @@ -37,6 +37,7 @@ | `model_id` | 模型标识名称 | | `temperature` | 采样温度,可选 | | `max_tokens` | 最大输出 token 数,可选 | +| `input_type` | 模型支持的输入类型,如 `["text"]` 或 `["text", "image"]`,默认 `["text"]`. 纯内部使用,不会传递给 LLM API | ## agents 字段 diff --git a/resources/templates/config.example.json b/resources/templates/config.example.json index 1167c4b..92f247e 100644 --- a/resources/templates/config.example.json +++ b/resources/templates/config.example.json @@ -28,12 +28,14 @@ "gpt-4o": { "model_id": "gpt-4o", "temperature": 0.7, - "max_tokens": 4096 + "max_tokens": 4096, + "input_type": ["text", "image"] }, "claude-sonnet-4-20250514": { "model_id": "claude-sonnet-4-20250514", "temperature": 0.7, - "max_tokens": 8192 + "max_tokens": 8192, + "input_type": ["text", "image"] } }, "agents": { diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index ea9bec2..5eaffd3 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,7 +1,8 @@ use crate::agent::context_compressor::estimate_tokens; +use crate::agent::media_handler::MediaHandlerRegistry; use crate::agent::system_prompt::build_system_prompt; use crate::bus::message::ContentBlock; -use crate::bus::ChatMessage; +use crate::bus::{ChatMessage, MediaRef}; use crate::config::LLMProviderConfig; use crate::observability::{ truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, @@ -10,7 +11,6 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess use crate::tools::ToolRegistry; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; -use std::io::Read; use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; @@ -21,24 +21,50 @@ const MAX_TOOL_RESULT_CHARS: usize = 16_000; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; -/// Build content blocks from text and media paths -fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { +/// Build content blocks from text and media, respecting model input capabilities +fn build_content_blocks( + text: &str, + media_refs: &[MediaRef], + input_types: &[String], + registry: &MediaHandlerRegistry, +) -> Vec { let mut blocks = Vec::new(); - // Add text block if there's text - if !text.is_empty() { + if !media_refs.is_empty() { + for mr in media_refs { + if input_types.contains(&mr.media_type) { + match registry.handle(&mr.media_type, &mr.path) { + Ok(content_blocks) => blocks.extend(content_blocks), + Err(e) => { + tracing::warn!( + path = %mr.path, + media_type = %mr.media_type, + error = %e, + "Media handler failed, falling back to text placeholder" + ); + blocks.push(ContentBlock::text(format!( + "[用户发来了一个文件,但处理失败: {}, 错误: {}]", + mr.path, e + ))); + } + } + } else { + tracing::debug!( + path = %mr.path, + media_type = %mr.media_type, + model_input_types = ?input_types, + "Media type not supported by model, using text placeholder" + ); + blocks.push(ContentBlock::text(format!( + "[用户发来了一个文件: {}]", + mr.path + ))); + } + } + } else if !text.is_empty() { blocks.push(ContentBlock::text(text)); } - // Add image blocks for media paths - for path in media_paths { - if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) { - let url = format!("data:{};base64,{}", mime_type, base64_data); - blocks.push(ContentBlock::image_url(url)); - } - } - - // If nothing, add empty text block if blocks.is_empty() { blocks.push(ContentBlock::text("")); } @@ -46,22 +72,6 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec blocks } -/// Encode an image file to base64 data URL -fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { - use base64::{Engine as _, engine::general_purpose::STANDARD}; - - let mut file = std::fs::File::open(path)?; - let mut buffer = Vec::new(); - file.read_to_end(&mut buffer)?; - - let mime = mime_guess::from_path(path) - .first_or_octet_stream() - .to_string(); - - let encoded = STANDARD.encode(&buffer); - Ok((mime, encoded)) -} - /// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS. /// Preserves the end of the output as it often contains the conclusion/useful result. fn truncate_tool_result(output: &str) -> String { @@ -200,24 +210,6 @@ fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value { other => other.clone(), } } - -/// Convert ChatMessage to LLM Message format -fn chat_message_to_llm_message(m: &ChatMessage) -> Message { - let content = if m.media_refs.is_empty() { - vec![ContentBlock::text(&m.content)] - } else { - build_content_blocks(&m.content, &m.media_refs) - }; - - Message { - role: m.role.clone(), - content, - tool_call_id: m.tool_call_id.clone(), - name: m.tool_name.clone(), - tool_calls: m.tool_calls.clone(), - } -} - /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { @@ -229,8 +221,9 @@ pub struct AgentLoop { model_name: String, context_window: usize, notify_tx: Option>, + input_types: Vec, + media_registry: MediaHandlerRegistry, } - #[derive(Debug, Clone)] pub struct AgentProcessResult { pub final_response: ChatMessage, @@ -243,6 +236,7 @@ impl AgentLoop { let max_iterations = provider_config.max_tool_iterations; let model_name = provider_config.model_id.clone(); let workspace_dir = provider_config.workspace_dir.clone(); + let input_types = provider_config.input_types.clone(); let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; @@ -255,6 +249,8 @@ impl AgentLoop { max_iterations, workspace_dir, model_name, + input_types, + media_registry: MediaHandlerRegistry::with_defaults(), }) } @@ -263,6 +259,7 @@ impl AgentLoop { let max_iterations = provider_config.max_tool_iterations; let model_name = provider_config.model_id.clone(); let workspace_dir = provider_config.workspace_dir.clone(); + let input_types = provider_config.input_types.clone(); let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; @@ -275,11 +272,13 @@ impl AgentLoop { max_iterations, workspace_dir, model_name, + input_types, + media_registry: MediaHandlerRegistry::with_defaults(), }) } /// Create a new AgentLoop with an existing shared provider. - pub fn with_provider(provider: Arc, max_iterations: usize, model_name: String, workspace_dir: PathBuf) -> Self { + pub fn with_provider(provider: Arc, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec) -> Self { Self { provider, tools: Arc::new(ToolRegistry::new()), @@ -289,6 +288,8 @@ impl AgentLoop { max_iterations, workspace_dir, model_name, + input_types, + media_registry: MediaHandlerRegistry::with_defaults(), } } @@ -299,6 +300,7 @@ impl AgentLoop { max_iterations: usize, model_name: String, workspace_dir: PathBuf, + input_types: Vec, ) -> Self { Self { provider, @@ -309,6 +311,8 @@ impl AgentLoop { max_iterations, workspace_dir, model_name, + input_types, + media_registry: MediaHandlerRegistry::with_defaults(), } } @@ -369,6 +373,22 @@ impl AgentLoop { &self.tools } + fn chat_message_to_llm_message(&self, m: &ChatMessage) -> Message { + let content = if m.media_refs.is_empty() { + vec![ContentBlock::text(&m.content)] + } else { + build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry) + }; + + Message { + role: m.role.clone(), + content, + tool_call_id: m.tool_call_id.clone(), + name: m.tool_name.clone(), + tool_calls: m.tool_calls.clone(), + } + } + /// Process a message using the provided conversation history. /// History management is handled externally by SessionManager. /// @@ -421,7 +441,7 @@ impl AgentLoop { // Convert messages to LLM format let messages_for_llm: Vec = messages .iter() - .map(chat_message_to_llm_message) + .map(|m| self.chat_message_to_llm_message(m)) .collect(); // Build request @@ -549,7 +569,7 @@ impl AgentLoop { // Convert messages to LLM format let messages_for_llm: Vec = messages .iter() - .map(chat_message_to_llm_message) + .map(|m| self.chat_message_to_llm_message(m)) .collect(); let request = ChatCompletionRequest { @@ -766,6 +786,8 @@ mod tests { #[test] fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() { + use crate::providers::Message; + let chat_message = ChatMessage::assistant_with_tool_calls( "calling tool", vec![ToolCall { @@ -775,7 +797,14 @@ mod tests { }], ); - let provider_message = chat_message_to_llm_message(&chat_message); + let content = vec![ContentBlock::text(&chat_message.content)]; + let provider_message = Message { + role: chat_message.role.clone(), + content, + tool_call_id: chat_message.tool_call_id.clone(), + name: chat_message.tool_name.clone(), + tool_calls: chat_message.tool_calls.clone(), + }; assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); diff --git a/src/agent/media_handler.rs b/src/agent/media_handler.rs new file mode 100644 index 0000000..6a93f8d --- /dev/null +++ b/src/agent/media_handler.rs @@ -0,0 +1,101 @@ +use std::collections::HashMap; +use std::fmt; +use std::io::Read; + +use crate::bus::message::ContentBlock; + +pub trait MediaHandler: Send + Sync { + fn media_type(&self) -> &str; + fn handle(&self, path: &str) -> Result, MediaHandlerError>; +} + +#[derive(Debug)] +pub enum MediaHandlerError { + Io(std::io::Error), + UnsupportedFormat(String), +} + +impl fmt::Display for MediaHandlerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MediaHandlerError::Io(e) => write!(f, "I/O error: {}", e), + MediaHandlerError::UnsupportedFormat(msg) => write!(f, "Unsupported format: {}", msg), + } + } +} + +impl std::error::Error for MediaHandlerError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + MediaHandlerError::Io(e) => Some(e), + MediaHandlerError::UnsupportedFormat(_) => None, + } + } +} + +pub struct ImageHandler; + +impl MediaHandler for ImageHandler { + fn media_type(&self) -> &str { + "image" + } + + fn handle(&self, path: &str) -> Result, MediaHandlerError> { + let (mime_type, base64_data) = + encode_image_to_base64(path).map_err(MediaHandlerError::Io)?; + let url = format!("data:{};base64,{}", mime_type, base64_data); + Ok(vec![ContentBlock::image_url(url)]) + } +} + +fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { + use base64::{engine::general_purpose::STANDARD, Engine as _}; + + let mut file = std::fs::File::open(path)?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + + let mime = mime_guess::from_path(path) + .first_or_octet_stream() + .to_string(); + + let encoded = STANDARD.encode(&buffer); + Ok((mime, encoded)) +} + +pub struct MediaHandlerRegistry { + handlers: HashMap>, +} + +impl MediaHandlerRegistry { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + } + } + + pub fn register(&mut self, handler: Box) { + self.handlers + .insert(handler.media_type().to_string(), handler); + } + + pub fn handle( + &self, + media_type: &str, + path: &str, + ) -> Result, MediaHandlerError> { + match self.handlers.get(media_type) { + Some(handler) => handler.handle(path), + None => Err(MediaHandlerError::UnsupportedFormat(format!( + "no handler for type: {}", + media_type + ))), + } + } + + pub fn with_defaults() -> Self { + let mut reg = Self::new(); + reg.register(Box::new(ImageHandler)); + reg + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 2153826..5319eeb 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,6 @@ pub mod agent_loop; pub mod context_compressor; +pub mod media_handler; pub mod system_prompt; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; diff --git a/src/bus/message.rs b/src/bus/message.rs index ded2b00..1981fb2 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -33,6 +33,16 @@ impl ContentBlock { } } +// ============================================================================ +// MediaRef - Media reference in ChatMessage (carries type info) +// ============================================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MediaRef { + pub path: String, + pub media_type: String, +} + // ============================================================================ // MediaItem - Media metadata for messages // ============================================================================ @@ -54,6 +64,13 @@ impl MediaItem { original_key: None, } } + + pub fn to_media_ref(&self) -> MediaRef { + MediaRef { + path: self.path.clone(), + media_type: self.media_type.clone(), + } + } } // ============================================================================ @@ -65,7 +82,7 @@ pub struct ChatMessage { pub id: String, pub role: String, pub content: String, - pub media_refs: Vec, // Paths to media files for context + pub media_refs: Vec, pub timestamp: i64, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, @@ -112,7 +129,7 @@ impl ChatMessage { } } - pub fn user_with_media(content: impl Into, media_refs: Vec) -> Self { + pub fn user_with_media(content: impl Into, media_refs: Vec) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), role: "user".to_string(), diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 70be932..6c46feb 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -2,7 +2,7 @@ pub mod dispatcher; pub mod message; pub use dispatcher::OutboundDispatcher; -pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; +pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind}; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index 8a78f20..fa9b555 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -366,13 +366,7 @@ impl FeishuChannel { .unwrap_or("image/jpeg") .to_string(); - let ext = match content_type.as_str() { - "image/png" => "png", - "image/gif" => "gif", - "image/webp" => "webp", - "image/bmp" => "bmp", - _ => "jpg", - }; + let ext = resolve_image_ext(&content_type); let data = resp.bytes().await .map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))? @@ -382,7 +376,7 @@ impl FeishuChannel { tracing::debug!(data_len = %data.len(), content_type = %content_type, "Downloaded image data"); let filename = format!("{}_{}.{}", message_id, &image_key[..8.min(image_key.len())], ext); - let file_path = media_dir.join(&filename); + let file_path = resolve_unique_path(media_dir, &filename).await; tokio::fs::write(&file_path, &data).await .map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?; @@ -394,7 +388,7 @@ impl FeishuChannel { tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image"); - Ok((format!("[image: {}]", filename), Some(media_item))) + Ok((String::new(), Some(media_item))) } /// Download file/audio from Feishu @@ -434,13 +428,19 @@ impl FeishuChannel { .map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))? .to_vec(); - let extension = match file_type { - "audio" => "mp3", - "video" => "mp4", - _ => "bin", - }; - let filename = format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension); - let file_path = media_dir.join(&filename); + let filename = content_json + .get("file_name") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| { + let ext = resolve_file_ext(content_json); + if ext.is_empty() { + format!("{}_{}", message_id, &file_key[..8.min(file_key.len())]) + } else { + format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], ext) + } + }); + let file_path = resolve_unique_path(media_dir, &filename).await; tokio::fs::write(&file_path, &data).await .map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?; @@ -452,7 +452,7 @@ impl FeishuChannel { tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file"); - Ok((format!("[{}: {}]", file_type, filename), Some(media_item))) + Ok((String::new(), Some(media_item))) } /// Upload image to Feishu and return the image_key @@ -1543,6 +1543,55 @@ fn strip_at_placeholders(text: &str) -> String { result } +fn resolve_image_ext(content_type: &str) -> &str { + match content_type { + "image/png" => "png", + "image/gif" => "gif", + "image/webp" => "webp", + "image/bmp" => "bmp", + _ => "jpg", + } +} + +fn resolve_file_ext(content_json: &serde_json::Value) -> String { + if let Some(name) = content_json + .get("file_name") + .and_then(|v| v.as_str()) + { + if let Some(ext) = std::path::Path::new(name).extension().and_then(|e| e.to_str()) { + return ext.to_string(); + } + } + String::new() +} + +async fn resolve_unique_path(dir: &Path, filename: &str) -> std::path::PathBuf { + let candidate = dir.join(filename); + if !tokio::fs::try_exists(&candidate).await.unwrap_or(false) { + return candidate; + } + let stem = std::path::Path::new(filename) + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or(filename); + let ext = std::path::Path::new(filename) + .extension() + .and_then(|s| s.to_str()) + .unwrap_or(""); + let mut n = 1; + loop { + let candidate = if ext.is_empty() { + dir.join(format!("{}({})", stem, n)) + } else { + dir.join(format!("{}({}).{}", stem, n, ext)) + }; + if !tokio::fs::try_exists(&candidate).await.unwrap_or(false) { + return candidate; + } + n += 1; + } +} + impl FeishuChannel { fn strip_thinking_tags(content: &str) -> String { use std::sync::LazyLock; diff --git a/src/config/mod.rs b/src/config/mod.rs index 7c3bc14..af5fa14 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -112,10 +112,16 @@ pub struct ModelConfig { pub temperature: Option, #[serde(default)] pub max_tokens: Option, + #[serde(default = "default_input_type")] + pub input_type: Vec, #[serde(flatten)] pub extra: HashMap, } +fn default_input_type() -> Vec { + vec!["text".to_string()] +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AgentConfig { pub provider: String, @@ -376,6 +382,7 @@ pub struct LLMProviderConfig { pub max_tool_iterations: usize, pub token_limit: usize, pub workspace_dir: PathBuf, + pub input_types: Vec, } fn get_default_config_path() -> PathBuf { @@ -443,6 +450,7 @@ impl Config { max_tool_iterations: agent.max_tool_iterations, token_limit: agent.token_limit, workspace_dir: expand_path(&self.workspace_dir), + input_types: model.input_type.clone(), }) } } diff --git a/src/session/session.rs b/src/session/session.rs index d677581..fb29557 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use tokio::sync::{Mutex, mpsc, oneshot}; -use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; +use crate::bus::{ChatMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind}; use crate::mcp::get_mcp_status; use crate::storage::{Storage, StorageError}; use std::sync::Arc as StdArc; @@ -361,7 +361,7 @@ impl Session { tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory"); } - pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { + pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { if media_refs.is_empty() { ChatMessage::user(content) } else { @@ -372,7 +372,7 @@ impl Session { pub fn create_user_message_with_source( &self, content: &str, - media_refs: Vec, + media_refs: Vec, source: MessageSource, ) -> ChatMessage { if media_refs.is_empty() { @@ -486,6 +486,7 @@ impl Session { self.provider_config.max_tool_iterations, self.provider_config.model_id.clone(), self.provider_config.workspace_dir.clone(), + self.provider_config.input_types.clone(), ).with_context_window(self.provider_config.token_limit)) } @@ -889,6 +890,7 @@ impl SessionManager { self.provider_config.max_tool_iterations, self.provider_config.model_id.clone(), self.provider_config.workspace_dir.clone(), + self.provider_config.input_types.clone(), ).with_context_window(self.provider_config.token_limit)) } @@ -1501,8 +1503,8 @@ fn spawn_agent_worker( return; // stale worker } - let media_refs: Vec = - task.media.iter().map(|m| m.path.clone()).collect(); + let media_refs: Vec = + task.media.iter().map(|m| m.to_media_ref()).collect(); let user_message = guard.create_user_message(&task.content, media_refs); if let Err(e) = guard.add_message(user_message, true).await { @@ -1902,6 +1904,7 @@ mod tests { max_tool_iterations: 1, token_limit: 4096, workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"), + input_types: vec!["text".to_string()], } } } diff --git a/tests/test_integration.rs b/tests/test_integration.rs index 09f705e..dc0ecd9 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -25,6 +25,8 @@ fn load_config() -> Option { model_extra: HashMap::new(), max_tool_iterations: 20, token_limit: 128_000, + workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"), + input_types: vec!["text".to_string()], }) } diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 39ead1d..86e565d 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -25,6 +25,8 @@ fn load_openai_config() -> Option { model_extra: HashMap::new(), max_tool_iterations: 20, token_limit: 128_000, + workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"), + input_types: vec!["text".to_string()], }) }