feat: WebSocket 媒体文件处理优化

- 后端 ws.rs: 处理前端上传的 base64 内容,保存到本地文件并更新路径
- 后端 ws.rs: 历史消息加载时从文件读取内容填充 base64,过滤 media_refs_json
- 前端 App.tsx: 传递 attachments 给 handleMessage 实现实时显示
- 前端 useChat.ts: handleMessage 支持 attachments 参数
- 前端 MessageInput.tsx: 支持剪贴板粘贴文件/图片
- 前端 MessageInput.tsx: 修复拖拽文件时闪烁问题
- 测试 test_request_format.rs: 补充缺失的 attachments 字段
This commit is contained in:
ooodc 2026-05-30 10:22:30 +08:00
parent c2293238fc
commit 7d9355fd78
6 changed files with 278 additions and 21 deletions

View File

@ -208,6 +208,8 @@ impl AgentExecutionService {
} }
let enriched_content = let enriched_content =
enrich_user_content_with_media_refs(request.content, &media_refs)?; enrich_user_content_with_media_refs(request.content, &media_refs)?;
enrich_user_content_with_media_refs(request.content, &media_refs)?;
enrich_user_content_with_media_refs(request.content, &media_refs)?;
// 先计算 user_message_count在添加新消息之前 // 先计算 user_message_count在添加新消息之前
let history_before = session_guard.get_or_create_history(request.chat_id).clone(); let history_before = session_guard.get_or_create_history(request.chat_id).clone();

View File

