use crate::command::adapter::{AdapterError, InputAdapter, OutputAdapter}; use crate::command::context::AdapterContext; use crate::command::response::{CommandResponse, MessageKind}; use crate::command::Command; use crate::protocol::WsOutbound; /// WebSocket 输入适配器 /// /// 将 WebSocket 的 JSON 输入直接反序列化为 Command pub struct WebSocketInputAdapter; impl WebSocketInputAdapter { /// 创建新的 WebSocket 输入适配器 pub fn new() -> Self { Self } } impl Default for WebSocketInputAdapter { fn default() -> Self { Self::new() } } impl InputAdapter for WebSocketInputAdapter { fn try_parse( &self, input: &str, _ctx: AdapterContext, ) -> Result, AdapterError> { // 尝试将 JSON 反序列化为 Command // 如果失败,说明不是 Command 消息,返回 None match serde_json::from_str(input) { Ok(cmd) => Ok(Some(cmd)), Err(_) => Ok(None), } } } /// WebSocket 输出适配器 /// /// 将 CommandResponse 转换为 WsOutbound 消息列表 pub struct WebSocketOutputAdapter; impl WebSocketOutputAdapter { /// 创建新的 WebSocket 输出适配器 pub fn new() -> Self { Self } } impl Default for WebSocketOutputAdapter { fn default() -> Self { Self::new() } } impl OutputAdapter for WebSocketOutputAdapter { type Output = Vec; fn adapt(&self, response: CommandResponse) -> Vec { let mut outbounds = Vec::new(); // 如果出错,返回错误消息 if let Some(error) = response.error { outbounds.push(WsOutbound::Error { code: error.code, message: error.message, }); return outbounds; } // 转换响应消息为 WsOutbound for msg in &response.messages { let outbound = match msg.kind { MessageKind::Text => WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), }, MessageKind::Notification => { // 根据元数据判断具体类型 if let Some(topics_json) = response.metadata.get("topics") { // Topic 列表响应 - 优先检查 topics match serde_json::from_str::>(topics_json) { Ok(topics) => { let session_id = response.metadata.get("session_id") .cloned() .unwrap_or_default(); WsOutbound::TopicList { topics, session_id, } } Err(_) => WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), }, } } else if let Some(session_id) = response.metadata.get("session_id") { // 有 session_id 但没有 topic_id 的是创建会话 if response.metadata.get("topic_id").is_none() { WsOutbound::SessionCreated { session_id: session_id.clone(), title: msg.content.clone(), } } else { // 加载会话 let message_count = response.metadata.get("message_count") .and_then(|s| s.parse().ok()) .unwrap_or(0); WsOutbound::SessionLoaded { session_id: session_id.clone(), title: msg.content.clone(), message_count, } } } else if let Some(topic_id) = response.metadata.get("topic_id") { // 只有 topic_id,可能是加载话题 let message_count = response.metadata.get("message_count") .and_then(|s| s.parse().ok()) .unwrap_or(0); WsOutbound::SessionLoaded { session_id: topic_id.clone(), title: msg.content.clone(), message_count, } } else if let Some(channels_json) = response.metadata.get("channels") { // 通道列表响应 match serde_json::from_str::>(channels_json) { Ok(channels) => WsOutbound::ChannelList { channels }, Err(_) => WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), }, } } else if let Some(sessions_json) = response.metadata.get("sessions") { // 会话列表响应 match serde_json::from_str::>(sessions_json) { Ok(sessions) => { let channel_name = response.metadata.get("channel_name").cloned(); WsOutbound::SessionList { sessions, current_session_id: None, channel_name, } } Err(_) => WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), }, } } else if let Some(topics_json) = response.metadata.get("topics") { // Topic 列表响应 match serde_json::from_str::>(topics_json) { Ok(topics) => { let session_id = response.metadata.get("session_id") .cloned() .unwrap_or_default(); WsOutbound::TopicList { topics, session_id, } } Err(_) => WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), }, } } else { // 默认通知 WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), } } } MessageKind::Error => WsOutbound::Error { code: "RESPONSE_ERROR".to_string(), message: msg.content.clone(), }, _ => WsOutbound::AssistantResponse { id: response.request_id.to_string(), content: msg.content.clone(), role: "assistant".to_string(), }, }; outbounds.push(outbound); } outbounds } } #[cfg(test)] mod tests { use super::*; #[test] fn test_websocket_input_adapter_valid_command() { let adapter = WebSocketInputAdapter::new(); let ctx = AdapterContext::new("test"); let json = r#"{"type":"create_session","title":"my session"}"#; let result = adapter.try_parse(json, ctx).unwrap(); assert!(result.is_some()); let cmd = result.unwrap(); assert!(matches!( cmd, Command::CreateSession { title: Some(ref t) } if t == "my session" )); } #[test] fn test_websocket_input_adapter_invalid_json() { let adapter = WebSocketInputAdapter::new(); let ctx = AdapterContext::new("test"); let json = "not a command"; let result = adapter.try_parse(json, ctx).unwrap(); assert!(result.is_none()); } #[test] fn test_websocket_output_adapter_session_created() { let adapter = WebSocketOutputAdapter::new(); let request_id = uuid::Uuid::new_v4(); let response = CommandResponse::success(request_id) .with_message(MessageKind::Notification, "My Session") .with_metadata("session_id", "abc123"); let outbounds = adapter.adapt(response); assert_eq!(outbounds.len(), 1); assert!(matches!(outbounds[0], WsOutbound::SessionCreated { .. })); } }