feat: introduce multimodal content handling with media support
- Added ContentBlock enum for multimodal content representation (text, image). - Enhanced ChatMessage struct to include media references. - Updated InboundMessage and OutboundMessage to use MediaItem for media handling. - Implemented media download and upload functionality in FeishuChannel. - Modified message processing in the gateway to handle media items. - Improved logging for message processing and media handling in debug mode. - Refactored message serialization for LLM providers to support content blocks.
This commit is contained in:
parent
a051f83050
commit
2dada36bc6
@ -4,7 +4,7 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls"] }
|
||||
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] }
|
||||
dotenv = "0.15"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
regex = "1.0"
|
||||
@ -23,3 +23,5 @@ tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
tracing-appender = "0.2"
|
||||
anyhow = "1.0"
|
||||
mime_guess = "2.0"
|
||||
base64 = "0.22"
|
||||
|
||||
@ -1,9 +1,52 @@
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Build content blocks from text and media paths
|
||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Add text block if there's text
|
||||
if !text.is_empty() {
|
||||
blocks.push(ContentBlock::text(text));
|
||||
}
|
||||
|
||||
// Add image blocks for media paths
|
||||
for path in media_paths {
|
||||
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
||||
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
blocks.push(ContentBlock::image_url(url));
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing, add empty text block
|
||||
if blocks.is_empty() {
|
||||
blocks.push(ContentBlock::text(""));
|
||||
}
|
||||
|
||||
blocks
|
||||
}
|
||||
|
||||
/// Encode an image file to base64 data URL
|
||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer)?;
|
||||
|
||||
let mime = mime_guess::from_path(path)
|
||||
.first_or_octet_stream()
|
||||
.to_string();
|
||||
|
||||
let encoded = STANDARD.encode(&buffer);
|
||||
Ok((mime, encoded))
|
||||
}
|
||||
|
||||
/// Stateless AgentLoop - history is managed externally by SessionManager
|
||||
pub struct AgentLoop {
|
||||
provider: Box<dyn LLMProvider>,
|
||||
@ -40,14 +83,26 @@ impl AgentLoop {
|
||||
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
||||
let messages_for_llm: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
.map(|m| {
|
||||
let content = if m.media_refs.is_empty() {
|
||||
vec![ContentBlock::text(&m.content)]
|
||||
} else {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
|
||||
build_content_blocks(&m.content, &m.media_refs)
|
||||
};
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
|
||||
Message {
|
||||
role: m.role.clone(),
|
||||
content,
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(history_len = messages.len(), "Sending request to LLM");
|
||||
|
||||
let tools = if self.tools.has_tools() {
|
||||
@ -69,6 +124,7 @@ impl AgentLoop {
|
||||
AgentError::LlmError(e.to_string())
|
||||
})?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
response_len = response.content.len(),
|
||||
tool_calls_len = response.tool_calls.len(),
|
||||
@ -103,11 +159,18 @@ impl AgentLoop {
|
||||
async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
||||
let messages_for_llm: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
.map(|m| {
|
||||
let content = if m.media_refs.is_empty() {
|
||||
vec![ContentBlock::text(&m.content)]
|
||||
} else {
|
||||
build_content_blocks(&m.content, &m.media_refs)
|
||||
};
|
||||
Message {
|
||||
role: m.role.clone(),
|
||||
content,
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ impl OutboundDispatcher {
|
||||
|
||||
loop {
|
||||
let msg = self.bus.consume_outbound().await;
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
channel = %msg.channel,
|
||||
chat_id = %msg.chat_id,
|
||||
|
||||
@ -2,7 +2,60 @@ use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ============================================================================
|
||||
// ChatMessage - Legacy type used by AgentLoop for LLM conversation history
|
||||
// ContentBlock - Multimodal content representation (OpenAI-style)
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ImageUrlBlock },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageUrlBlock {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
impl ContentBlock {
|
||||
pub fn text(content: impl Into<String>) -> Self {
|
||||
Self::Text { text: content.into() }
|
||||
}
|
||||
|
||||
pub fn image_url(url: impl Into<String>) -> Self {
|
||||
Self::ImageUrl {
|
||||
image_url: ImageUrlBlock { url: url.into() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MediaItem - Media metadata for messages
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MediaItem {
|
||||
pub path: String, // Local file path
|
||||
pub media_type: String, // "image", "audio", "file", "video"
|
||||
pub mime_type: Option<String>,
|
||||
pub original_key: Option<String>, // Feishu file_key for download
|
||||
}
|
||||
|
||||
impl MediaItem {
|
||||
pub fn new(path: impl Into<String>, media_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
path: path.into(),
|
||||
media_type: media_type.into(),
|
||||
mime_type: None,
|
||||
original_key: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ChatMessage - Used by AgentLoop for LLM conversation history
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -10,6 +63,7 @@ pub struct ChatMessage {
|
||||
pub id: String,
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub media_refs: Vec<String>, // Paths to media files for context
|
||||
pub timestamp: i64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
@ -23,6 +77,19 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_with_media(content: impl Into<String>, media_refs: Vec<String>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.into(),
|
||||
media_refs,
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -34,6 +101,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -45,6 +113,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -56,6 +125,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "tool".to_string(),
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
tool_name: Some(tool_name.into()),
|
||||
@ -74,7 +144,7 @@ pub struct InboundMessage {
|
||||
pub chat_id: String,
|
||||
pub content: String,
|
||||
pub timestamp: i64,
|
||||
pub media: Vec<String>,
|
||||
pub media: Vec<MediaItem>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
@ -94,7 +164,7 @@ pub struct OutboundMessage {
|
||||
pub chat_id: String,
|
||||
pub content: String,
|
||||
pub reply_to: Option<String>,
|
||||
pub media: Vec<String>,
|
||||
pub media: Vec<MediaItem>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ pub mod dispatcher;
|
||||
pub mod message;
|
||||
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{ChatMessage, InboundMessage, OutboundMessage};
|
||||
pub use message::{ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
@ -33,6 +33,8 @@ impl MessageBus {
|
||||
|
||||
/// Publish an inbound message (Channel -> Bus)
|
||||
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, content_len = %msg.content.len(), media_count = %msg.media.len(), "Bus: publishing inbound message");
|
||||
self.inbound_tx
|
||||
.send(msg)
|
||||
.await
|
||||
@ -41,16 +43,21 @@ impl MessageBus {
|
||||
|
||||
/// Consume an inbound message (Agent -> Bus)
|
||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||
self.inbound_rx
|
||||
let msg = self.inbound_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.expect("bus inbound closed")
|
||||
.expect("bus inbound closed");
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
|
||||
msg
|
||||
}
|
||||
|
||||
/// Publish an outbound message (Agent -> Bus)
|
||||
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(channel = %msg.channel, chat_id = %msg.chat_id, content_len = %msg.content.len(), "Bus: publishing outbound message");
|
||||
self.outbound_tx
|
||||
.send(msg)
|
||||
.await
|
||||
|
||||
@ -62,37 +62,18 @@ pub trait Channel: Send + Sync + 'static {
|
||||
async fn handle_and_publish(
|
||||
&self,
|
||||
bus: &Arc<MessageBus>,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
msg: &InboundMessage,
|
||||
) -> Result<(), ChannelError> {
|
||||
if !self.is_allowed(sender_id) {
|
||||
if !self.is_allowed(&msg.sender_id) {
|
||||
tracing::warn!(
|
||||
channel = %self.name(),
|
||||
sender = %sender_id,
|
||||
sender = %msg.sender_id,
|
||||
"Access denied"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let msg = InboundMessage {
|
||||
channel: self.name().to_string(),
|
||||
sender_id: sender_id.to_string(),
|
||||
chat_id: chat_id.to_string(),
|
||||
content: content.to_string(),
|
||||
timestamp: current_timestamp(),
|
||||
media: vec![],
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
bus.publish_inbound(msg).await?;
|
||||
bus.publish_inbound(msg.clone()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as i64
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
@ -5,7 +6,7 @@ use serde::Deserialize;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::{Message as ProstMessage, bytes::Bytes};
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::bus::{MessageBus, MediaItem, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
||||
|
||||
@ -135,10 +136,10 @@ pub struct FeishuChannel {
|
||||
|
||||
/// Parsed message data from a Feishu frame
|
||||
struct ParsedMessage {
|
||||
message_id: String,
|
||||
open_id: String,
|
||||
chat_id: String,
|
||||
content: String,
|
||||
media: Option<MediaItem>,
|
||||
}
|
||||
|
||||
impl FeishuChannel {
|
||||
@ -220,6 +221,270 @@ impl FeishuChannel {
|
||||
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
|
||||
}
|
||||
|
||||
/// Download media and save locally, return (description, media_item)
|
||||
async fn download_media(
|
||||
&self,
|
||||
msg_type: &str,
|
||||
content_json: &serde_json::Value,
|
||||
message_id: &str,
|
||||
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||
let media_dir = Path::new(&self.config.media_dir);
|
||||
tokio::fs::create_dir_all(media_dir).await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create media dir: {}", e)))?;
|
||||
|
||||
match msg_type {
|
||||
"image" => self.download_image(content_json, message_id, media_dir).await,
|
||||
"audio" | "file" | "media" => self.download_file(content_json, message_id, media_dir, msg_type).await,
|
||||
_ => Ok((format!("[unsupported media type: {}]", msg_type), None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Download image from Feishu
|
||||
async fn download_image(
|
||||
&self,
|
||||
content_json: &serde_json::Value,
|
||||
message_id: &str,
|
||||
media_dir: &Path,
|
||||
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||
let image_key = content_json.get("image_key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ChannelError::Other("No image_key in message".to_string()))?;
|
||||
|
||||
let token = self.get_tenant_token().await?;
|
||||
|
||||
// Use message resource API for downloading message images
|
||||
let url = format!("{}/im/v1/messages/{}/resources/{}?type=image", FEISHU_API_BASE, message_id, image_key);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(url = %url, image_key = %image_key, message_id = %message_id, "Downloading image from Feishu via message resource API");
|
||||
|
||||
let resp = self.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)))?;
|
||||
|
||||
let status = resp.status();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(status = %status, "Image download response status");
|
||||
|
||||
if !status.is_success() {
|
||||
let error_text = resp.text().await.unwrap_or_default();
|
||||
return Err(ChannelError::Other(format!("Image download failed {}: {}", status, error_text)));
|
||||
}
|
||||
|
||||
let data = resp.bytes().await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))?
|
||||
.to_vec();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(data_len = %data.len(), "Downloaded image data");
|
||||
|
||||
let filename = format!("{}_{}.jpg", message_id, &image_key[..8.min(image_key.len())]);
|
||||
let file_path = media_dir.join(&filename);
|
||||
|
||||
tokio::fs::write(&file_path, &data).await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?;
|
||||
|
||||
let media_item = MediaItem::new(
|
||||
file_path.to_string_lossy().to_string(),
|
||||
"image",
|
||||
);
|
||||
|
||||
tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image");
|
||||
|
||||
Ok((format!("[image: {}]", filename), Some(media_item)))
|
||||
}
|
||||
|
||||
/// Download file/audio from Feishu
|
||||
async fn download_file(
|
||||
&self,
|
||||
content_json: &serde_json::Value,
|
||||
message_id: &str,
|
||||
media_dir: &Path,
|
||||
file_type: &str,
|
||||
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||
let file_key = content_json.get("file_key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ChannelError::Other("No file_key in message".to_string()))?;
|
||||
|
||||
let token = self.get_tenant_token().await?;
|
||||
|
||||
// Use message resource API for downloading message files
|
||||
let url = format!("{}/im/v1/messages/{}/resources/{}?type=file", FEISHU_API_BASE, message_id, file_key);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(url = %url, file_key = %file_key, message_id = %message_id, "Downloading file from Feishu via message resource API");
|
||||
|
||||
let resp = self.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)))?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let error_text = resp.text().await.unwrap_or_default();
|
||||
return Err(ChannelError::Other(format!("File download failed {}: {}", status, error_text)));
|
||||
}
|
||||
|
||||
let data = resp.bytes().await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))?
|
||||
.to_vec();
|
||||
|
||||
let extension = match file_type {
|
||||
"audio" => "mp3",
|
||||
"video" => "mp4",
|
||||
_ => "bin",
|
||||
};
|
||||
let filename = format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension);
|
||||
let file_path = media_dir.join(&filename);
|
||||
|
||||
tokio::fs::write(&file_path, &data).await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?;
|
||||
|
||||
let media_item = MediaItem::new(
|
||||
file_path.to_string_lossy().to_string(),
|
||||
file_type,
|
||||
);
|
||||
|
||||
tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file");
|
||||
|
||||
Ok((format!("[{}: {}]", file_type, filename), Some(media_item)))
|
||||
}
|
||||
|
||||
/// Upload image to Feishu and return the image_key
|
||||
async fn upload_image(&self, file_path: &str) -> Result<String, ChannelError> {
|
||||
let token = self.get_tenant_token().await?;
|
||||
|
||||
let mime = mime_guess::from_path(file_path)
|
||||
.first_or_octet_stream()
|
||||
.to_string();
|
||||
|
||||
let file_name = std::path::Path::new(file_path)
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("image.jpg");
|
||||
|
||||
let file_data = tokio::fs::read(file_path).await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
let part = reqwest::multipart::Part::bytes(file_data)
|
||||
.file_name(file_name.to_string())
|
||||
.mime_str(&mime)
|
||||
.map_err(|e| ChannelError::Other(format!("Invalid mime type: {}", e)))?;
|
||||
|
||||
let form = reqwest::multipart::Form::new()
|
||||
.text("image_type", "message".to_string())
|
||||
.part("image", part);
|
||||
|
||||
let resp = self.http_client
|
||||
.post(format!("{}/im/v1/images/upload", FEISHU_API_BASE))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)))?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UploadResp {
|
||||
code: i32,
|
||||
msg: Option<String>,
|
||||
data: Option<UploadData>,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct UploadData {
|
||||
image_key: String,
|
||||
}
|
||||
|
||||
let result: UploadResp = resp.json().await
|
||||
.map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?;
|
||||
|
||||
if result.code != 0 {
|
||||
return Err(ChannelError::Other(format!(
|
||||
"Upload image failed: code={} msg={}",
|
||||
result.code,
|
||||
result.msg.as_deref().unwrap_or("unknown")
|
||||
)));
|
||||
}
|
||||
|
||||
result.data
|
||||
.map(|d| d.image_key)
|
||||
.ok_or_else(|| ChannelError::Other("No image_key in response".to_string()))
|
||||
}
|
||||
|
||||
/// Upload file to Feishu and return the file_key
|
||||
async fn upload_file(&self, file_path: &str) -> Result<String, ChannelError> {
|
||||
let token = self.get_tenant_token().await?;
|
||||
|
||||
let file_name = std::path::Path::new(file_path)
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("file.bin");
|
||||
|
||||
let extension = std::path::Path::new(file_path)
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
let file_type = match extension.as_str() {
|
||||
"mp3" | "m4a" | "wav" | "ogg" => "audio",
|
||||
"mp4" | "mov" | "avi" | "mkv" => "video",
|
||||
"pdf" | "doc" | "docx" | "xls" | "xlsx" => "doc",
|
||||
_ => "file",
|
||||
};
|
||||
|
||||
let file_data = tokio::fs::read(file_path).await
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
let part = reqwest::multipart::Part::bytes(file_data)
|
||||
.file_name(file_name.to_string())
|
||||
.mime_str("application/octet-stream")
|
||||
.map_err(|e| ChannelError::Other(format!("Invalid mime type: {}", e)))?;
|
||||
|
||||
let form = reqwest::multipart::Form::new()
|
||||
.text("file_type", file_type.to_string())
|
||||
.text("file_name", file_name.to_string())
|
||||
.part("file", part);
|
||||
|
||||
let resp = self.http_client
|
||||
.post(format!("{}/im/v1/files", FEISHU_API_BASE))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ChannelError::ConnectionError(format!("Upload file HTTP error: {}", e)))?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UploadResp {
|
||||
code: i32,
|
||||
msg: Option<String>,
|
||||
data: Option<UploadData>,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct UploadData {
|
||||
file_key: String,
|
||||
}
|
||||
|
||||
let result: UploadResp = resp.json().await
|
||||
.map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?;
|
||||
|
||||
if result.code != 0 {
|
||||
return Err(ChannelError::Other(format!(
|
||||
"Upload file failed: code={} msg={}",
|
||||
result.code,
|
||||
result.msg.as_deref().unwrap_or("unknown")
|
||||
)));
|
||||
}
|
||||
|
||||
result.data
|
||||
.map(|d| d.file_key)
|
||||
.ok_or_else(|| ChannelError::Other("No file_key in response".to_string()))
|
||||
}
|
||||
|
||||
/// Send a text message to Feishu chat (implements Channel trait)
|
||||
async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
||||
let token = self.get_tenant_token().await?;
|
||||
@ -285,10 +550,15 @@ impl FeishuChannel {
|
||||
let payload = frame.payload.as_deref()
|
||||
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(payload_len = %payload.len(), "Received frame payload");
|
||||
|
||||
let event: LarkEvent = serde_json::from_slice(payload)
|
||||
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
|
||||
|
||||
let event_type = event.header.event_type.as_str();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(event_type = %event_type, "Received event type");
|
||||
if event_type != "im.message.receive_v1" {
|
||||
return Ok(None);
|
||||
}
|
||||
@ -303,22 +573,66 @@ impl FeishuChannel {
|
||||
|
||||
let message_id = payload_data.message.message_id.clone();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(message_id = %message_id, "Received Feishu message");
|
||||
|
||||
let open_id = payload_data.sender.sender_id.open_id
|
||||
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
||||
|
||||
let msg = payload_data.message;
|
||||
let chat_id = msg.chat_id.clone();
|
||||
let msg_type = msg.message_type.as_str();
|
||||
let content = parse_message_content(msg_type, &msg.content);
|
||||
let raw_content = msg.content.clone();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(msg_type = %msg_type, chat_id = %chat_id, open_id = %open_id, "Parsing message content");
|
||||
|
||||
let (content, media) = self.parse_and_download_message(msg_type, &raw_content, &message_id).await?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
if let Some(ref m) = media {
|
||||
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
|
||||
}
|
||||
|
||||
Ok(Some(ParsedMessage {
|
||||
message_id,
|
||||
open_id,
|
||||
chat_id,
|
||||
content,
|
||||
media,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Parse message content and download media if needed
|
||||
async fn parse_and_download_message(
|
||||
&self,
|
||||
msg_type: &str,
|
||||
content: &str,
|
||||
message_id: &str,
|
||||
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||
match msg_type {
|
||||
"text" => {
|
||||
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
|
||||
} else {
|
||||
content.to_string()
|
||||
};
|
||||
Ok((text, None))
|
||||
}
|
||||
"post" => {
|
||||
let text = parse_post_content(content);
|
||||
Ok((text, None))
|
||||
}
|
||||
"image" | "audio" | "file" | "media" => {
|
||||
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
self.download_media(msg_type, &content_json, message_id).await
|
||||
} else {
|
||||
Ok((format!("[{}: content unavailable]", msg_type), None))
|
||||
}
|
||||
}
|
||||
_ => Ok((content.to_string(), None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send acknowledgment for a message
|
||||
async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> {
|
||||
let mut ack = frame.clone();
|
||||
@ -389,8 +703,26 @@ impl FeishuChannel {
|
||||
let channel = self.clone();
|
||||
let bus = bus.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = channel.handle_and_publish(&bus, &parsed.open_id, &parsed.chat_id, &parsed.content).await {
|
||||
let media_count = if parsed.media.is_some() { 1 } else { 0 };
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
|
||||
let msg = crate::bus::InboundMessage {
|
||||
channel: "feishu".to_string(),
|
||||
sender_id: parsed.open_id.clone(),
|
||||
chat_id: parsed.chat_id.clone(),
|
||||
content: parsed.content.clone(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as i64,
|
||||
media: parsed.media.map(|m| vec![m]).unwrap_or_default(),
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
if let Err(e) = channel.handle_and_publish(&bus, &msg).await {
|
||||
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus");
|
||||
} else {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Message published to bus successfully");
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -416,6 +748,7 @@ impl FeishuChannel {
|
||||
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
||||
}
|
||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!("Feishu WebSocket closed");
|
||||
break;
|
||||
}
|
||||
@ -456,43 +789,31 @@ impl FeishuChannel {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_message_content(msg_type: &str, content: &str) -> String {
|
||||
match msg_type {
|
||||
"text" => {
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
"post" => {
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
let mut texts = vec![];
|
||||
if let Some(post) = parsed.get("post") {
|
||||
if let Some(content_arr) = post.get("content") {
|
||||
if let Some(arr) = content_arr.as_array() {
|
||||
for item in arr {
|
||||
if let Some(arr2) = item.as_array() {
|
||||
for inner in arr2 {
|
||||
if let Some(text) = inner.get("text").and_then(|v| v.as_str()) {
|
||||
texts.push(text.to_string());
|
||||
}
|
||||
}
|
||||
fn parse_post_content(content: &str) -> String {
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
let mut texts = vec![];
|
||||
if let Some(post) = parsed.get("post") {
|
||||
if let Some(content_arr) = post.get("content") {
|
||||
if let Some(arr) = content_arr.as_array() {
|
||||
for item in arr {
|
||||
if let Some(arr2) = item.as_array() {
|
||||
for inner in arr2 {
|
||||
if let Some(text) = inner.get("text").and_then(|v| v.as_str()) {
|
||||
texts.push(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if texts.is_empty() {
|
||||
content.to_string()
|
||||
} else {
|
||||
texts.join("")
|
||||
}
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
_ => content.to_string(),
|
||||
if texts.is_empty() {
|
||||
content.to_string()
|
||||
} else {
|
||||
texts.join("")
|
||||
}
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@ -574,6 +895,104 @@ impl Channel for FeishuChannel {
|
||||
let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) };
|
||||
let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||
|
||||
self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await
|
||||
// If no media, send text only
|
||||
if msg.media.is_empty() {
|
||||
return self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await;
|
||||
}
|
||||
|
||||
// Handle multimodal message - send with media
|
||||
let token = self.get_tenant_token().await?;
|
||||
|
||||
// Build content with media references
|
||||
let mut content_parts = Vec::new();
|
||||
|
||||
// Add text content if present
|
||||
if !msg.content.is_empty() {
|
||||
content_parts.push(serde_json::json!({
|
||||
"tag": "text",
|
||||
"text": msg.content
|
||||
}));
|
||||
}
|
||||
|
||||
// Upload and add media
|
||||
for media_item in &msg.media {
|
||||
let path = &media_item.path;
|
||||
match media_item.media_type.as_str() {
|
||||
"image" => {
|
||||
match self.upload_image(path).await {
|
||||
Ok(image_key) => {
|
||||
content_parts.push(serde_json::json!({
|
||||
"tag": "image",
|
||||
"image_key": image_key
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, path = %path, "Failed to upload image");
|
||||
}
|
||||
}
|
||||
}
|
||||
"audio" | "file" | "video" => {
|
||||
match self.upload_file(path).await {
|
||||
Ok(file_key) => {
|
||||
content_parts.push(serde_json::json!({
|
||||
"tag": "file",
|
||||
"file_key": file_key
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, path = %path, "Failed to upload file");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(media_type = %media_item.media_type, "Unsupported media type for sending");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no content parts after processing, just send empty text
|
||||
if content_parts.is_empty() {
|
||||
return self.send_message_to_feishu(receive_id, receive_id_type, "").await;
|
||||
}
|
||||
|
||||
// Determine message type
|
||||
let has_image = msg.media.iter().any(|m| m.media_type == "image");
|
||||
let msg_type = if has_image && msg.content.is_empty() {
|
||||
"image"
|
||||
} else {
|
||||
"post"
|
||||
};
|
||||
|
||||
let content = serde_json::json!({
|
||||
"content": content_parts
|
||||
}).to_string();
|
||||
|
||||
let resp = self.http_client
|
||||
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.json(&serde_json::json!({
|
||||
"receive_id": receive_id,
|
||||
"msg_type": msg_type,
|
||||
"content": content
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)))?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SendResp {
|
||||
code: i32,
|
||||
msg: String,
|
||||
}
|
||||
|
||||
let send_resp: SendResp = resp.json().await
|
||||
.map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?;
|
||||
|
||||
if send_resp.code != 0 {
|
||||
return Err(ChannelError::Other(format!("Send multimodal message failed: code={} msg={}", send_resp.code, send_resp.msg)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,6 +35,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
input.write_output(&format!("Error: {}", message)).await?;
|
||||
}
|
||||
WsOutbound::SessionEstablished { session_id } => {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(session_id = %session_id, "Session established");
|
||||
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
||||
}
|
||||
|
||||
@ -28,12 +28,19 @@ pub struct FeishuChannelConfig {
|
||||
pub allow_from: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub agent: String,
|
||||
#[serde(default = "default_media_dir")]
|
||||
pub media_dir: String,
|
||||
}
|
||||
|
||||
fn default_allow_from() -> Vec<String> {
|
||||
vec!["*".to_string()]
|
||||
}
|
||||
|
||||
fn default_media_dir() -> String {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||
home.join(".picobot/media/feishu").to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ProviderConfig {
|
||||
#[serde(rename = "type")]
|
||||
|
||||
@ -53,11 +53,22 @@ impl GatewayState {
|
||||
tracing::info!("Inbound processor started");
|
||||
loop {
|
||||
let inbound = bus_for_inbound.consume_inbound().await;
|
||||
tracing::debug!(
|
||||
channel = %inbound.channel,
|
||||
chat_id = %inbound.chat_id,
|
||||
"Processing inbound message"
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
channel = %inbound.channel,
|
||||
chat_id = %inbound.chat_id,
|
||||
sender = %inbound.sender_id,
|
||||
content = %inbound.content,
|
||||
media_count = %inbound.media.len(),
|
||||
"Processing inbound message"
|
||||
);
|
||||
if !inbound.media.is_empty() {
|
||||
for (i, m) in inbound.media.iter().enumerate() {
|
||||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process via session manager
|
||||
match session_manager.handle_message(
|
||||
@ -65,6 +76,7 @@ impl GatewayState {
|
||||
&inbound.sender_id,
|
||||
&inbound.chat_id,
|
||||
&inbound.content,
|
||||
inbound.media,
|
||||
).await {
|
||||
Ok(response_content) => {
|
||||
let outbound = crate::bus::OutboundMessage {
|
||||
|
||||
@ -56,6 +56,12 @@ impl Session {
|
||||
history.push(ChatMessage::user(content));
|
||||
}
|
||||
|
||||
/// 添加带媒体的用户消息到指定 chat_id 的历史
|
||||
pub fn add_user_message_with_media(&mut self, chat_id: &str, content: &str, media_refs: Vec<String>) {
|
||||
let history = self.get_or_create_history(chat_id);
|
||||
history.push(ChatMessage::user_with_media(content, media_refs));
|
||||
}
|
||||
|
||||
/// 添加助手响应到指定 chat_id 的历史
|
||||
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
@ -68,6 +74,7 @@ impl Session {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||
}
|
||||
}
|
||||
@ -76,6 +83,7 @@ impl Session {
|
||||
pub fn clear_all_history(&mut self) {
|
||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||
self.chat_histories.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||
}
|
||||
|
||||
@ -139,6 +147,7 @@ impl SessionManager {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(channel = %channel_name, "Creating new session");
|
||||
true
|
||||
};
|
||||
@ -184,13 +193,21 @@ impl SessionManager {
|
||||
_sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
media: Vec<crate::bus::MediaItem>,
|
||||
) -> Result<String, AgentError> {
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
content_len = content.len(),
|
||||
"Routing message to agent"
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
content_len = content.len(),
|
||||
media_count = %media.len(),
|
||||
"Routing message to agent"
|
||||
);
|
||||
for (i, m) in media.iter().enumerate() {
|
||||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 session 存在(可能需要重建)
|
||||
self.ensure_session(channel_name).await?;
|
||||
@ -209,7 +226,14 @@ impl SessionManager {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
// 添加用户消息到历史
|
||||
session_guard.add_user_message(chat_id, content);
|
||||
if media.is_empty() {
|
||||
session_guard.add_user_message(chat_id, content);
|
||||
} else {
|
||||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||||
session_guard.add_user_message_with_media(chat_id, content, media_refs);
|
||||
}
|
||||
|
||||
// 获取完整历史
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
@ -224,6 +248,7 @@ impl SessionManager {
|
||||
response
|
||||
};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
|
||||
@ -62,6 +62,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
if let Ok(text) = serialize_outbound(&msg) {
|
||||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
||||
break;
|
||||
}
|
||||
@ -91,6 +92,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
}
|
||||
}
|
||||
Ok(WsMessage::Close(_)) | Err(_) => {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(session_id = %session_id, "WebSocket closed");
|
||||
break;
|
||||
}
|
||||
@ -145,6 +147,7 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
||||
|
||||
match agent.process(history).await {
|
||||
Ok(response) => {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
||||
// 添加助手响应到历史
|
||||
session_guard.add_assistant_message(&chat_id, response.clone());
|
||||
|
||||
@ -3,9 +3,55 @@ use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use super::traits::Usage;
|
||||
|
||||
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||
}
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
blocks.iter().map(|b| match b {
|
||||
ContentBlock::Text { text } => {
|
||||
serde_json::json!({ "type": "text", "text": text })
|
||||
}
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
convert_image_url_to_anthropic(&image_url.url)
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||
// data:image/png;base64,... -> Anthropic image block
|
||||
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
|
||||
.ok()
|
||||
.and_then(|re| re.captures(url))
|
||||
{
|
||||
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
|
||||
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
|
||||
return serde_json::json!({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data
|
||||
}
|
||||
});
|
||||
}
|
||||
// Regular URL -> Anthropic image block with url source
|
||||
serde_json::json!({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": url
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub struct AnthropicProvider {
|
||||
client: Client,
|
||||
name: String,
|
||||
@ -58,7 +104,8 @@ struct AnthropicRequest {
|
||||
#[derive(Serialize)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
#[serde(serialize_with = "serialize_content_blocks")]
|
||||
content: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -122,7 +169,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
.iter()
|
||||
.map(|m| AnthropicMessage {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
content: convert_content_blocks(&m.content),
|
||||
})
|
||||
.collect(),
|
||||
max_tokens,
|
||||
|
||||
@ -1,12 +1,27 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use super::traits::Usage;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
if blocks.len() == 1 {
|
||||
if let ContentBlock::Text { text } = &blocks[0] {
|
||||
return Value::String(text.clone());
|
||||
}
|
||||
}
|
||||
Value::Array(blocks.iter().map(|b| match b {
|
||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||
}
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
client: Client,
|
||||
name: String,
|
||||
@ -107,14 +122,14 @@ impl LLMProvider for OpenAIProvider {
|
||||
if m.role == "tool" {
|
||||
json!({
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"tool_call_id": m.tool_call_id,
|
||||
"name": m.name,
|
||||
})
|
||||
} else {
|
||||
json!({
|
||||
"role": m.role,
|
||||
"content": m.content
|
||||
"content": convert_content_blocks(&m.content)
|
||||
})
|
||||
}
|
||||
}).collect::<Vec<_>>(),
|
||||
@ -131,6 +146,30 @@ impl LLMProvider for OpenAIProvider {
|
||||
body["tools"] = json!(tools);
|
||||
}
|
||||
|
||||
// Debug: Log LLM request summary (only in debug builds)
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
// Log messages summary
|
||||
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
|
||||
tracing::debug!(msg_count = msg_count, "LLM request messages count");
|
||||
|
||||
// Log first 20 bytes of base64 images (don't log full base64)
|
||||
if let Some(msgs) = body["messages"].as_array() {
|
||||
for (i, msg) in msgs.iter().enumerate() {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||
for (j, item) in content.iter().enumerate() {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||||
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
||||
let prefix = &url_str[..20.min(url_str.len())];
|
||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut req_builder = self
|
||||
.client
|
||||
.post(&url)
|
||||
@ -146,6 +185,13 @@ impl LLMProvider for OpenAIProvider {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
// Debug: Log LLM response (only in debug builds)
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let resp_preview = &text[..text.len().min(100)];
|
||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 bytes shown)");
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(format!("API error {}: {}", status, text).into());
|
||||
}
|
||||
|
||||
@ -1,16 +1,64 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::bus::message::ContentBlock;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub content: Vec<ContentBlock>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "system".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "tool".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: Some(tool_name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user