158 lines
4.1 KiB
Rust
158 lines
4.1 KiB
Rust
use async_trait::async_trait;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::{RwLock, mpsc};
|
|
|
|
use crate::bus::{MessageBus, OutboundMessage};
|
|
use crate::protocol::WsOutbound;
|
|
use crate::protocol::ws_adapter::ws_outbound_from_outbound_message;
|
|
|
|
use super::base::{Channel, ChannelError};
|
|
|
|
#[derive(Clone)]
|
|
struct CliConnection {
|
|
connection_id: String,
|
|
sender: mpsc::Sender<WsOutbound>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct CliChannel {
|
|
connections: Arc<RwLock<HashMap<String, CliConnection>>>,
|
|
}
|
|
|
|
impl CliChannel {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
connections: Arc::new(RwLock::new(HashMap::new())),
|
|
}
|
|
}
|
|
|
|
pub async fn register_connection(
|
|
&self,
|
|
session_id: impl Into<String>,
|
|
connection_id: impl Into<String>,
|
|
sender: mpsc::Sender<WsOutbound>,
|
|
) {
|
|
let session_id = session_id.into();
|
|
let connection_id = connection_id.into();
|
|
let previous = self.connections.write().await.insert(
|
|
session_id.clone(),
|
|
CliConnection {
|
|
connection_id: connection_id.clone(),
|
|
sender,
|
|
},
|
|
);
|
|
|
|
if previous.is_some() {
|
|
tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced");
|
|
}
|
|
}
|
|
|
|
pub async fn unregister_connection(&self, connection_id: &str) {
|
|
self.connections
|
|
.write()
|
|
.await
|
|
.retain(|_, connection| connection.connection_id != connection_id);
|
|
}
|
|
}
|
|
|
|
impl Default for CliChannel {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Channel for CliChannel {
|
|
fn name(&self) -> &str {
|
|
"cli"
|
|
}
|
|
|
|
fn is_running(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
async fn start(&self, _bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn stop(&self) -> Result<(), ChannelError> {
|
|
self.connections.write().await.clear();
|
|
Ok(())
|
|
}
|
|
|
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
|
let connection = self.connections.read().await.get(&msg.chat_id).cloned();
|
|
let Some(connection) = connection else {
|
|
return Err(ChannelError::SendError(format!(
|
|
"No active CLI connection for session {}",
|
|
msg.chat_id
|
|
)));
|
|
};
|
|
|
|
for outbound in ws_outbound_from_outbound_message(&msg) {
|
|
connection
|
|
.sender
|
|
.send(outbound)
|
|
.await
|
|
.map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::bus::OutboundMessage;
|
|
|
|
#[tokio::test]
|
|
async fn test_cli_channel_sends_to_registered_session() {
|
|
let channel = CliChannel::new();
|
|
let (sender, mut receiver) = mpsc::channel(4);
|
|
channel
|
|
.register_connection("session-1", "conn-1", sender)
|
|
.await;
|
|
|
|
channel
|
|
.send(OutboundMessage::assistant(
|
|
"cli",
|
|
"session-1",
|
|
None, // session_id
|
|
"hello",
|
|
None,
|
|
HashMap::new(),
|
|
))
|
|
.await
|
|
.unwrap();
|
|
|
|
let outbound = receiver.recv().await.unwrap();
|
|
assert!(matches!(outbound, WsOutbound::AssistantResponse { .. }));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_cli_channel_unregisters_connection_sessions() {
|
|
let channel = CliChannel::new();
|
|
let (sender, _receiver) = mpsc::channel(4);
|
|
channel
|
|
.register_connection("session-1", "conn-1", sender)
|
|
.await;
|
|
channel.unregister_connection("conn-1").await;
|
|
|
|
let error = channel
|
|
.send(OutboundMessage::assistant(
|
|
"cli",
|
|
"session-1",
|
|
None, // session_id
|
|
"hello",
|
|
None,
|
|
HashMap::new(),
|
|
))
|
|
.await
|
|
.unwrap_err();
|
|
|
|
assert!(error.to_string().contains("No active CLI connection"));
|
|
}
|
|
}
|