129 lines
4.1 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
use super::{GatewayState, session::Session};
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
eprintln!("Failed to get provider config: {}", e);
return;
}
};
// CLI 使用独立的 sessionchannel_name = "cli-{uuid}"
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
// 创建 CLI session
let session = match Session::new(channel_name.clone(), provider_config, sender).await {
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
eprintln!("Failed to create session: {}", e);
return;
}
};
let session_id = session.lock().await.id;
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
session_id: session_id.to_string(),
}).await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
handle_inbound(&session, inbound).await;
}
Err(e) => {
let _ = session.lock().await.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
break;
}
_ => {}
}
}
}
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
let inbound_clone = inbound.clone();
// 提取 content 和 chat_idCLI 使用 session id 作为 chat_id
let (content, chat_id) = match inbound_clone {
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
// CLI 使用 session 中的 channel_name 作为标识
// chat_id 使用传入的或使用默认
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
(content, chat_id)
}
_ => return,
};
let user_msg = ChatMessage::user(content);
let mut session_guard = session.lock().await;
let agent = match session_guard.get_or_create_agent(&chat_id).await {
Ok(a) => a,
Err(e) => {
let _ = session_guard.send(WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: e.to_string(),
}).await;
return;
}
};
drop(session_guard);
let mut agent = agent.lock().await;
match agent.process(user_msg).await {
Ok(response) => {
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
}).await;
}
Err(e) => {
let _ = session.lock().await.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}