- Added SessionSummary struct for session metadata. - Updated ws_handler to create and manage CLI sessions more robustly. - Implemented session creation, loading, renaming, archiving, and deletion via WebSocket messages. - Introduced SessionStore for persistent session storage using SQLite. - Enhanced error handling and logging for session operations. - Updated protocol definitions for new session-related WebSocket messages. - Refactored tests to cover new session functionalities and ensure proper serialization.
320 lines
12 KiB
Rust
320 lines
12 KiB
Rust
use std::sync::Arc;
|
|
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
|
use axum::extract::State;
|
|
use axum::response::Response;
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use tokio::sync::{mpsc, Mutex};
|
|
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
|
use super::{GatewayState, session::Session};
|
|
|
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
|
ws.on_upgrade(|socket| async {
|
|
handle_socket(socket, state).await;
|
|
})
|
|
}
|
|
|
|
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
|
|
|
let provider_config = match state.config.get_provider_config("default") {
|
|
Ok(cfg) => cfg,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to get provider config");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let initial_record = match state.session_manager.create_cli_session(None) {
|
|
Ok(record) => record,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to create initial CLI session");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let channel_name = "cli".to_string();
|
|
|
|
// 创建 CLI session
|
|
let session = match Session::new(
|
|
channel_name.clone(),
|
|
provider_config,
|
|
sender,
|
|
state.session_manager.tools(),
|
|
state.session_manager.store(),
|
|
)
|
|
.await
|
|
{
|
|
Ok(s) => Arc::new(Mutex::new(s)),
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to create session");
|
|
return;
|
|
}
|
|
};
|
|
|
|
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
|
|
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
|
|
return;
|
|
}
|
|
|
|
let runtime_session_id = session.lock().await.id;
|
|
let mut current_session_id = initial_record.id.clone();
|
|
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
|
|
|
let _ = session
|
|
.lock()
|
|
.await
|
|
.send(WsOutbound::SessionEstablished {
|
|
session_id: current_session_id.clone(),
|
|
})
|
|
.await;
|
|
|
|
let (mut ws_sender, mut ws_receiver) = ws.split();
|
|
|
|
let mut receiver = receiver;
|
|
let session_id_for_sender = runtime_session_id;
|
|
tokio::spawn(async move {
|
|
while let Some(msg) = receiver.recv().await {
|
|
if let Ok(text) = serialize_outbound(&msg) {
|
|
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
while let Some(msg) = ws_receiver.next().await {
|
|
match msg {
|
|
Ok(WsMessage::Text(text)) => {
|
|
let text = text.to_string();
|
|
match parse_inbound(&text) {
|
|
Ok(inbound) => {
|
|
if let Err(e) = handle_inbound(&state, &session, &mut current_session_id, inbound).await {
|
|
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
|
let _ = session
|
|
.lock()
|
|
.await
|
|
.send(WsOutbound::Error {
|
|
code: "SESSION_ERROR".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
|
let _ = session
|
|
.lock()
|
|
.await
|
|
.send(WsOutbound::Error {
|
|
code: "PARSE_ERROR".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
Ok(WsMessage::Close(_)) | Err(_) => {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
|
}
|
|
|
|
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
|
SessionSummary {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
channel_name: record.channel_name,
|
|
chat_id: record.chat_id,
|
|
message_count: record.message_count,
|
|
last_active_at: record.last_active_at,
|
|
archived_at: record.archived_at,
|
|
}
|
|
}
|
|
|
|
async fn handle_inbound(
|
|
state: &Arc<GatewayState>,
|
|
session: &Arc<Mutex<Session>>,
|
|
current_session_id: &mut String,
|
|
inbound: WsInbound,
|
|
) -> Result<(), crate::agent::AgentError> {
|
|
match inbound {
|
|
WsInbound::UserInput { content, chat_id, .. } => {
|
|
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
|
let mut session_guard = session.lock().await;
|
|
|
|
session_guard.ensure_persistent_session(&chat_id)?;
|
|
session_guard.ensure_chat_loaded(&chat_id)?;
|
|
|
|
let user_message = session_guard.create_user_message(&content, Vec::new());
|
|
session_guard.append_persisted_message(&chat_id, user_message)?;
|
|
|
|
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
|
let history = match session_guard
|
|
.compressor()
|
|
.compress_if_needed(raw_history, session_guard.provider_config())
|
|
.await
|
|
{
|
|
Ok(history) => history,
|
|
Err(error) => {
|
|
tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history");
|
|
session_guard.get_or_create_history(&chat_id).clone()
|
|
}
|
|
};
|
|
|
|
let agent = session_guard.create_agent()?;
|
|
match agent.process(history).await {
|
|
Ok(response) => {
|
|
session_guard.append_persisted_message(&chat_id, response.clone())?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::AssistantResponse {
|
|
id: response.id,
|
|
content: response.content,
|
|
role: response.role,
|
|
})
|
|
.await;
|
|
}
|
|
Err(error) => {
|
|
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
|
let _ = session_guard
|
|
.send(WsOutbound::Error {
|
|
code: "LLM_ERROR".to_string(),
|
|
message: error.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
WsInbound::ClearHistory { session_id, chat_id } => {
|
|
let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.clear_session_messages(&target)?;
|
|
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.remove_history(&target);
|
|
let _ = session_guard
|
|
.send(WsOutbound::HistoryCleared {
|
|
session_id: target,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::CreateSession { title } => {
|
|
let record = state.session_manager.create_cli_session(title.as_deref())?;
|
|
*current_session_id = record.id.clone();
|
|
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.ensure_chat_loaded(&record.id)?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionCreated {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::ListSessions { include_archived } => {
|
|
let records = state.session_manager.list_cli_sessions(include_archived)?;
|
|
let summaries = records.into_iter().map(to_session_summary).collect();
|
|
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionList {
|
|
sessions: summaries,
|
|
current_session_id: Some(current_session_id.clone()),
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::LoadSession { session_id } => {
|
|
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::Error {
|
|
code: "SESSION_NOT_FOUND".to_string(),
|
|
message: format!("Session not found: {}", session_id),
|
|
})
|
|
.await;
|
|
return Ok(());
|
|
};
|
|
|
|
*current_session_id = record.id.clone();
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.ensure_chat_loaded(&record.id)?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionLoaded {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
message_count: record.message_count,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::RenameSession { session_id, title } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.rename_session(&target, &title)?;
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionRenamed {
|
|
session_id: target,
|
|
title,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::ArchiveSession { session_id } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.archive_session(&target)?;
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionArchived { session_id: target })
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::DeleteSession { session_id } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.delete_session(&target)?;
|
|
|
|
let replacement = if target == *current_session_id {
|
|
Some(state.session_manager.create_cli_session(None)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.remove_history(&target);
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionDeleted {
|
|
session_id: target.clone(),
|
|
})
|
|
.await;
|
|
|
|
if let Some(record) = replacement {
|
|
*current_session_id = record.id.clone();
|
|
session_guard.ensure_chat_loaded(&record.id)?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionCreated {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
})
|
|
.await;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
WsInbound::Ping => {
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard.send(WsOutbound::Pong).await;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|