use super::GatewayState; use crate::agent::AgentError; use crate::bus::InboundMessage; use crate::command::adapter::OutputAdapter; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::mpsc; const CLI_CHANNEL_NAME: &str = "cli"; pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; }) } async fn handle_socket(ws: WebSocket, state: Arc) { let (sender, receiver) = mpsc::channel::(100); let cli_sessions = state.session_manager.cli_sessions(); let initial_record = match cli_sessions.create(None) { Ok(record) => record, Err(e) => { tracing::error!(error = %e, "Failed to create initial CLI session"); return; } }; let runtime_session_id = uuid::Uuid::new_v4().to_string(); let mut current_session_id = initial_record.id.clone(); state .channel_manager .cli_channel() .register_connection( current_session_id.clone(), runtime_session_id.clone(), sender.clone(), ) .await; tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established"); let _ = sender .send(WsOutbound::SessionEstablished { session_id: current_session_id.clone(), }) .await; let (mut ws_sender, mut ws_receiver) = ws.split(); let mut receiver = receiver; let session_id_for_sender = runtime_session_id.clone(); tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) { if ws_sender.send(WsMessage::Text(text.into())).await.is_err() { #[cfg(debug_assertions)] tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error"); break; } } } }); while let Some(msg) = ws_receiver.next().await { match msg { Ok(WsMessage::Text(text)) => { let text = text.to_string(); match parse_inbound(&text) { Ok(inbound) => { if let Err(e) = handle_inbound( &state, &sender, &runtime_session_id, &mut current_session_id, inbound, ) .await { tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); let _ = sender .send(WsOutbound::Error { code: "SESSION_ERROR".to_string(), message: e.to_string(), }) .await; } } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); let _ = sender .send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), }) .await; } } } Ok(WsMessage::Close(_)) | Err(_) => { #[cfg(debug_assertions)] tracing::debug!(session_id = %runtime_session_id, "WebSocket closed"); break; } _ => {} } } state .channel_manager .cli_channel() .unregister_connection(&runtime_session_id) .await; tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { SessionSummary { session_id: record.id, title: record.title, channel_name: record.channel_name, chat_id: record.chat_id, message_count: record.message_count, last_active_at: record.last_active_at, archived_at: record.archived_at, } } async fn handle_inbound( state: &Arc, sender: &mpsc::Sender, runtime_session_id: &str, current_session_id: &mut String, inbound: WsInbound, ) -> Result<(), crate::agent::AgentError> { match inbound { WsInbound::UserInput { content, chat_id, sender_id, .. } => { let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); state .channel_manager .cli_channel() .register_connection( chat_id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; state .bus .publish_inbound(InboundMessage { channel: CLI_CHANNEL_NAME.to_string(), sender_id, chat_id, content, timestamp: current_timestamp(), media: Vec::new(), metadata: HashMap::new(), forwarded_metadata: HashMap::new(), }) .await .map_err(|error| AgentError::Other(error.to_string()))?; Ok(()) } WsInbound::ClearHistory { session_id, chat_id, } => { let target = session_id .or(chat_id) .unwrap_or_else(|| current_session_id.clone()); state .session_manager .cli_sessions() .clear_messages(&target)?; if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { session.lock().await.remove_history(&target); } let _ = sender .send(WsOutbound::HistoryCleared { session_id: target }) .await; Ok(()) } WsInbound::CreateSession { title } => { // 使用新的命令层处理 let _input_adapter = WebSocketInputAdapter::new(); let output_adapter = WebSocketOutputAdapter::new(); let cli_sessions = state.session_manager.cli_sessions(); let handler = SessionCommandHandler::new(cli_sessions); let router = { let mut r = CommandRouter::new(); r.register(Box::new(handler)); r }; // 构建命令 let cmd = crate::command::Command::CreateSession { title }; let cmd_ctx = CommandContext::new("websocket") .with_session_id(current_session_id.as_str()); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; // 适配输出 let outbounds = output_adapter.adapt(response); // 处理响应 for msg in outbounds { if let WsOutbound::SessionCreated { session_id, title: _ } = &msg { *current_session_id = session_id.clone(); state .channel_manager .cli_channel() .register_connection( session_id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; } let _ = sender.send(msg).await; } Ok(()) } WsInbound::ListSessions { include_archived } => { let records = state .session_manager .cli_sessions() .list(include_archived)?; let summaries = records.into_iter().map(to_session_summary).collect(); let _ = sender .send(WsOutbound::SessionList { sessions: summaries, current_session_id: Some(current_session_id.clone()), }) .await; Ok(()) } WsInbound::LoadSession { session_id } => { let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else { let _ = sender .send(WsOutbound::Error { code: "SESSION_NOT_FOUND".to_string(), message: format!("Session not found: {}", session_id), }) .await; return Ok(()); }; *current_session_id = record.id.clone(); state .channel_manager .cli_channel() .register_connection( record.id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; let _ = sender .send(WsOutbound::SessionLoaded { session_id: record.id, title: record.title, message_count: record.message_count, }) .await; Ok(()) } WsInbound::RenameSession { session_id, title } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state .session_manager .cli_sessions() .rename(&target, &title)?; let _ = sender .send(WsOutbound::SessionRenamed { session_id: target, title, }) .await; Ok(()) } WsInbound::ArchiveSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state.session_manager.cli_sessions().archive(&target)?; let _ = sender .send(WsOutbound::SessionArchived { session_id: target }) .await; Ok(()) } WsInbound::DeleteSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state.session_manager.cli_sessions().delete(&target)?; let replacement = if target == *current_session_id { Some(state.session_manager.cli_sessions().create(None)?) } else { None }; if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { session.lock().await.remove_history(&target); } let _ = sender .send(WsOutbound::SessionDeleted { session_id: target.clone(), }) .await; if let Some(record) = replacement { *current_session_id = record.id.clone(); state .channel_manager .cli_channel() .register_connection( record.id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; let _ = sender .send(WsOutbound::SessionCreated { session_id: record.id, title: record.title, }) .await; } Ok(()) } WsInbound::SaveSession { filepath, session_id } => { let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone()); // 获取所需依赖 let store = state.session_manager.store(); let provider_config = state.config.get_provider_config("default") .map_err(|e| AgentError::Other(e.to_string()))?; // 构建处理器 let handler = SaveSessionCommandHandler::new(store, provider_config); let router = { let mut r = CommandRouter::new(); r.register(Box::new(handler)); r }; // 构建命令 let cmd = crate::command::Command::SaveSession { filepath }; let cmd_ctx = CommandContext::new("websocket") .with_session_id(&target_session_id); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; // 处理响应 if response.success { let filepath = response .metadata .get("filepath") .cloned() .unwrap_or_default(); let _ = sender .send(WsOutbound::SessionSaved { session_id: target_session_id, filepath, }) .await; } else { let error = response.error.unwrap_or_else(|| { crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error") }); let _ = sender .send(WsOutbound::Error { code: error.code, message: error.message, }) .await; } Ok(()) } WsInbound::Ping => { let _ = sender.send(WsOutbound::Pong).await; Ok(()) } } } fn current_timestamp() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_millis() as i64 } fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String { sender_id .map(str::trim) .filter(|sender_id| !sender_id.is_empty()) .map(ToOwned::to_owned) .unwrap_or_else(|| runtime_session_id.to_string()) } #[cfg(test)] mod tests { use super::resolve_ws_sender_id; #[test] fn test_resolve_ws_sender_id_prefers_inbound_sender() { assert_eq!( resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42" ); assert_eq!( resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42" ); } #[test] fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() { assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); } }