use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{RwLock, mpsc}; use crate::bus::{MessageBus, OutboundMessage}; use crate::protocol::WsOutbound; use crate::protocol::ws_adapter::ws_outbound_from_outbound_message; use super::base::{Channel, ChannelError}; #[derive(Clone)] struct CliConnection { connection_id: String, sender: mpsc::Sender, } #[derive(Clone)] pub struct CliChannel { connections: Arc>>, } impl CliChannel { pub fn new() -> Self { Self { connections: Arc::new(RwLock::new(HashMap::new())), } } pub async fn register_connection( &self, session_id: impl Into, connection_id: impl Into, sender: mpsc::Sender, ) { let session_id = session_id.into(); let connection_id = connection_id.into(); let previous = self.connections.write().await.insert( session_id.clone(), CliConnection { connection_id: connection_id.clone(), sender, }, ); if previous.is_some() { tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced"); } } pub async fn unregister_connection(&self, connection_id: &str) { self.connections .write() .await .retain(|_, connection| connection.connection_id != connection_id); } } impl Default for CliChannel { fn default() -> Self { Self::new() } } #[async_trait] impl Channel for CliChannel { fn name(&self) -> &str { "cli" } fn is_running(&self) -> bool { true } async fn start(&self, _bus: Arc) -> Result<(), ChannelError> { Ok(()) } async fn stop(&self) -> Result<(), ChannelError> { self.connections.write().await.clear(); Ok(()) } async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { let connection = self.connections.read().await.get(&msg.chat_id).cloned(); let Some(connection) = connection else { return Err(ChannelError::SendError(format!( "No active CLI connection for session {}", msg.chat_id ))); }; for outbound in ws_outbound_from_outbound_message(&msg) { connection .sender .send(outbound) .await .map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?; } Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::bus::OutboundMessage; #[tokio::test] async fn test_cli_channel_sends_to_registered_session() { let channel = CliChannel::new(); let (sender, mut receiver) = mpsc::channel(4); channel .register_connection("session-1", "conn-1", sender) .await; channel .send(OutboundMessage::assistant( "cli", "session-1", None, // session_id "hello", None, HashMap::new(), )) .await .unwrap(); let outbound = receiver.recv().await.unwrap(); assert!(matches!(outbound, WsOutbound::AssistantResponse { .. })); } #[tokio::test] async fn test_cli_channel_unregisters_connection_sessions() { let channel = CliChannel::new(); let (sender, _receiver) = mpsc::channel(4); channel .register_connection("session-1", "conn-1", sender) .await; channel.unregister_connection("conn-1").await; let error = channel .send(OutboundMessage::assistant( "cli", "session-1", None, // session_id "hello", None, HashMap::new(), )) .await .unwrap_err(); assert!(error.to_string().contains("No active CLI connection")); } }