重构文件消息处理流程。

This commit is contained in:
xiaoski 2026-05-25 11:59:29 +08:00
parent 22be6e404b
commit f0879f8d13
14 changed files with 301 additions and 84 deletions

View File

@ -1,6 +1,6 @@
--- ---
name: about-picobot name: about-picobot
description: PicoBot 自身设计信息的索引入口。含配置、数据库、架构、常见问题等。具体内容在 references/ 目录下config 示例在 assets/ 目录下,请用 file_read 工具查阅对应文件。 description: PicoBot 自身设计信息的索引入口。含配置、数据库、架构、常见问题等,如需要修改自身配置或了解自身工作机制加载查询。具体内容在 references/ 目录下config 示例在 assets/ 目录下,请用 file_read 工具查阅对应文件。
always: true always: true
--- ---
# About PicoBot # About PicoBot

View File

@ -28,12 +28,14 @@
"gpt-4o": { "gpt-4o": {
"model_id": "gpt-4o", "model_id": "gpt-4o",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 4096 "max_tokens": 4096,
"input_type": ["text", "image"]
}, },
"claude-sonnet-4-20250514": { "claude-sonnet-4-20250514": {
"model_id": "claude-sonnet-4-20250514", "model_id": "claude-sonnet-4-20250514",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 8192 "max_tokens": 8192,
"input_type": ["text", "image"]
} }
}, },
"agents": { "agents": {

View File

@ -37,6 +37,7 @@
| `model_id` | 模型标识名称 | | `model_id` | 模型标识名称 |
| `temperature` | 采样温度,可选 | | `temperature` | 采样温度,可选 |
| `max_tokens` | 最大输出 token 数,可选 | | `max_tokens` | 最大输出 token 数,可选 |
| `input_type` | 模型支持的输入类型,如 `["text"]``["text", "image"]`,默认 `["text"]`. 纯内部使用,不会传递给 LLM API |
## agents 字段 ## agents 字段

View File

@ -28,12 +28,14 @@
"gpt-4o": { "gpt-4o": {
"model_id": "gpt-4o", "model_id": "gpt-4o",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 4096 "max_tokens": 4096,
"input_type": ["text", "image"]
}, },
"claude-sonnet-4-20250514": { "claude-sonnet-4-20250514": {
"model_id": "claude-sonnet-4-20250514", "model_id": "claude-sonnet-4-20250514",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 8192 "max_tokens": 8192,
"input_type": ["text", "image"]
} }
}, },
"agents": { "agents": {

View File

@ -1,7 +1,8 @@
use crate::agent::context_compressor::estimate_tokens; use crate::agent::context_compressor::estimate_tokens;
use crate::agent::media_handler::MediaHandlerRegistry;
use crate::agent::system_prompt::build_system_prompt; use crate::agent::system_prompt::build_system_prompt;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage; use crate::bus::{ChatMessage, MediaRef};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::observability::{ use crate::observability::{
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
@ -10,7 +11,6 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::io::Read;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -21,24 +21,50 @@ const MAX_TOOL_RESULT_CHARS: usize = 16_000;
/// Minimum characters to keep when truncating /// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200; const TRUNCATION_SUFFIX_LEN: usize = 200;
/// Build content blocks from text and media paths /// Build content blocks from text and media, respecting model input capabilities
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> { fn build_content_blocks(
text: &str,
media_refs: &[MediaRef],
input_types: &[String],
registry: &MediaHandlerRegistry,
) -> Vec<ContentBlock> {
let mut blocks = Vec::new(); let mut blocks = Vec::new();
// Add text block if there's text if !media_refs.is_empty() {
if !text.is_empty() { for mr in media_refs {
if input_types.contains(&mr.media_type) {
match registry.handle(&mr.media_type, &mr.path) {
Ok(content_blocks) => blocks.extend(content_blocks),
Err(e) => {
tracing::warn!(
path = %mr.path,
media_type = %mr.media_type,
error = %e,
"Media handler failed, falling back to text placeholder"
);
blocks.push(ContentBlock::text(format!(
"[用户发来了一个文件,但处理失败: {}, 错误: {}]",
mr.path, e
)));
}
}
} else {
tracing::debug!(
path = %mr.path,
media_type = %mr.media_type,
model_input_types = ?input_types,
"Media type not supported by model, using text placeholder"
);
blocks.push(ContentBlock::text(format!(
"[用户发来了一个文件: {}]",
mr.path
)));
}
}
} else if !text.is_empty() {
blocks.push(ContentBlock::text(text)); blocks.push(ContentBlock::text(text));
} }
// Add image blocks for media paths
for path in media_paths {
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
let url = format!("data:{};base64,{}", mime_type, base64_data);
blocks.push(ContentBlock::image_url(url));
}
}
// If nothing, add empty text block
if blocks.is_empty() { if blocks.is_empty() {
blocks.push(ContentBlock::text("")); blocks.push(ContentBlock::text(""));
} }
@ -46,22 +72,6 @@ fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock>
blocks blocks
} }
/// Encode an image file to base64 data URL
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let mime = mime_guess::from_path(path)
.first_or_octet_stream()
.to_string();
let encoded = STANDARD.encode(&buffer);
Ok((mime, encoded))
}
/// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS. /// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS.
/// Preserves the end of the output as it often contains the conclusion/useful result. /// Preserves the end of the output as it often contains the conclusion/useful result.
fn truncate_tool_result(output: &str) -> String { fn truncate_tool_result(output: &str) -> String {
@ -200,24 +210,6 @@ fn canonicalise_json(value: &serde_json::Value) -> serde_json::Value {
other => other.clone(), other => other.clone(),
} }
} }
/// Convert ChatMessage to LLM Message format
fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs)
};
Message {
role: m.role.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// AgentLoop - Stateless agent that processes messages with tool calling support. /// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager. /// History is managed externally by SessionManager.
pub struct AgentLoop { pub struct AgentLoop {
@ -229,8 +221,9 @@ pub struct AgentLoop {
model_name: String, model_name: String,
context_window: usize, context_window: usize,
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>, notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
input_types: Vec<String>,
media_registry: MediaHandlerRegistry,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AgentProcessResult { pub struct AgentProcessResult {
pub final_response: ChatMessage, pub final_response: ChatMessage,
@ -243,6 +236,7 @@ impl AgentLoop {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone(); let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone(); let workspace_dir = provider_config.workspace_dir.clone();
let input_types = provider_config.input_types.clone();
let provider = create_provider(provider_config) let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?; .map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
@ -255,6 +249,8 @@ impl AgentLoop {
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
}) })
} }
@ -263,6 +259,7 @@ impl AgentLoop {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone(); let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone(); let workspace_dir = provider_config.workspace_dir.clone();
let input_types = provider_config.input_types.clone();
let provider = create_provider(provider_config) let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?; .map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
@ -275,11 +272,13 @@ impl AgentLoop {
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
}) })
} }
/// Create a new AgentLoop with an existing shared provider. /// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf) -> Self { pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
Self { Self {
provider, provider,
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
@ -289,6 +288,8 @@ impl AgentLoop {
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
} }
} }
@ -299,6 +300,7 @@ impl AgentLoop {
max_iterations: usize, max_iterations: usize,
model_name: String, model_name: String,
workspace_dir: PathBuf, workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self { ) -> Self {
Self { Self {
provider, provider,
@ -309,6 +311,8 @@ impl AgentLoop {
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
input_types,
media_registry: MediaHandlerRegistry::with_defaults(),
} }
} }
@ -369,6 +373,22 @@ impl AgentLoop {
&self.tools &self.tools
} }
fn chat_message_to_llm_message(&self, m: &ChatMessage) -> Message {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
};
Message {
role: m.role.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
}
}
/// Process a message using the provided conversation history. /// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager. /// History management is handled externally by SessionManager.
/// ///
@ -421,7 +441,7 @@ impl AgentLoop {
// Convert messages to LLM format // Convert messages to LLM format
let messages_for_llm: Vec<Message> = messages let messages_for_llm: Vec<Message> = messages
.iter() .iter()
.map(chat_message_to_llm_message) .map(|m| self.chat_message_to_llm_message(m))
.collect(); .collect();
// Build request // Build request
@ -549,7 +569,7 @@ impl AgentLoop {
// Convert messages to LLM format // Convert messages to LLM format
let messages_for_llm: Vec<Message> = messages let messages_for_llm: Vec<Message> = messages
.iter() .iter()
.map(chat_message_to_llm_message) .map(|m| self.chat_message_to_llm_message(m))
.collect(); .collect();
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
@ -766,6 +786,8 @@ mod tests {
#[test] #[test]
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() { fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
use crate::providers::Message;
let chat_message = ChatMessage::assistant_with_tool_calls( let chat_message = ChatMessage::assistant_with_tool_calls(
"calling tool", "calling tool",
vec![ToolCall { vec![ToolCall {
@ -775,7 +797,14 @@ mod tests {
}], }],
); );
let provider_message = chat_message_to_llm_message(&chat_message); let content = vec![ContentBlock::text(&chat_message.content)];
let provider_message = Message {
role: chat_message.role.clone(),
content,
tool_call_id: chat_message.tool_call_id.clone(),
name: chat_message.tool_name.clone(),
tool_calls: chat_message.tool_calls.clone(),
};
assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);

101
src/agent/media_handler.rs Normal file
View File

@ -0,0 +1,101 @@
use std::collections::HashMap;
use std::fmt;
use std::io::Read;
use crate::bus::message::ContentBlock;
pub trait MediaHandler: Send + Sync {
fn media_type(&self) -> &str;
fn handle(&self, path: &str) -> Result<Vec<ContentBlock>, MediaHandlerError>;
}
#[derive(Debug)]
pub enum MediaHandlerError {
Io(std::io::Error),
UnsupportedFormat(String),
}
impl fmt::Display for MediaHandlerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MediaHandlerError::Io(e) => write!(f, "I/O error: {}", e),
MediaHandlerError::UnsupportedFormat(msg) => write!(f, "Unsupported format: {}", msg),
}
}
}
impl std::error::Error for MediaHandlerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
MediaHandlerError::Io(e) => Some(e),
MediaHandlerError::UnsupportedFormat(_) => None,
}
}
}
pub struct ImageHandler;
impl MediaHandler for ImageHandler {
fn media_type(&self) -> &str {
"image"
}
fn handle(&self, path: &str) -> Result<Vec<ContentBlock>, MediaHandlerError> {
let (mime_type, base64_data) =
encode_image_to_base64(path).map_err(MediaHandlerError::Io)?;
let url = format!("data:{};base64,{}", mime_type, base64_data);
Ok(vec![ContentBlock::image_url(url)])
}
}
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let mime = mime_guess::from_path(path)
.first_or_octet_stream()
.to_string();
let encoded = STANDARD.encode(&buffer);
Ok((mime, encoded))
}
pub struct MediaHandlerRegistry {
handlers: HashMap<String, Box<dyn MediaHandler>>,
}
impl MediaHandlerRegistry {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
}
}
pub fn register(&mut self, handler: Box<dyn MediaHandler>) {
self.handlers
.insert(handler.media_type().to_string(), handler);
}
pub fn handle(
&self,
media_type: &str,
path: &str,
) -> Result<Vec<ContentBlock>, MediaHandlerError> {
match self.handlers.get(media_type) {
Some(handler) => handler.handle(path),
None => Err(MediaHandlerError::UnsupportedFormat(format!(
"no handler for type: {}",
media_type
))),
}
}
pub fn with_defaults() -> Self {
let mut reg = Self::new();
reg.register(Box::new(ImageHandler));
reg
}
}

