重构项目结构,添加代理、网关和客户端模块,更新配置以支持默认网关设置,增强错误处理,添加 WebSocket 支持

This commit is contained in:
xiaoxixi 2026-04-06 16:36:17 +08:00
parent 8b1e6e7e06
commit 35d201f206
18 changed files with 848 additions and 33 deletions

View File

@ -1,5 +1,5 @@
[package]
name = "PicoBot"
name = "picobot"
version = "0.1.0"
edition = "2024"
@ -7,7 +7,14 @@ edition = "2024"
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
dotenv = "0.15"
serde = { version = "1.0", features = ["derive"] }
regex = "1.0"
serde_json = "1.0"
async-trait = "0.1"
thiserror = "1.0"
tokio = { version = "1.0", features = ["full"] }
uuid = { version = "1.0", features = ["v4"] }
axum = { version = "0.8", features = ["ws"] }
tokio-tungstenite = "0.26"
futures-util = "0.3"
clap = { version = "4", features = ["derive"] }
dirs = "5"

72
src/agent/agent_loop.rs Normal file
View File

@ -0,0 +1,72 @@
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
history: Vec<ChatMessage>,
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
history: Vec::new(),
})
}
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
self.history.push(user_message.clone());
let messages: Vec<Message> = self.history
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatCompletionRequest {
messages,
temperature: None,
max_tokens: None,
tools: None,
};
let response = (*self.provider).chat(request).await
.map_err(|e| AgentError::LlmError(e.to_string()))?;
let assistant_message = ChatMessage::assistant(response.content);
self.history.push(assistant_message.clone());
Ok(assistant_message)
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn history(&self) -> &[ChatMessage] {
&self.history
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
}
}
}
impl std::error::Error for AgentError {}

3
src/agent/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod agent_loop;
pub use agent_loop::{AgentLoop, AgentError};

45
src/bus/message.rs Normal file
View File

@ -0,0 +1,45 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub id: String,
pub role: String,
pub content: String,
pub timestamp: i64,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(),
content: content.into(),
timestamp: current_timestamp(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
content: content.into(),
timestamp: current_timestamp(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "system".to_string(),
content: content.into(),
timestamp: current_timestamp(),
}
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}

44
src/bus/mod.rs Normal file
View File

@ -0,0 +1,44 @@
pub mod message;
pub use message::ChatMessage;
use tokio::sync::{mpsc, broadcast};
pub struct MessageBus {
user_tx: mpsc::Sender<ChatMessage>,
llm_tx: broadcast::Sender<ChatMessage>,
}
impl MessageBus {
pub fn new(buffer_size: usize) -> Self {
let (user_tx, _) = mpsc::channel(buffer_size);
let (llm_tx, _) = broadcast::channel(buffer_size);
Self { user_tx, llm_tx }
}
pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> {
self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed)
}
pub fn send_llm_output(&self, msg: ChatMessage) -> Result<usize, BusError> {
self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed)
}
}
#[derive(Debug)]
pub enum BusError {
ChannelClosed,
SendError(usize),
}
impl std::fmt::Display for BusError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BusError::ChannelClosed => write!(f, "Channel closed"),
BusError::SendError(n) => write!(f, "Send error, {} receivers", n),
}
}
}
impl std::error::Error for BusError {}

50
src/cli/channel.rs Normal file
View File

@ -0,0 +1,50 @@
use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt};
pub struct CliChannel {
read: BufReader<tokio::io::Stdin>,
write: tokio::io::Stdout,
}
impl CliChannel {
pub fn new() -> Self {
Self {
read: BufReader::new(tokio::io::stdin()),
write: tokio::io::stdout(),
}
}
pub async fn read_line(&mut self, prompt: &str) -> Result<Option<String>, std::io::Error> {
print!("{}", prompt);
self.write.flush().await?;
let mut line = String::new();
let bytes_read = self.read.read_line(&mut line).await?;
if bytes_read == 0 {
return Ok(None);
}
Ok(Some(line.trim_end().to_string()))
}
pub async fn write_line(&mut self, content: &str) -> Result<(), std::io::Error> {
self.write.write_all(content.as_bytes()).await?;
self.write.write_all(b"\n").await?;
self.write.flush().await
}
pub async fn write_response(&mut self, content: &str) -> Result<(), std::io::Error> {
for line in content.lines() {
self.write.write_all(b" ").await?;
self.write.write_all(line.as_bytes()).await?;
self.write.write_all(b"\n").await?;
}
self.write.flush().await
}
}
impl Default for CliChannel {
fn default() -> Self {
Self::new()
}
}

70
src/cli/input.rs Normal file
View File

