use super::GatewayState; use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::InboundMessage; use crate::command::adapter::{InputAdapter, OutputAdapter}; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; use crate::command::handlers::session_query::SessionQueryCommandHandler; use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use crate::skills::SkillPromptProvider; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::mpsc; const CLI_CHANNEL_NAME: &str = "cli"; pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; }) } async fn handle_socket(ws: WebSocket, state: Arc) { let (sender, receiver) = mpsc::channel::(100); let cli_sessions = state.session_manager.cli_sessions(); let initial_record = match cli_sessions.create(None) { Ok(record) => record, Err(e) => { tracing::error!(error = %e, "Failed to create initial CLI session"); return; } }; let runtime_session_id = uuid::Uuid::new_v4().to_string(); let mut current_session_id = initial_record.id.clone(); let mut current_topic_id: Option = None; state .channel_manager .cli_channel() .register_connection( current_session_id.clone(), runtime_session_id.clone(), sender.clone(), ) .await; tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established"); let _ = sender .send(WsOutbound::SessionEstablished { session_id: current_session_id.clone(), }) .await; let (mut ws_sender, mut ws_receiver) = ws.split(); let mut receiver = receiver; let session_id_for_sender = runtime_session_id.clone(); tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) { if ws_sender.send(WsMessage::Text(text.into())).await.is_err() { #[cfg(debug_assertions)] tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error"); break; } } } }); while let Some(msg) = ws_receiver.next().await { match msg { Ok(WsMessage::Text(text)) => { let text = text.to_string(); match parse_inbound(&text) { Ok(inbound) => { if let Err(e) = handle_inbound( &state, &sender, &runtime_session_id, &mut current_session_id, &mut current_topic_id, inbound, ) .await { tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); let _ = sender .send(WsOutbound::Error { code: "SESSION_ERROR".to_string(), message: e.to_string(), }) .await; } } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); let _ = sender .send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), }) .await; } } } Ok(WsMessage::Close(_)) | Err(_) => { #[cfg(debug_assertions)] tracing::debug!(session_id = %runtime_session_id, "WebSocket closed"); break; } _ => {} } } state .channel_manager .cli_channel() .unregister_connection(&runtime_session_id) .await; tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } async fn handle_inbound( state: &Arc, sender: &mpsc::Sender, runtime_session_id: &str, current_session_id: &mut String, current_topic_id: &mut Option, inbound: WsInbound, ) -> Result<(), AgentError> { match inbound { WsInbound::Message { content, chat_id, sender_id, .. } => { let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); state .channel_manager .cli_channel() .register_connection( chat_id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; state .bus .publish_inbound(InboundMessage { channel: CLI_CHANNEL_NAME.to_string(), sender_id, chat_id, content, timestamp: current_timestamp(), media: Vec::new(), metadata: HashMap::new(), forwarded_metadata: HashMap::new(), }) .await .map_err(|error| AgentError::Other(error.to_string()))?; Ok(()) } WsInbound::Command { payload } => { // 使用 Command 系统处理命令 let input_adapter = WebSocketInputAdapter::new(); let output_adapter = WebSocketOutputAdapter::new(); // 解析命令 let adapter_ctx = crate::command::context::AdapterContext::new("websocket") .with_session_id(current_session_id.as_str()); let cmd = match input_adapter.try_parse(&payload, adapter_ctx) { Ok(Some(cmd)) => cmd, Ok(None) => { // 不是命令,返回错误 let _ = sender .send(WsOutbound::Error { code: "INVALID_COMMAND".to_string(), message: "Invalid command payload".to_string(), }) .await; return Ok(()); } Err(e) => { let _ = sender .send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), }) .await; return Ok(()); } }; // 创建命令路由器 let _cli_sessions = state.session_manager.cli_sessions(); let store = state.session_manager.store(); let skills = state.session_manager.skills(); let provider_config = state.config.get_provider_config("default") .map_err(|e| AgentError::Other(e.to_string()))?; let prompt_repository = state.session_manager.store().clone(); let system_prompt_provider: Arc = Arc::new(CompositeSystemPromptProvider::new(vec![ Box::new(AgentPromptProvider::new( 0, provider_config.clone(), prompt_repository.clone(), )), Box::new(SkillPromptProvider::new(skills)), ])); let mut router = CommandRouter::new(); // 注册 Session 处理器,添加 SessionManager let session_handler = SessionCommandHandler::new(store.clone()) .with_session_manager(state.session_manager.clone()); router.register(Box::new(session_handler)); // 注册 SessionQuery 处理器 let session_query_handler = SessionQueryCommandHandler::new(store.clone()) .with_session_manager(state.session_manager.clone()); router.register(Box::new(session_query_handler)); router.register(Box::new(SaveSessionCommandHandler::new( store, system_prompt_provider, ))); // 构建命令上下文 tracing::debug!( current_session_id = %current_session_id, current_topic_id = ?current_topic_id, "Building CommandContext for WebSocket command" ); let cmd_ctx = CommandContext::new("websocket", "cli") .with_session_id(current_session_id.as_str()) .with_chat_id(current_session_id.as_str()) .with_topic_id(current_topic_id.as_deref().unwrap_or("")); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; // 处理响应 if response.success { // 更新当前会话 ID(如果是创建会话) if let Some(session_id) = response.metadata.get("session_id") { tracing::info!( old_session_id = %current_session_id, new_session_id = %session_id, "Updating current_session_id" ); *current_session_id = session_id.clone(); state .channel_manager .cli_channel() .register_connection( session_id.clone(), runtime_session_id.to_string(), sender.clone(), ) .await; } // 更新当前话题 ID(如果是创建话题或切换话题) if let Some(topic_id) = response.metadata.get("topic_id") { tracing::info!( old_topic_id = ?current_topic_id, new_topic_id = %topic_id, "Updating current_topic_id" ); *current_topic_id = Some(topic_id.clone()); } } else if let Some(ref error) = response.error { tracing::warn!( error_code = %error.code, error_message = %error.message, "Command failed" ); } // 适配并发送响应 let outbounds = output_adapter.adapt(response); for msg in outbounds { let _ = sender.send(msg).await; } Ok(()) } WsInbound::Ping => { let _ = sender.send(WsOutbound::Pong).await; Ok(()) } } } fn current_timestamp() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_millis() as i64 } fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String { sender_id .map(str::trim) .filter(|sender_id| !sender_id.is_empty()) .map(ToOwned::to_owned) .unwrap_or_else(|| runtime_session_id.to_string()) } #[cfg(test)] mod tests { use super::resolve_ws_sender_id; #[test] fn test_resolve_ws_sender_id_prefers_inbound_sender() { assert_eq!( resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42" ); assert_eq!( resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42" ); } #[test] fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() { assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); } }