PicoBot/src/protocol/ws_adapter.rs

232 lines
8.3 KiB
Rust

#[cfg(test)]
use crate::bus::ChatMessage;
use crate::bus::OutboundMessage;
use crate::bus::message::OutboundEventKind;
#[cfg(test)]
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use super::WsOutbound;
const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。";
#[cfg(test)]
pub(crate) fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
let mut outbound = Vec::new();
if !message.content.trim().is_empty() {
outbound.push(WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
});
}
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
}],
},
_ => Vec::new(),
}
}
pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
match message.event_kind {
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
vec![WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
arguments: message
.tool_arguments
.clone()
.unwrap_or(serde_json::Value::Null),
content: message.content.clone(),
role: message.role.clone(),
}],
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
}],
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: message.content.clone(),
}],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::messages::ToolCall;
use serde_json::json;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
let message = ChatMessage::assistant_with_tool_calls(
"日报已整理完成。",
vec![ToolCall {
id: "call-1".to_string(),
name: "memory_manage".to_string(),
arguments: json!({"action": "put"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 2);
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
}
#[test]
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
let message = OutboundMessage::tool_call(
"cli",
"session-1",
None, // session_id
"call-1",
"calculator",
json!({"expression": "1 + 1"}),
None,
Default::default(),
);
let outbound = ws_outbound_from_outbound_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
}