@ -0,0 +1,70 @@
use crate::bus::ChatMessage;
use super::channel::CliChannel;
pub struct InputHandler {
channel: CliChannel,
}
impl InputHandler {
pub fn new() -> Self {
Self {
channel: CliChannel::new(),
}
}
pub async fn read_input(&mut self, prompt: &str) -> Result<Option<ChatMessage>, InputError> {
match self.channel.read_line(prompt).await {
Ok(Some(line)) => {
if line.trim().is_empty() {
return Ok(None);
}
if let Some(cmd) = self.handle_special_commands(&line) {
return Ok(Some(cmd));
}
Ok(Some(ChatMessage::user(line)))
}
Ok(None) => Ok(None),
Err(e) => Err(InputError::IoError(e)),
}
}
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
self.channel.write_line(content).await.map_err(InputError::IoError)
}
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
self.channel.write_response(content).await.map_err(InputError::IoError)
}
fn handle_special_commands(&self, line: &str) -> Option<ChatMessage> {
match line.trim() {
"/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")),
"/clear" => Some(ChatMessage::system("__CLEAR__")),
_ => None,
}
}
}
impl Default for InputHandler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum InputError {
IoError(std::io::Error),
}
impl std::fmt::Display for InputError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InputError::IoError(e) => write!(f, "IO error: {}", e),
}
}
}
impl std::error::Error for InputError {}

5
src/cli/mod.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod channel;
pub mod input;
pub use channel::CliChannel;
pub use input::InputHandler;

89
src/client/mod.rs Normal file
View File

@ -0,0 +1,89 @@
pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_outbound};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::cli::InputHandler;
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
serde_json::from_str(raw)
}
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let (ws_stream, _) = connect_async(gateway_url).await?;
println!("Connected to gateway");
let (mut sender, mut receiver) = ws_stream.split();
let mut input = InputHandler::new();
input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?;
// Main loop: poll both stdin and WebSocket
loop {
tokio::select! {
// Handle WebSocket messages
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let text = text.to_string();
if let Ok(outbound) = parse_message(&text) {
match outbound {
WsOutbound::AssistantResponse { content, .. } => {
input.write_response(&content).await?;
}
WsOutbound::Error { message, .. } => {
input.write_output(&format!("Error: {}", message)).await?;
}
WsOutbound::SessionEstablished { session_id } => {
input.write_output(&format!("Session: {}\n", session_id)).await?;
}
_ => {}
}
}
}
Some(Ok(Message::Close(_))) | None => {
input.write_output("Gateway disconnected").await?;
break;
}
_ => {}
}
}
// Handle stdin input
result = input.read_input("> ") => {
match result {
Ok(Some(msg)) => {
match msg.content.as_str() {
"__EXIT__" => {
input.write_output("Goodbye!").await?;
break;
}
"__CLEAR__" => {
let inbound = WsInbound::ClearHistory;
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
_ => {}
}
let inbound = WsInbound::UserInput { content: msg.content };
if let Ok(text) = serialize_inbound(&inbound) {
if sender.send(Message::Text(text.into())).await.is_err() {
eprintln!("Failed to send message");
break;
}
}
}
Ok(None) => break,
Err(e) => {
eprintln!("Input error: {}", e);
break;
}
}
}
}
}
Ok(())
}

View File

