930 lines
38 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use super::GatewayState;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::{InboundMessage, MediaItem};
use crate::command::adapter::{InputAdapter, OutputAdapter};
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
use crate::command::context::CommandContext;
use crate::command::handler::CommandRouter;
use crate::command::handlers::delete_topic::DeleteTopicCommandHandler;
use crate::command::handlers::get_current::GetCurrentSessionCommandHandler;
use crate::command::handlers::help::HelpCommandHandler;
use crate::command::handlers::list_channels::ListChannelsCommandHandler;
use crate::command::handlers::list_memories::ListMemoriesCommandHandler;
use crate::command::handlers::list_skills::ListSkillsCommandHandler;
use crate::command::handlers::list_scheduler_jobs::ListSchedulerJobsCommandHandler;
use crate::command::handlers::memory_crud::MemoryCrudCommandHandler;
use crate::command::handlers::list_sessions::ListSessionsCommandHandler;
use crate::command::handlers::list_sessions_by_channel::ListSessionsByChannelCommandHandler;
use crate::command::handlers::list_topics::ListTopicsCommandHandler;
use crate::command::handlers::load_chat_messages::LoadChatMessagesCommandHandler;
use crate::command::handlers::load_task_messages::LoadTaskMessagesCommandHandler;
use crate::command::handlers::load_topic::LoadTopicCommandHandler;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::stop_execution::StopExecutionCommandHandler;
use crate::command::handlers::switch_topic::SwitchTopicCommandHandler;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound};
use crate::skills::SkillPromptProvider;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use base64::{Engine as _, engine::general_purpose::STANDARD};
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::mpsc;
const CLI_CHANNEL_NAME: &str = "cli";
/// Default media directory for WebSocket uploads
fn default_ws_media_dir() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("media").join("ws")
}
/// Build a unique filename for media upload
fn build_media_filename(media_type: &str, file_name: Option<&str>) -> String {
if let Some(file_name) = file_name {
let sanitized: String = file_name
.chars()
.map(|ch| match ch {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
_ => ch,
})
.collect();
if !sanitized.trim().is_empty() {
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
}
}
format!("{}_{}", media_type, uuid::Uuid::new_v4())
}
/// Process attachments with base64 content: save to local file and return MediaItem with correct path
/// Keeps content_base64 for frontend display/download
fn process_attachments_with_base64(attachments: Vec<MediaSummary>) -> Result<Vec<MediaItem>, AgentError> {
if attachments.is_empty() {
return Ok(Vec::new());
}
let media_dir = default_ws_media_dir();
std::fs::create_dir_all(&media_dir)
.map_err(|error| AgentError::Other(format!("Failed to create media dir: {}", error)))?;
attachments
.into_iter()
.map(|att| {
// If content_base64 exists, save to file and update path
if let Some(base64_content) = &att.content_base64 {
let decoded = STANDARD
.decode(base64_content)
.map_err(|error| AgentError::Other(format!("Failed to decode base64: {}", error)))?;
let filename = build_media_filename(&att.media_type, att.file_name.as_deref());
let file_path = media_dir.join(&filename);
std::fs::write(&file_path, decoded)
.map_err(|error| AgentError::Other(format!("Failed to write media file: {}", error)))?;
tracing::info!(
filename = %filename,
media_type = %att.media_type,
file_path = %file_path.to_string_lossy(),
"Saved WebSocket media to local file"
);
Ok(MediaItem {
path: file_path.to_string_lossy().to_string(),
media_type: att.media_type,
mime_type: att.mime_type,
original_key: None,
// Keep content_base64 for frontend display/download
content_base64: att.content_base64,
file_name: att.file_name,
})
} else {
// No base64 content, keep original path (should already be valid)
Ok(MediaItem {
path: att.path,
media_type: att.media_type,
mime_type: att.mime_type,
original_key: None,
content_base64: None,
file_name: att.file_name,
})
}
})
.collect()
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let cli_sessions = state.session_manager.cli_sessions();
let store = state.session_manager.store();
// 1. 先查询 websocket 通道的 Sessions
let websocket_sessions = store.list_sessions("websocket", false)
.unwrap_or_default();
// 2. 如果没有,自动创建一个默认 Session
let initial_record = if websocket_sessions.is_empty() {
match cli_sessions.create_with_channel("websocket", Some("默认会话")) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial WebSocket session");
return;
}
}
} else {
// 使用最新的 Session
websocket_sessions[0].clone()
};
let runtime_session_id = uuid::Uuid::new_v4().to_string();
let mut current_session_id = initial_record.id.clone();
let mut current_topic_id: Option<String> = None;
state
.channel_manager
.cli_channel()
.register_connection(
current_session_id.clone(),
runtime_session_id.clone(),
sender.clone(),
)
.await;
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "WebSocket session established");
let _ = sender
.send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(),
})
.await;
// 连接建立后立即发送通道列表(合并 websocket + ChannelManager 动态通道)
let channels = state.channel_manager.build_channel_list().await;
let _ = sender
.send(WsOutbound::ChannelList { channels })
.await;
// 3. 重新查询 websocket 通道的 Session 列表(包含刚创建的)
let final_sessions = store.list_sessions("websocket", false)
.unwrap_or_default();
tracing::info!("Sending {} websocket sessions to client", final_sessions.len());
for s in &final_sessions {
tracing::info!(" - {}: {} (channel: {})", s.id, s.title, s.channel_name);
}
let session_summaries: Vec<crate::protocol::SessionSummary> = final_sessions
.into_iter()
.map(|s| crate::protocol::SessionSummary {
session_id: s.id,
title: s.title,
channel_name: s.channel_name,
chat_id: s.chat_id,
message_count: s.message_count,
last_active_at: s.last_active_at,
archived_at: s.archived_at,
})
.collect();
let _ = sender
.send(WsOutbound::SessionList {
sessions: session_summaries,
current_session_id: Some(current_session_id.clone()),
channel_name: Some("websocket".to_string()),
})
.await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = runtime_session_id.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
if let Err(e) = handle_inbound(
&state,
&sender,
&runtime_session_id,
&mut current_session_id,
&mut current_topic_id,
inbound,
)
.await
{
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = sender
.send(WsOutbound::Error {
timestamp: Some(crate::protocol::now_timestamp()),
code:"SESSION_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = sender
.send(WsOutbound::Error {
timestamp: Some(crate::protocol::now_timestamp()),
code:"PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
break;
}
_ => {}
}
}
state
.channel_manager
.cli_channel()
.unregister_connection(&runtime_session_id)
.await;
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
async fn handle_inbound(
state: &Arc<GatewayState>,
sender: &mpsc::Sender<WsOutbound>,
runtime_session_id: &str,
current_session_id: &mut String,
current_topic_id: &mut Option<String>,
inbound: WsInbound,
) -> Result<(), AgentError> {
match inbound {
WsInbound::Message {
content,
attachments,
chat_id,
sender_id,
..
} => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
state
.channel_manager
.cli_channel()
.register_connection(
chat_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
// Process attachments: save base64 content to local files and build MediaItems with correct paths
let media = process_attachments_with_base64(attachments)?;
state
.bus
.publish_inbound(InboundMessage {
channel: CLI_CHANNEL_NAME.to_string(),
sender_id,
chat_id,
content,
timestamp: current_timestamp(),
media,
metadata: HashMap::new(),
forwarded_metadata: HashMap::new(),
})
.await
.map_err(|error| AgentError::Other(error.to_string()))?;
Ok(())
}
WsInbound::Command { payload } => {
// 使用 Command 系统处理命令
let input_adapter = WebSocketInputAdapter::new();
let output_adapter = WebSocketOutputAdapter::new();
// 解析命令
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
.with_session_id(current_session_id.as_str());
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
Ok(Some(cmd)) => cmd,
Ok(None) => {
// 不是命令,返回错误
let _ = sender
.send(WsOutbound::Error {
timestamp: Some(crate::protocol::now_timestamp()),
code: "INVALID_COMMAND".to_string(),
message: "Invalid command payload".to_string(),
})
.await;
return Ok(());
}
Err(e) => {
let _ = sender
.send(WsOutbound::Error {
timestamp: Some(crate::protocol::now_timestamp()),
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
return Ok(());
}
};
// 创建命令路由器
let _cli_sessions = state.session_manager.cli_sessions();
let store = state.session_manager.store();
let skills = state.session_manager.skills();
let skills_for_handler = skills.clone();
let provider_config = state.config.get_provider_config("default")
.map_err(|e| AgentError::Other(e.to_string()))?;
let prompt_repository = state.session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0,
provider_config.clone(),
prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(skills)),
]));
let mut router = CommandRouter::new();
// 注册 Session 处理器
let session_handler = SessionCommandHandler::new(store.clone())
.with_session_manager(state.session_manager.clone());
router.register(Box::new(session_handler));
// 注册 list_sessions 处理器
router.register(Box::new(ListSessionsCommandHandler::new(store.clone())));
// 注册 list_sessions_by_channel 处理器
router.register(Box::new(ListSessionsByChannelCommandHandler::new(store.clone())));
// 注册 list_channels 处理器
router.register(Box::new(ListChannelsCommandHandler::new(Arc::new(state.channel_manager.clone()))));
// 注册 list_topics 处理器
router.register(Box::new(ListTopicsCommandHandler::new(store.clone())));
// 注册 switch_topic 处理器
let switch_handler = SwitchTopicCommandHandler::new(store.clone())
.with_session_manager(state.session_manager.clone());
router.register(Box::new(switch_handler));
// 注册 get_current 处理器
router.register(Box::new(GetCurrentSessionCommandHandler::new(store.clone())));
// 注册 load_topic 处理器
router.register(Box::new(LoadTopicCommandHandler::new(store.clone())));
// 注册 load_task_messages 处理器
router.register(Box::new(LoadTaskMessagesCommandHandler::new(
state.task_repository.clone(),
store.clone(),
)));
router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(),
state.task_repository.clone(),
system_prompt_provider.clone(),
)));
// 注册 delete_topic 处理器
router.register(Box::new(
DeleteTopicCommandHandler::new(store.clone())
.with_session_manager(state.session_manager.clone()),
));
// 注册 help 处理器
let metadata = router.metadata_arc();
router.register(Box::new(HelpCommandHandler::new(metadata)));
// 注册 list_scheduler_jobs 处理器
router.register(Box::new(ListSchedulerJobsCommandHandler::new(store.clone())));
// 注册 list_memories 处理器
router.register(Box::new(ListMemoriesCommandHandler::new(store.clone())));
// 注册 list_skills 处理器
router.register(Box::new(ListSkillsCommandHandler::new(skills_for_handler)));
// 注册 memory_crud 处理器
router.register(Box::new(MemoryCrudCommandHandler::new(store.clone())));
// 注册 load_chat_messages 处理器
router.register(Box::new(LoadChatMessagesCommandHandler::new()));
// 注册 stop_execution 处理器
router.register(Box::new(StopExecutionCommandHandler::new(
state.cancel_manager.clone(),
)));
// 构建命令上下文
tracing::debug!(
current_session_id = %current_session_id,
current_topic_id = ?current_topic_id,
"Building CommandContext for WebSocket command"
);
let mut cmd_ctx = CommandContext::new("websocket", "cli")
.with_session_id(current_session_id.as_str())
.with_chat_id(current_session_id.as_str());
// 只在有 topic_id 时才设置
if let Some(ref topic_id) = *current_topic_id {
cmd_ctx = cmd_ctx.with_topic_id(topic_id.as_str());
}
// 执行命令
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
// 处理响应
if response.success {
// 更新当前会话 ID如果是创建会话
if let Some(session_id) = response.metadata.get("session_id") {
tracing::info!(
old_session_id = %current_session_id,
new_session_id = %session_id,
"Updating current_session_id"
);
*current_session_id = session_id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
session_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
}
// 更新当前话题 ID如果是创建话题或切换话题
if let Some(topic_id) = response.metadata.get("topic_id") {
tracing::info!(
old_topic_id = ?current_topic_id,
new_topic_id = %topic_id,
"Updating current_topic_id"
);
*current_topic_id = Some(topic_id.clone());
// 加载并发送该话题的历史消息
if let Err(e) = send_topic_history(&store, current_session_id, topic_id, sender).await {
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to send topic history");
}
}
// 加载子智能体任务消息
if let Some(task_session_id) = response.metadata.get("task_session_id") {
// 提前提取 task_id用于给历史消息打标记
let task_id = response.metadata.get("task_id").cloned().unwrap_or_default();
if let Err(e) = send_task_messages(&store, task_session_id, sender, Some(task_id.clone())).await {
tracing::warn!(error = %e, task_session_id = %task_session_id, "Failed to send task messages");
}
// 发送 TaskMessagesLoaded 元数据
let description = response.metadata.get("task_description").cloned().unwrap_or_default();
let subagent_type = response.metadata.get("task_subagent_type").cloned().unwrap_or_default();
let status = response.metadata.get("task_status").cloned().unwrap_or_default();
let summary = response.metadata.get("task_summary").cloned();
let _ = sender.send(WsOutbound::TaskMessagesLoaded {
task_id,
description,
subagent_type,
status,
summary,
}).await;
}
// 处理定时任务列表
if let Some(jobs_json) = response.metadata.get("scheduler_jobs") {
if let Ok(jobs) = serde_json::from_str::<Vec<crate::protocol::SchedulerJobSummary>>(jobs_json) {
let _ = sender.send(WsOutbound::SchedulerJobList { jobs }).await;
}
}
// 处理技能列表
if let Some(skills_json) = response.metadata.get("skills") {
if let Ok(skills) = serde_json::from_str::<Vec<crate::protocol::SkillSummary>>(skills_json) {
let _ = sender.send(WsOutbound::SkillList { skills }).await;
}
}
// 处理记忆列表
if let Some(memories_json) = response.metadata.get("memories") {
if let Ok(memories) = serde_json::from_str::<Vec<crate::protocol::MemorySummary>>(memories_json) {
let _ = sender.send(WsOutbound::MemoryList { memories }).await;
}
}
// 记忆 CRUD 后自动刷新列表
if response.metadata.get("memory_updated").map(|v| v.as_str()) == Some("true") {
if let Ok(records) = store.list_memories_for_scope("user", crate::storage::GLOBAL_SCOPE_KEY) {
let memories: Vec<crate::protocol::MemorySummary> = records
.into_iter()
.filter(|m| m.namespace != "_meta")
.map(|m| crate::protocol::MemorySummary {
id: m.id,
namespace: m.namespace,
memory_key: m.memory_key,
content: m.content,
created_at: m.created_at,
updated_at: m.updated_at,
})
.collect();
let _ = sender.send(WsOutbound::MemoryList { memories }).await;
}
}
// 处理加载聊天消息请求
if let Some(load_chat_id) = response.metadata.get("load_chat_id") {
let load_chat_channel = response.metadata.get("load_chat_channel")
.cloned()
.unwrap_or_default();
// session_id = "{channel}:{chat_id}" (cli channel 例外)
let session_id = crate::storage::persistent_session_id(
&load_chat_channel,
load_chat_id,
);
if let Err(e) = send_task_messages(&store, &session_id, sender, None).await {
tracing::warn!(
error = %e,
channel = %load_chat_channel,
chat_id = %load_chat_id,
session_id = %session_id,
"Failed to send chat messages"
);
}
}
if current_topic_id.is_none() {
if let Some(topics_json) = response.metadata.get("topics") {
match serde_json::from_str::<Vec<crate::protocol::TopicSummary>>(topics_json) {
Ok(topics) => {
if let Some(first_topic) = topics.first() {
let topic_id = first_topic.topic_id.clone();
*current_topic_id = Some(topic_id.clone());
if let Err(e) = send_topic_history(&store, current_session_id, &topic_id, sender).await {
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to send initial topic history");
}
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse topics metadata for initial history");
}
}
}
}
} else if let Some(ref error) = response.error {
tracing::warn!(
error_code = %error.code,
error_message = %error.message,
"Command failed"
);
}
// 适配并发送响应
let outbounds = output_adapter.adapt(response);
for msg in outbounds {
let _ = sender.send(msg).await;
}
Ok(())
}
WsInbound::Ping => {
let _ = sender.send(WsOutbound::Pong).await;
Ok(())
}
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
sender_id
.map(str::trim)
.filter(|sender_id| !sender_id.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| runtime_session_id.to_string())
}
/// 加载并发送话题历史消息
async fn send_topic_history(
store: &Arc<crate::storage::SessionStore>,
_session_id: &str,
topic_id: &str,
sender: &mpsc::Sender<WsOutbound>,
) -> Result<(), Box<dyn std::error::Error>> {
// 加载话题消息
let messages = store.load_messages_for_topic(topic_id)?;
tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history");
// 将消息转换为 WsOutbound 并发送
for msg in messages {
for outbound in chat_message_to_ws_outbound(&msg) {
let _ = sender.send(outbound).await;
}
}
Ok(())
}
/// 加载并发送子智能体任务的历史消息
async fn send_task_messages(
store: &Arc<crate::storage::SessionStore>,
session_id: &str,
sender: &mpsc::Sender<WsOutbound>,
subagent_task_id: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
let messages = store.load_messages(session_id)?;
tracing::info!(session_id = %session_id, message_count = messages.len(), "Sending task messages");
for msg in messages {
let mut outbounds = chat_message_to_ws_outbound(&msg);
if let Some(ref task_id) = subagent_task_id {
for ob in &mut outbounds {
set_subagent_task_id(ob, task_id);
}
}
for outbound in outbounds {
let _ = sender.send(outbound).await;
}
}
Ok(())
}
/// 给 WsOutbound 消息注入 subagent_task_id仅对有该字段的变体生效
fn set_subagent_task_id(outbound: &mut WsOutbound, task_id: &str) {
match outbound {
WsOutbound::AssistantResponse {
subagent_task_id, ..
}
| WsOutbound::ToolCall {
subagent_task_id, ..
}
| WsOutbound::ToolResult {
subagent_task_id, ..
}
| WsOutbound::ToolPending {
subagent_task_id, ..
} => {
*subagent_task_id = Some(task_id.to_string());
}
_ => {} // 其他变体没有 subagent_task_id 字段
}
}
/// 将 ChatMessage 转换为 WsOutbound 列表
fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Vec<WsOutbound> {
use crate::bus::message::ToolMessageState;
// Helper function to strip media_refs_json from content
fn strip_media_refs_json(content: &str) -> String {
// Remove the media_refs_json suffix if present
if let Some(pos) = content.find("\n\nmedia_refs_json:") {
content[..pos].to_string()
} else {
content.to_string()
}
}
// Build attachments from media_refs, reading file content for base64
let attachments: Vec<MediaSummary> = msg
.media_refs
.iter()
.filter_map(|path| {
// Try to read file and encode as base64
let file_content = std::fs::read(path).ok()?;
let base64_content = STANDARD.encode(&file_content);
// Guess mime type from path
let mime_type = mime_guess::from_path(path)
.first_raw()
.map(ToOwned::to_owned);
// Determine media type from mime type
let media_type = mime_type
.as_ref()
.map(|m| {
if m.starts_with("image/") { "image" }
else if m.starts_with("audio/") { "audio" }
else if m.starts_with("video/") { "video" }
else { "file" }
})
.unwrap_or("file");
// Get file name from path
let file_name = std::path::Path::new(path)
.file_name()
.and_then(|name| name.to_str())
.map(ToOwned::to_owned);
Some(MediaSummary {
path: path.clone(),
media_type: media_type.to_string(),
mime_type,
content_base64: Some(base64_content),
file_name,
})
})
.collect();
match msg.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &msg.tool_calls {
let mut outbound = Vec::new();
let has_content_or_reasoning = !msg.content.trim().is_empty() || msg.reasoning_content.is_some();
if has_content_or_reasoning {
outbound.push(WsOutbound::AssistantResponse {
id: msg.id.clone(),
content: msg.content.clone(),
role: msg.role.clone(),
attachments: Vec::new(),
subagent_task_id: None,
topic_id: None,
timestamp: Some(msg.timestamp / 1000),
reasoning_content: msg.reasoning_content.clone(),
});
}
// AssistantResponse 已携带 reasoning 时ToolCall 不再重复
let tc_reasoning = if has_content_or_reasoning { None } else { msg.reasoning_content.clone() };
for tool_call in tool_calls {
outbound.push(WsOutbound::ToolCall {
id: tool_call.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format!("{}\nargs: {}", tool_call.name, tool_call.arguments),
role: msg.role.clone(),
subagent_task_id: None,
topic_id: None,
timestamp: Some(msg.timestamp / 1000),
reasoning_content: tc_reasoning.clone(),
});
}
outbound
} else {
// 普通助手消息
vec![WsOutbound::AssistantResponse {
id: msg.id.clone(),
content: msg.content.clone(),
role: msg.role.clone(),
attachments: Vec::new(),
subagent_task_id: None,
topic_id: None,
timestamp: Some(msg.timestamp / 1000),
reasoning_content: msg.reasoning_content.clone(),
}]
}
}
"tool" => {
let tool_state = msg.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed);
match tool_state {
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: msg.tool_call_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: msg.tool_call_id.clone().unwrap_or_default(),
tool_name: msg.tool_name.clone().unwrap_or_default(),
content: msg.content.clone(),
role: msg.role.clone(),
subagent_task_id: None,
topic_id: None,
duration_ms: msg.tool_duration_ms,
timestamp: Some(msg.timestamp / 1000),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: msg.tool_call_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: msg.tool_call_id.clone().unwrap_or_default(),
tool_name: msg.tool_name.clone().unwrap_or_default(),
content: msg.content.clone(),
role: msg.role.clone(),
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
subagent_task_id: None,
topic_id: None,
timestamp: Some(msg.timestamp / 1000),
}],
}
}
"user" => vec![WsOutbound::AssistantResponse {
id: msg.id.clone(),
content: strip_media_refs_json(&msg.content),
role: msg.role.clone(),
attachments,
subagent_task_id: None,
topic_id: None,
timestamp: Some(msg.timestamp / 1000),
reasoning_content: None,
}],
_ => Vec::new(),
}
}
#[cfg(test)]
mod tests {
use super::{resolve_ws_sender_id, build_media_filename, process_attachments_with_base64};
use crate::protocol::MediaSummary;
use base64::{Engine as _, engine::general_purpose::STANDARD};
#[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
assert_eq!(
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
"user-42"
);
assert_eq!(
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
"user-42"
);
}
#[test]
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
}
#[test]
fn test_build_media_filename_preserves_original_name() {
let filename = build_media_filename("image", Some("photo.png"));
assert!(filename.ends_with("_photo.png"));
// UUID is 36 chars, plus underscore and original name
assert!(filename.len() >= 36 + 1 + "photo.png".len());
}
#[test]
fn test_build_media_filename_generates_default_when_no_name() {
let filename = build_media_filename("image", None);
assert!(filename.starts_with("image_"));
}
#[test]
fn test_process_attachments_with_base64_saves_to_file() {
let test_content = "test image content";
let base64_content = STANDARD.encode(test_content.as_bytes());
let attachments = vec![MediaSummary {
path: "test_image.png".to_string(),
media_type: "image".to_string(),
mime_type: Some("image/png".to_string()),
content_base64: Some(base64_content.clone()),
file_name: Some("test_image.png".to_string()),
}];
let result = process_attachments_with_base64(attachments).unwrap();
// Verify path is now a full path
assert!(result[0].path.contains(".picobot"));
assert!(result[0].path.contains("media"));
assert!(result[0].path.contains("ws"));
// Verify content_base64 is kept for frontend display
assert!(result[0].content_base64.is_some());
assert_eq!(result[0].content_base64.as_ref().unwrap(), &base64_content);
// Verify file was actually written
let file_content = std::fs::read(&result[0].path).unwrap();
assert_eq!(file_content, test_content.as_bytes());
// Cleanup
std::fs::remove_file(&result[0].path).ok();
}
#[test]
fn test_process_attachments_without_base64_keeps_original_path() {
let attachments = vec![MediaSummary {
path: "/existing/path/image.jpg".to_string(),
media_type: "image".to_string(),
mime_type: Some("image/jpeg".to_string()),
content_base64: None,
file_name: Some("image.jpg".to_string()),
}];
let result = process_attachments_with_base64(attachments).unwrap();
// Path should remain unchanged
assert_eq!(result[0].path, "/existing/path/image.jpg");
}
#[test]
fn test_process_empty_attachments_returns_empty_vec() {
let result = process_attachments_with_base64(Vec::new()).unwrap();
assert!(result.is_empty());
}
}