PicoBot/src/command/adapters/websocket.rs

161 lines
4.8 KiB
Rust

use crate::command::adapter::{AdapterError, InputAdapter, OutputAdapter};
use crate::command::context::AdapterContext;
use crate::command::response::{CommandResponse, MessageKind};
use crate::command::Command;
use crate::protocol::WsOutbound;
/// WebSocket 输入适配器
///
/// 将 WebSocket 的 JSON 输入直接反序列化为 Command
pub struct WebSocketInputAdapter;
impl WebSocketInputAdapter {
/// 创建新的 WebSocket 输入适配器
pub fn new() -> Self {
Self
}
}
impl Default for WebSocketInputAdapter {
fn default() -> Self {
Self::new()
}
}
impl InputAdapter for WebSocketInputAdapter {
fn try_parse(
&self,
input: &str,
_ctx: AdapterContext,
) -> Result<Option<Command>, AdapterError> {
// 尝试将 JSON 反序列化为 Command
// 如果失败,说明不是 Command 消息,返回 None
match serde_json::from_str(input) {
Ok(cmd) => Ok(Some(cmd)),
Err(_) => Ok(None),
}
}
}
/// WebSocket 输出适配器
///
/// 将 CommandResponse 转换为 WsOutbound 消息列表
pub struct WebSocketOutputAdapter;
impl WebSocketOutputAdapter {
/// 创建新的 WebSocket 输出适配器
pub fn new() -> Self {
Self
}
}
impl Default for WebSocketOutputAdapter {
fn default() -> Self {
Self::new()
}
}
impl OutputAdapter for WebSocketOutputAdapter {
type Output = Vec<WsOutbound>;
fn adapt(&self, response: CommandResponse) -> Vec<WsOutbound> {
let mut outbounds = Vec::new();
// 如果出错,返回错误消息
if let Some(error) = response.error {
outbounds.push(WsOutbound::Error {
code: error.code,
message: error.message,
});
return outbounds;
}
// 转换响应消息为 WsOutbound
for msg in &response.messages {
let outbound = match msg.kind {
MessageKind::Text => WsOutbound::AssistantResponse {
id: response.request_id.to_string(),
content: msg.content.clone(),
role: "assistant".to_string(),
},
MessageKind::Notification => {
// 根据元数据判断具体类型
if let Some(session_id) = response.metadata.get("session_id") {
WsOutbound::SessionCreated {
session_id: session_id.clone(),
title: msg.content.clone(),
}
} else {
// 默认通知
WsOutbound::AssistantResponse {
id: response.request_id.to_string(),
content: msg.content.clone(),
role: "assistant".to_string(),
}
}
}
MessageKind::Error => WsOutbound::Error {
code: "RESPONSE_ERROR".to_string(),
message: msg.content.clone(),
},
_ => WsOutbound::AssistantResponse {
id: response.request_id.to_string(),
content: msg.content.clone(),
role: "assistant".to_string(),
},
};
outbounds.push(outbound);
}
outbounds
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_websocket_input_adapter_valid_command() {
let adapter = WebSocketInputAdapter::new();
let ctx = AdapterContext::new("test");
let json = r#"{"type":"create_session","title":"my session"}"#;
let result = adapter.try_parse(json, ctx).unwrap();
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(
cmd,
Command::CreateSession {
title: Some(ref t)
} if t == "my session"
));
}
#[test]
fn test_websocket_input_adapter_invalid_json() {
let adapter = WebSocketInputAdapter::new();
let ctx = AdapterContext::new("test");
let json = "not a command";
let result = adapter.try_parse(json, ctx).unwrap();
assert!(result.is_none());
}
#[test]
fn test_websocket_output_adapter_session_created() {
let adapter = WebSocketOutputAdapter::new();
let request_id = uuid::Uuid::new_v4();
let response = CommandResponse::success(request_id)
.with_message(MessageKind::Notification, "My Session")
.with_metadata("session_id", "abc123");
let outbounds = adapter.adapt(response);
assert_eq!(outbounds.len(), 1);
assert!(matches!(outbounds[0], WsOutbound::SessionCreated { .. }));
}
}