use std::sync::Arc; use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use axum::extract::State; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; use crate::bus::ChatMessage; use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound}; use super::{GatewayState, session::Session}; pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; }) } async fn handle_socket(ws: WebSocket, state: Arc) { let (sender, receiver) = mpsc::channel::(100); let provider_config = match state.config.get_provider_config("default") { Ok(cfg) => cfg, Err(e) => { eprintln!("Failed to get provider config: {}", e); return; } }; // CLI 使用独立的 session,channel_name = "cli-{uuid}" let channel_name = format!("cli-{}", uuid::Uuid::new_v4()); // 创建 CLI session let session = match Session::new(channel_name.clone(), provider_config, sender).await { Ok(s) => Arc::new(Mutex::new(s)), Err(e) => { eprintln!("Failed to create session: {}", e); return; } }; let session_id = session.lock().await.id; let _ = session.lock().await.send(WsOutbound::SessionEstablished { session_id: session_id.to_string(), }).await; let (mut ws_sender, mut ws_receiver) = ws.split(); let mut receiver = receiver; tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) { if ws_sender.send(WsMessage::Text(text.into())).await.is_err() { break; } } } }); while let Some(msg) = ws_receiver.next().await { match msg { Ok(WsMessage::Text(text)) => { let text = text.to_string(); match parse_inbound(&text) { Ok(inbound) => { handle_inbound(&session, inbound).await; } Err(e) => { let _ = session.lock().await.send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), }).await; } } } Ok(WsMessage::Close(_)) | Err(_) => { break; } _ => {} } } } async fn handle_inbound(session: &Arc>, inbound: WsInbound) { let inbound_clone = inbound.clone(); // 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id) let (content, chat_id) = match inbound_clone { WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => { // CLI 使用 session 中的 channel_name 作为标识 // chat_id 使用传入的或使用默认 let chat_id = chat_id.unwrap_or_else(|| "default".to_string()); (content, chat_id) } _ => return, }; let user_msg = ChatMessage::user(content); let mut session_guard = session.lock().await; let agent = match session_guard.get_or_create_agent(&chat_id).await { Ok(a) => a, Err(e) => { let _ = session_guard.send(WsOutbound::Error { code: "AGENT_ERROR".to_string(), message: e.to_string(), }).await; return; } }; drop(session_guard); let mut agent = agent.lock().await; match agent.process(user_msg).await { Ok(response) => { let _ = session.lock().await.send(WsOutbound::AssistantResponse { id: response.id, content: response.content, role: response.role, }).await; } Err(e) => { let _ = session.lock().await.send(WsOutbound::Error { code: "LLM_ERROR".to_string(), message: e.to_string(), }).await; } } }