use super::GatewayState; use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::{InboundMessage, MediaItem}; use crate::command::adapter::{InputAdapter, OutputAdapter}; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; use crate::command::handlers::delete_topic::DeleteTopicCommandHandler; use crate::command::handlers::get_current::GetCurrentSessionCommandHandler; use crate::command::handlers::help::HelpCommandHandler; use crate::command::handlers::list_channels::ListChannelsCommandHandler; use crate::command::handlers::list_memories::ListMemoriesCommandHandler; use crate::command::handlers::list_skills::ListSkillsCommandHandler; use crate::command::handlers::list_scheduler_jobs::ListSchedulerJobsCommandHandler; use crate::command::handlers::memory_crud::MemoryCrudCommandHandler; use crate::command::handlers::list_sessions::ListSessionsCommandHandler; use crate::command::handlers::list_sessions_by_channel::ListSessionsByChannelCommandHandler; use crate::command::handlers::list_topics::ListTopicsCommandHandler; use crate::command::handlers::load_chat_messages::LoadChatMessagesCommandHandler; use crate::command::handlers::load_task_messages::LoadTaskMessagesCommandHandler; use crate::command::handlers::load_topic::LoadTopicCommandHandler; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::command::handlers::stop_execution::StopExecutionCommandHandler; use crate::command::handlers::switch_topic::SwitchTopicCommandHandler; use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound}; use crate::skills::SkillPromptProvider; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use base64::{Engine as _, engine::general_purpose::STANDARD}; use futures_util::{SinkExt, StreamExt}; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::mpsc; const CLI_CHANNEL_NAME: &str = "cli"; /// Default media directory for WebSocket uploads fn default_ws_media_dir() -> PathBuf { let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); home.join(".picobot").join("media").join("ws") } /// Build a unique filename for media upload fn build_media_filename(media_type: &str, file_name: Option<&str>) -> String { if let Some(file_name) = file_name { let sanitized: String = file_name .chars() .map(|ch| match ch { '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_', _ => ch, }) .collect(); if !sanitized.trim().is_empty() { return format!("{}_{}", uuid::Uuid::new_v4(), sanitized); } } format!("{}_{}", media_type, uuid::Uuid::new_v4()) } /// Process attachments with base64 content: save to local file and return MediaItem with correct path /// Keeps content_base64 for frontend display/download fn process_attachments_with_base64(attachments: Vec) -> Result, AgentError> { if attachments.is_empty() { return Ok(Vec::new()); } let media_dir = default_ws_media_dir(); std::fs::create_dir_all(&media_dir) .map_err(|error| AgentError::Other(format!("Failed to create media dir: {}", error)))?; attachments .into_iter() .map(|att| { // If content_base64 exists, save to file and update path if let Some(base64_content) = &att.content_base64 { let decoded = STANDARD .decode(base64_content) .map_err(|error| AgentError::Other(format!("Failed to decode base64: {}", error)))?; let filename = build_media_filename(&att.media_type, att.file_name.as_deref()); let file_path = media_dir.join(&filename); std::fs::write(&file_path, decoded) .map_err(|error| AgentError::Other(format!("Failed to write media file: {}", error)))?; tracing::info!( filename = %filename, media_type = %att.media_type, file_path = %file_path.to_string_lossy(), "Saved WebSocket media to local file" ); Ok(MediaItem { path: file_path.to_string_lossy().to_string(), media_type: att.media_type, mime_type: att.mime_type, original_key: None, // Keep content_base64 for frontend display/download content_base64: att.content_base64, file_name: att.file_name, }) } else { // No base64 content, keep original path (should already be valid) Ok(MediaItem { path: att.path, media_type: att.media_type, mime_type: att.mime_type, original_key: None, content_base64: None, file_name: att.file_name, }) } }) .collect() } pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; }) } async fn handle_socket(ws: WebSocket, state: Arc) { let (sender, receiver) = mpsc::channel::(100); let cli_sessions = state.session_manager.cli_sessions(); let store = state.session_manager.store(); // 1. 先查询 websocket 通道的 Sessions let websocket_sessions = store.list_sessions("websocket", false) .unwrap_or_default(); // 2. 如果没有,自动创建一个默认 Session let initial_record = if websocket_sessions.is_empty() { match cli_sessions.create_with_channel("websocket", Some("默认会话")) { Ok(record) => record, Err(e) => { tracing::error!(error = %e, "Failed to create initial WebSocket session"); return; } } } else { // 使用最新的 Session websocket_sessions[0].clone() }; let runtime_session_id = uuid::Uuid::new_v4().to_string(); let mut current_session_id = initial_record.id.clone(); let mut current_topic_id: Option = None; state .channel_manager .cli_channel() .register_connection( current_session_id.clone(), runtime_session_id.clone(), sender.clone(), ) .await; tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "WebSocket session established"); let _ = sender .send(WsOutbound::SessionEstablished { session_id: current_session_id.clone(), }) .await; // 连接建立后立即发送通道列表(合并 websocket + ChannelManager 动态通道) let channels = state.channel_manager.build_channel_list().await; let _ = sender .send(WsOutbound::ChannelList { channels }) .await; // 3. 重新查询 websocket 通道的 Session 列表(包含刚创建的) let final_sessions = store.list_sessions("websocket", false) .unwrap_or_default(); tracing::info!("Sending {} websocket sessions to client", final_sessions.len()); for s in &final_sessions { tracing::info!(" - {}: {} (channel: {})", s.id, s.title, s.channel_name); } let session_summaries: Vec = final_sessions .into_iter() .map(|s| crate::protocol::SessionSummary { session_id: s.id, title: s.title, channel_name: s.channel_name, chat_id: s.chat_id, message_count: s.message_count, last_active_at: s.last_active_at, archived_at: s.archived_at, }) .collect(); let _ = sender .send(WsOutbound::SessionList { sessions: session_summaries, current_session_id: Some(current_session_id.clone()), channel_name: Some("websocket".to_string()), }) .await; let (mut ws_sender, mut ws_receiver) = ws.split(); let mut receiver = receiver; let session_id_for_sender = runtime_session_id.clone(); tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) { if ws_sender.send(WsMessage::Text(text.into())).await.is_err() { #[cfg(debug_assertions)] tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error"); break; } } } }); while let Some(msg) = ws_receiver.next().await { match msg { Ok(WsMessage::Text(text)) => { let text = text.to_string(); match parse_inbound(&text) { Ok(inbound) => { if let Err(e) = handle_inbound( &state, &sender, &runtime_session_id, &mut current_session_id, &mut current_topic_id, inbound, ) .await { tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); let _ = sender .send(WsOutbound::Error { timestamp: Some(crate::protocol::now_timestamp()), code:"SESSION_ERROR".to_string(), message: e.to_string(), }) .await; } } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); let _ = sender .send(WsOutbound::Error { timestamp: Some(crate::protocol::now_timestamp()), code:"PARSE_ERROR".to_string(), message: e.to_string(), }) .await; } } } Ok(WsMessage::Close(_)) | Err(_) => { #[cfg(debug_assertions)] tracing::debug!(session_id = %runtime_session_id, "WebSocket closed"); break; } _ => {} } } state .channel_manager .cli_channel() .unregister_connection(&runtime_session_id) .await; tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } async fn handle_inbound( state: &Arc, sender: &mpsc::Sender, runtime_session_id: &str, current_session_id: &mut String, current_topic_id: &mut Option, inbound: WsInbound, ) -> Result<(), AgentError> { match inbound { WsInbound::Message { content, attachments, chat_id, sender_id, .. } => { let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); state .channel_manager .cli_channel() .register_connection( chat_id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; // Process attachments: save base64 content to local files and build MediaItems with correct paths let media = process_attachments_with_base64(attachments)?; state .bus .publish_inbound(InboundMessage { channel: CLI_CHANNEL_NAME.to_string(), sender_id, chat_id, content, timestamp: current_timestamp(), media, metadata: HashMap::new(), forwarded_metadata: HashMap::new(), }) .await .map_err(|error| AgentError::Other(error.to_string()))?; Ok(()) } WsInbound::Command { payload } => { // 使用 Command 系统处理命令 let input_adapter = WebSocketInputAdapter::new(); let output_adapter = WebSocketOutputAdapter::new(); // 解析命令 let adapter_ctx = crate::command::context::AdapterContext::new("websocket") .with_session_id(current_session_id.as_str()); let cmd = match input_adapter.try_parse(&payload, adapter_ctx) { Ok(Some(cmd)) => cmd, Ok(None) => { // 不是命令,返回错误 let _ = sender .send(WsOutbound::Error { timestamp: Some(crate::protocol::now_timestamp()), code: "INVALID_COMMAND".to_string(), message: "Invalid command payload".to_string(), }) .await; return Ok(()); } Err(e) => { let _ = sender .send(WsOutbound::Error { timestamp: Some(crate::protocol::now_timestamp()), code: "PARSE_ERROR".to_string(), message: e.to_string(), }) .await; return Ok(()); } }; // 创建命令路由器 let _cli_sessions = state.session_manager.cli_sessions(); let store = state.session_manager.store(); let skills = state.session_manager.skills(); let skills_for_handler = skills.clone(); let provider_config = state.config.get_provider_config("default") .map_err(|e| AgentError::Other(e.to_string()))?; let prompt_repository = state.session_manager.store().clone(); let system_prompt_provider: Arc = Arc::new(CompositeSystemPromptProvider::new(vec![ Box::new(AgentPromptProvider::new( 0, provider_config.clone(), prompt_repository.clone(), )), Box::new(SkillPromptProvider::new(skills)), ])); let mut router = CommandRouter::new(); // 注册 Session 处理器 let session_handler = SessionCommandHandler::new(store.clone()) .with_session_manager(state.session_manager.clone()); router.register(Box::new(session_handler)); // 注册 list_sessions 处理器 router.register(Box::new(ListSessionsCommandHandler::new(store.clone()))); // 注册 list_sessions_by_channel 处理器 router.register(Box::new(ListSessionsByChannelCommandHandler::new(store.clone()))); // 注册 list_channels 处理器 router.register(Box::new(ListChannelsCommandHandler::new(Arc::new(state.channel_manager.clone())))); // 注册 list_topics 处理器 router.register(Box::new(ListTopicsCommandHandler::new(store.clone()))); // 注册 switch_topic 处理器 let switch_handler = SwitchTopicCommandHandler::new(store.clone()) .with_session_manager(state.session_manager.clone()); router.register(Box::new(switch_handler)); // 注册 get_current 处理器 router.register(Box::new(GetCurrentSessionCommandHandler::new(store.clone()))); // 注册 load_topic 处理器 router.register(Box::new(LoadTopicCommandHandler::new(store.clone()))); // 注册 load_task_messages 处理器 router.register(Box::new(LoadTaskMessagesCommandHandler::new( state.task_repository.clone(), store.clone(), ))); router.register(Box::new(SaveSessionCommandHandler::new( store.clone(), state.task_repository.clone(), system_prompt_provider.clone(), ))); // 注册 delete_topic 处理器 router.register(Box::new( DeleteTopicCommandHandler::new(store.clone()) .with_session_manager(state.session_manager.clone()), )); // 注册 help 处理器 let metadata = router.metadata_arc(); router.register(Box::new(HelpCommandHandler::new(metadata))); // 注册 list_scheduler_jobs 处理器 router.register(Box::new(ListSchedulerJobsCommandHandler::new(store.clone()))); // 注册 list_memories 处理器 router.register(Box::new(ListMemoriesCommandHandler::new(store.clone()))); // 注册 list_skills 处理器 router.register(Box::new(ListSkillsCommandHandler::new(skills_for_handler))); // 注册 memory_crud 处理器 router.register(Box::new(MemoryCrudCommandHandler::new(store.clone()))); // 注册 load_chat_messages 处理器 router.register(Box::new(LoadChatMessagesCommandHandler::new())); // 注册 stop_execution 处理器 router.register(Box::new(StopExecutionCommandHandler::new( state.cancel_manager.clone(), ))); // 构建命令上下文 tracing::debug!( current_session_id = %current_session_id, current_topic_id = ?current_topic_id, "Building CommandContext for WebSocket command" ); let mut cmd_ctx = CommandContext::new("websocket", "cli") .with_session_id(current_session_id.as_str()) .with_chat_id(current_session_id.as_str()); // 只在有 topic_id 时才设置 if let Some(ref topic_id) = *current_topic_id { cmd_ctx = cmd_ctx.with_topic_id(topic_id.as_str()); } // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; // 处理响应 if response.success { // 更新当前会话 ID(如果是创建会话) if let Some(session_id) = response.metadata.get("session_id") { tracing::info!( old_session_id = %current_session_id, new_session_id = %session_id, "Updating current_session_id" ); *current_session_id = session_id.clone(); state .channel_manager .cli_channel() .register_connection( session_id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; } // 更新当前话题 ID(如果是创建话题或切换话题) if let Some(topic_id) = response.metadata.get("topic_id") { tracing::info!( old_topic_id = ?current_topic_id, new_topic_id = %topic_id, "Updating current_topic_id" ); *current_topic_id = Some(topic_id.clone()); // 加载并发送该话题的历史消息 if let Err(e) = send_topic_history(&store, current_session_id, topic_id, sender).await { tracing::warn!(error = %e, topic_id = %topic_id, "Failed to send topic history"); } } // 加载子智能体任务消息 if let Some(task_session_id) = response.metadata.get("task_session_id") { // 提前提取 task_id,用于给历史消息打标记 let task_id = response.metadata.get("task_id").cloned().unwrap_or_default(); if let Err(e) = send_task_messages(&store, task_session_id, sender, Some(task_id.clone())).await { tracing::warn!(error = %e, task_session_id = %task_session_id, "Failed to send task messages"); } // 发送 TaskMessagesLoaded 元数据 let description = response.metadata.get("task_description").cloned().unwrap_or_default(); let subagent_type = response.metadata.get("task_subagent_type").cloned().unwrap_or_default(); let status = response.metadata.get("task_status").cloned().unwrap_or_default(); let summary = response.metadata.get("task_summary").cloned(); let _ = sender.send(WsOutbound::TaskMessagesLoaded { task_id, description, subagent_type, status, summary, }).await; } // 处理定时任务列表 if let Some(jobs_json) = response.metadata.get("scheduler_jobs") { if let Ok(jobs) = serde_json::from_str::>(jobs_json) { let _ = sender.send(WsOutbound::SchedulerJobList { jobs }).await; } } // 处理技能列表 if let Some(skills_json) = response.metadata.get("skills") { if let Ok(skills) = serde_json::from_str::>(skills_json) { let _ = sender.send(WsOutbound::SkillList { skills }).await; } } // 处理记忆列表 if let Some(memories_json) = response.metadata.get("memories") { if let Ok(memories) = serde_json::from_str::>(memories_json) { let _ = sender.send(WsOutbound::MemoryList { memories }).await; } } // 记忆 CRUD 后自动刷新列表 if response.metadata.get("memory_updated").map(|v| v.as_str()) == Some("true") { if let Ok(records) = store.list_memories_for_scope("user", crate::storage::GLOBAL_SCOPE_KEY) { let memories: Vec = records .into_iter() .filter(|m| m.namespace != "_meta") .map(|m| crate::protocol::MemorySummary { id: m.id, namespace: m.namespace, memory_key: m.memory_key, content: m.content, created_at: m.created_at, updated_at: m.updated_at, }) .collect(); let _ = sender.send(WsOutbound::MemoryList { memories }).await; } } // 处理加载聊天消息请求 if let Some(load_chat_id) = response.metadata.get("load_chat_id") { let load_chat_channel = response.metadata.get("load_chat_channel") .cloned() .unwrap_or_default(); // session_id = "{channel}:{chat_id}" (cli channel 例外) let session_id = crate::storage::persistent_session_id( &load_chat_channel, load_chat_id, ); if let Err(e) = send_task_messages(&store, &session_id, sender, None).await { tracing::warn!( error = %e, channel = %load_chat_channel, chat_id = %load_chat_id, session_id = %session_id, "Failed to send chat messages" ); } } if current_topic_id.is_none() { if let Some(topics_json) = response.metadata.get("topics") { match serde_json::from_str::>(topics_json) { Ok(topics) => { if let Some(first_topic) = topics.first() { let topic_id = first_topic.topic_id.clone(); *current_topic_id = Some(topic_id.clone()); if let Err(e) = send_topic_history(&store, current_session_id, &topic_id, sender).await { tracing::warn!(error = %e, topic_id = %topic_id, "Failed to send initial topic history"); } } } Err(e) => { tracing::warn!(error = %e, "Failed to parse topics metadata for initial history"); } } } } } else if let Some(ref error) = response.error { tracing::warn!( error_code = %error.code, error_message = %error.message, "Command failed" ); } // 适配并发送响应 let outbounds = output_adapter.adapt(response); for msg in outbounds { let _ = sender.send(msg).await; } Ok(()) } WsInbound::Ping => { let _ = sender.send(WsOutbound::Pong).await; Ok(()) } } } fn current_timestamp() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_millis() as i64 } fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String { sender_id .map(str::trim) .filter(|sender_id| !sender_id.is_empty()) .map(ToOwned::to_owned) .unwrap_or_else(|| runtime_session_id.to_string()) } /// 加载并发送话题历史消息 async fn send_topic_history( store: &Arc, _session_id: &str, topic_id: &str, sender: &mpsc::Sender, ) -> Result<(), Box> { // 加载话题消息 let messages = store.load_messages_for_topic(topic_id)?; tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history"); // 将消息转换为 WsOutbound 并发送 for msg in messages { for outbound in chat_message_to_ws_outbound(&msg) { let _ = sender.send(outbound).await; } } Ok(()) } /// 加载并发送子智能体任务的历史消息 async fn send_task_messages( store: &Arc, session_id: &str, sender: &mpsc::Sender, subagent_task_id: Option, ) -> Result<(), Box> { let messages = store.load_messages(session_id)?; tracing::info!(session_id = %session_id, message_count = messages.len(), "Sending task messages"); for msg in messages { let mut outbounds = chat_message_to_ws_outbound(&msg); if let Some(ref task_id) = subagent_task_id { for ob in &mut outbounds { set_subagent_task_id(ob, task_id); } } for outbound in outbounds { let _ = sender.send(outbound).await; } } Ok(()) } /// 给 WsOutbound 消息注入 subagent_task_id(仅对有该字段的变体生效) fn set_subagent_task_id(outbound: &mut WsOutbound, task_id: &str) { match outbound { WsOutbound::AssistantResponse { subagent_task_id, .. } | WsOutbound::ToolCall { subagent_task_id, .. } | WsOutbound::ToolResult { subagent_task_id, .. } | WsOutbound::ToolPending { subagent_task_id, .. } => { *subagent_task_id = Some(task_id.to_string()); } _ => {} // 其他变体没有 subagent_task_id 字段 } } /// 将 ChatMessage 转换为 WsOutbound 列表 fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Vec { use crate::bus::message::ToolMessageState; // Helper function to strip media_refs_json from content fn strip_media_refs_json(content: &str) -> String { // Remove the media_refs_json suffix if present if let Some(pos) = content.find("\n\nmedia_refs_json:") { content[..pos].to_string() } else { content.to_string() } } // Build attachments from media_refs, reading file content for base64 let attachments: Vec = msg .media_refs .iter() .filter_map(|path| { // Try to read file and encode as base64 let file_content = std::fs::read(path).ok()?; let base64_content = STANDARD.encode(&file_content); // Guess mime type from path let mime_type = mime_guess::from_path(path) .first_raw() .map(ToOwned::to_owned); // Determine media type from mime type let media_type = mime_type .as_ref() .map(|m| { if m.starts_with("image/") { "image" } else if m.starts_with("audio/") { "audio" } else if m.starts_with("video/") { "video" } else { "file" } }) .unwrap_or("file"); // Get file name from path let file_name = std::path::Path::new(path) .file_name() .and_then(|name| name.to_str()) .map(ToOwned::to_owned); Some(MediaSummary { path: path.clone(), media_type: media_type.to_string(), mime_type, content_base64: Some(base64_content), file_name, }) }) .collect(); match msg.role.as_str() { "assistant" => { if let Some(tool_calls) = &msg.tool_calls { let mut outbound = Vec::new(); let has_content_or_reasoning = !msg.content.trim().is_empty() || msg.reasoning_content.is_some(); if has_content_or_reasoning { outbound.push(WsOutbound::AssistantResponse { id: msg.id.clone(), content: msg.content.clone(), role: msg.role.clone(), attachments: Vec::new(), subagent_task_id: None, topic_id: None, timestamp: Some(msg.timestamp / 1000), reasoning_content: msg.reasoning_content.clone(), }); } // AssistantResponse 已携带 reasoning 时,ToolCall 不再重复 let tc_reasoning = if has_content_or_reasoning { None } else { msg.reasoning_content.clone() }; for tool_call in tool_calls { outbound.push(WsOutbound::ToolCall { id: msg.id.clone(), tool_call_id: tool_call.id.clone(), tool_name: tool_call.name.clone(), arguments: tool_call.arguments.clone(), content: format!("{}\nargs: {}", tool_call.name, tool_call.arguments), role: msg.role.clone(), subagent_task_id: None, topic_id: None, timestamp: Some(msg.timestamp / 1000), reasoning_content: tc_reasoning.clone(), }); } outbound } else { // 普通助手消息 vec![WsOutbound::AssistantResponse { id: msg.id.clone(), content: msg.content.clone(), role: msg.role.clone(), attachments: Vec::new(), subagent_task_id: None, topic_id: None, timestamp: Some(msg.timestamp / 1000), reasoning_content: msg.reasoning_content.clone(), }] } } "tool" => { let tool_state = msg.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed); match tool_state { ToolMessageState::Completed => vec![WsOutbound::ToolResult { id: msg.id.clone(), tool_call_id: msg.tool_call_id.clone().unwrap_or_default(), tool_name: msg.tool_name.clone().unwrap_or_default(), content: msg.content.clone(), role: msg.role.clone(), subagent_task_id: None, topic_id: None, duration_ms: msg.tool_duration_ms, timestamp: Some(msg.timestamp / 1000), }], ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending { id: msg.id.clone(), tool_call_id: msg.tool_call_id.clone().unwrap_or_default(), tool_name: msg.tool_name.clone().unwrap_or_default(), content: msg.content.clone(), role: msg.role.clone(), resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(), subagent_task_id: None, topic_id: None, timestamp: Some(msg.timestamp / 1000), }], } } "user" => vec![WsOutbound::AssistantResponse { id: msg.id.clone(), content: strip_media_refs_json(&msg.content), role: msg.role.clone(), attachments, subagent_task_id: None, topic_id: None, timestamp: Some(msg.timestamp / 1000), reasoning_content: None, }], _ => Vec::new(), } } #[cfg(test)] mod tests { use super::{resolve_ws_sender_id, build_media_filename, process_attachments_with_base64}; use crate::protocol::MediaSummary; use base64::{Engine as _, engine::general_purpose::STANDARD}; #[test] fn test_resolve_ws_sender_id_prefers_inbound_sender() { assert_eq!( resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42" ); assert_eq!( resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42" ); } #[test] fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() { assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); } #[test] fn test_build_media_filename_preserves_original_name() { let filename = build_media_filename("image", Some("photo.png")); assert!(filename.ends_with("_photo.png")); // UUID is 36 chars, plus underscore and original name assert!(filename.len() >= 36 + 1 + "photo.png".len()); } #[test] fn test_build_media_filename_generates_default_when_no_name() { let filename = build_media_filename("image", None); assert!(filename.starts_with("image_")); } #[test] fn test_process_attachments_with_base64_saves_to_file() { let test_content = "test image content"; let base64_content = STANDARD.encode(test_content.as_bytes()); let attachments = vec![MediaSummary { path: "test_image.png".to_string(), media_type: "image".to_string(), mime_type: Some("image/png".to_string()), content_base64: Some(base64_content.clone()), file_name: Some("test_image.png".to_string()), }]; let result = process_attachments_with_base64(attachments).unwrap(); // Verify path is now a full path assert!(result[0].path.contains(".picobot")); assert!(result[0].path.contains("media")); assert!(result[0].path.contains("ws")); // Verify content_base64 is kept for frontend display assert!(result[0].content_base64.is_some()); assert_eq!(result[0].content_base64.as_ref().unwrap(), &base64_content); // Verify file was actually written let file_content = std::fs::read(&result[0].path).unwrap(); assert_eq!(file_content, test_content.as_bytes()); // Cleanup std::fs::remove_file(&result[0].path).ok(); } #[test] fn test_process_attachments_without_base64_keeps_original_path() { let attachments = vec![MediaSummary { path: "/existing/path/image.jpg".to_string(), media_type: "image".to_string(), mime_type: Some("image/jpeg".to_string()), content_base64: None, file_name: Some("image.jpg".to_string()), }]; let result = process_attachments_with_base64(attachments).unwrap(); // Path should remain unchanged assert_eq!(result[0].path, "/existing/path/image.jpg"); } #[test] fn test_process_empty_attachments_returns_empty_vec() { let result = process_attachments_with_base64(Vec::new()).unwrap(); assert!(result.is_empty()); } }