@ -1,12 +1,19 @@
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub providers: HashMap<String, ProviderConfig>,
pub models: HashMap<String, ModelConfig>,
pub agents: HashMap<String, AgentConfig>,
#[serde(default)]
pub gateway: GatewayConfig,
#[serde(default)]
pub client: ClientConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -36,6 +43,49 @@ pub struct AgentConfig {
pub model: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig {
#[serde(default = "default_gateway_host")]
pub host: String,
#[serde(default = "default_gateway_port")]
pub port: u16,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ClientConfig {
#[serde(default = "default_gateway_url")]
pub gateway_url: String,
}
fn default_gateway_host() -> String {
"127.0.0.1".to_string()
}
fn default_gateway_port() -> u16 {
19876
}
fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string()
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
gateway_url: default_gateway_url(),
}
}
}
#[derive(Debug, Clone)]
pub struct LLMProviderConfig {
pub provider_type: String,
@ -49,9 +99,37 @@ pub struct LLMProviderConfig {
pub model_extra: HashMap<String, serde_json::Value>,
}
fn get_default_config_path() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".config").join("picobot").join("config.json")
}
impl Config {
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = fs::read_to_string(path)?;
Self::load_from(Path::new(path))
}
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
let path = get_default_config_path();
Self::load_from(&path)
}
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
load_env_file()?;
let content = if path.exists() {
fs::read_to_string(path)?
} else {
// Fallback to current directory
let fallback = Path::new("config.json");
if fallback.exists() {
fs::read_to_string(fallback)?
} else {
return Err(Box::new(ConfigError::ConfigNotFound(
path.to_string_lossy().to_string(),
)));
}
};
let content = resolve_env_placeholders(&content);
let config: Config = serde_json::from_str(&content)?;
Ok(config)
}
@ -82,6 +160,7 @@ impl Config {
#[derive(Debug)]
pub enum ConfigError {
ConfigNotFound(String),
AgentNotFound(String),
ProviderNotFound(String),
ModelNotFound(String),
@ -90,6 +169,7 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
@ -99,6 +179,37 @@ impl std::fmt::Display for ConfigError {
impl std::error::Error for ConfigError {}
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
let env_path = Path::new(".env");
if env_path.exists() {
let content = fs::read_to_string(env_path)?;
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = value.trim().trim_matches('"').trim_matches('\'');
if !value.is_empty() {
// SAFETY: Setting environment variables for the current process
// is safe as we're only modifying our own process state
unsafe { env::set_var(key, value) };
}
}
}
}
Ok(())
}
fn resolve_env_placeholders(content: &str) -> String {
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
re.replace_all(content, |caps: &regex::Captures| {
let var_name = &caps[1];
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
}).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
@ -129,4 +240,11 @@ mod tests {
assert_eq!(provider_config.model_id, "qwen-plus");
assert_eq!(provider_config.temperature, Some(0.0));
}
#[test]
fn test_default_gateway_config() {
let config = Config::load("config.json").unwrap();
assert_eq!(config.gateway.host, "0.0.0.0");
assert_eq!(config.gateway.port, 19876);
}
}

15
src/gateway/http.rs Normal file
View File

@ -0,0 +1,15 @@
use axum::Json;
use serde::Serialize;
#[derive(Serialize)]
pub struct HealthResponse {
status: String,
version: String,
}
pub async fn health() -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
})
}

44
src/gateway/mod.rs Normal file
View File

@ -0,0 +1,44 @@
pub mod http;
pub mod session;
pub mod ws;
use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::config::Config;
use session::SessionManager;
pub struct GatewayState {
pub config: Config,
pub session_manager: SessionManager,
}
impl GatewayState {
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
let config = Config::load_default()?;
Ok(Self {
config,
session_manager: SessionManager::new(),
})
}
}
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
let state = Arc::new(GatewayState::new()?);
// CLI args override config file values
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
let bind_port = port.unwrap_or(state.config.gateway.port);
let app = Router::new()
.route("/health", routing::get(http::health))
.route("/ws", routing::get(ws::ws_handler))
.with_state(state);
let addr = format!("{}:{}", bind_host, bind_port);
let listener = TcpListener::bind(&addr).await?;
println!("Gateway listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

67
src/gateway/session.rs Normal file
View File

@ -0,0 +1,67 @@
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::config::LLMProviderConfig;
use crate::agent::AgentLoop;
use crate::protocol::WsOutbound;
pub struct Session {
pub id: Uuid,
pub agent_loop: Arc<Mutex<AgentLoop>>,
pub user_tx: mpsc::Sender<WsOutbound>,
}
impl Session {
pub async fn new(
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
) -> Result<Self, crate::agent::AgentError> {
let agent_loop = AgentLoop::new(provider_config)?;
Ok(Self {
id: Uuid::new_v4(),
agent_loop: Arc::new(Mutex::new(agent_loop)),
user_tx,
})
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
}
use std::collections::HashMap;
use std::sync::RwLock;
pub struct SessionManager {
sessions: RwLock<HashMap<Uuid, Arc<Session>>>,
}
impl SessionManager {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
pub fn add(&self, session: Arc<Session>) {
self.sessions.write().unwrap().insert(session.id, session);
}
pub fn remove(&self, id: &Uuid) {
self.sessions.write().unwrap().remove(id);
}
pub fn get(&self, id: &Uuid) -> Option<Arc<Session>> {
self.sessions.read().unwrap().get(id).cloned()
}
pub fn len(&self) -> usize {
self.sessions.read().unwrap().len()
}
}
impl Default for SessionManager {
fn default() -> Self {
Self::new()
}
}

116
src/gateway/ws.rs Normal file
View File

@ -0,0 +1,116 @@
use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
use super::{GatewayState, session::Session};
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
eprintln!("Failed to get provider config: {}", e);
return;
}
};
let session = match Session::new(provider_config, sender).await {
Ok(s) => Arc::new(s),
Err(e) => {
eprintln!("Failed to create session: {}", e);
return;
}
};
let session_id = session.id;
state.session_manager.add(session.clone());
let _ = session.send(WsOutbound::SessionEstablished {
session_id: session_id.to_string(),
}).await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
handle_inbound(&session, inbound).await;
}
Err(e) => {
let _ = session.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
break;
}
_ => {}
}
}
state.session_manager.remove(&session_id);
}
async fn handle_inbound(session: &Arc<Session>, inbound: WsInbound) {
match inbound {
WsInbound::UserInput { content } => {
let user_msg = ChatMessage::user(content);
let mut agent = session.agent_loop.lock().await;
match agent.process(user_msg).await {
Ok(response) => {
let _ = session.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
}).await;
}
Err(e) => {
let _ = session.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}
WsInbound::ClearHistory => {
let mut agent = session.agent_loop.lock().await;
agent.clear_history();
let _ = session.send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: "History cleared.".to_string(),
role: "system".to_string(),
}).await;
}
WsInbound::Ping => {
let _ = session.send(WsOutbound::Pong).await;
}
}
}

