PicoBot/src/client/mod.rs

97 lines
3.9 KiB
Rust

pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_outbound};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::cli::InputHandler;
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
serde_json::from_str(raw)
}
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let (ws_stream, _) = connect_async(gateway_url).await?;
tracing::info!(url = %gateway_url, "Connected to gateway");
let (mut sender, mut receiver) = ws_stream.split();
let mut input = InputHandler::new();
input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?;
// Main loop: poll both stdin and WebSocket
loop {
tokio::select! {
// Handle WebSocket messages
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let text = text.to_string();
if let Ok(outbound) = parse_message(&text) {
match outbound {
WsOutbound::AssistantResponse { content, .. } => {
input.write_response(&content).await?;
}
WsOutbound::Error { message, .. } => {
input.write_output(&format!("Error: {}", message)).await?;
}
WsOutbound::SessionEstablished { session_id } => {
tracing::debug!(session_id = %session_id, "Session established");
input.write_output(&format!("Session: {}\n", session_id)).await?;
}
_ => {}
}
}
}
Some(Ok(Message::Close(_))) | None => {
tracing::info!("Gateway disconnected");
input.write_output("Gateway disconnected").await?;
break;
}
_ => {}
}
}
// Handle stdin input
result = input.read_input("> ") => {
match result {
Ok(Some(msg)) => {
match msg.content.as_str() {
"__EXIT__" => {
input.write_output("Goodbye!").await?;
break;
}
"__CLEAR__" => {
let inbound = WsInbound::ClearHistory { chat_id: None };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
_ => {}
}
let inbound = WsInbound::UserInput {
content: msg.content,
channel: None,
chat_id: None,
sender_id: None,
};
if let Ok(text) = serialize_inbound(&inbound) {
if sender.send(Message::Text(text.into())).await.is_err() {
tracing::error!("Failed to send message to gateway");
break;
}
}
}
Ok(None) => break,
Err(e) => {
tracing::error!(error = %e, "Input error");
break;
}
}
}
}
}
Ok(())
}