446 lines
15 KiB
Rust
446 lines
15 KiB
Rust
use super::GatewayState;
|
|
use crate::agent::AgentError;
|
|
use crate::bus::InboundMessage;
|
|
use crate::command::adapter::OutputAdapter;
|
|
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
|
use crate::command::context::CommandContext;
|
|
use crate::command::handler::CommandRouter;
|
|
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
|
use crate::command::handlers::session::SessionCommandHandler;
|
|
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
|
use axum::extract::State;
|
|
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
|
use axum::response::Response;
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
|
|
const CLI_CHANNEL_NAME: &str = "cli";
|
|
|
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
|
ws.on_upgrade(|socket| async {
|
|
handle_socket(socket, state).await;
|
|
})
|
|
}
|
|
|
|
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
|
|
|
let cli_sessions = state.session_manager.cli_sessions();
|
|
let initial_record = match cli_sessions.create(None) {
|
|
Ok(record) => record,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to create initial CLI session");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
|
let mut current_session_id = initial_record.id.clone();
|
|
state
|
|
.channel_manager
|
|
.cli_channel()
|
|
.register_connection(
|
|
current_session_id.clone(),
|
|
runtime_session_id.clone(),
|
|
sender.clone(),
|
|
)
|
|
.await;
|
|
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
|
|
|
let _ = sender
|
|
.send(WsOutbound::SessionEstablished {
|
|
session_id: current_session_id.clone(),
|
|
})
|
|
.await;
|
|
|
|
let (mut ws_sender, mut ws_receiver) = ws.split();
|
|
|
|
let mut receiver = receiver;
|
|
let session_id_for_sender = runtime_session_id.clone();
|
|
tokio::spawn(async move {
|
|
while let Some(msg) = receiver.recv().await {
|
|
if let Ok(text) = serialize_outbound(&msg) {
|
|
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
while let Some(msg) = ws_receiver.next().await {
|
|
match msg {
|
|
Ok(WsMessage::Text(text)) => {
|
|
let text = text.to_string();
|
|
match parse_inbound(&text) {
|
|
Ok(inbound) => {
|
|
if let Err(e) = handle_inbound(
|
|
&state,
|
|
&sender,
|
|
&runtime_session_id,
|
|
&mut current_session_id,
|
|
inbound,
|
|
)
|
|
.await
|
|
{
|
|
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
|
let _ = sender
|
|
.send(WsOutbound::Error {
|
|
code: "SESSION_ERROR".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
|
let _ = sender
|
|
.send(WsOutbound::Error {
|
|
code: "PARSE_ERROR".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
Ok(WsMessage::Close(_)) | Err(_) => {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
state
|
|
.channel_manager
|
|
.cli_channel()
|
|
.unregister_connection(&runtime_session_id)
|
|
.await;
|
|
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
|
}
|
|
|
|
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
|
SessionSummary {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
channel_name: record.channel_name,
|
|
chat_id: record.chat_id,
|
|
message_count: record.message_count,
|
|
last_active_at: record.last_active_at,
|
|
archived_at: record.archived_at,
|
|
}
|
|
}
|
|
|
|
async fn handle_inbound(
|
|
state: &Arc<GatewayState>,
|
|
sender: &mpsc::Sender<WsOutbound>,
|
|
runtime_session_id: &str,
|
|
current_session_id: &mut String,
|
|
inbound: WsInbound,
|
|
) -> Result<(), crate::agent::AgentError> {
|
|
match inbound {
|
|
WsInbound::UserInput {
|
|
content,
|
|
chat_id,
|
|
sender_id,
|
|
..
|
|
} => {
|
|
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
|
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
|
|
|
state
|
|
.channel_manager
|
|
.cli_channel()
|
|
.register_connection(
|
|
chat_id.clone(),
|
|
runtime_session_id.to_string(),
|
|
sender.clone(),
|
|
)
|
|
.await;
|
|
|
|
state
|
|
.bus
|
|
.publish_inbound(InboundMessage {
|
|
channel: CLI_CHANNEL_NAME.to_string(),
|
|
sender_id,
|
|
chat_id,
|
|
content,
|
|
timestamp: current_timestamp(),
|
|
media: Vec::new(),
|
|
metadata: HashMap::new(),
|
|
forwarded_metadata: HashMap::new(),
|
|
})
|
|
.await
|
|
.map_err(|error| AgentError::Other(error.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
WsInbound::ClearHistory {
|
|
session_id,
|
|
chat_id,
|
|
} => {
|
|
let target = session_id
|
|
.or(chat_id)
|
|
.unwrap_or_else(|| current_session_id.clone());
|
|
state
|
|
.session_manager
|
|
.cli_sessions()
|
|
.clear_messages(&target)?;
|
|
|
|
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
|
session.lock().await.remove_history(&target);
|
|
}
|
|
|
|
let _ = sender
|
|
.send(WsOutbound::HistoryCleared { session_id: target })
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::CreateSession { title } => {
|
|
// 使用新的命令层处理
|
|
let _input_adapter = WebSocketInputAdapter::new();
|
|
let output_adapter = WebSocketOutputAdapter::new();
|
|
let cli_sessions = state.session_manager.cli_sessions();
|
|
let handler = SessionCommandHandler::new(cli_sessions);
|
|
let router = {
|
|
let mut r = CommandRouter::new();
|
|
r.register(Box::new(handler));
|
|
r
|
|
};
|
|
|
|
// 构建命令
|
|
let cmd = crate::command::Command::CreateSession { title };
|
|
let cmd_ctx = CommandContext::new("websocket")
|
|
.with_session_id(current_session_id.as_str());
|
|
|
|
// 执行命令
|
|
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
|
|
|
// 适配输出
|
|
let outbounds = output_adapter.adapt(response);
|
|
|
|
// 处理响应
|
|
for msg in outbounds {
|
|
if let WsOutbound::SessionCreated { session_id, title: _ } = &msg {
|
|
*current_session_id = session_id.clone();
|
|
state
|
|
.channel_manager
|
|
.cli_channel()
|
|
.register_connection(
|
|
session_id.clone(),
|
|
runtime_session_id.to_string(),
|
|
sender.clone(),
|
|
)
|
|
.await;
|
|
}
|
|
let _ = sender.send(msg).await;
|
|
}
|
|
Ok(())
|
|
}
|
|
WsInbound::ListSessions { include_archived } => {
|
|
let records = state
|
|
.session_manager
|
|
.cli_sessions()
|
|
.list(include_archived)?;
|
|
let summaries = records.into_iter().map(to_session_summary).collect();
|
|
|
|
let _ = sender
|
|
.send(WsOutbound::SessionList {
|
|
sessions: summaries,
|
|
current_session_id: Some(current_session_id.clone()),
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::LoadSession { session_id } => {
|
|
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
|
|
let _ = sender
|
|
.send(WsOutbound::Error {
|
|
code: "SESSION_NOT_FOUND".to_string(),
|
|
message: format!("Session not found: {}", session_id),
|
|
})
|
|
.await;
|
|
return Ok(());
|
|
};
|
|
|
|
*current_session_id = record.id.clone();
|
|
state
|
|
.channel_manager
|
|
.cli_channel()
|
|
.register_connection(
|
|
record.id.clone(),
|
|
runtime_session_id.to_string(),
|
|
sender.clone(),
|
|
)
|
|
.await;
|
|
let _ = sender
|
|
.send(WsOutbound::SessionLoaded {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
message_count: record.message_count,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::RenameSession { session_id, title } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state
|
|
.session_manager
|
|
.cli_sessions()
|
|
.rename(&target, &title)?;
|
|
let _ = sender
|
|
.send(WsOutbound::SessionRenamed {
|
|
session_id: target,
|
|
title,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::ArchiveSession { session_id } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.cli_sessions().archive(&target)?;
|
|
let _ = sender
|
|
.send(WsOutbound::SessionArchived { session_id: target })
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::DeleteSession { session_id } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.cli_sessions().delete(&target)?;
|
|
|
|
let replacement = if target == *current_session_id {
|
|
Some(state.session_manager.cli_sessions().create(None)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
|
session.lock().await.remove_history(&target);
|
|
}
|
|
|
|
let _ = sender
|
|
.send(WsOutbound::SessionDeleted {
|
|
session_id: target.clone(),
|
|
})
|
|
.await;
|
|
|
|
if let Some(record) = replacement {
|
|
*current_session_id = record.id.clone();
|
|
state
|
|
.channel_manager
|
|
.cli_channel()
|
|
.register_connection(
|
|
record.id.clone(),
|
|
runtime_session_id.to_string(),
|
|
sender.clone(),
|
|
)
|
|
.await;
|
|
let _ = sender
|
|
.send(WsOutbound::SessionCreated {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
})
|
|
.await;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
WsInbound::SaveSession { filepath, session_id } => {
|
|
let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
|
|
// 获取所需依赖
|
|
let store = state.session_manager.store();
|
|
let provider_config = state.config.get_provider_config("default")
|
|
.map_err(|e| AgentError::Other(e.to_string()))?;
|
|
|
|
// 构建处理器
|
|
let handler = SaveSessionCommandHandler::new(store, provider_config);
|
|
let router = {
|
|
let mut r = CommandRouter::new();
|
|
r.register(Box::new(handler));
|
|
r
|
|
};
|
|
|
|
// 构建命令
|
|
let cmd = crate::command::Command::SaveSession { filepath };
|
|
let cmd_ctx = CommandContext::new("websocket")
|
|
.with_session_id(&target_session_id);
|
|
|
|
// 执行命令
|
|
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
|
|
|
// 处理响应
|
|
if response.success {
|
|
let filepath = response
|
|
.metadata
|
|
.get("filepath")
|
|
.cloned()
|
|
.unwrap_or_default();
|
|
let _ = sender
|
|
.send(WsOutbound::SessionSaved {
|
|
session_id: target_session_id,
|
|
filepath,
|
|
})
|
|
.await;
|
|
} else {
|
|
let error = response.error.unwrap_or_else(|| {
|
|
crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error")
|
|
});
|
|
let _ = sender
|
|
.send(WsOutbound::Error {
|
|
code: error.code,
|
|
message: error.message,
|
|
})
|
|
.await;
|
|
}
|
|
Ok(())
|
|
}
|
|
WsInbound::Ping => {
|
|
let _ = sender.send(WsOutbound::Pong).await;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
fn current_timestamp() -> i64 {
|
|
std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_millis() as i64
|
|
}
|
|
|
|
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
|
sender_id
|
|
.map(str::trim)
|
|
.filter(|sender_id| !sender_id.is_empty())
|
|
.map(ToOwned::to_owned)
|
|
.unwrap_or_else(|| runtime_session_id.to_string())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::resolve_ws_sender_id;
|
|
|
|
#[test]
|
|
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
|
assert_eq!(
|
|
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
|
|
"user-42"
|
|
);
|
|
assert_eq!(
|
|
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
|
|
"user-42"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
|
|
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
|
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
|
}
|
|
}
|