View File

@ -1,2 +1,8 @@
pub mod config;
pub mod providers;
pub mod bus;
pub mod cli;
pub mod agent;
pub mod gateway;
pub mod client;
pub mod protocol;

View File

@ -1,38 +1,57 @@
mod config;
mod providers;
use clap::{Parser, CommandFactory};
use config::Config;
use providers::{create_provider, ChatCompletionRequest, Message};
#[derive(Parser)]
#[command(name = "picobot")]
#[command(about = "A CLI chatbot", long_about = None)]
enum Command {
/// Connect to gateway
Agent {
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
#[arg(long)]
gateway_url: Option<String>,
},
/// Start gateway server
Gateway {
/// Host to bind to
#[arg(long)]
host: Option<String>,
/// Port to listen on
#[arg(long)]
port: Option<u16>,
},
}
#[tokio::main]
async fn main() {
// Load config
let config = Config::load("config.json").expect("Failed to load config.json");
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut cmd = Command::command();
// Get provider config for "default" agent
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
// If no arguments, print help
if std::env::args().len() <= 1 {
cmd.print_help()?;
println!();
return Ok(());
}
// Create provider
let provider = create_provider(provider_config).expect("Failed to create provider");
println!("Provider type: {}", provider.ptype());
println!("Provider name: {}", provider.name());
println!("Model ID: {}", provider.model_id());
// Create request (no model ID needed - it's baked into the provider)
let request = ChatCompletionRequest {
messages: vec![Message {
role: "user".to_string(),
content: "Hello!".to_string(),
}],
temperature: None, // Will use config default if not provided
max_tokens: None, // Will use config default if not provided
tools: None,
// Load config for defaults
let config = match picobot::config::Config::load_default() {
Ok(cfg) => Some(cfg),
Err(e) => {
eprintln!("Warning: Could not load config from ~/.config/picobot/config.json: {}", e);
eprintln!("Using built-in defaults. Run `picobot gateway` or `picobot agent` with --help for options.\n");
None
}
};
// Example usage:
// match provider.chat(request).await {
// Ok(resp) => println!("Response: {}", resp.content),
// Err(e) => eprintln!("Error: {}", e),
// }
match Command::parse() {
Command::Agent { gateway_url } => {
let url = gateway_url
.or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone()))
.unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string());
picobot::client::run(&url).await?;
}
Command::Gateway { host, port } => {
picobot::gateway::run(host, port).await?;
}
}
Ok(())
}

37
src/protocol.rs Normal file
View File

@ -0,0 +1,37 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WsInbound {
#[serde(rename = "user_input")]
UserInput { content: String },
#[serde(rename = "clear_history")]
ClearHistory,
#[serde(rename = "ping")]
Ping,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WsOutbound {
#[serde(rename = "assistant_response")]
AssistantResponse { id: String, content: String, role: String },
#[serde(rename = "error")]
Error { code: String, message: String },
#[serde(rename = "session_established")]
SessionEstablished { session_id: String },
#[serde(rename = "pong")]
Pong,
}
pub fn parse_inbound(raw: &str) -> Result<WsInbound, serde_json::Error> {
serde_json::from_str(raw)
}
pub fn serialize_inbound(msg: &WsInbound) -> Result<String, serde_json::Error> {
serde_json::to_string(msg)
}
pub fn serialize_outbound(msg: &WsOutbound) -> Result<String, serde_json::Error> {
serde_json::to_string(msg)
}

View File

@ -134,7 +134,15 @@ impl LLMProvider for OpenAIProvider {
let resp = req_builder.json(&body).send().await?;
let openai_resp: OpenAIResponse = resp.json().await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(format!("API error {}: {}", status, text).into());
}
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
let content = openai_resp.choices[0]
.message