ooodc 8bb32fa066 feat: enhance WebSocket session management and storage
- Added SessionSummary struct for session metadata.
- Updated ws_handler to create and manage CLI sessions more robustly.
- Implemented session creation, loading, renaming, archiving, and deletion via WebSocket messages.
- Introduced SessionStore for persistent session storage using SQLite.
- Enhanced error handling and logging for session operations.
- Updated protocol definitions for new session-related WebSocket messages.
- Refactored tests to cover new session functionalities and ensure proper serialization.
2026-04-18 13:09:14 +08:00

320 lines
12 KiB
Rust

use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::Session};
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
return;
}
};
let initial_record = match state.session_manager.create_cli_session(None) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
return;
}
};
let channel_name = "cli".to_string();
// 创建 CLI session
let session = match Session::new(
channel_name.clone(),
provider_config,
sender,
state.session_manager.tools(),
state.session_manager.store(),
)
.await
{
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
return;
}
};
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
return;
}
let runtime_session_id = session.lock().await.id;
let mut current_session_id = initial_record.id.clone();
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = session
.lock()
.await
.send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(),
})
.await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = runtime_session_id;
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
if let Err(e) = handle_inbound(&state, &session, &mut current_session_id, inbound).await {
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
break;
}
_ => {}
}
}
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
SessionSummary {
session_id: record.id,
title: record.title,
channel_name: record.channel_name,
chat_id: record.chat_id,
message_count: record.message_count,
last_active_at: record.last_active_at,
archived_at: record.archived_at,
}
}
async fn handle_inbound(
state: &Arc<GatewayState>,
session: &Arc<Mutex<Session>>,
current_session_id: &mut String,
inbound: WsInbound,
) -> Result<(), crate::agent::AgentError> {
match inbound {
WsInbound::UserInput { content, chat_id, .. } => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
session_guard.append_persisted_message(&chat_id, user_message)?;
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
let history = match session_guard
.compressor()
.compress_if_needed(raw_history, session_guard.provider_config())
.await
{
Ok(history) => history,
Err(error) => {
tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history");
session_guard.get_or_create_history(&chat_id).clone()
}
};
let agent = session_guard.create_agent()?;
match agent.process(history).await {
Ok(response) => {
session_guard.append_persisted_message(&chat_id, response.clone())?;
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
})
.await;
}
Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
let _ = session_guard
.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: error.to_string(),
})
.await;
}
}
Ok(())
}
WsInbound::ClearHistory { session_id, chat_id } => {
let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?;
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::HistoryCleared {
session_id: target,
})
.await;
Ok(())
}
WsInbound::CreateSession { title } => {
let record = state.session_manager.create_cli_session(title.as_deref())?;
*current_session_id = record.id.clone();
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
})
.await;
Ok(())
}
WsInbound::ListSessions { include_archived } => {
let records = state.session_manager.list_cli_sessions(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect();
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionList {
sessions: summaries,
current_session_id: Some(current_session_id.clone()),
})
.await;
Ok(())
}
WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::Error {
code: "SESSION_NOT_FOUND".to_string(),
message: format!("Session not found: {}", session_id),
})
.await;
return Ok(());
};
*current_session_id = record.id.clone();
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionLoaded {
session_id: record.id,
title: record.title,
message_count: record.message_count,
})
.await;
Ok(())
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.rename_session(&target, &title)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionRenamed {
session_id: target,
title,
})
.await;
Ok(())
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.archive_session(&target)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target })
.await;
Ok(())
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.delete_session(&target)?;
let replacement = if target == *current_session_id {
Some(state.session_manager.create_cli_session(None)?)
} else {
None
};
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::SessionDeleted {
session_id: target.clone(),
})
.await;
if let Some(record) = replacement {
*current_session_id = record.id.clone();
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
})
.await;
}
Ok(())
}
WsInbound::Ping => {
let session_guard = session.lock().await;
let _ = session_guard.send(WsOutbound::Pong).await;
Ok(())
}
}
}