diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index e92dcf2..7597b75 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -208,6 +208,8 @@ impl AgentExecutionService { } let enriched_content = enrich_user_content_with_media_refs(request.content, &media_refs)?; + enrich_user_content_with_media_refs(request.content, &media_refs)?; + enrich_user_content_with_media_refs(request.content, &media_refs)?; // 先计算 user_message_count(在添加新消息之前) let history_before = session_guard.get_or_create_history(request.chat_id).clone(); diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 5a1520f..b1f7d1b 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -17,18 +17,100 @@ use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::command::handlers::switch_topic::SwitchTopicCommandHandler; use crate::gateway::agent_prompt_provider::AgentPromptProvider; -use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound}; +use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound}; use crate::skills::SkillPromptProvider; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; +use base64::{Engine as _, engine::general_purpose::STANDARD}; use futures_util::{SinkExt, StreamExt}; use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use tokio::sync::mpsc; const CLI_CHANNEL_NAME: &str = "cli"; +/// Default media directory for WebSocket uploads +fn default_ws_media_dir() -> PathBuf { + let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); + home.join(".picobot").join("media").join("ws") +} + +/// Build a unique filename for media upload +fn build_media_filename(media_type: &str, file_name: Option<&str>) -> String { + if let Some(file_name) = file_name { + let sanitized: String = file_name + .chars() + .map(|ch| match ch { + '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_', + _ => ch, + }) + .collect(); + if !sanitized.trim().is_empty() { + return format!("{}_{}", uuid::Uuid::new_v4(), sanitized); + } + } + format!("{}_{}", media_type, uuid::Uuid::new_v4()) +} + +/// Process attachments with base64 content: save to local file and return MediaItem with correct path +/// Keeps content_base64 for frontend display/download +fn process_attachments_with_base64(attachments: Vec) -> Result, AgentError> { + if attachments.is_empty() { + return Ok(Vec::new()); + } + + let media_dir = default_ws_media_dir(); + std::fs::create_dir_all(&media_dir) + .map_err(|error| AgentError::Other(format!("Failed to create media dir: {}", error)))?; + + attachments + .into_iter() + .map(|att| { + // If content_base64 exists, save to file and update path + if let Some(base64_content) = &att.content_base64 { + let decoded = STANDARD + .decode(base64_content) + .map_err(|error| AgentError::Other(format!("Failed to decode base64: {}", error)))?; + + let filename = build_media_filename(&att.media_type, att.file_name.as_deref()); + let file_path = media_dir.join(&filename); + + std::fs::write(&file_path, decoded) + .map_err(|error| AgentError::Other(format!("Failed to write media file: {}", error)))?; + + tracing::info!( + filename = %filename, + media_type = %att.media_type, + file_path = %file_path.to_string_lossy(), + "Saved WebSocket media to local file" + ); + + Ok(MediaItem { + path: file_path.to_string_lossy().to_string(), + media_type: att.media_type, + mime_type: att.mime_type, + original_key: None, + // Keep content_base64 for frontend display/download + content_base64: att.content_base64, + file_name: att.file_name, + }) + } else { + // No base64 content, keep original path (should already be valid) + Ok(MediaItem { + path: att.path, + media_type: att.media_type, + mime_type: att.mime_type, + original_key: None, + content_base64: None, + file_name: att.file_name, + }) + } + }) + .collect() +} + pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; @@ -214,18 +296,8 @@ async fn handle_inbound( ) .await; - // 将协议层 attachments 转换为内部 MediaItem - let media: Vec = attachments - .iter() - .map(|a| MediaItem { - path: a.path.clone(), - media_type: a.media_type.clone(), - mime_type: a.mime_type.clone(), - original_key: None, - content_base64: a.content_base64.clone(), - file_name: a.file_name.clone(), - }) - .collect(); + // Process attachments: save base64 content to local files and build MediaItems with correct paths + let media = process_attachments_with_base64(attachments)?; state .bus @@ -504,6 +576,57 @@ async fn send_task_messages( fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option { use crate::bus::message::ToolMessageState; + // Helper function to strip media_refs_json from content + fn strip_media_refs_json(content: &str) -> String { + // Remove the media_refs_json suffix if present + if let Some(pos) = content.find("\n\nmedia_refs_json:") { + content[..pos].to_string() + } else { + content.to_string() + } + } + + // Build attachments from media_refs, reading file content for base64 + let attachments: Vec = msg + .media_refs + .iter() + .filter_map(|path| { + // Try to read file and encode as base64 + let file_content = std::fs::read(path).ok()?; + let base64_content = STANDARD.encode(&file_content); + + // Guess mime type from path + let mime_type = mime_guess::from_path(path) + .first_raw() + .map(ToOwned::to_owned); + + // Determine media type from mime type + let media_type = mime_type + .as_ref() + .map(|m| { + if m.starts_with("image/") { "image" } + else if m.starts_with("audio/") { "audio" } + else if m.starts_with("video/") { "video" } + else { "file" } + }) + .unwrap_or("file"); + + // Get file name from path + let file_name = std::path::Path::new(path) + .file_name() + .and_then(|name| name.to_str()) + .map(ToOwned::to_owned); + + Some(MediaSummary { + path: path.clone(), + media_type: media_type.to_string(), + mime_type, + content_base64: Some(base64_content), + file_name, + }) + }) + .collect(); + match msg.role.as_str() { "assistant" => { if let Some(tool_calls) = &msg.tool_calls { @@ -553,9 +676,9 @@ fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option Some(WsOutbound::AssistantResponse { id: msg.id.clone(), - content: msg.content.clone(), + content: strip_media_refs_json(&msg.content), role: msg.role.clone(), - attachments: Vec::new(), + attachments, subagent_task_id: None, }), _ => None, @@ -564,7 +687,9 @@ fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Option= 36 + 1 + "photo.png".len()); + } + + #[test] + fn test_build_media_filename_generates_default_when_no_name() { + let filename = build_media_filename("image", None); + assert!(filename.starts_with("image_")); + } + + #[test] + fn test_process_attachments_with_base64_saves_to_file() { + let test_content = "test image content"; + let base64_content = STANDARD.encode(test_content.as_bytes()); + + let attachments = vec![MediaSummary { + path: "test_image.png".to_string(), + media_type: "image".to_string(), + mime_type: Some("image/png".to_string()), + content_base64: Some(base64_content.clone()), + file_name: Some("test_image.png".to_string()), + }]; + + let result = process_attachments_with_base64(attachments).unwrap(); + + // Verify path is now a full path + assert!(result[0].path.contains(".picobot")); + assert!(result[0].path.contains("media")); + assert!(result[0].path.contains("ws")); + + // Verify content_base64 is kept for frontend display + assert!(result[0].content_base64.is_some()); + assert_eq!(result[0].content_base64.as_ref().unwrap(), &base64_content); + + // Verify file was actually written + let file_content = std::fs::read(&result[0].path).unwrap(); + assert_eq!(file_content, test_content.as_bytes()); + + // Cleanup + std::fs::remove_file(&result[0].path).ok(); + } + + #[test] + fn test_process_attachments_without_base64_keeps_original_path() { + let attachments = vec![MediaSummary { + path: "/existing/path/image.jpg".to_string(), + media_type: "image".to_string(), + mime_type: Some("image/jpeg".to_string()), + content_base64: None, + file_name: Some("image.jpg".to_string()), + }]; + + let result = process_attachments_with_base64(attachments).unwrap(); + + // Path should remain unchanged + assert_eq!(result[0].path, "/existing/path/image.jpg"); + } + + #[test] + fn test_process_empty_attachments_returns_empty_vec() { + let result = process_attachments_with_base64(Vec::new()).unwrap(); + assert!(result.is_empty()); + } } diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index 0311948..e86d6c5 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -70,6 +70,7 @@ fn test_command_inbound_serialization() { fn test_message_inbound_serialization() { let msg = WsInbound::Message { content: "Hello world".to_string(), + attachments: Vec::new(), channel: None, chat_id: Some("session-1".to_string()), sender_id: Some("user-1".to_string()), diff --git a/web/src/App.tsx b/web/src/App.tsx index caf9c71..51de7cf 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -126,7 +126,7 @@ function App() { handleCommand(cmd) sendMessage({ type: 'command', payload: JSON.stringify(cmd) }) } else { - handleMessage(content) + handleMessage(content, attachments) sendMessage({ type: 'message', content, diff --git a/web/src/components/Chat/MessageInput.tsx b/web/src/components/Chat/MessageInput.tsx index 568f5d7..104b722 100644 --- a/web/src/components/Chat/MessageInput.tsx +++ b/web/src/components/Chat/MessageInput.tsx @@ -122,6 +122,59 @@ export function MessageInput({ setAttachments(prev => prev.filter((_, i) => i !== index)) } + // 粘贴事件处理 + const handlePaste = async (e: React.ClipboardEvent) => { + if (disabled || isReadOnly) return + + const clipboardData = e.clipboardData + const items = clipboardData.items + + // 检查是否有文件(图片或其他文件) + const files: File[] = [] + for (const item of Array.from(items)) { + if (item.kind === 'file') { + const file = item.getAsFile() + if (file) { + files.push(file) + } + } + } + + // 如果有文件,处理文件并阻止默认粘贴行为 + if (files.length > 0) { + e.preventDefault() + // 直接处理文件数组 + setError(null) + for (const file of files) { + if (file.size > MAX_FILE_SIZE) { + setError(`文件 "${file.name}" 超过 50MB 限制`) + continue + } + + const base64 = await readFileAsBase64(file) + const mimeType = file.type || 'application/octet-stream' + const mediaType = getMediaType(mimeType) + + const attachment: Attachment = { + path: file.name, + media_type: mediaType, + mime_type: mimeType, + content_base64: base64, + file_name: file.name, + } + + const fileAttachment: FileAttachment = { + file, + attachment, + preview: mediaType === 'image' ? base64 : undefined, + } + + setAttachments(prev => [...prev, fileAttachment]) + } + } + // 否则让默认的文本粘贴行为继续 + } + // 拖拽事件 const handleDragEnter = (e: React.DragEvent) => { e.preventDefault() @@ -134,7 +187,12 @@ export function MessageInput({ const handleDragLeave = (e: React.DragEvent) => { e.preventDefault() e.stopPropagation() - setIsDragging(false) + // 检查是否真的离开了拖拽区域(而不是进入子元素) + const relatedTarget = e.relatedTarget as Node | null + const currentTarget = e.currentTarget + if (!relatedTarget || !currentTarget.contains(relatedTarget)) { + setIsDragging(false) + } } const handleDragOver = (e: React.DragEvent) => { @@ -299,6 +357,7 @@ export function MessageInput({ value={content} onChange={(e) => setContent(e.target.value)} onKeyDown={handleKeyDown} + onPaste={handlePaste} placeholder={placeholder} disabled={disabled} rows={1} @@ -323,7 +382,7 @@ export function MessageInput({ {/* 提示 */}
- 按 Enter 发送,Shift+Enter 换行 · 支持拖拽文件 · 最大 50MB + 按 Enter 发送,Shift+Enter 换行 · 支持拖拽/粘贴文件 · 最大 50MB
diff --git a/web/src/hooks/useChat.ts b/web/src/hooks/useChat.ts index 84fee6e..b220a91 100644 --- a/web/src/hooks/useChat.ts +++ b/web/src/hooks/useChat.ts @@ -14,6 +14,7 @@ import type { TopicSummary, Session, TaskMessagesLoaded, + Attachment, } from '../types/protocol' // 简化后的层级状态 @@ -40,7 +41,7 @@ interface UseChatReturn { subAgentView: SubAgentView | null // 方法 - handleMessage: (content: string) => void + handleMessage: (content: string, attachments?: Attachment[]) => void handleCommand: (command: Command) => void clearMessages: () => void handleServerMessage: (message: WsOutbound) => void @@ -388,7 +389,7 @@ export function useChat(): UseChatReturn { } }, []) - const handleMessage = useCallback((content: string) => { + const handleMessage = useCallback((content: string, attachments?: Attachment[]) => { setMessages((prev) => [ ...prev, { @@ -397,6 +398,7 @@ export function useChat(): UseChatReturn { content, timestamp: Date.now(), type: 'message', + attachments: attachments || [], }, ]) setIsLoading(true)