View File

@ -1,5 +1,6 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor; pub mod context_compressor;
pub mod media_handler;
pub mod system_prompt; pub mod system_prompt;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};

View File

@ -33,6 +33,16 @@ impl ContentBlock {
} }
} }
// ============================================================================
// MediaRef - Media reference in ChatMessage (carries type info)
// ============================================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MediaRef {
pub path: String,
pub media_type: String,
}
// ============================================================================ // ============================================================================
// MediaItem - Media metadata for messages // MediaItem - Media metadata for messages
// ============================================================================ // ============================================================================
@ -54,6 +64,13 @@ impl MediaItem {
original_key: None, original_key: None,
} }
} }
pub fn to_media_ref(&self) -> MediaRef {
MediaRef {
path: self.path.clone(),
media_type: self.media_type.clone(),
}
}
} }
// ============================================================================ // ============================================================================
@ -65,7 +82,7 @@ pub struct ChatMessage {
pub id: String, pub id: String,
pub role: String, pub role: String,
pub content: String, pub content: String,
pub media_refs: Vec<String>, // Paths to media files for context pub media_refs: Vec<MediaRef>,
pub timestamp: i64, pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
@ -112,7 +129,7 @@ impl ChatMessage {
} }
} }
pub fn user_with_media(content: impl Into<String>, media_refs: Vec<String>) -> Self { pub fn user_with_media(content: impl Into<String>, media_refs: Vec<MediaRef>) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(), role: "user".to_string(),

View File

@ -2,7 +2,7 @@ pub mod dispatcher;
pub mod message; pub mod message;
pub use dispatcher::OutboundDispatcher; pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};

