From 35d201f2060d49e06e6d092fd456bf88262756be Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Mon, 6 Apr 2026 16:36:17 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E9=A1=B9=E7=9B=AE=E7=BB=93?= =?UTF-8?q?=E6=9E=84=EF=BC=8C=E6=B7=BB=E5=8A=A0=E4=BB=A3=E7=90=86=E3=80=81?= =?UTF-8?q?=E7=BD=91=E5=85=B3=E5=92=8C=E5=AE=A2=E6=88=B7=E7=AB=AF=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=EF=BC=8C=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE=E4=BB=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=BB=98=E8=AE=A4=E7=BD=91=E5=85=B3=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E5=A2=9E=E5=BC=BA=E9=94=99=E8=AF=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=EF=BC=8C=E6=B7=BB=E5=8A=A0=20WebSocket=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 9 ++- src/agent/agent_loop.rs | 72 ++++++++++++++++++++++++ src/agent/mod.rs | 3 + src/bus/message.rs | 45 +++++++++++++++ src/bus/mod.rs | 44 +++++++++++++++ src/cli/channel.rs | 50 +++++++++++++++++ src/cli/input.rs | 70 +++++++++++++++++++++++ src/cli/mod.rs | 5 ++ src/client/mod.rs | 89 +++++++++++++++++++++++++++++ src/config/mod.rs | 120 +++++++++++++++++++++++++++++++++++++++- src/gateway/http.rs | 15 +++++ src/gateway/mod.rs | 44 +++++++++++++++ src/gateway/session.rs | 67 ++++++++++++++++++++++ src/gateway/ws.rs | 116 ++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 ++ src/main.rs | 79 ++++++++++++++++---------- src/protocol.rs | 37 +++++++++++++ src/providers/openai.rs | 10 +++- 18 files changed, 848 insertions(+), 33 deletions(-) create mode 100644 src/agent/agent_loop.rs create mode 100644 src/agent/mod.rs create mode 100644 src/bus/message.rs create mode 100644 src/bus/mod.rs create mode 100644 src/cli/channel.rs create mode 100644 src/cli/input.rs create mode 100644 src/cli/mod.rs create mode 100644 src/client/mod.rs create mode 100644 src/gateway/http.rs create mode 100644 src/gateway/mod.rs create mode 100644 src/gateway/session.rs create mode 100644 src/gateway/ws.rs create mode 100644 src/protocol.rs diff --git a/Cargo.toml b/Cargo.toml index 91f8c4d..63fd16e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "PicoBot" +name = "picobot" version = "0.1.0" edition = "2024" @@ -7,7 +7,14 @@ edition = "2024" reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } dotenv = "0.15" serde = { version = "1.0", features = ["derive"] } +regex = "1.0" serde_json = "1.0" async-trait = "0.1" thiserror = "1.0" tokio = { version = "1.0", features = ["full"] } +uuid = { version = "1.0", features = ["v4"] } +axum = { version = "0.8", features = ["ws"] } +tokio-tungstenite = "0.26" +futures-util = "0.3" +clap = { version = "4", features = ["derive"] } +dirs = "5" diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs new file mode 100644 index 0000000..fbf60b3 --- /dev/null +++ b/src/agent/agent_loop.rs @@ -0,0 +1,72 @@ +use crate::bus::ChatMessage; +use crate::config::LLMProviderConfig; +use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message}; + +pub struct AgentLoop { + provider: Box, + history: Vec, +} + +impl AgentLoop { + pub fn new(provider_config: LLMProviderConfig) -> Result { + let provider = create_provider(provider_config) + .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + + Ok(Self { + provider, + history: Vec::new(), + }) + } + + pub async fn process(&mut self, user_message: ChatMessage) -> Result { + self.history.push(user_message.clone()); + + let messages: Vec = self.history + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ChatCompletionRequest { + messages, + temperature: None, + max_tokens: None, + tools: None, + }; + + let response = (*self.provider).chat(request).await + .map_err(|e| AgentError::LlmError(e.to_string()))?; + + let assistant_message = ChatMessage::assistant(response.content); + self.history.push(assistant_message.clone()); + + Ok(assistant_message) + } + + pub fn clear_history(&mut self) { + self.history.clear(); + } + + pub fn history(&self) -> &[ChatMessage] { + &self.history + } +} + +#[derive(Debug)] +pub enum AgentError { + ProviderCreation(String), + LlmError(String), +} + +impl std::fmt::Display for AgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e), + AgentError::LlmError(e) => write!(f, "LLM error: {}", e), + } + } +} + +impl std::error::Error for AgentError {} diff --git a/src/agent/mod.rs b/src/agent/mod.rs new file mode 100644 index 0000000..7c84e22 --- /dev/null +++ b/src/agent/mod.rs @@ -0,0 +1,3 @@ +pub mod agent_loop; + +pub use agent_loop::{AgentLoop, AgentError}; diff --git a/src/bus/message.rs b/src/bus/message.rs new file mode 100644 index 0000000..8700c5f --- /dev/null +++ b/src/bus/message.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub id: String, + pub role: String, + pub content: String, + pub timestamp: i64, +} + +impl ChatMessage { + pub fn user(content: impl Into) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + role: "user".to_string(), + content: content.into(), + timestamp: current_timestamp(), + } + } + + pub fn assistant(content: impl Into) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + role: "assistant".to_string(), + content: content.into(), + timestamp: current_timestamp(), + } + } + + pub fn system(content: impl Into) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + role: "system".to_string(), + content: content.into(), + timestamp: current_timestamp(), + } + } +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as i64 +} diff --git a/src/bus/mod.rs b/src/bus/mod.rs new file mode 100644 index 0000000..91eb63d --- /dev/null +++ b/src/bus/mod.rs @@ -0,0 +1,44 @@ +pub mod message; + +pub use message::ChatMessage; + +use tokio::sync::{mpsc, broadcast}; + +pub struct MessageBus { + user_tx: mpsc::Sender, + llm_tx: broadcast::Sender, +} + +impl MessageBus { + pub fn new(buffer_size: usize) -> Self { + let (user_tx, _) = mpsc::channel(buffer_size); + let (llm_tx, _) = broadcast::channel(buffer_size); + + Self { user_tx, llm_tx } + } + + pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> { + self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed) + } + + pub fn send_llm_output(&self, msg: ChatMessage) -> Result { + self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed) + } +} + +#[derive(Debug)] +pub enum BusError { + ChannelClosed, + SendError(usize), +} + +impl std::fmt::Display for BusError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BusError::ChannelClosed => write!(f, "Channel closed"), + BusError::SendError(n) => write!(f, "Send error, {} receivers", n), + } + } +} + +impl std::error::Error for BusError {} diff --git a/src/cli/channel.rs b/src/cli/channel.rs new file mode 100644 index 0000000..56030e1 --- /dev/null +++ b/src/cli/channel.rs @@ -0,0 +1,50 @@ +use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt}; + +pub struct CliChannel { + read: BufReader, + write: tokio::io::Stdout, +} + +impl CliChannel { + pub fn new() -> Self { + Self { + read: BufReader::new(tokio::io::stdin()), + write: tokio::io::stdout(), + } + } + + pub async fn read_line(&mut self, prompt: &str) -> Result, std::io::Error> { + print!("{}", prompt); + self.write.flush().await?; + + let mut line = String::new(); + let bytes_read = self.read.read_line(&mut line).await?; + + if bytes_read == 0 { + return Ok(None); + } + + Ok(Some(line.trim_end().to_string())) + } + + pub async fn write_line(&mut self, content: &str) -> Result<(), std::io::Error> { + self.write.write_all(content.as_bytes()).await?; + self.write.write_all(b"\n").await?; + self.write.flush().await + } + + pub async fn write_response(&mut self, content: &str) -> Result<(), std::io::Error> { + for line in content.lines() { + self.write.write_all(b" ").await?; + self.write.write_all(line.as_bytes()).await?; + self.write.write_all(b"\n").await?; + } + self.write.flush().await + } +} + +impl Default for CliChannel { + fn default() -> Self { + Self::new() + } +} diff --git a/src/cli/input.rs b/src/cli/input.rs new file mode 100644 index 0000000..2e5fa91 --- /dev/null +++ b/src/cli/input.rs @@ -0,0 +1,70 @@ +use crate::bus::ChatMessage; + +use super::channel::CliChannel; + +pub struct InputHandler { + channel: CliChannel, +} + +impl InputHandler { + pub fn new() -> Self { + Self { + channel: CliChannel::new(), + } + } + + pub async fn read_input(&mut self, prompt: &str) -> Result, InputError> { + match self.channel.read_line(prompt).await { + Ok(Some(line)) => { + if line.trim().is_empty() { + return Ok(None); + } + + if let Some(cmd) = self.handle_special_commands(&line) { + return Ok(Some(cmd)); + } + + Ok(Some(ChatMessage::user(line))) + } + Ok(None) => Ok(None), + Err(e) => Err(InputError::IoError(e)), + } + } + + pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> { + self.channel.write_line(content).await.map_err(InputError::IoError) + } + + pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> { + self.channel.write_response(content).await.map_err(InputError::IoError) + } + + fn handle_special_commands(&self, line: &str) -> Option { + match line.trim() { + "/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")), + "/clear" => Some(ChatMessage::system("__CLEAR__")), + _ => None, + } + } +} + +impl Default for InputHandler { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub enum InputError { + IoError(std::io::Error), +} + +impl std::fmt::Display for InputError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InputError::IoError(e) => write!(f, "IO error: {}", e), + } + } +} + +impl std::error::Error for InputError {} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..fe628b9 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,5 @@ +pub mod channel; +pub mod input; + +pub use channel::CliChannel; +pub use input::InputHandler; diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..de792f5 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,89 @@ +pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_outbound}; + +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::{connect_async, tungstenite::Message}; + +use crate::cli::InputHandler; + +fn parse_message(raw: &str) -> Result { + serde_json::from_str(raw) +} + +pub async fn run(gateway_url: &str) -> Result<(), Box> { + let (ws_stream, _) = connect_async(gateway_url).await?; + println!("Connected to gateway"); + + let (mut sender, mut receiver) = ws_stream.split(); + + let mut input = InputHandler::new(); + input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?; + + // Main loop: poll both stdin and WebSocket + loop { + tokio::select! { + // Handle WebSocket messages + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + let text = text.to_string(); + if let Ok(outbound) = parse_message(&text) { + match outbound { + WsOutbound::AssistantResponse { content, .. } => { + input.write_response(&content).await?; + } + WsOutbound::Error { message, .. } => { + input.write_output(&format!("Error: {}", message)).await?; + } + WsOutbound::SessionEstablished { session_id } => { + input.write_output(&format!("Session: {}\n", session_id)).await?; + } + _ => {} + } + } + } + Some(Ok(Message::Close(_))) | None => { + input.write_output("Gateway disconnected").await?; + break; + } + _ => {} + } + } + // Handle stdin input + result = input.read_input("> ") => { + match result { + Ok(Some(msg)) => { + match msg.content.as_str() { + "__EXIT__" => { + input.write_output("Goodbye!").await?; + break; + } + "__CLEAR__" => { + let inbound = WsInbound::ClearHistory; + if let Ok(text) = serialize_inbound(&inbound) { + let _ = sender.send(Message::Text(text.into())).await; + } + continue; + } + _ => {} + } + + let inbound = WsInbound::UserInput { content: msg.content }; + if let Ok(text) = serialize_inbound(&inbound) { + if sender.send(Message::Text(text.into())).await.is_err() { + eprintln!("Failed to send message"); + break; + } + } + } + Ok(None) => break, + Err(e) => { + eprintln!("Input error: {}", e); + break; + } + } + } + } + } + + Ok(()) +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 798d014..822f558 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,12 +1,19 @@ +use regex::Regex; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::env; use std::fs; +use std::path::{Path, PathBuf}; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { pub providers: HashMap, pub models: HashMap, pub agents: HashMap, + #[serde(default)] + pub gateway: GatewayConfig, + #[serde(default)] + pub client: ClientConfig, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -36,6 +43,49 @@ pub struct AgentConfig { pub model: String, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct GatewayConfig { + #[serde(default = "default_gateway_host")] + pub host: String, + #[serde(default = "default_gateway_port")] + pub port: u16, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ClientConfig { + #[serde(default = "default_gateway_url")] + pub gateway_url: String, +} + +fn default_gateway_host() -> String { + "127.0.0.1".to_string() +} + +fn default_gateway_port() -> u16 { + 19876 +} + +fn default_gateway_url() -> String { + "ws://127.0.0.1:19876/ws".to_string() +} + +impl Default for GatewayConfig { + fn default() -> Self { + Self { + host: default_gateway_host(), + port: default_gateway_port(), + } + } +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + gateway_url: default_gateway_url(), + } + } +} + #[derive(Debug, Clone)] pub struct LLMProviderConfig { pub provider_type: String, @@ -49,9 +99,37 @@ pub struct LLMProviderConfig { pub model_extra: HashMap, } +fn get_default_config_path() -> PathBuf { + let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); + home.join(".config").join("picobot").join("config.json") +} + impl Config { pub fn load(path: &str) -> Result> { - let content = fs::read_to_string(path)?; + Self::load_from(Path::new(path)) + } + + pub fn load_default() -> Result> { + let path = get_default_config_path(); + Self::load_from(&path) + } + + fn load_from(path: &Path) -> Result> { + load_env_file()?; + let content = if path.exists() { + fs::read_to_string(path)? + } else { + // Fallback to current directory + let fallback = Path::new("config.json"); + if fallback.exists() { + fs::read_to_string(fallback)? + } else { + return Err(Box::new(ConfigError::ConfigNotFound( + path.to_string_lossy().to_string(), + ))); + } + }; + let content = resolve_env_placeholders(&content); let config: Config = serde_json::from_str(&content)?; Ok(config) } @@ -82,6 +160,7 @@ impl Config { #[derive(Debug)] pub enum ConfigError { + ConfigNotFound(String), AgentNotFound(String), ProviderNotFound(String), ModelNotFound(String), @@ -90,6 +169,7 @@ pub enum ConfigError { impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), @@ -99,6 +179,37 @@ impl std::fmt::Display for ConfigError { impl std::error::Error for ConfigError {} +fn load_env_file() -> Result<(), Box> { + let env_path = Path::new(".env"); + if env_path.exists() { + let content = fs::read_to_string(env_path)?; + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + if let Some((key, value)) = line.split_once('=') { + let key = key.trim(); + let value = value.trim().trim_matches('"').trim_matches('\''); + if !value.is_empty() { + // SAFETY: Setting environment variables for the current process + // is safe as we're only modifying our own process state + unsafe { env::set_var(key, value) }; + } + } + } + } + Ok(()) +} + +fn resolve_env_placeholders(content: &str) -> String { + let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex"); + re.replace_all(content, |caps: ®ex::Captures| { + let var_name = &caps[1]; + env::var(var_name).unwrap_or_else(|_| caps[0].to_string()) + }).to_string() +} + #[cfg(test)] mod tests { use super::*; @@ -129,4 +240,11 @@ mod tests { assert_eq!(provider_config.model_id, "qwen-plus"); assert_eq!(provider_config.temperature, Some(0.0)); } + + #[test] + fn test_default_gateway_config() { + let config = Config::load("config.json").unwrap(); + assert_eq!(config.gateway.host, "0.0.0.0"); + assert_eq!(config.gateway.port, 19876); + } } diff --git a/src/gateway/http.rs b/src/gateway/http.rs new file mode 100644 index 0000000..beed167 --- /dev/null +++ b/src/gateway/http.rs @@ -0,0 +1,15 @@ +use axum::Json; +use serde::Serialize; + +#[derive(Serialize)] +pub struct HealthResponse { + status: String, + version: String, +} + +pub async fn health() -> Json { + Json(HealthResponse { + status: "ok".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }) +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs new file mode 100644 index 0000000..15ce939 --- /dev/null +++ b/src/gateway/mod.rs @@ -0,0 +1,44 @@ +pub mod http; +pub mod session; +pub mod ws; + +use std::sync::Arc; +use axum::{routing, Router}; +use tokio::net::TcpListener; + +use crate::config::Config; +use session::SessionManager; + +pub struct GatewayState { + pub config: Config, + pub session_manager: SessionManager, +} + +impl GatewayState { + pub fn new() -> Result> { + let config = Config::load_default()?; + Ok(Self { + config, + session_manager: SessionManager::new(), + }) + } +} + +pub async fn run(host: Option, port: Option) -> Result<(), Box> { + let state = Arc::new(GatewayState::new()?); + + // CLI args override config file values + let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone()); + let bind_port = port.unwrap_or(state.config.gateway.port); + + let app = Router::new() + .route("/health", routing::get(http::health)) + .route("/ws", routing::get(ws::ws_handler)) + .with_state(state); + + let addr = format!("{}:{}", bind_host, bind_port); + let listener = TcpListener::bind(&addr).await?; + println!("Gateway listening on {}", addr); + axum::serve(listener, app).await?; + Ok(()) +} diff --git a/src/gateway/session.rs b/src/gateway/session.rs new file mode 100644 index 0000000..bfc48a6 --- /dev/null +++ b/src/gateway/session.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; +use tokio::sync::{Mutex, mpsc}; +use uuid::Uuid; +use crate::config::LLMProviderConfig; +use crate::agent::AgentLoop; +use crate::protocol::WsOutbound; + +pub struct Session { + pub id: Uuid, + pub agent_loop: Arc>, + pub user_tx: mpsc::Sender, +} + +impl Session { + pub async fn new( + provider_config: LLMProviderConfig, + user_tx: mpsc::Sender, + ) -> Result { + let agent_loop = AgentLoop::new(provider_config)?; + Ok(Self { + id: Uuid::new_v4(), + agent_loop: Arc::new(Mutex::new(agent_loop)), + user_tx, + }) + } + + pub async fn send(&self, msg: WsOutbound) { + let _ = self.user_tx.send(msg).await; + } +} + +use std::collections::HashMap; +use std::sync::RwLock; + +pub struct SessionManager { + sessions: RwLock>>, +} + +impl SessionManager { + pub fn new() -> Self { + Self { + sessions: RwLock::new(HashMap::new()), + } + } + + pub fn add(&self, session: Arc) { + self.sessions.write().unwrap().insert(session.id, session); + } + + pub fn remove(&self, id: &Uuid) { + self.sessions.write().unwrap().remove(id); + } + + pub fn get(&self, id: &Uuid) -> Option> { + self.sessions.read().unwrap().get(id).cloned() + } + + pub fn len(&self) -> usize { + self.sessions.read().unwrap().len() + } +} + +impl Default for SessionManager { + fn default() -> Self { + Self::new() + } +} diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs new file mode 100644 index 0000000..f0a80a8 --- /dev/null +++ b/src/gateway/ws.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; +use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; +use axum::extract::State; +use axum::response::Response; +use futures_util::{SinkExt, StreamExt}; +use tokio::sync::mpsc; +use crate::bus::ChatMessage; +use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound}; +use super::{GatewayState, session::Session}; + +pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { + ws.on_upgrade(|socket| async { + handle_socket(socket, state).await; + }) +} + +async fn handle_socket(ws: WebSocket, state: Arc) { + let (sender, receiver) = mpsc::channel::(100); + + let provider_config = match state.config.get_provider_config("default") { + Ok(cfg) => cfg, + Err(e) => { + eprintln!("Failed to get provider config: {}", e); + return; + } + }; + + let session = match Session::new(provider_config, sender).await { + Ok(s) => Arc::new(s), + Err(e) => { + eprintln!("Failed to create session: {}", e); + return; + } + }; + + let session_id = session.id; + state.session_manager.add(session.clone()); + + let _ = session.send(WsOutbound::SessionEstablished { + session_id: session_id.to_string(), + }).await; + + let (mut ws_sender, mut ws_receiver) = ws.split(); + + let mut receiver = receiver; + tokio::spawn(async move { + while let Some(msg) = receiver.recv().await { + if let Ok(text) = serialize_outbound(&msg) { + if ws_sender.send(WsMessage::Text(text.into())).await.is_err() { + break; + } + } + } + }); + + while let Some(msg) = ws_receiver.next().await { + match msg { + Ok(WsMessage::Text(text)) => { + let text = text.to_string(); + match parse_inbound(&text) { + Ok(inbound) => { + handle_inbound(&session, inbound).await; + } + Err(e) => { + let _ = session.send(WsOutbound::Error { + code: "PARSE_ERROR".to_string(), + message: e.to_string(), + }).await; + } + } + } + Ok(WsMessage::Close(_)) | Err(_) => { + break; + } + _ => {} + } + } + + state.session_manager.remove(&session_id); +} + +async fn handle_inbound(session: &Arc, inbound: WsInbound) { + match inbound { + WsInbound::UserInput { content } => { + let user_msg = ChatMessage::user(content); + let mut agent = session.agent_loop.lock().await; + match agent.process(user_msg).await { + Ok(response) => { + let _ = session.send(WsOutbound::AssistantResponse { + id: response.id, + content: response.content, + role: response.role, + }).await; + } + Err(e) => { + let _ = session.send(WsOutbound::Error { + code: "LLM_ERROR".to_string(), + message: e.to_string(), + }).await; + } + } + } + WsInbound::ClearHistory => { + let mut agent = session.agent_loop.lock().await; + agent.clear_history(); + let _ = session.send(WsOutbound::AssistantResponse { + id: uuid::Uuid::new_v4().to_string(), + content: "History cleared.".to_string(), + role: "system".to_string(), + }).await; + } + WsInbound::Ping => { + let _ = session.send(WsOutbound::Pong).await; + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 4e2fd8e..76e66e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,8 @@ pub mod config; pub mod providers; +pub mod bus; +pub mod cli; +pub mod agent; +pub mod gateway; +pub mod client; +pub mod protocol; diff --git a/src/main.rs b/src/main.rs index 6f7cadc..ab2871a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,38 +1,57 @@ -mod config; -mod providers; +use clap::{Parser, CommandFactory}; -use config::Config; -use providers::{create_provider, ChatCompletionRequest, Message}; +#[derive(Parser)] +#[command(name = "picobot")] +#[command(about = "A CLI chatbot", long_about = None)] +enum Command { + /// Connect to gateway + Agent { + /// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws) + #[arg(long)] + gateway_url: Option, + }, + /// Start gateway server + Gateway { + /// Host to bind to + #[arg(long)] + host: Option, + /// Port to listen on + #[arg(long)] + port: Option, + }, +} #[tokio::main] -async fn main() { - // Load config - let config = Config::load("config.json").expect("Failed to load config.json"); +async fn main() -> Result<(), Box> { + let mut cmd = Command::command(); - // Get provider config for "default" agent - let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); + // If no arguments, print help + if std::env::args().len() <= 1 { + cmd.print_help()?; + println!(); + return Ok(()); + } - // Create provider - let provider = create_provider(provider_config).expect("Failed to create provider"); - - println!("Provider type: {}", provider.ptype()); - println!("Provider name: {}", provider.name()); - println!("Model ID: {}", provider.model_id()); - - // Create request (no model ID needed - it's baked into the provider) - let request = ChatCompletionRequest { - messages: vec![Message { - role: "user".to_string(), - content: "Hello!".to_string(), - }], - temperature: None, // Will use config default if not provided - max_tokens: None, // Will use config default if not provided - tools: None, + // Load config for defaults + let config = match picobot::config::Config::load_default() { + Ok(cfg) => Some(cfg), + Err(e) => { + eprintln!("Warning: Could not load config from ~/.config/picobot/config.json: {}", e); + eprintln!("Using built-in defaults. Run `picobot gateway` or `picobot agent` with --help for options.\n"); + None + } }; - // Example usage: - // match provider.chat(request).await { - // Ok(resp) => println!("Response: {}", resp.content), - // Err(e) => eprintln!("Error: {}", e), - // } + match Command::parse() { + Command::Agent { gateway_url } => { + let url = gateway_url + .or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone())) + .unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string()); + picobot::client::run(&url).await?; + } + Command::Gateway { host, port } => { + picobot::gateway::run(host, port).await?; + } + } + Ok(()) } diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..4e8a22a --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,37 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum WsInbound { + #[serde(rename = "user_input")] + UserInput { content: String }, + #[serde(rename = "clear_history")] + ClearHistory, + #[serde(rename = "ping")] + Ping, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum WsOutbound { + #[serde(rename = "assistant_response")] + AssistantResponse { id: String, content: String, role: String }, + #[serde(rename = "error")] + Error { code: String, message: String }, + #[serde(rename = "session_established")] + SessionEstablished { session_id: String }, + #[serde(rename = "pong")] + Pong, +} + +pub fn parse_inbound(raw: &str) -> Result { + serde_json::from_str(raw) +} + +pub fn serialize_inbound(msg: &WsInbound) -> Result { + serde_json::to_string(msg) +} + +pub fn serialize_outbound(msg: &WsOutbound) -> Result { + serde_json::to_string(msg) +} diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 043b07b..2367681 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -134,7 +134,15 @@ impl LLMProvider for OpenAIProvider { let resp = req_builder.json(&body).send().await?; - let openai_resp: OpenAIResponse = resp.json().await?; + let status = resp.status(); + let text = resp.text().await?; + + if !status.is_success() { + return Err(format!("API error {}: {}", status, text).into()); + } + + let openai_resp: OpenAIResponse = serde_json::from_str(&text) + .map_err(|e| format!("decode error: {} | body: {}", e, &text))?; let content = openai_resp.choices[0] .message