930 lines
38 KiB
Rust
930 lines
38 KiB
Rust
use super::GatewayState;
|
||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||
use crate::bus::{InboundMessage, MediaItem};
|
||
use crate::command::adapter::{InputAdapter, OutputAdapter};
|
||
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
||
use crate::command::context::CommandContext;
|
||
use crate::command::handler::CommandRouter;
|
||
use crate::command::handlers::delete_topic::DeleteTopicCommandHandler;
|
||
use crate::command::handlers::get_current::GetCurrentSessionCommandHandler;
|
||
use crate::command::handlers::help::HelpCommandHandler;
|
||
use crate::command::handlers::list_channels::ListChannelsCommandHandler;
|
||
use crate::command::handlers::list_memories::ListMemoriesCommandHandler;
|
||
use crate::command::handlers::list_skills::ListSkillsCommandHandler;
|
||
use crate::command::handlers::list_scheduler_jobs::ListSchedulerJobsCommandHandler;
|
||
use crate::command::handlers::memory_crud::MemoryCrudCommandHandler;
|
||
use crate::command::handlers::list_sessions::ListSessionsCommandHandler;
|
||
use crate::command::handlers::list_sessions_by_channel::ListSessionsByChannelCommandHandler;
|
||
use crate::command::handlers::list_topics::ListTopicsCommandHandler;
|
||
use crate::command::handlers::load_chat_messages::LoadChatMessagesCommandHandler;
|
||
use crate::command::handlers::load_task_messages::LoadTaskMessagesCommandHandler;
|
||
use crate::command::handlers::load_topic::LoadTopicCommandHandler;
|
||
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||
use crate::command::handlers::session::SessionCommandHandler;
|
||
use crate::command::handlers::stop_execution::StopExecutionCommandHandler;
|
||
use crate::command::handlers::switch_topic::SwitchTopicCommandHandler;
|
||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||
use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound};
|
||
use crate::skills::SkillPromptProvider;
|
||
use axum::extract::State;
|
||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||
use axum::response::Response;
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||
use futures_util::{SinkExt, StreamExt};
|
||
use std::collections::HashMap;
|
||
use std::path::PathBuf;
|
||
use std::sync::Arc;
|
||
use tokio::sync::mpsc;
|
||
|
||
const CLI_CHANNEL_NAME: &str = "cli";
|
||
|
||
/// Default media directory for WebSocket uploads
|
||
fn default_ws_media_dir() -> PathBuf {
|
||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||
home.join(".picobot").join("media").join("ws")
|
||
}
|
||
|
||
/// Build a unique filename for media upload
|
||
fn build_media_filename(media_type: &str, file_name: Option<&str>) -> String {
|
||
if let Some(file_name) = file_name {
|
||
let sanitized: String = file_name
|
||
.chars()
|
||
.map(|ch| match ch {
|
||
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
|
||
_ => ch,
|
||
})
|
||
.collect();
|
||
if !sanitized.trim().is_empty() {
|
||
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
|
||
}
|
||
}
|
||
format!("{}_{}", media_type, uuid::Uuid::new_v4())
|
||
}
|
||
|
||
/// Process attachments with base64 content: save to local file and return MediaItem with correct path
|
||
/// Keeps content_base64 for frontend display/download
|
||
fn process_attachments_with_base64(attachments: Vec<MediaSummary>) -> Result<Vec<MediaItem>, AgentError> {
|
||
if attachments.is_empty() {
|
||
return Ok(Vec::new());
|
||
}
|
||
|
||
let media_dir = default_ws_media_dir();
|
||
std::fs::create_dir_all(&media_dir)
|
||
.map_err(|error| AgentError::Other(format!("Failed to create media dir: {}", error)))?;
|
||
|
||
attachments
|
||
.into_iter()
|
||
.map(|att| {
|
||
// If content_base64 exists, save to file and update path
|
||
if let Some(base64_content) = &att.content_base64 {
|
||
let decoded = STANDARD
|
||
.decode(base64_content)
|
||
.map_err(|error| AgentError::Other(format!("Failed to decode base64: {}", error)))?;
|
||
|
||
let filename = build_media_filename(&att.media_type, att.file_name.as_deref());
|
||
let file_path = media_dir.join(&filename);
|
||
|
||
std::fs::write(&file_path, decoded)
|
||
.map_err(|error| AgentError::Other(format!("Failed to write media file: {}", error)))?;
|
||
|
||
tracing::info!(
|
||
filename = %filename,
|
||
media_type = %att.media_type,
|
||
file_path = %file_path.to_string_lossy(),
|
||
"Saved WebSocket media to local file"
|
||
);
|
||
|
||
Ok(MediaItem {
|
||
path: file_path.to_string_lossy().to_string(),
|
||
media_type: att.media_type,
|
||
mime_type: att.mime_type,
|
||
original_key: None,
|
||
// Keep content_base64 for frontend display/download
|
||
content_base64: att.content_base64,
|
||
file_name: att.file_name,
|
||
})
|
||
} else {
|
||
// No base64 content, keep original path (should already be valid)
|
||
Ok(MediaItem {
|
||
path: att.path,
|
||
media_type: att.media_type,
|
||
mime_type: att.mime_type,
|
||
original_key: None,
|
||
content_base64: None,
|
||
file_name: att.file_name,
|
||
})
|
||
}
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||
ws.on_upgrade(|socket| async {
|
||
handle_socket(socket, state).await;
|
||
})
|
||
}
|
||
|
||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
||
|
||
let cli_sessions = state.session_manager.cli_sessions();
|
||
let store = state.session_manager.store();
|
||
|
||
// 1. 先查询 websocket 通道的 Sessions
|
||
let websocket_sessions = store.list_sessions("websocket", false)
|
||
.unwrap_or_default();
|
||
|
||
// 2. 如果没有,自动创建一个默认 Session
|
||
let initial_record = if websocket_sessions.is_empty() {
|
||
match cli_sessions.create_with_channel("websocket", Some("默认会话")) {
|
||
Ok(record) => record,
|
||
Err(e) => {
|
||
tracing::error!(error = %e, "Failed to create initial WebSocket session");
|
||
return;
|
||
}
|
||
}
|
||
} else {
|
||
// 使用最新的 Session
|
||
websocket_sessions[0].clone()
|
||
};
|
||
|
||
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
||
let mut current_session_id = initial_record.id.clone();
|
||
let mut current_topic_id: Option<String> = None;
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.register_connection(
|
||
current_session_id.clone(),
|
||
runtime_session_id.clone(),
|
||
sender.clone(),
|
||
)
|
||
.await;
|
||
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "WebSocket session established");
|
||
|
||
let _ = sender
|
||
.send(WsOutbound::SessionEstablished {
|
||
session_id: current_session_id.clone(),
|
||
})
|
||
.await;
|
||
|
||
// 连接建立后立即发送通道列表(合并 websocket + ChannelManager 动态通道)
|
||
let channels = state.channel_manager.build_channel_list().await;
|
||
let _ = sender
|
||
.send(WsOutbound::ChannelList { channels })
|
||
.await;
|
||
|
||
// 3. 重新查询 websocket 通道的 Session 列表(包含刚创建的)
|
||
let final_sessions = store.list_sessions("websocket", false)
|
||
.unwrap_or_default();
|
||
|
||
tracing::info!("Sending {} websocket sessions to client", final_sessions.len());
|
||
for s in &final_sessions {
|
||
tracing::info!(" - {}: {} (channel: {})", s.id, s.title, s.channel_name);
|
||
}
|
||
|
||
let session_summaries: Vec<crate::protocol::SessionSummary> = final_sessions
|
||
.into_iter()
|
||
.map(|s| crate::protocol::SessionSummary {
|
||
session_id: s.id,
|
||
title: s.title,
|
||
channel_name: s.channel_name,
|
||
chat_id: s.chat_id,
|
||
message_count: s.message_count,
|
||
last_active_at: s.last_active_at,
|
||
archived_at: s.archived_at,
|
||
})
|
||
.collect();
|
||
|
||
let _ = sender
|
||
.send(WsOutbound::SessionList {
|
||
sessions: session_summaries,
|
||
current_session_id: Some(current_session_id.clone()),
|
||
channel_name: Some("websocket".to_string()),
|
||
})
|
||
.await;
|
||
|
||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||
|
||
let mut receiver = receiver;
|
||
let session_id_for_sender = runtime_session_id.clone();
|
||
tokio::spawn(async move {
|
||
while let Some(msg) = receiver.recv().await {
|
||
if let Ok(text) = serialize_outbound(&msg) {
|
||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
while let Some(msg) = ws_receiver.next().await {
|
||
match msg {
|
||
Ok(WsMessage::Text(text)) => {
|
||
let text = text.to_string();
|
||
match parse_inbound(&text) {
|
||
Ok(inbound) => {
|
||
if let Err(e) = handle_inbound(
|
||
&state,
|
||
&sender,
|
||
&runtime_session_id,
|
||
&mut current_session_id,
|
||
&mut current_topic_id,
|
||
inbound,
|
||
)
|
||
.await
|
||
{
|
||
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
timestamp: Some(crate::protocol::now_timestamp()),
|
||
code:"SESSION_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
})
|
||
.await;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
timestamp: Some(crate::protocol::now_timestamp()),
|
||
code:"PARSE_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
})
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
Ok(WsMessage::Close(_)) | Err(_) => {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
|
||
break;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.unregister_connection(&runtime_session_id)
|
||
.await;
|
||
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||
}
|
||
|
||
|
||
async fn handle_inbound(
|
||
state: &Arc<GatewayState>,
|
||
sender: &mpsc::Sender<WsOutbound>,
|
||
runtime_session_id: &str,
|
||
current_session_id: &mut String,
|
||
current_topic_id: &mut Option<String>,
|
||
inbound: WsInbound,
|
||
) -> Result<(), AgentError> {
|
||
match inbound {
|
||
WsInbound::Message {
|
||
content,
|
||
attachments,
|
||
chat_id,
|
||
sender_id,
|
||
..
|
||
} => {
|
||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.register_connection(
|
||
chat_id.clone(),
|
||
runtime_session_id.to_string(),
|
||
sender.clone(),
|
||
)
|
||
.await;
|
||
|
||
// Process attachments: save base64 content to local files and build MediaItems with correct paths
|
||
let media = process_attachments_with_base64(attachments)?;
|
||
|
||
state
|
||
.bus
|
||
.publish_inbound(InboundMessage {
|
||
channel: CLI_CHANNEL_NAME.to_string(),
|
||
sender_id,
|
||
chat_id,
|
||
content,
|
||
timestamp: current_timestamp(),
|
||
media,
|
||
metadata: HashMap::new(),
|
||
forwarded_metadata: HashMap::new(),
|
||
})
|
||
.await
|
||
.map_err(|error| AgentError::Other(error.to_string()))?;
|
||
|
||
Ok(())
|
||
}
|
||
WsInbound::Command { payload } => {
|
||
// 使用 Command 系统处理命令
|
||
let input_adapter = WebSocketInputAdapter::new();
|
||
let output_adapter = WebSocketOutputAdapter::new();
|
||
|
||
// 解析命令
|
||
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
|
||
.with_session_id(current_session_id.as_str());
|
||
|
||
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
|
||
Ok(Some(cmd)) => cmd,
|
||
Ok(None) => {
|
||
// 不是命令,返回错误
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
timestamp: Some(crate::protocol::now_timestamp()),
|
||
code: "INVALID_COMMAND".to_string(),
|
||
message: "Invalid command payload".to_string(),
|
||
})
|
||
.await;
|
||
return Ok(());
|
||
}
|
||
Err(e) => {
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
timestamp: Some(crate::protocol::now_timestamp()),
|
||
code: "PARSE_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
})
|
||
.await;
|
||
return Ok(());
|
||
}
|
||
};
|
||
|
||
// 创建命令路由器
|
||
let _cli_sessions = state.session_manager.cli_sessions();
|
||
let store = state.session_manager.store();
|
||
let skills = state.session_manager.skills();
|
||
let skills_for_handler = skills.clone();
|
||
let provider_config = state.config.get_provider_config("default")
|
||
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||
let prompt_repository = state.session_manager.store().clone();
|
||
|
||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||
Box::new(AgentPromptProvider::new(
|
||
0,
|
||
provider_config.clone(),
|
||
prompt_repository.clone(),
|
||
)),
|
||
Box::new(SkillPromptProvider::new(skills)),
|
||
]));
|
||
|
||
let mut router = CommandRouter::new();
|
||
// 注册 Session 处理器
|
||
let session_handler = SessionCommandHandler::new(store.clone())
|
||
.with_session_manager(state.session_manager.clone());
|
||
router.register(Box::new(session_handler));
|
||
// 注册 list_sessions 处理器
|
||
router.register(Box::new(ListSessionsCommandHandler::new(store.clone())));
|
||
// 注册 list_sessions_by_channel 处理器
|
||
router.register(Box::new(ListSessionsByChannelCommandHandler::new(store.clone())));
|
||
// 注册 list_channels 处理器
|
||
router.register(Box::new(ListChannelsCommandHandler::new(Arc::new(state.channel_manager.clone()))));
|
||
// 注册 list_topics 处理器
|
||
router.register(Box::new(ListTopicsCommandHandler::new(store.clone())));
|
||
// 注册 switch_topic 处理器
|
||
let switch_handler = SwitchTopicCommandHandler::new(store.clone())
|
||
.with_session_manager(state.session_manager.clone());
|
||
router.register(Box::new(switch_handler));
|
||
// 注册 get_current 处理器
|
||
router.register(Box::new(GetCurrentSessionCommandHandler::new(store.clone())));
|
||
// 注册 load_topic 处理器
|
||
router.register(Box::new(LoadTopicCommandHandler::new(store.clone())));
|
||
// 注册 load_task_messages 处理器
|
||
router.register(Box::new(LoadTaskMessagesCommandHandler::new(
|
||
state.task_repository.clone(),
|
||
store.clone(),
|
||
)));
|
||
router.register(Box::new(SaveSessionCommandHandler::new(
|
||
store.clone(),
|
||
state.task_repository.clone(),
|
||
system_prompt_provider.clone(),
|
||
)));
|
||
// 注册 delete_topic 处理器
|
||
router.register(Box::new(
|
||
DeleteTopicCommandHandler::new(store.clone())
|
||
.with_session_manager(state.session_manager.clone()),
|
||
));
|
||
// 注册 help 处理器
|
||
let metadata = router.metadata_arc();
|
||
router.register(Box::new(HelpCommandHandler::new(metadata)));
|
||
// 注册 list_scheduler_jobs 处理器
|
||
router.register(Box::new(ListSchedulerJobsCommandHandler::new(store.clone())));
|
||
// 注册 list_memories 处理器
|
||
router.register(Box::new(ListMemoriesCommandHandler::new(store.clone())));
|
||
// 注册 list_skills 处理器
|
||
router.register(Box::new(ListSkillsCommandHandler::new(skills_for_handler)));
|
||
// 注册 memory_crud 处理器
|
||
router.register(Box::new(MemoryCrudCommandHandler::new(store.clone())));
|
||
// 注册 load_chat_messages 处理器
|
||
router.register(Box::new(LoadChatMessagesCommandHandler::new()));
|
||
// 注册 stop_execution 处理器
|
||
router.register(Box::new(StopExecutionCommandHandler::new(
|
||
state.cancel_manager.clone(),
|
||
)));
|
||
|
||
// 构建命令上下文
|
||
tracing::debug!(
|
||
current_session_id = %current_session_id,
|
||
current_topic_id = ?current_topic_id,
|
||
"Building CommandContext for WebSocket command"
|
||
);
|
||
let mut cmd_ctx = CommandContext::new("websocket", "cli")
|
||
.with_session_id(current_session_id.as_str())
|
||
.with_chat_id(current_session_id.as_str());
|
||
// 只在有 topic_id 时才设置
|
||
if let Some(ref topic_id) = *current_topic_id {
|
||
cmd_ctx = cmd_ctx.with_topic_id(topic_id.as_str());
|
||
}
|
||
|
||
// 执行命令
|
||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||
|
||
// 处理响应
|
||
if response.success {
|
||
// 更新当前会话 ID(如果是创建会话)
|
||
if let Some(session_id) = response.metadata.get("session_id") {
|
||
tracing::info!(
|
||
old_session_id = %current_session_id,
|
||
new_session_id = %session_id,
|
||
"Updating current_session_id"
|
||
);
|
||
*current_session_id = session_id.clone();
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.register_connection(
|
||
session_id.clone(),
|
||
runtime_session_id.to_string(),
|
||
sender.clone(),
|
||
)
|
||
.await;
|
||
}
|
||
// 更新当前话题 ID(如果是创建话题或切换话题)
|
||
if let Some(topic_id) = response.metadata.get("topic_id") {
|
||
tracing::info!(
|
||
old_topic_id = ?current_topic_id,
|
||
new_topic_id = %topic_id,
|
||
"Updating current_topic_id"
|
||
);
|
||
*current_topic_id = Some(topic_id.clone());
|
||
|
||
// 加载并发送该话题的历史消息
|
||
if let Err(e) = send_topic_history(&store, current_session_id, topic_id, sender).await {
|
||
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to send topic history");
|
||
}
|
||
}
|
||
// 加载子智能体任务消息
|
||
if let Some(task_session_id) = response.metadata.get("task_session_id") {
|
||
// 提前提取 task_id,用于给历史消息打标记
|
||
let task_id = response.metadata.get("task_id").cloned().unwrap_or_default();
|
||
if let Err(e) = send_task_messages(&store, task_session_id, sender, Some(task_id.clone())).await {
|
||
tracing::warn!(error = %e, task_session_id = %task_session_id, "Failed to send task messages");
|
||
}
|
||
|
||
// 发送 TaskMessagesLoaded 元数据
|
||
let description = response.metadata.get("task_description").cloned().unwrap_or_default();
|
||
let subagent_type = response.metadata.get("task_subagent_type").cloned().unwrap_or_default();
|
||
let status = response.metadata.get("task_status").cloned().unwrap_or_default();
|
||
let summary = response.metadata.get("task_summary").cloned();
|
||
|
||
let _ = sender.send(WsOutbound::TaskMessagesLoaded {
|
||
task_id,
|
||
description,
|
||
subagent_type,
|
||
status,
|
||
summary,
|
||
}).await;
|
||
}
|
||
|
||
// 处理定时任务列表
|
||
if let Some(jobs_json) = response.metadata.get("scheduler_jobs") {
|
||
if let Ok(jobs) = serde_json::from_str::<Vec<crate::protocol::SchedulerJobSummary>>(jobs_json) {
|
||
let _ = sender.send(WsOutbound::SchedulerJobList { jobs }).await;
|
||
}
|
||
}
|
||
|
||
// 处理技能列表
|
||
if let Some(skills_json) = response.metadata.get("skills") {
|
||
if let Ok(skills) = serde_json::from_str::<Vec<crate::protocol::SkillSummary>>(skills_json) {
|
||
let _ = sender.send(WsOutbound::SkillList { skills }).await;
|
||
}
|
||
}
|
||
|
||
// 处理记忆列表
|
||
if let Some(memories_json) = response.metadata.get("memories") {
|
||
if let Ok(memories) = serde_json::from_str::<Vec<crate::protocol::MemorySummary>>(memories_json) {
|
||
let _ = sender.send(WsOutbound::MemoryList { memories }).await;
|
||
}
|
||
}
|
||
|
||
// 记忆 CRUD 后自动刷新列表
|
||
if response.metadata.get("memory_updated").map(|v| v.as_str()) == Some("true") {
|
||
if let Ok(records) = store.list_memories_for_scope("user", crate::storage::GLOBAL_SCOPE_KEY) {
|
||
let memories: Vec<crate::protocol::MemorySummary> = records
|
||
.into_iter()
|
||
.filter(|m| m.namespace != "_meta")
|
||
.map(|m| crate::protocol::MemorySummary {
|
||
id: m.id,
|
||
namespace: m.namespace,
|
||
memory_key: m.memory_key,
|
||
content: m.content,
|
||
created_at: m.created_at,
|
||
updated_at: m.updated_at,
|
||
})
|
||
.collect();
|
||
let _ = sender.send(WsOutbound::MemoryList { memories }).await;
|
||
}
|
||
}
|
||
|
||
// 处理加载聊天消息请求
|
||
if let Some(load_chat_id) = response.metadata.get("load_chat_id") {
|
||
let load_chat_channel = response.metadata.get("load_chat_channel")
|
||
.cloned()
|
||
.unwrap_or_default();
|
||
// session_id = "{channel}:{chat_id}" (cli channel 例外)
|
||
let session_id = crate::storage::persistent_session_id(
|
||
&load_chat_channel,
|
||
load_chat_id,
|
||
);
|
||
if let Err(e) = send_task_messages(&store, &session_id, sender, None).await {
|
||
tracing::warn!(
|
||
error = %e,
|
||
channel = %load_chat_channel,
|
||
chat_id = %load_chat_id,
|
||
session_id = %session_id,
|
||
"Failed to send chat messages"
|
||
);
|
||
}
|
||
}
|
||
|
||
if current_topic_id.is_none() {
|
||
if let Some(topics_json) = response.metadata.get("topics") {
|
||
match serde_json::from_str::<Vec<crate::protocol::TopicSummary>>(topics_json) {
|
||
Ok(topics) => {
|
||
if let Some(first_topic) = topics.first() {
|
||
let topic_id = first_topic.topic_id.clone();
|
||
*current_topic_id = Some(topic_id.clone());
|
||
if let Err(e) = send_topic_history(&store, current_session_id, &topic_id, sender).await {
|
||
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to send initial topic history");
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Failed to parse topics metadata for initial history");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} else if let Some(ref error) = response.error {
|
||
tracing::warn!(
|
||
error_code = %error.code,
|
||
error_message = %error.message,
|
||
"Command failed"
|
||
);
|
||
}
|
||
|
||
// 适配并发送响应
|
||
let outbounds = output_adapter.adapt(response);
|
||
for msg in outbounds {
|
||
let _ = sender.send(msg).await;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
WsInbound::Ping => {
|
||
let _ = sender.send(WsOutbound::Pong).await;
|
||
Ok(())
|
||
}
|
||
}
|
||
}
|
||
|
||
fn current_timestamp() -> i64 {
|
||
std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap()
|
||
.as_millis() as i64
|
||
}
|
||
|
||
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
||
sender_id
|
||
.map(str::trim)
|
||
.filter(|sender_id| !sender_id.is_empty())
|
||
.map(ToOwned::to_owned)
|
||
.unwrap_or_else(|| runtime_session_id.to_string())
|
||
}
|
||
|
||
/// 加载并发送话题历史消息
|
||
async fn send_topic_history(
|
||
store: &Arc<crate::storage::SessionStore>,
|
||
_session_id: &str,
|
||
topic_id: &str,
|
||
sender: &mpsc::Sender<WsOutbound>,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
// 加载话题消息
|
||
let messages = store.load_messages_for_topic(topic_id)?;
|
||
|
||
tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history");
|
||
|
||
// 将消息转换为 WsOutbound 并发送
|
||
for msg in messages {
|
||
for outbound in chat_message_to_ws_outbound(&msg) {
|
||
let _ = sender.send(outbound).await;
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 加载并发送子智能体任务的历史消息
|
||
async fn send_task_messages(
|
||
store: &Arc<crate::storage::SessionStore>,
|
||
session_id: &str,
|
||
sender: &mpsc::Sender<WsOutbound>,
|
||
subagent_task_id: Option<String>,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
let messages = store.load_messages(session_id)?;
|
||
|
||
tracing::info!(session_id = %session_id, message_count = messages.len(), "Sending task messages");
|
||
|
||
for msg in messages {
|
||
let mut outbounds = chat_message_to_ws_outbound(&msg);
|
||
if let Some(ref task_id) = subagent_task_id {
|
||
for ob in &mut outbounds {
|
||
set_subagent_task_id(ob, task_id);
|
||
}
|
||
}
|
||
for outbound in outbounds {
|
||
let _ = sender.send(outbound).await;
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 给 WsOutbound 消息注入 subagent_task_id(仅对有该字段的变体生效)
|
||
fn set_subagent_task_id(outbound: &mut WsOutbound, task_id: &str) {
|
||
match outbound {
|
||
WsOutbound::AssistantResponse {
|
||
subagent_task_id, ..
|
||
}
|
||
| WsOutbound::ToolCall {
|
||
subagent_task_id, ..
|
||
}
|
||
| WsOutbound::ToolResult {
|
||
subagent_task_id, ..
|
||
}
|
||
| WsOutbound::ToolPending {
|
||
subagent_task_id, ..
|
||
} => {
|
||
*subagent_task_id = Some(task_id.to_string());
|
||
}
|
||
_ => {} // 其他变体没有 subagent_task_id 字段
|
||
}
|
||
}
|
||
|
||
/// 将 ChatMessage 转换为 WsOutbound 列表
|
||
fn chat_message_to_ws_outbound(msg: &crate::bus::ChatMessage) -> Vec<WsOutbound> {
|
||
use crate::bus::message::ToolMessageState;
|
||
|
||
// Helper function to strip media_refs_json from content
|
||
fn strip_media_refs_json(content: &str) -> String {
|
||
// Remove the media_refs_json suffix if present
|
||
if let Some(pos) = content.find("\n\nmedia_refs_json:") {
|
||
content[..pos].to_string()
|
||
} else {
|
||
content.to_string()
|
||
}
|
||
}
|
||
|
||
// Build attachments from media_refs, reading file content for base64
|
||
let attachments: Vec<MediaSummary> = msg
|
||
.media_refs
|
||
.iter()
|
||
.filter_map(|path| {
|
||
// Try to read file and encode as base64
|
||
let file_content = std::fs::read(path).ok()?;
|
||
let base64_content = STANDARD.encode(&file_content);
|
||
|
||
// Guess mime type from path
|
||
let mime_type = mime_guess::from_path(path)
|
||
.first_raw()
|
||
.map(ToOwned::to_owned);
|
||
|
||
// Determine media type from mime type
|
||
let media_type = mime_type
|
||
.as_ref()
|
||
.map(|m| {
|
||
if m.starts_with("image/") { "image" }
|
||
else if m.starts_with("audio/") { "audio" }
|
||
else if m.starts_with("video/") { "video" }
|
||
else { "file" }
|
||
})
|
||
.unwrap_or("file");
|
||
|
||
// Get file name from path
|
||
let file_name = std::path::Path::new(path)
|
||
.file_name()
|
||
.and_then(|name| name.to_str())
|
||
.map(ToOwned::to_owned);
|
||
|
||
Some(MediaSummary {
|
||
path: path.clone(),
|
||
media_type: media_type.to_string(),
|
||
mime_type,
|
||
content_base64: Some(base64_content),
|
||
file_name,
|
||
})
|
||
})
|
||
.collect();
|
||
|
||
match msg.role.as_str() {
|
||
"assistant" => {
|
||
if let Some(tool_calls) = &msg.tool_calls {
|
||
let mut outbound = Vec::new();
|
||
let has_content_or_reasoning = !msg.content.trim().is_empty() || msg.reasoning_content.is_some();
|
||
if has_content_or_reasoning {
|
||
outbound.push(WsOutbound::AssistantResponse {
|
||
id: msg.id.clone(),
|
||
content: msg.content.clone(),
|
||
role: msg.role.clone(),
|
||
attachments: Vec::new(),
|
||
subagent_task_id: None,
|
||
topic_id: None,
|
||
timestamp: Some(msg.timestamp / 1000),
|
||
reasoning_content: msg.reasoning_content.clone(),
|
||
});
|
||
}
|
||
// AssistantResponse 已携带 reasoning 时,ToolCall 不再重复
|
||
let tc_reasoning = if has_content_or_reasoning { None } else { msg.reasoning_content.clone() };
|
||
for tool_call in tool_calls {
|
||
outbound.push(WsOutbound::ToolCall {
|
||
id: msg.id.clone(),
|
||
tool_call_id: tool_call.id.clone(),
|
||
tool_name: tool_call.name.clone(),
|
||
arguments: tool_call.arguments.clone(),
|
||
content: format!("{}\nargs: {}", tool_call.name, tool_call.arguments),
|
||
role: msg.role.clone(),
|
||
subagent_task_id: None,
|
||
topic_id: None,
|
||
timestamp: Some(msg.timestamp / 1000),
|
||
reasoning_content: tc_reasoning.clone(),
|
||
});
|
||
}
|
||
outbound
|
||
} else {
|
||
// 普通助手消息
|
||
vec![WsOutbound::AssistantResponse {
|
||
id: msg.id.clone(),
|
||
content: msg.content.clone(),
|
||
role: msg.role.clone(),
|
||
attachments: Vec::new(),
|
||
subagent_task_id: None,
|
||
topic_id: None,
|
||
timestamp: Some(msg.timestamp / 1000),
|
||
reasoning_content: msg.reasoning_content.clone(),
|
||
}]
|
||
}
|
||
}
|
||
"tool" => {
|
||
let tool_state = msg.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed);
|
||
match tool_state {
|
||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
||
id: msg.id.clone(),
|
||
tool_call_id: msg.tool_call_id.clone().unwrap_or_default(),
|
||
tool_name: msg.tool_name.clone().unwrap_or_default(),
|
||
content: msg.content.clone(),
|
||
role: msg.role.clone(),
|
||
subagent_task_id: None,
|
||
topic_id: None,
|
||
duration_ms: msg.tool_duration_ms,
|
||
timestamp: Some(msg.timestamp / 1000),
|
||
}],
|
||
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
||
id: msg.id.clone(),
|
||
tool_call_id: msg.tool_call_id.clone().unwrap_or_default(),
|
||
tool_name: msg.tool_name.clone().unwrap_or_default(),
|
||
content: msg.content.clone(),
|
||
role: msg.role.clone(),
|
||
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
||
subagent_task_id: None,
|
||
topic_id: None,
|
||
timestamp: Some(msg.timestamp / 1000),
|
||
}],
|
||
}
|
||
}
|
||
"user" => vec![WsOutbound::AssistantResponse {
|
||
id: msg.id.clone(),
|
||
content: strip_media_refs_json(&msg.content),
|
||
role: msg.role.clone(),
|
||
attachments,
|
||
subagent_task_id: None,
|
||
topic_id: None,
|
||
timestamp: Some(msg.timestamp / 1000),
|
||
reasoning_content: None,
|
||
}],
|
||
_ => Vec::new(),
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::{resolve_ws_sender_id, build_media_filename, process_attachments_with_base64};
|
||
use crate::protocol::MediaSummary;
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||
|
||
#[test]
|
||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||
assert_eq!(
|
||
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
|
||
"user-42"
|
||
);
|
||
assert_eq!(
|
||
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
|
||
"user-42"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
|
||
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
||
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_media_filename_preserves_original_name() {
|
||
let filename = build_media_filename("image", Some("photo.png"));
|
||
assert!(filename.ends_with("_photo.png"));
|
||
// UUID is 36 chars, plus underscore and original name
|
||
assert!(filename.len() >= 36 + 1 + "photo.png".len());
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_media_filename_generates_default_when_no_name() {
|
||
let filename = build_media_filename("image", None);
|
||
assert!(filename.starts_with("image_"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_process_attachments_with_base64_saves_to_file() {
|
||
let test_content = "test image content";
|
||
let base64_content = STANDARD.encode(test_content.as_bytes());
|
||
|
||
let attachments = vec![MediaSummary {
|
||
path: "test_image.png".to_string(),
|
||
media_type: "image".to_string(),
|
||
mime_type: Some("image/png".to_string()),
|
||
content_base64: Some(base64_content.clone()),
|
||
file_name: Some("test_image.png".to_string()),
|
||
}];
|
||
|
||
let result = process_attachments_with_base64(attachments).unwrap();
|
||
|
||
// Verify path is now a full path
|
||
assert!(result[0].path.contains(".picobot"));
|
||
assert!(result[0].path.contains("media"));
|
||
assert!(result[0].path.contains("ws"));
|
||
|
||
// Verify content_base64 is kept for frontend display
|
||
assert!(result[0].content_base64.is_some());
|
||
assert_eq!(result[0].content_base64.as_ref().unwrap(), &base64_content);
|
||
|
||
// Verify file was actually written
|
||
let file_content = std::fs::read(&result[0].path).unwrap();
|
||
assert_eq!(file_content, test_content.as_bytes());
|
||
|
||
// Cleanup
|
||
std::fs::remove_file(&result[0].path).ok();
|
||
}
|
||
|
||
#[test]
|
||
fn test_process_attachments_without_base64_keeps_original_path() {
|
||
let attachments = vec![MediaSummary {
|
||
path: "/existing/path/image.jpg".to_string(),
|
||
media_type: "image".to_string(),
|
||
mime_type: Some("image/jpeg".to_string()),
|
||
content_base64: None,
|
||
file_name: Some("image.jpg".to_string()),
|
||
}];
|
||
|
||
let result = process_attachments_with_base64(attachments).unwrap();
|
||
|
||
// Path should remain unchanged
|
||
assert_eq!(result[0].path, "/existing/path/image.jpg");
|
||
}
|
||
|
||
#[test]
|
||
fn test_process_empty_attachments_returns_empty_vec() {
|
||
let result = process_attachments_with_base64(Vec::new()).unwrap();
|
||
assert!(result.is_empty());
|
||
}
|
||
}
|