View File

@ -366,13 +366,7 @@ impl FeishuChannel {
.unwrap_or("image/jpeg") .unwrap_or("image/jpeg")
.to_string(); .to_string();
let ext = match content_type.as_str() { let ext = resolve_image_ext(&content_type);
"image/png" => "png",
"image/gif" => "gif",
"image/webp" => "webp",
"image/bmp" => "bmp",
_ => "jpg",
};
let data = resp.bytes().await let data = resp.bytes().await
.map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))? .map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))?
@ -382,7 +376,7 @@ impl FeishuChannel {
tracing::debug!(data_len = %data.len(), content_type = %content_type, "Downloaded image data"); tracing::debug!(data_len = %data.len(), content_type = %content_type, "Downloaded image data");
let filename = format!("{}_{}.{}", message_id, &image_key[..8.min(image_key.len())], ext); let filename = format!("{}_{}.{}", message_id, &image_key[..8.min(image_key.len())], ext);
let file_path = media_dir.join(&filename); let file_path = resolve_unique_path(media_dir, &filename).await;
tokio::fs::write(&file_path, &data).await tokio::fs::write(&file_path, &data).await
.map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?; .map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?;
@ -394,7 +388,7 @@ impl FeishuChannel {
tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image"); tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image");
Ok((format!("[image: {}]", filename), Some(media_item))) Ok((String::new(), Some(media_item)))
} }
/// Download file/audio from Feishu /// Download file/audio from Feishu
@ -434,13 +428,19 @@ impl FeishuChannel {
.map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))? .map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))?
.to_vec(); .to_vec();
let extension = match file_type { let filename = content_json
"audio" => "mp3", .get("file_name")
"video" => "mp4", .and_then(|v| v.as_str())
_ => "bin", .map(|s| s.to_string())
}; .unwrap_or_else(|| {
let filename = format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension); let ext = resolve_file_ext(content_json);
let file_path = media_dir.join(&filename); if ext.is_empty() {
format!("{}_{}", message_id, &file_key[..8.min(file_key.len())])
} else {
format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], ext)
}
});
let file_path = resolve_unique_path(media_dir, &filename).await;
tokio::fs::write(&file_path, &data).await tokio::fs::write(&file_path, &data).await
.map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?; .map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?;
@ -452,7 +452,7 @@ impl FeishuChannel {
tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file"); tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file");
Ok((format!("[{}: {}]", file_type, filename), Some(media_item))) Ok((String::new(), Some(media_item)))
} }
/// Upload image to Feishu and return the image_key /// Upload image to Feishu and return the image_key
@ -1543,6 +1543,55 @@ fn strip_at_placeholders(text: &str) -> String {
result result
} }
fn resolve_image_ext(content_type: &str) -> &str {
match content_type {
"image/png" => "png",
"image/gif" => "gif",
"image/webp" => "webp",
"image/bmp" => "bmp",
_ => "jpg",
}
}
fn resolve_file_ext(content_json: &serde_json::Value) -> String {
if let Some(name) = content_json
.get("file_name")
.and_then(|v| v.as_str())
{
if let Some(ext) = std::path::Path::new(name).extension().and_then(|e| e.to_str()) {
return ext.to_string();
}
}
String::new()
}
async fn resolve_unique_path(dir: &Path, filename: &str) -> std::path::PathBuf {
let candidate = dir.join(filename);
if !tokio::fs::try_exists(&candidate).await.unwrap_or(false) {
return candidate;
}
let stem = std::path::Path::new(filename)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(filename);
let ext = std::path::Path::new(filename)
.extension()
.and_then(|s| s.to_str())
.unwrap_or("");
let mut n = 1;
loop {
let candidate = if ext.is_empty() {
dir.join(format!("{}({})", stem, n))
} else {
dir.join(format!("{}({}).{}", stem, n, ext))
};
if !tokio::fs::try_exists(&candidate).await.unwrap_or(false) {
return candidate;
}
n += 1;
}
}
impl FeishuChannel { impl FeishuChannel {
fn strip_thinking_tags(content: &str) -> String { fn strip_thinking_tags(content: &str) -> String {
use std::sync::LazyLock; use std::sync::LazyLock;

View File

@ -112,10 +112,16 @@ pub struct ModelConfig {
pub temperature: Option<f32>, pub temperature: Option<f32>,
#[serde(default)] #[serde(default)]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
#[serde(default = "default_input_type")]
pub input_type: Vec<String>,
#[serde(flatten)] #[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>, pub extra: HashMap<String, serde_json::Value>,
} }
fn default_input_type() -> Vec<String> {
vec!["text".to_string()]
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentConfig { pub struct AgentConfig {
pub provider: String, pub provider: String,
@ -376,6 +382,7 @@ pub struct LLMProviderConfig {
pub max_tool_iterations: usize, pub max_tool_iterations: usize,
pub token_limit: usize, pub token_limit: usize,
pub workspace_dir: PathBuf, pub workspace_dir: PathBuf,
pub input_types: Vec<String>,
} }
fn get_default_config_path() -> PathBuf { fn get_default_config_path() -> PathBuf {
@ -443,6 +450,7 @@ impl Config {
max_tool_iterations: agent.max_tool_iterations, max_tool_iterations: agent.max_tool_iterations,
token_limit: agent.token_limit, token_limit: agent.token_limit,
workspace_dir: expand_path(&self.workspace_dir), workspace_dir: expand_path(&self.workspace_dir),
input_types: model.input_type.clone(),
}) })
} }
} }

