161 lines
4.8 KiB
Rust
161 lines
4.8 KiB
Rust
use crate::command::adapter::{AdapterError, InputAdapter, OutputAdapter};
|
|
use crate::command::context::AdapterContext;
|
|
use crate::command::response::{CommandResponse, MessageKind};
|
|
use crate::command::Command;
|
|
use crate::protocol::WsOutbound;
|
|
|
|
/// WebSocket 输入适配器
|
|
///
|
|
/// 将 WebSocket 的 JSON 输入直接反序列化为 Command
|
|
pub struct WebSocketInputAdapter;
|
|
|
|
impl WebSocketInputAdapter {
|
|
/// 创建新的 WebSocket 输入适配器
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Default for WebSocketInputAdapter {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl InputAdapter for WebSocketInputAdapter {
|
|
fn try_parse(
|
|
&self,
|
|
input: &str,
|
|
_ctx: AdapterContext,
|
|
) -> Result<Option<Command>, AdapterError> {
|
|
// 尝试将 JSON 反序列化为 Command
|
|
// 如果失败,说明不是 Command 消息,返回 None
|
|
match serde_json::from_str(input) {
|
|
Ok(cmd) => Ok(Some(cmd)),
|
|
Err(_) => Ok(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// WebSocket 输出适配器
|
|
///
|
|
/// 将 CommandResponse 转换为 WsOutbound 消息列表
|
|
pub struct WebSocketOutputAdapter;
|
|
|
|
impl WebSocketOutputAdapter {
|
|
/// 创建新的 WebSocket 输出适配器
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Default for WebSocketOutputAdapter {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl OutputAdapter for WebSocketOutputAdapter {
|
|
type Output = Vec<WsOutbound>;
|
|
|
|
fn adapt(&self, response: CommandResponse) -> Vec<WsOutbound> {
|
|
let mut outbounds = Vec::new();
|
|
|
|
// 如果出错,返回错误消息
|
|
if let Some(error) = response.error {
|
|
outbounds.push(WsOutbound::Error {
|
|
code: error.code,
|
|
message: error.message,
|
|
});
|
|
return outbounds;
|
|
}
|
|
|
|
// 转换响应消息为 WsOutbound
|
|
for msg in &response.messages {
|
|
let outbound = match msg.kind {
|
|
MessageKind::Text => WsOutbound::AssistantResponse {
|
|
id: response.request_id.to_string(),
|
|
content: msg.content.clone(),
|
|
role: "assistant".to_string(),
|
|
},
|
|
MessageKind::Notification => {
|
|
// 根据元数据判断具体类型
|
|
if let Some(session_id) = response.metadata.get("session_id") {
|
|
WsOutbound::SessionCreated {
|
|
session_id: session_id.clone(),
|
|
title: msg.content.clone(),
|
|
}
|
|
} else {
|
|
// 默认通知
|
|
WsOutbound::AssistantResponse {
|
|
id: response.request_id.to_string(),
|
|
content: msg.content.clone(),
|
|
role: "assistant".to_string(),
|
|
}
|
|
}
|
|
}
|
|
MessageKind::Error => WsOutbound::Error {
|
|
code: "RESPONSE_ERROR".to_string(),
|
|
message: msg.content.clone(),
|
|
},
|
|
_ => WsOutbound::AssistantResponse {
|
|
id: response.request_id.to_string(),
|
|
content: msg.content.clone(),
|
|
role: "assistant".to_string(),
|
|
},
|
|
};
|
|
outbounds.push(outbound);
|
|
}
|
|
|
|
outbounds
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_websocket_input_adapter_valid_command() {
|
|
let adapter = WebSocketInputAdapter::new();
|
|
let ctx = AdapterContext::new("test");
|
|
|
|
let json = r#"{"type":"create_session","title":"my session"}"#;
|
|
let result = adapter.try_parse(json, ctx).unwrap();
|
|
|
|
assert!(result.is_some());
|
|
let cmd = result.unwrap();
|
|
assert!(matches!(
|
|
cmd,
|
|
Command::CreateSession {
|
|
title: Some(ref t)
|
|
} if t == "my session"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_websocket_input_adapter_invalid_json() {
|
|
let adapter = WebSocketInputAdapter::new();
|
|
let ctx = AdapterContext::new("test");
|
|
|
|
let json = "not a command";
|
|
let result = adapter.try_parse(json, ctx).unwrap();
|
|
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_websocket_output_adapter_session_created() {
|
|
let adapter = WebSocketOutputAdapter::new();
|
|
let request_id = uuid::Uuid::new_v4();
|
|
let response = CommandResponse::success(request_id)
|
|
.with_message(MessageKind::Notification, "My Session")
|
|
.with_metadata("session_id", "abc123");
|
|
|
|
let outbounds = adapter.adapt(response);
|
|
|
|
assert_eq!(outbounds.len(), 1);
|
|
assert!(matches!(outbounds[0], WsOutbound::SessionCreated { .. }));
|
|
}
|
|
}
|