use super::{ GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}, }; use crate::agent::EmittedMessageHandler; use crate::bus::ChatMessage; use crate::bus::message::{ToolMessageState, format_tool_call_content}; use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use async_trait::async_trait; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use std::sync::Arc; use tokio::sync::{Mutex, mpsc}; struct WsToolCallEmitter { sender: mpsc::Sender, show_tool_results: bool, } #[async_trait] impl EmittedMessageHandler for WsToolCallEmitter { async fn handle(&self, message: ChatMessage) { if !should_display_message_to_user(self.show_tool_results, &message) { return; } for outbound in ws_outbound_from_chat_message(&message) { let _ = self.sender.send(outbound).await; } } } pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; }) } async fn handle_socket(ws: WebSocket, state: Arc) { let (sender, receiver) = mpsc::channel::(100); let provider_config = match state.config.get_provider_config("default") { Ok(cfg) => cfg, Err(e) => { tracing::error!(error = %e, "Failed to get provider config"); return; } }; let initial_record = match state.session_manager.create_cli_session(None) { Ok(record) => record, Err(e) => { tracing::error!(error = %e, "Failed to create initial CLI session"); return; } }; let channel_name = "cli".to_string(); // 创建 CLI session let session = match Session::new( channel_name.clone(), provider_config, sender, state.session_manager.tools(), state.session_manager.skills(), state.session_manager.store(), state.config.gateway.agent_prompt_reinject_every, ) .await { Ok(s) => Arc::new(Mutex::new(s)), Err(e) => { tracing::error!(error = %e, "Failed to create session"); return; } }; if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) { tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history"); return; } let runtime_session_id = session.lock().await.id.to_string(); let mut current_session_id = initial_record.id.clone(); tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established"); let _ = session .lock() .await .send(WsOutbound::SessionEstablished { session_id: current_session_id.clone(), }) .await; let (mut ws_sender, mut ws_receiver) = ws.split(); let mut receiver = receiver; let session_id_for_sender = runtime_session_id.clone(); tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) { if ws_sender.send(WsMessage::Text(text.into())).await.is_err() { #[cfg(debug_assertions)] tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error"); break; } } } }); while let Some(msg) = ws_receiver.next().await { match msg { Ok(WsMessage::Text(text)) => { let text = text.to_string(); match parse_inbound(&text) { Ok(inbound) => { if let Err(e) = handle_inbound( &state, &session, &runtime_session_id, &mut current_session_id, inbound, ) .await { tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); let _ = session .lock() .await .send(WsOutbound::Error { code: "SESSION_ERROR".to_string(), message: e.to_string(), }) .await; } } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); let _ = session .lock() .await .send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), }) .await; } } } Ok(WsMessage::Close(_)) | Err(_) => { #[cfg(debug_assertions)] tracing::debug!(session_id = %runtime_session_id, "WebSocket closed"); break; } _ => {} } } tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { SessionSummary { session_id: record.id, title: record.title, channel_name: record.channel_name, chat_id: record.chat_id, message_count: record.message_count, last_active_at: record.last_active_at, archived_at: record.archived_at, } } fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { match message.role.as_str() { "assistant" => { if let Some(tool_calls) = &message.tool_calls { let mut outbound = Vec::new(); if !message.content.trim().is_empty() { outbound.push(WsOutbound::AssistantResponse { id: message.id.clone(), content: message.content.clone(), role: message.role.clone(), }); } outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall { id: message.id.clone(), tool_call_id: tool_call.id.clone(), tool_name: tool_call.name.clone(), arguments: tool_call.arguments.clone(), content: format_tool_call_content(&tool_call.name, &tool_call.arguments), role: message.role.clone(), })); outbound } else { vec![WsOutbound::AssistantResponse { id: message.id.clone(), content: message.content.clone(), role: message.role.clone(), }] } } "tool" => match message .tool_state .as_ref() .unwrap_or(&ToolMessageState::Completed) { ToolMessageState::Completed => vec![WsOutbound::ToolResult { id: message.id.clone(), tool_call_id: message.tool_call_id.clone().unwrap_or_default(), tool_name: message.tool_name.clone().unwrap_or_default(), content: message.content.clone(), role: message.role.clone(), }], ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending { id: message.id.clone(), tool_call_id: message.tool_call_id.clone().unwrap_or_default(), tool_name: message.tool_name.clone().unwrap_or_default(), content: message.content.clone(), role: message.role.clone(), resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(), }], }, _ => Vec::new(), } } fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool { if message.role != "tool" { return true; } show_tool_results || matches!( message .tool_state .as_ref() .unwrap_or(&ToolMessageState::Completed), ToolMessageState::PendingUserAction ) } async fn handle_inbound( state: &Arc, session: &Arc>, runtime_session_id: &str, current_session_id: &mut String, inbound: WsInbound, ) -> Result<(), crate::agent::AgentError> { match inbound { WsInbound::UserInput { content, chat_id, sender_id, .. } => { let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); let (history, agent, user_tx) = { let mut session_guard = session.lock().await; session_guard.ensure_persistent_session(&chat_id)?; session_guard.ensure_chat_loaded(&chat_id)?; if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? { let _ = session_guard .send(WsOutbound::AssistantResponse { id: uuid::Uuid::new_v4().to_string(), content: command_response, role: "assistant".to_string(), }) .await; return Ok(()); } session_guard.ensure_agent_prompt_before_user_message(&chat_id)?; let user_message = session_guard.create_user_message(&content, Vec::new()); let user_message_id = user_message.id.clone(); session_guard.append_persisted_message(&chat_id, user_message)?; let history = session_guard.get_or_create_history(&chat_id).clone(); session_guard.record_skill_offer(&chat_id)?; let live_emitter = Arc::new(WsToolCallEmitter { sender: session_guard.user_tx.clone(), show_tool_results: state.config.gateway.show_tool_results, }); let agent = session_guard .create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))? .with_emitted_message_handler(live_emitter); (history, agent, session_guard.user_tx.clone()) }; match agent.process(history).await { Ok(result) => { let mut session_guard = session.lock().await; session_guard .append_persisted_messages(&chat_id, result.emitted_messages.clone())?; for outbound in result .emitted_messages .iter() .filter(|message| { !message.is_assistant_tool_call_message() && should_display_message_to_user( state.config.gateway.show_tool_results, message, ) }) .flat_map(ws_outbound_from_chat_message) { let _ = session_guard.send(outbound).await; } drop(session_guard); if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()) .await { tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session"); } } Err(error) => { tracing::error!(chat_id = %chat_id, error = %error, "Agent process error"); let _ = user_tx .send(WsOutbound::Error { code: "LLM_ERROR".to_string(), message: error.to_string(), }) .await; } } Ok(()) } WsInbound::ClearHistory { session_id, chat_id, } => { let target = session_id .or(chat_id) .unwrap_or_else(|| current_session_id.clone()); state.session_manager.clear_session_messages(&target)?; let mut session_guard = session.lock().await; session_guard.remove_history(&target); let _ = session_guard .send(WsOutbound::HistoryCleared { session_id: target }) .await; Ok(()) } WsInbound::CreateSession { title } => { let record = state.session_manager.create_cli_session(title.as_deref())?; *current_session_id = record.id.clone(); let mut session_guard = session.lock().await; session_guard.ensure_chat_loaded(&record.id)?; let _ = session_guard .send(WsOutbound::SessionCreated { session_id: record.id, title: record.title, }) .await; Ok(()) } WsInbound::ListSessions { include_archived } => { let records = state.session_manager.list_cli_sessions(include_archived)?; let summaries = records.into_iter().map(to_session_summary).collect(); let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::SessionList { sessions: summaries, current_session_id: Some(current_session_id.clone()), }) .await; Ok(()) } WsInbound::LoadSession { session_id } => { let Some(record) = state.session_manager.get_session_record(&session_id)? else { let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::Error { code: "SESSION_NOT_FOUND".to_string(), message: format!("Session not found: {}", session_id), }) .await; return Ok(()); }; *current_session_id = record.id.clone(); let mut session_guard = session.lock().await; session_guard.ensure_chat_loaded(&record.id)?; let _ = session_guard .send(WsOutbound::SessionLoaded { session_id: record.id, title: record.title, message_count: record.message_count, }) .await; Ok(()) } WsInbound::RenameSession { session_id, title } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state.session_manager.rename_session(&target, &title)?; let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::SessionRenamed { session_id: target, title, }) .await; Ok(()) } WsInbound::ArchiveSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state.session_manager.archive_session(&target)?; let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::SessionArchived { session_id: target }) .await; Ok(()) } WsInbound::DeleteSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state.session_manager.delete_session(&target)?; let replacement = if target == *current_session_id { Some(state.session_manager.create_cli_session(None)?) } else { None }; let mut session_guard = session.lock().await; session_guard.remove_history(&target); let _ = session_guard .send(WsOutbound::SessionDeleted { session_id: target.clone(), }) .await; if let Some(record) = replacement { *current_session_id = record.id.clone(); session_guard.ensure_chat_loaded(&record.id)?; let _ = session_guard .send(WsOutbound::SessionCreated { session_id: record.id, title: record.title, }) .await; } Ok(()) } WsInbound::Ping => { let session_guard = session.lock().await; let _ = session_guard.send(WsOutbound::Pong).await; Ok(()) } } } fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String { sender_id .map(str::trim) .filter(|sender_id| !sender_id.is_empty()) .map(ToOwned::to_owned) .unwrap_or_else(|| runtime_session_id.to_string()) } #[cfg(test)] mod tests { use super::{ WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user, ws_outbound_from_chat_message, }; use crate::agent::EmittedMessageHandler; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; use crate::protocol::WsOutbound; use crate::providers::ToolCall; use serde_json::json; use tokio::sync::mpsc; #[test] fn test_ws_outbound_from_chat_message_expands_tool_calls() { let message = ChatMessage::assistant_with_tool_calls( "", vec![ToolCall { id: "call-1".to_string(), name: "calculator".to_string(), arguments: json!({"expression": "1 + 1"}), }], ); let outbound = ws_outbound_from_chat_message(&message); assert_eq!(outbound.len(), 1); match &outbound[0] { WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => { assert_eq!(tool_call_id, "call-1"); assert_eq!(tool_name, "calculator"); assert_eq!(arguments["expression"], "1 + 1"); assert_eq!(content, "### calculator\n- expression: 1 + 1"); } other => panic!("unexpected outbound variant: {:?}", other), } } #[test] fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() { let message = ChatMessage::assistant_with_tool_calls( "日报已整理完成。", vec![ToolCall { id: "call-1".to_string(), name: "memory_manage".to_string(), arguments: json!({"action": "put"}), }], ); let outbound = ws_outbound_from_chat_message(&message); assert_eq!(outbound.len(), 2); assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. })); assert!(matches!(outbound[1], WsOutbound::ToolCall { .. })); } #[test] fn test_ws_outbound_from_chat_message_includes_tool_results() { let message = ChatMessage::tool("call-1", "calculator", "2"); let outbound = ws_outbound_from_chat_message(&message); assert_eq!(outbound.len(), 1); assert!(matches!(outbound[0], WsOutbound::ToolResult { .. })); } #[test] fn test_ws_outbound_from_chat_message_includes_tool_pending() { let message = ChatMessage::tool_with_state( "call-1", "bash", "等待你完成授权后再继续。", ToolMessageState::PendingUserAction, ); let outbound = ws_outbound_from_chat_message(&message); assert_eq!(outbound.len(), 1); assert!(matches!(outbound[0], WsOutbound::ToolPending { .. })); } #[test] fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { let completed = ChatMessage::tool("call-1", "calculator", "2"); let pending = ChatMessage::tool_with_state( "call-2", "bash", "waiting", ToolMessageState::PendingUserAction, ); assert!(!should_display_message_to_user(false, &completed)); assert!(should_display_message_to_user(false, &pending)); assert!(should_display_message_to_user(true, &completed)); } #[test] fn test_resolve_ws_sender_id_prefers_inbound_sender() { assert_eq!( resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42" ); assert_eq!( resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42" ); } #[test] fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() { assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); } #[tokio::test] async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() { let (sender, mut receiver) = mpsc::channel(4); let emitter = WsToolCallEmitter { sender, show_tool_results: false, }; emitter .handle(ChatMessage::tool("call-1", "calculator", "2")) .await; assert!( tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv()) .await .is_err() ); } }