View File

@ -3,7 +3,7 @@ use std::sync::Arc;
use tokio::sync::{Mutex, mpsc, oneshot}; use tokio::sync::{Mutex, mpsc, oneshot};
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; use crate::bus::{ChatMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
use crate::mcp::get_mcp_status; use crate::mcp::get_mcp_status;
use crate::storage::{Storage, StorageError}; use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc; use std::sync::Arc as StdArc;
@ -361,7 +361,7 @@ impl Session {
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory"); tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
} }
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage { pub fn create_user_message(&self, content: &str, media_refs: Vec<MediaRef>) -> ChatMessage {
if media_refs.is_empty() { if media_refs.is_empty() {
ChatMessage::user(content) ChatMessage::user(content)
} else { } else {
@ -372,7 +372,7 @@ impl Session {
pub fn create_user_message_with_source( pub fn create_user_message_with_source(
&self, &self,
content: &str, content: &str,
media_refs: Vec<String>, media_refs: Vec<MediaRef>,
source: MessageSource, source: MessageSource,
) -> ChatMessage { ) -> ChatMessage {
if media_refs.is_empty() { if media_refs.is_empty() {
@ -486,6 +486,7 @@ impl Session {
self.provider_config.max_tool_iterations, self.provider_config.max_tool_iterations,
self.provider_config.model_id.clone(), self.provider_config.model_id.clone(),
self.provider_config.workspace_dir.clone(), self.provider_config.workspace_dir.clone(),
self.provider_config.input_types.clone(),
).with_context_window(self.provider_config.token_limit)) ).with_context_window(self.provider_config.token_limit))
} }
@ -889,6 +890,7 @@ impl SessionManager {
self.provider_config.max_tool_iterations, self.provider_config.max_tool_iterations,
self.provider_config.model_id.clone(), self.provider_config.model_id.clone(),
self.provider_config.workspace_dir.clone(), self.provider_config.workspace_dir.clone(),
self.provider_config.input_types.clone(),
).with_context_window(self.provider_config.token_limit)) ).with_context_window(self.provider_config.token_limit))
} }
@ -1501,8 +1503,8 @@ fn spawn_agent_worker(
return; // stale worker return; // stale worker
} }
let media_refs: Vec<String> = let media_refs: Vec<MediaRef> =
task.media.iter().map(|m| m.path.clone()).collect(); task.media.iter().map(|m| m.to_media_ref()).collect();
let user_message = let user_message =
guard.create_user_message(&task.content, media_refs); guard.create_user_message(&task.content, media_refs);
if let Err(e) = guard.add_message(user_message, true).await { if let Err(e) = guard.add_message(user_message, true).await {
@ -1902,6 +1904,7 @@ mod tests {
max_tool_iterations: 1, max_tool_iterations: 1,
token_limit: 4096, token_limit: 4096,
workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"), workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"),
input_types: vec!["text".to_string()],
} }
} }
} }

View File

@ -25,6 +25,8 @@ fn load_config() -> Option<LLMProviderConfig> {
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 20, max_tool_iterations: 20,
token_limit: 128_000, token_limit: 128_000,
workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"),
input_types: vec!["text".to_string()],
}) })
} }

View File

@ -25,6 +25,8 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 20, max_tool_iterations: 20,
token_limit: 128_000, token_limit: 128_000,
workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"),
input_types: vec!["text".to_string()],
}) })
} }