feat: WebSocket 媒体文件处理优化
- 后端 ws.rs: 处理前端上传的 base64 内容,保存到本地文件并更新路径 - 后端 ws.rs: 历史消息加载时从文件读取内容填充 base64,过滤 media_refs_json - 前端 App.tsx: 传递 attachments 给 handleMessage 实现实时显示 - 前端 useChat.ts: handleMessage 支持 attachments 参数 - 前端 MessageInput.tsx: 支持剪贴板粘贴文件/图片 - 前端 MessageInput.tsx: 修复拖拽文件时闪烁问题 - 测试 test_request_format.rs: 补充缺失的 attachments 字段
This commit is contained in:
parent
c2293238fc
commit
7d9355fd78
@ -208,6 +208,8 @@ impl AgentExecutionService {
|
||||
}
|
||||
let enriched_content =
|
||||
enrich_user_content_with_media_refs(request.content, &media_refs)?;
|
||||
enrich_user_content_with_media_refs(request.content, &media_refs)?;
|
||||
enrich_user_content_with_media_refs(request.content, &media_refs)?;
|
||||
|
||||
// 先计算 user_message_count(在添加新消息之前)
|
||||
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
|
||||
|
||||
@ -17,18 +17,100 @@ use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||
use crate::command::handlers::session::SessionCommandHandler;
|
||||
use crate::command::handlers::switch_topic::SwitchTopicCommandHandler;
|
||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||
use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound};
|
||||
use crate::skills::SkillPromptProvider;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
use axum::response::Response;
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const CLI_CHANNEL_NAME: &str = "cli";
|
||||
|
||||
/// Default media directory for WebSocket uploads
|
||||
fn default_ws_media_dir() -> PathBuf {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
home.join(".picobot").join("media").join("ws")
|
||||
}
|
||||
|
||||
/// Build a unique filename for media upload
|
||||
fn build_media_filename(media_type: &str, file_name: Option<&str>) -> String {
|
||||
if let Some(file_name) = file_name {
|
||||
let sanitized: String = file_name
|
||||
.chars()
|
||||
.map(|ch| match ch {
|
||||
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
|
||||
_ => ch,
|
||||
})
|
||||
.collect();
|
||||
if !sanitized.trim().is_empty() {
|
||||
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
|
||||
}
|
||||
}
|
||||
format!("{}_{}", media_type, uuid::Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// Process attachments with base64 content: save to local file and return MediaItem with correct path
|
||||
/// Keeps content_base64 for frontend display/download
|
||||
fn process_attachments_with_base64(attachments: Vec<MediaSummary>) -> Result<Vec<MediaItem>, AgentError> {
|
||||
if attachments.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let media_dir = default_ws_media_dir();
|
||||
std::fs::create_dir_all(&media_dir)
|
||||
.map_err(|error| AgentError::Other(format!("Failed to create media dir: {}", error)))?;
|
||||
|
||||
attachments
|
||||
.into_iter()
|
||||
.map(|att| {
|
||||
// If content_base64 exists, save to file and update path
|
||||
if let Some(base64_content) = &att.content_base64 {
|
||||
let decoded = STANDARD
|
||||
.decode(base64_content)
|
||||
.map_err(|error| AgentError::Other(format!("Failed to decode base64: {}", error)))?;
|
||||
|
||||
let filename = build_media_filename(&att.media_type, att.file_name.as_deref());
|
||||
let file_path = media_dir.join(&filename);
|
||||
|
||||
std::fs::write(&file_path, decoded)
|
||||
.map_err(|error| AgentError::Other(format!("Failed to write media file: {}", error)))?;
|
||||
|
||||
tracing::info!(
|
||||
filename = %filename,
|
||||
media_type = %att.media_type,
|
||||
file_path = %file_path.to_string_lossy(),
|
||||
"Saved WebSocket media to local file"
|
||||
);
|
||||
|
||||
Ok(MediaItem {
|
||||
path: file_path.to_string_lossy().to_string(),
|
||||
media_type: att.media_type,
|
||||
mime_type: att.mime_type,
|
||||
original_key: None,
|
||||
// Keep content_base64 for frontend display/download
|
||||
content_base64: att.content_base64,
|
||||
file_name: att.file_name,
|
||||
})
|
||||
} else {
|
||||
// No base64 content, keep original path (should already be valid)
|
||||
Ok(MediaItem {
|
||||
path: att.path,
|
||||
media_type: att.media_type,
|
||||
mime_type: att.mime_type,
|
||||
original_key: None,
|
||||
content_base64: None,
|
||||
file_name: att.file_name,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||
ws.on_upgrade(|socket| async {
|
||||
handle_socket(socket, state).await;
|
||||
@ -214,18 +296,8 @@ async fn handle_inbound(
|
||||
)
|
||||
.await;
|
||||
|
||||
// 将协议层 attachments 转换为内部 MediaItem
|
||||
let media: Vec<MediaItem> = attachments
|
||||
.iter()
|
||||
.map(|a| MediaItem {
|
||||
path: a.path.clone(),
|
||||
media_type: a.media_type.clone(),
|
||||
mime_type: a.mime_type.clone(),
|
||||
original_key: None,
|
||||
content_base64: a.content_base64.clone(),
|
||||
file_name: a.file_name.clone(),
|
||||
})
|
||||
.collect();
|
||||
// Process attachments: save base64 content to local files and build MediaItems with correct paths
|
||||
let media = process_attachments_with_base64(attachments)?;
|
||||
|
||||
state
|
||||
.bus
|
||||
@ -504,6 +576,57 @@ async fn send_task_messages(
|
||||
fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbound> {
|
||||
use crate::bus::message::ToolMessageState;
|
||||
|
||||
// Helper function to strip media_refs_json from content
|
||||
fn strip_media_refs_json(content: &str) -> String {
|
||||
// Remove the media_refs_json suffix if present
|
||||
if let Some(pos) = content.find("\n\nmedia_refs_json:") {
|
||||
content[..pos].to_string()
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
// Build attachments from media_refs, reading file content for base64
|
||||
let attachments: Vec<MediaSummary> = msg
|
||||
.media_refs
|
||||
.iter()
|
||||
.filter_map(|path| {
|
||||
// Try to read file and encode as base64
|
||||
let file_content = std::fs::read(path).ok()?;
|
||||
let base64_content = STANDARD.encode(&file_content);
|
||||
|
||||
// Guess mime type from path
|
||||
let mime_type = mime_guess::from_path(path)
|
||||
.first_raw()
|
||||
.map(ToOwned::to_owned);
|
||||
|
||||
// Determine media type from mime type
|
||||
let media_type = mime_type
|
||||
.as_ref()
|
||||
.map(|m| {
|
||||
if m.starts_with("image/") { "image" }
|
||||
else if m.starts_with("audio/") { "audio" }
|
||||
else if m.starts_with("video/") { "video" }
|
||||
else { "file" }
|
||||
})
|
||||
.unwrap_or("file");
|
||||
|
||||
// Get file name from path
|
||||
let file_name = std::path::Path::new(path)
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.map(ToOwned::to_owned);
|
||||
|
||||
Some(MediaSummary {
|
||||
path: path.clone(),
|
||||
media_type: media_type.to_string(),
|
||||
mime_type,
|
||||
content_base64: Some(base64_content),
|
||||
file_name,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
match msg.role.as_str() {
|
||||
"assistant" => {
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
@ -553,9 +676,9 @@ fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbou
|
||||
}
|
||||
"user" => Some(WsOutbound::AssistantResponse {
|
||||
id: msg.id.clone(),
|
||||
content: msg.content.clone(),
|
||||
content: strip_media_refs_json(&msg.content),
|
||||
role: msg.role.clone(),
|
||||
attachments: Vec::new(),
|
||||
attachments,
|
||||
subagent_task_id: None,
|
||||
}),
|
||||
_ => None,
|
||||
@ -564,7 +687,9 @@ fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option<WsOutbou
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::resolve_ws_sender_id;
|
||||
use super::{resolve_ws_sender_id, build_media_filename, process_attachments_with_base64};
|
||||
use crate::protocol::MediaSummary;
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
|
||||
#[test]
|
||||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||||
@ -583,4 +708,72 @@ mod tests {
|
||||
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
||||
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_media_filename_preserves_original_name() {
|
||||
let filename = build_media_filename("image", Some("photo.png"));
|
||||
assert!(filename.ends_with("_photo.png"));
|
||||
// UUID is 36 chars, plus underscore and original name
|
||||
assert!(filename.len() >= 36 + 1 + "photo.png".len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_media_filename_generates_default_when_no_name() {
|
||||
let filename = build_media_filename("image", None);
|
||||
assert!(filename.starts_with("image_"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_attachments_with_base64_saves_to_file() {
|
||||
let test_content = "test image content";
|
||||
let base64_content = STANDARD.encode(test_content.as_bytes());
|
||||
|
||||
let attachments = vec![MediaSummary {
|
||||
path: "test_image.png".to_string(),
|
||||
media_type: "image".to_string(),
|
||||
mime_type: Some("image/png".to_string()),
|
||||
content_base64: Some(base64_content.clone()),
|
||||
file_name: Some("test_image.png".to_string()),
|
||||
}];
|
||||
|
||||
let result = process_attachments_with_base64(attachments).unwrap();
|
||||
|
||||
// Verify path is now a full path
|
||||
assert!(result[0].path.contains(".picobot"));
|
||||
assert!(result[0].path.contains("media"));
|
||||
assert!(result[0].path.contains("ws"));
|
||||
|
||||
// Verify content_base64 is kept for frontend display
|
||||
assert!(result[0].content_base64.is_some());
|
||||
assert_eq!(result[0].content_base64.as_ref().unwrap(), &base64_content);
|
||||
|
||||
// Verify file was actually written
|
||||
let file_content = std::fs::read(&result[0].path).unwrap();
|
||||
assert_eq!(file_content, test_content.as_bytes());
|
||||
|
||||
// Cleanup
|
||||
std::fs::remove_file(&result[0].path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_attachments_without_base64_keeps_original_path() {
|
||||
let attachments = vec![MediaSummary {
|
||||
path: "/existing/path/image.jpg".to_string(),
|
||||
media_type: "image".to_string(),
|
||||
mime_type: Some("image/jpeg".to_string()),
|
||||
content_base64: None,
|
||||
file_name: Some("image.jpg".to_string()),
|
||||
}];
|
||||
|
||||
let result = process_attachments_with_base64(attachments).unwrap();
|
||||
|
||||
// Path should remain unchanged
|
||||
assert_eq!(result[0].path, "/existing/path/image.jpg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_empty_attachments_returns_empty_vec() {
|
||||
let result = process_attachments_with_base64(Vec::new()).unwrap();
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,6 +70,7 @@ fn test_command_inbound_serialization() {
|
||||
fn test_message_inbound_serialization() {
|
||||
let msg = WsInbound::Message {
|
||||
content: "Hello world".to_string(),
|
||||
attachments: Vec::new(),
|
||||
channel: None,
|
||||
chat_id: Some("session-1".to_string()),
|
||||
sender_id: Some("user-1".to_string()),
|
||||
|
||||
@ -126,7 +126,7 @@ function App() {
|
||||
handleCommand(cmd)
|
||||
sendMessage({ type: 'command', payload: JSON.stringify(cmd) })
|
||||
} else {
|
||||
handleMessage(content)
|
||||
handleMessage(content, attachments)
|
||||
sendMessage({
|
||||
type: 'message',
|
||||
content,
|
||||
|
||||
@ -122,6 +122,59 @@ export function MessageInput({
|
||||
setAttachments(prev => prev.filter((_, i) => i !== index))
|
||||
}
|
||||
|
||||
// 粘贴事件处理
|
||||
const handlePaste = async (e: React.ClipboardEvent) => {
|
||||
if (disabled || isReadOnly) return
|
||||
|
||||
const clipboardData = e.clipboardData
|
||||
const items = clipboardData.items
|
||||
|
||||
// 检查是否有文件(图片或其他文件)
|
||||
const files: File[] = []
|
||||
for (const item of Array.from(items)) {
|
||||
if (item.kind === 'file') {
|
||||
const file = item.getAsFile()
|
||||
if (file) {
|
||||
files.push(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有文件,处理文件并阻止默认粘贴行为
|
||||
if (files.length > 0) {
|
||||
e.preventDefault()
|
||||
// 直接处理文件数组
|
||||
setError(null)
|
||||
for (const file of files) {
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
setError(`文件 "${file.name}" 超过 50MB 限制`)
|
||||
continue
|
||||
}
|
||||
|
||||
const base64 = await readFileAsBase64(file)
|
||||
const mimeType = file.type || 'application/octet-stream'
|
||||
const mediaType = getMediaType(mimeType)
|
||||
|
||||
const attachment: Attachment = {
|
||||
path: file.name,
|
||||
media_type: mediaType,
|
||||
mime_type: mimeType,
|
||||
content_base64: base64,
|
||||
file_name: file.name,
|
||||
}
|
||||
|
||||
const fileAttachment: FileAttachment = {
|
||||
file,
|
||||
attachment,
|
||||
preview: mediaType === 'image' ? base64 : undefined,
|
||||
}
|
||||
|
||||
setAttachments(prev => [...prev, fileAttachment])
|
||||
}
|
||||
}
|
||||
// 否则让默认的文本粘贴行为继续
|
||||
}
|
||||
|
||||
// 拖拽事件
|
||||
const handleDragEnter = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
@ -134,8 +187,13 @@ export function MessageInput({
|
||||
const handleDragLeave = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
// 检查是否真的离开了拖拽区域(而不是进入子元素)
|
||||
const relatedTarget = e.relatedTarget as Node | null
|
||||
const currentTarget = e.currentTarget
|
||||
if (!relatedTarget || !currentTarget.contains(relatedTarget)) {
|
||||
setIsDragging(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleDragOver = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
@ -299,6 +357,7 @@ export function MessageInput({
|
||||
value={content}
|
||||
onChange={(e) => setContent(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
onPaste={handlePaste}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
rows={1}
|
||||
@ -323,7 +382,7 @@ export function MessageInput({
|
||||
|
||||
{/* 提示 */}
|
||||
<div className="mt-2 text-center text-xs text-zinc-500">
|
||||
按 Enter 发送,Shift+Enter 换行 · 支持拖拽文件 · 最大 50MB
|
||||
按 Enter 发送,Shift+Enter 换行 · 支持拖拽/粘贴文件 · 最大 50MB
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -14,6 +14,7 @@ import type {
|
||||
TopicSummary,
|
||||
Session,
|
||||
TaskMessagesLoaded,
|
||||
Attachment,
|
||||
} from '../types/protocol'
|
||||
|
||||
// 简化后的层级状态
|
||||
@ -40,7 +41,7 @@ interface UseChatReturn {
|
||||
subAgentView: SubAgentView | null
|
||||
|
||||
// 方法
|
||||
handleMessage: (content: string) => void
|
||||
handleMessage: (content: string, attachments?: Attachment[]) => void
|
||||
handleCommand: (command: Command) => void
|
||||
clearMessages: () => void
|
||||
handleServerMessage: (message: WsOutbound) => void
|
||||
@ -388,7 +389,7 @@ export function useChat(): UseChatReturn {
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleMessage = useCallback((content: string) => {
|
||||
const handleMessage = useCallback((content: string, attachments?: Attachment[]) => {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
@ -397,6 +398,7 @@ export function useChat(): UseChatReturn {
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
type: 'message',
|
||||
attachments: attachments || [],
|
||||
},
|
||||
])
|
||||
setIsLoading(true)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user