@ -17,18 +17,100 @@ use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::switch_topic::SwitchTopicCommandHandler; use crate::command::handlers::switch_topic::SwitchTopicCommandHandler;
use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound};
use crate::skills::SkillPromptProvider; use crate::skills::SkillPromptProvider;
use axum::extract::State; use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response; use axum::response::Response;
use base64::{Engine as _, engine::general_purpose::STANDARD};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
const CLI_CHANNEL_NAME: &str = "cli"; const CLI_CHANNEL_NAME: &str = "cli";
/// Default media directory for WebSocket uploads
fn default_ws_media_dir() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("media").join("ws")
}
/// Build a unique filename for media upload
fn build_media_filename(media_type: &str, file_name: Option<&str>) -> String {
if let Some(file_name) = file_name {
let sanitized: String = file_name
.chars()
.map(|ch| match ch {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
_ => ch,
})
.collect();
if !sanitized.trim().is_empty() {
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
}
}
format!("{}_{}", media_type, uuid::Uuid::new_v4())
}
/// Process attachments with base64 content: save to local file and return MediaItem with correct path
/// Keeps content_base64 for frontend display/download
fn process_attachments_with_base64(attachments: Vec<MediaSummary>) -> Result<Vec<MediaItem>, AgentError> {
if attachments.is_empty() {
return Ok(Vec::new());
}
let media_dir = default_ws_media_dir();
std::fs::create_dir_all(&media_dir)
.map_err(|error| AgentError::Other(format!("Failed to create media dir: {}", error)))?;
attachments
.into_iter()
.map(|att| {
// If content_base64 exists, save to file and update path
if let Some(base64_content) = &att.content_base64 {
let decoded = STANDARD
.decode(base64_content)
.map_err(|error| AgentError::Other(format!("Failed to decode base64: {}", error)))?;
let filename = build_media_filename(&att.media_type, att.file_name.as_deref());
let file_path = media_dir.join(&filename);
std::fs::write(&file_path, decoded)
.map_err(|error| AgentError::Other(format!("Failed to write media file: {}", error)))?;
tracing::info!(
filename = %filename,
media_type = %att.media_type,
file_path = %file_path.to_string_lossy(),
"Saved WebSocket media to local file"
);
Ok(MediaItem {
path: file_path.to_string_lossy().to_string(),
media_type: att.media_type,
mime_type: att.mime_type,
original_key: None,
// Keep content_base64 for frontend display/download
content_base64: att.content_base64,
file_name: att.file_name,
})
} else {
// No base64 content, keep original path (should already be valid)
Ok(MediaItem {
path: att.path,
media_type: att.media_type,
mime_type: att.mime_type,
original_key: None,
content_base64: None,
file_name: att.file_name,
})
}
})
.collect()
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async { ws.on_upgrade(|socket| async {
handle_socket(socket, state).await; handle_socket(socket, state).await;
@ -214,18 +296,8 @@ async fn handle_inbound(
) )
.await; .await;
// 将协议层 attachments 转换为内部 MediaItem // Process attachments: save base64 content to local files and build MediaItems with correct paths
let media: Vec<MediaItem> = attachments let media = process_attachments_with_base64(attachments)?;
.iter()
.map(|a| MediaItem {
path: a.path.clone(),
media_type: a.media_type.clone(),
mime_type: a.mime_type.clone(),
original_key: None,
content_base64: a.content_base64.clone(),
file_name: a.file_name.clone(),
})
.collect();
state state
.bus .bus
@ -504,6 +576,57 @@ async fn send_task_messages(
fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbound> { fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbound> {
use crate::bus::message::ToolMessageState; use crate::bus::message::ToolMessageState;
// Helper function to strip media_refs_json from content
fn strip_media_refs_json(content: &str) -> String {
// Remove the media_refs_json suffix if present
if let Some(pos) = content.find("\n\nmedia_refs_json:") {
content[..pos].to_string()
} else {
content.to_string()
}
}
// Build attachments from media_refs, reading file content for base64
let attachments: Vec<MediaSummary> = msg
.media_refs
.iter()
.filter_map(|path| {
// Try to read file and encode as base64
let file_content = std::fs::read(path).ok()?;
let base64_content = STANDARD.encode(&file_content);
// Guess mime type from path
let mime_type = mime_guess::from_path(path)
.first_raw()
.map(ToOwned::to_owned);
// Determine media type from mime type
let media_type = mime_type
.as_ref()
.map(|m| {
if m.starts_with("image/") { "image" }
else if m.starts_with("audio/") { "audio" }
else if m.starts_with("video/") { "video" }
else { "file" }
})
.unwrap_or("file");
// Get file name from path
let file_name = std::path::Path::new(path)
.file_name()
.and_then(|name| name.to_str())
.map(ToOwned::to_owned);
Some(MediaSummary {
path: path.clone(),
media_type: media_type.to_string(),
mime_type,
content_base64: Some(base64_content),
file_name,
})
})
.collect();
match msg.role.as_str() { match msg.role.as_str() {
"assistant" => { "assistant" => {
if let Some(tool_calls) = &msg.tool_calls { if let Some(tool_calls) = &msg.tool_calls {
@ -553,9 +676,9 @@ fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbou
} }
"user" => Some(WsOutbound::AssistantResponse { "user" => Some(WsOutbound::AssistantResponse {
id: msg.id.clone(), id: msg.id.clone(),
content: msg.content.clone(), content: strip_media_refs_json(&msg.content),
role: msg.role.clone(), role: msg.role.clone(),
attachments: Vec::new(), attachments,
subagent_task_id: None, subagent_task_id: None,
}), }),
_ => None, _ => None,
@ -564,7 +687,9 @@ fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbou
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::resolve_ws_sender_id; use super::{resolve_ws_sender_id, build_media_filename, process_attachments_with_base64};
use crate::protocol::MediaSummary;
use base64::{Engine as _, engine::general_purpose::STANDARD};
#[test] #[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() { fn test_resolve_ws_sender_id_prefers_inbound_sender() {
@ -583,4 +708,72 @@ mod tests {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
} }
#[test]
fn test_build_media_filename_preserves_original_name() {
let filename = build_media_filename("image", Some("photo.png"));
assert!(filename.ends_with("_photo.png"));
// UUID is 36 chars, plus underscore and original name
assert!(filename.len() >= 36 + 1 + "photo.png".len());
}
#[test]
fn test_build_media_filename_generates_default_when_no_name() {
let filename = build_media_filename("image", None);
assert!(filename.starts_with("image_"));
}
#[test]
fn test_process_attachments_with_base64_saves_to_file() {
let test_content = "test image content";
let base64_content = STANDARD.encode(test_content.as_bytes());
let attachments = vec![MediaSummary {
path: "test_image.png".to_string(),
media_type: "image".to_string(),
mime_type: Some("image/png".to_string()),
content_base64: Some(base64_content.clone()),
file_name: Some("test_image.png".to_string()),
}];
let result = process_attachments_with_base64(attachments).unwrap();
// Verify path is now a full path
assert!(result[0].path.contains(".picobot"));
assert!(result[0].path.contains("media"));
assert!(result[0].path.contains("ws"));
// Verify content_base64 is kept for frontend display
assert!(result[0].content_base64.is_some());
assert_eq!(result[0].content_base64.as_ref().unwrap(), &base64_content);
// Verify file was actually written
let file_content = std::fs::read(&result[0].path).unwrap();
assert_eq!(file_content, test_content.as_bytes());
// Cleanup
std::fs::remove_file(&result[0].path).ok();
}
#[test]
fn test_process_attachments_without_base64_keeps_original_path() {
let attachments = vec![MediaSummary {
path: "/existing/path/image.jpg".to_string(),
media_type: "image".to_string(),
mime_type: Some("image/jpeg".to_string()),
content_base64: None,
file_name: Some("image.jpg".to_string()),
}];
let result = process_attachments_with_base64(attachments).unwrap();
// Path should remain unchanged
assert_eq!(result[0].path, "/existing/path/image.jpg");
}
#[test]
fn test_process_empty_attachments_returns_empty_vec() {
let result = process_attachments_with_base64(Vec::new()).unwrap();
assert!(result.is_empty());
}
} }

View File

@ -70,6 +70,7 @@ fn test_command_inbound_serialization() {
fn test_message_inbound_serialization() { fn test_message_inbound_serialization() {
let msg = WsInbound::Message { let msg = WsInbound::Message {
content: "Hello world".to_string(), content: "Hello world".to_string(),
attachments: Vec::new(),
channel: None, channel: None,
chat_id: Some("session-1".to_string()), chat_id: Some("session-1".to_string()),
sender_id: Some("user-1".to_string()), sender_id: Some("user-1".to_string()),

View File

@ -126,7 +126,7 @@ function App() {
handleCommand(cmd) handleCommand(cmd)
sendMessage({ type: 'command', payload: JSON.stringify(cmd) }) sendMessage({ type: 'command', payload: JSON.stringify(cmd) })
} else { } else {
handleMessage(content) handleMessage(content, attachments)
sendMessage({ sendMessage({
type: 'message', type: 'message',
content, content,

View File

@ -122,6 +122,59 @@ export function MessageInput({
setAttachments(prev => prev.filter((_, i) => i !== index)) setAttachments(prev => prev.filter((_, i) => i !== index))
} }
// 粘贴事件处理
const handlePaste = async (e: React.ClipboardEvent) => {
if (disabled || isReadOnly) return
const clipboardData = e.clipboardData
const items = clipboardData.items
// 检查是否有文件(图片或其他文件)
const files: File[] = []
for (const item of Array.from(items)) {
if (item.kind === 'file') {
const file = item.getAsFile()
if (file) {
files.push(file)
}
}
}
// 如果有文件,处理文件并阻止默认粘贴行为
if (files.length > 0) {
e.preventDefault()
// 直接处理文件数组
setError(null)
for (const file of files) {
if (file.size > MAX_FILE_SIZE) {
setError(`文件 "${file.name}" 超过 50MB 限制`)
continue
}
const base64 = await readFileAsBase64(file)
const mimeType = file.type || 'application/octet-stream'
const mediaType = getMediaType(mimeType)
const attachment: Attachment = {
path: file.name,
media_type: mediaType,
mime_type: mimeType,
content_base64: base64,
file_name: file.name,
}
const fileAttachment: FileAttachment = {
file,
attachment,
preview: mediaType === 'image' ? base64 : undefined,
}
setAttachments(prev => [...prev, fileAttachment])
}
}
// 否则让默认的文本粘贴行为继续
}
// 拖拽事件 // 拖拽事件
const handleDragEnter = (e: React.DragEvent) => { const handleDragEnter = (e: React.DragEvent) => {
e.preventDefault() e.preventDefault()
@ -134,7 +187,12 @@ export function MessageInput({
const handleDragLeave = (e: React.DragEvent) => { const handleDragLeave = (e: React.DragEvent) => {
e.preventDefault() e.preventDefault()
e.stopPropagation() e.stopPropagation()
setIsDragging(false) // 检查是否真的离开了拖拽区域(而不是进入子元素)
const relatedTarget = e.relatedTarget as Node | null
const currentTarget = e.currentTarget
if (!relatedTarget || !currentTarget.contains(relatedTarget)) {
setIsDragging(false)
}
} }
const handleDragOver = (e: React.DragEvent) => { const handleDragOver = (e: React.DragEvent) => {
@ -299,6 +357,7 @@ export function MessageInput({
value={content} value={content}
onChange={(e) => setContent(e.target.value)} onChange={(e) => setContent(e.target.value)}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
onPaste={handlePaste}
placeholder={placeholder} placeholder={placeholder}
disabled={disabled} disabled={disabled}
rows={1} rows={1}
@ -323,7 +382,7 @@ export function MessageInput({
{/* 提示 */} {/* 提示 */}
<div className="mt-2 text-center text-xs text-zinc-500"> <div className="mt-2 text-center text-xs text-zinc-500">
Enter Shift+Enter · · 50MB Enter Shift+Enter · / · 50MB
</div> </div>
</div> </div>
</div> </div>

View File

@ -14,6 +14,7 @@ import type {
TopicSummary, TopicSummary,
Session, Session,
TaskMessagesLoaded, TaskMessagesLoaded,
Attachment,
} from '../types/protocol' } from '../types/protocol'
// 简化后的层级状态 // 简化后的层级状态
@ -40,7 +41,7 @@ interface UseChatReturn {
subAgentView: SubAgentView | null subAgentView: SubAgentView | null
// 方法 // 方法
handleMessage: (content: string) => void handleMessage: (content: string, attachments?: Attachment[]) => void
handleCommand: (command: Command) => void handleCommand: (command: Command) => void
clearMessages: () => void clearMessages: () => void
handleServerMessage: (message: WsOutbound) => void handleServerMessage: (message: WsOutbound) => void
@ -388,7 +389,7 @@ export function useChat(): UseChatReturn {
} }
}, []) }, [])
const handleMessage = useCallback((content: string) => { const handleMessage = useCallback((content: string, attachments?: Attachment[]) => {
setMessages((prev) => [ setMessages((prev) => [
...prev, ...prev,
{ {
@ -397,6 +398,7 @@ export function useChat(): UseChatReturn {
content, content,
timestamp: Date.now(), timestamp: Date.now(),
type: 'message', type: 'message',
attachments: attachments || [],
}, },
]) ])
setIsLoading(true) setIsLoading(true)