PicoBot/tests/test_request_format.rs

185 lines
6.0 KiB
Rust

use picobot::protocol::{WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
/// Test that message with special characters is properly escaped
#[test]
fn test_message_special_characters() {
let msg = Message::user("Hello \"world\"\nNew line\tTab");
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.role, "user");
assert_eq!(deserialized.content.len(), 1);
let encoded = serde_json::to_string(&deserialized.content).unwrap();
assert!(encoded.contains("Hello \\\"world\\\"\\nNew line\\tTab"));
}
/// Test that multi-line system prompt is preserved
#[test]
fn test_multiline_system_prompt() {
let messages = vec![
Message::system(
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
),
Message::user("Hi"),
];
let json = serde_json::to_string(&messages[0]).unwrap();
assert!(json.contains("helpful assistant"));
assert!(json.contains("rules"));
assert!(json.contains("1. Be kind"));
}
/// Test ChatCompletionRequest serialization (without model field)
#[test]
fn test_chat_request_serialization() {
let request = ChatCompletionRequest {
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
temperature: Some(0.7),
max_tokens: Some(100),
tools: None,
};
let json = serde_json::to_string(&request).unwrap();
// Verify structure
assert!(json.contains(r#""role":"system""#));
assert!(json.contains(r#""role":"user""#));
assert!(json.contains("You are helpful"));
assert!(json.contains("Hello"));
assert!(json.contains(r#""temperature":0.7"#));
assert!(json.contains(r#""max_tokens":100"#));
}
#[test]
fn test_command_inbound_serialization() {
// Command is now sent as payload in WsInbound::Command
let command_json = r#"{"type":"create_session","title":"demo"}"#;
let msg = WsInbound::Command {
payload: command_json.to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"command""#));
assert!(json.contains(r#""payload":""#));
assert!(json.contains(r#"create_session"#));
}
#[test]
fn test_message_inbound_serialization() {
let msg = WsInbound::Message {
content: "Hello world".to_string(),
attachments: Vec::new(),
channel: None,
chat_id: Some("session-1".to_string()),
sender_id: Some("user-1".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"message""#));
assert!(json.contains(r#""content":"Hello world""#));
assert!(json.contains(r#""chat_id":"session-1""#));
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
match decoded {
WsInbound::Message { content, chat_id, .. } => {
assert_eq!(content, "Hello world");
assert_eq!(chat_id.as_deref(), Some("session-1"));
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_session_created_outbound_serialization() {
let msg = WsOutbound::SessionCreated {
session_id: "session-1".to_string(),
title: "demo".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"session_created""#));
assert!(json.contains(r#""session_id":"session-1""#));
assert!(json.contains(r#""title":"demo""#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::SessionCreated { session_id, title } => {
assert_eq!(session_id, "session-1");
assert_eq!(title, "demo");
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_tool_call_outbound_serialization() {
let msg = WsOutbound::ToolCall {
id: "msg-1".to_string(),
tool_call_id: "call-1".to_string(),
tool_name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1 + 1"}),
content: "调用工具: calculator".to_string(),
role: "assistant".to_string(),
subagent_task_id: None,
topic_id: None,
timestamp: None,
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"tool_call""#));
assert!(json.contains(r#""tool_name":"calculator""#));
assert!(json.contains(r#""expression":"1 + 1""#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_tool_result_outbound_serialization() {
let msg = WsOutbound::ToolResult {
id: "msg-2".to_string(),
tool_call_id: "call-1".to_string(),
tool_name: "calculator".to_string(),
content: "工具结果: calculator\n\n2".to_string(),
role: "tool".to_string(),
subagent_task_id: None,
duration_ms: None,
topic_id: None,
timestamp: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"tool_result""#));
assert!(json.contains(r#""tool_name":"calculator""#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::ToolResult {
tool_call_id,
tool_name,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert!(content.contains('2'));
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}