158 lines
4.1 KiB
Rust

use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use crate::bus::{MessageBus, OutboundMessage};
use crate::protocol::WsOutbound;
use crate::protocol::ws_adapter::ws_outbound_from_outbound_message;
use super::base::{Channel, ChannelError};
#[derive(Clone)]
struct CliConnection {
connection_id: String,
sender: mpsc::Sender<WsOutbound>,
}
#[derive(Clone)]
pub struct CliChannel {
connections: Arc<RwLock<HashMap<String, CliConnection>>>,
}
impl CliChannel {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_connection(
&self,
session_id: impl Into<String>,
connection_id: impl Into<String>,
sender: mpsc::Sender<WsOutbound>,
) {
let session_id = session_id.into();
let connection_id = connection_id.into();
let previous = self.connections.write().await.insert(
session_id.clone(),
CliConnection {
connection_id: connection_id.clone(),
sender,
},
);
if previous.is_some() {
tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced");
}
}
pub async fn unregister_connection(&self, connection_id: &str) {
self.connections
.write()
.await
.retain(|_, connection| connection.connection_id != connection_id);
}
}
impl Default for CliChannel {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Channel for CliChannel {
fn name(&self) -> &str {
"cli"
}
fn is_running(&self) -> bool {
true
}
async fn start(&self, _bus: Arc<MessageBus>) -> Result<(), ChannelError> {
Ok(())
}
async fn stop(&self) -> Result<(), ChannelError> {
self.connections.write().await.clear();
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let connection = self.connections.read().await.get(&msg.chat_id).cloned();
let Some(connection) = connection else {
return Err(ChannelError::SendError(format!(
"No active CLI connection for session {}",
msg.chat_id
)));
};
for outbound in ws_outbound_from_outbound_message(&msg) {
connection
.sender
.send(outbound)
.await
.map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::OutboundMessage;
#[tokio::test]
async fn test_cli_channel_sends_to_registered_session() {
let channel = CliChannel::new();
let (sender, mut receiver) = mpsc::channel(4);
channel
.register_connection("session-1", "conn-1", sender)
.await;
channel
.send(OutboundMessage::assistant(
"cli",
"session-1",
None, // session_id
"hello",
None,
HashMap::new(),
))
.await
.unwrap();
let outbound = receiver.recv().await.unwrap();
assert!(matches!(outbound, WsOutbound::AssistantResponse { .. }));
}
#[tokio::test]
async fn test_cli_channel_unregisters_connection_sessions() {
let channel = CliChannel::new();
let (sender, _receiver) = mpsc::channel(4);
channel
.register_connection("session-1", "conn-1", sender)
.await;
channel.unregister_connection("conn-1").await;
let error = channel
.send(OutboundMessage::assistant(
"cli",
"session-1",
None, // session_id
"hello",
None,
HashMap::new(),
))
.await
.unwrap_err();
assert!(error.to_string().contains("No active CLI connection"));
}
}