重构项目结构,添加代理、网关和客户端模块,更新配置以支持默认网关设置,增强错误处理,添加 WebSocket 支持
This commit is contained in:
parent
8b1e6e7e06
commit
35d201f206
@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "PicoBot"
|
||||
name = "picobot"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
@ -7,7 +7,14 @@ edition = "2024"
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||
dotenv = "0.15"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
regex = "1.0"
|
||||
serde_json = "1.0"
|
||||
async-trait = "0.1"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
axum = { version = "0.8", features = ["ws"] }
|
||||
tokio-tungstenite = "0.26"
|
||||
futures-util = "0.3"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
dirs = "5"
|
||||
|
||||
72
src/agent/agent_loop.rs
Normal file
72
src/agent/agent_loop.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
||||
|
||||
pub struct AgentLoop {
|
||||
provider: Box<dyn LLMProvider>,
|
||||
history: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||
let provider = create_provider(provider_config)
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
history: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
|
||||
self.history.push(user_message.clone());
|
||||
|
||||
let messages: Vec<Message> = self.history
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = (*self.provider).chat(request).await
|
||||
.map_err(|e| AgentError::LlmError(e.to_string()))?;
|
||||
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
self.history.push(assistant_message.clone());
|
||||
|
||||
Ok(assistant_message)
|
||||
}
|
||||
|
||||
pub fn clear_history(&mut self) {
|
||||
self.history.clear();
|
||||
}
|
||||
|
||||
pub fn history(&self) -> &[ChatMessage] {
|
||||
&self.history
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AgentError {
|
||||
ProviderCreation(String),
|
||||
LlmError(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AgentError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
||||
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AgentError {}
|
||||
3
src/agent/mod.rs
Normal file
3
src/agent/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod agent_loop;
|
||||
|
||||
pub use agent_loop::{AgentLoop, AgentError};
|
||||
45
src/bus/message.rs
Normal file
45
src/bus/message.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub id: String,
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.into(),
|
||||
timestamp: current_timestamp(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content.into(),
|
||||
timestamp: current_timestamp(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
content: content.into(),
|
||||
timestamp: current_timestamp(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as i64
|
||||
}
|
||||
44
src/bus/mod.rs
Normal file
44
src/bus/mod.rs
Normal file
@ -0,0 +1,44 @@
|
||||
pub mod message;
|
||||
|
||||
pub use message::ChatMessage;
|
||||
|
||||
use tokio::sync::{mpsc, broadcast};
|
||||
|
||||
pub struct MessageBus {
|
||||
user_tx: mpsc::Sender<ChatMessage>,
|
||||
llm_tx: broadcast::Sender<ChatMessage>,
|
||||
}
|
||||
|
||||
impl MessageBus {
|
||||
pub fn new(buffer_size: usize) -> Self {
|
||||
let (user_tx, _) = mpsc::channel(buffer_size);
|
||||
let (llm_tx, _) = broadcast::channel(buffer_size);
|
||||
|
||||
Self { user_tx, llm_tx }
|
||||
}
|
||||
|
||||
pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> {
|
||||
self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed)
|
||||
}
|
||||
|
||||
pub fn send_llm_output(&self, msg: ChatMessage) -> Result<usize, BusError> {
|
||||
self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BusError {
|
||||
ChannelClosed,
|
||||
SendError(usize),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BusError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
BusError::ChannelClosed => write!(f, "Channel closed"),
|
||||
BusError::SendError(n) => write!(f, "Send error, {} receivers", n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for BusError {}
|
||||
50
src/cli/channel.rs
Normal file
50
src/cli/channel.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt};
|
||||
|
||||
pub struct CliChannel {
|
||||
read: BufReader<tokio::io::Stdin>,
|
||||
write: tokio::io::Stdout,
|
||||
}
|
||||
|
||||
impl CliChannel {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
read: BufReader::new(tokio::io::stdin()),
|
||||
write: tokio::io::stdout(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_line(&mut self, prompt: &str) -> Result<Option<String>, std::io::Error> {
|
||||
print!("{}", prompt);
|
||||
self.write.flush().await?;
|
||||
|
||||
let mut line = String::new();
|
||||
let bytes_read = self.read.read_line(&mut line).await?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(line.trim_end().to_string()))
|
||||
}
|
||||
|
||||
pub async fn write_line(&mut self, content: &str) -> Result<(), std::io::Error> {
|
||||
self.write.write_all(content.as_bytes()).await?;
|
||||
self.write.write_all(b"\n").await?;
|
||||
self.write.flush().await
|
||||
}
|
||||
|
||||
pub async fn write_response(&mut self, content: &str) -> Result<(), std::io::Error> {
|
||||
for line in content.lines() {
|
||||
self.write.write_all(b" ").await?;
|
||||
self.write.write_all(line.as_bytes()).await?;
|
||||
self.write.write_all(b"\n").await?;
|
||||
}
|
||||
self.write.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CliChannel {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
70
src/cli/input.rs
Normal file
70
src/cli/input.rs
Normal file
@ -0,0 +1,70 @@
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
use super::channel::CliChannel;
|
||||
|
||||
pub struct InputHandler {
|
||||
channel: CliChannel,
|
||||
}
|
||||
|
||||
impl InputHandler {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
channel: CliChannel::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_input(&mut self, prompt: &str) -> Result<Option<ChatMessage>, InputError> {
|
||||
match self.channel.read_line(prompt).await {
|
||||
Ok(Some(line)) => {
|
||||
if line.trim().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(cmd) = self.handle_special_commands(&line) {
|
||||
return Ok(Some(cmd));
|
||||
}
|
||||
|
||||
Ok(Some(ChatMessage::user(line)))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(InputError::IoError(e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
|
||||
self.channel.write_line(content).await.map_err(InputError::IoError)
|
||||
}
|
||||
|
||||
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
|
||||
self.channel.write_response(content).await.map_err(InputError::IoError)
|
||||
}
|
||||
|
||||
fn handle_special_commands(&self, line: &str) -> Option<ChatMessage> {
|
||||
match line.trim() {
|
||||
"/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")),
|
||||
"/clear" => Some(ChatMessage::system("__CLEAR__")),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InputHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InputError {
|
||||
IoError(std::io::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for InputError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
InputError::IoError(e) => write!(f, "IO error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for InputError {}
|
||||
5
src/cli/mod.rs
Normal file
5
src/cli/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub mod channel;
|
||||
pub mod input;
|
||||
|
||||
pub use channel::CliChannel;
|
||||
pub use input::InputHandler;
|
||||
89
src/client/mod.rs
Normal file
89
src/client/mod.rs
Normal file
@ -0,0 +1,89 @@
|
||||
pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_outbound};
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
use crate::cli::InputHandler;
|
||||
|
||||
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
||||
serde_json::from_str(raw)
|
||||
}
|
||||
|
||||
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (ws_stream, _) = connect_async(gateway_url).await?;
|
||||
println!("Connected to gateway");
|
||||
|
||||
let (mut sender, mut receiver) = ws_stream.split();
|
||||
|
||||
let mut input = InputHandler::new();
|
||||
input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?;
|
||||
|
||||
// Main loop: poll both stdin and WebSocket
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Handle WebSocket messages
|
||||
msg = receiver.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
let text = text.to_string();
|
||||
if let Ok(outbound) = parse_message(&text) {
|
||||
match outbound {
|
||||
WsOutbound::AssistantResponse { content, .. } => {
|
||||
input.write_response(&content).await?;
|
||||
}
|
||||
WsOutbound::Error { message, .. } => {
|
||||
input.write_output(&format!("Error: {}", message)).await?;
|
||||
}
|
||||
WsOutbound::SessionEstablished { session_id } => {
|
||||
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
input.write_output("Gateway disconnected").await?;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Handle stdin input
|
||||
result = input.read_input("> ") => {
|
||||
match result {
|
||||
Ok(Some(msg)) => {
|
||||
match msg.content.as_str() {
|
||||
"__EXIT__" => {
|
||||
input.write_output("Goodbye!").await?;
|
||||
break;
|
||||
}
|
||||
"__CLEAR__" => {
|
||||
let inbound = WsInbound::ClearHistory;
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let inbound = WsInbound::UserInput { content: msg.content };
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
if sender.send(Message::Text(text.into())).await.is_err() {
|
||||
eprintln!("Failed to send message");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(e) => {
|
||||
eprintln!("Input error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -1,12 +1,19 @@
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Config {
|
||||
pub providers: HashMap<String, ProviderConfig>,
|
||||
pub models: HashMap<String, ModelConfig>,
|
||||
pub agents: HashMap<String, AgentConfig>,
|
||||
#[serde(default)]
|
||||
pub gateway: GatewayConfig,
|
||||
#[serde(default)]
|
||||
pub client: ClientConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -36,6 +43,49 @@ pub struct AgentConfig {
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GatewayConfig {
|
||||
#[serde(default = "default_gateway_host")]
|
||||
pub host: String,
|
||||
#[serde(default = "default_gateway_port")]
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ClientConfig {
|
||||
#[serde(default = "default_gateway_url")]
|
||||
pub gateway_url: String,
|
||||
}
|
||||
|
||||
fn default_gateway_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_gateway_port() -> u16 {
|
||||
19876
|
||||
}
|
||||
|
||||
fn default_gateway_url() -> String {
|
||||
"ws://127.0.0.1:19876/ws".to_string()
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gateway_url: default_gateway_url(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LLMProviderConfig {
|
||||
pub provider_type: String,
|
||||
@ -49,9 +99,37 @@ pub struct LLMProviderConfig {
|
||||
pub model_extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
fn get_default_config_path() -> PathBuf {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
home.join(".config").join("picobot").join("config.json")
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let content = fs::read_to_string(path)?;
|
||||
Self::load_from(Path::new(path))
|
||||
}
|
||||
|
||||
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let path = get_default_config_path();
|
||||
Self::load_from(&path)
|
||||
}
|
||||
|
||||
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
load_env_file()?;
|
||||
let content = if path.exists() {
|
||||
fs::read_to_string(path)?
|
||||
} else {
|
||||
// Fallback to current directory
|
||||
let fallback = Path::new("config.json");
|
||||
if fallback.exists() {
|
||||
fs::read_to_string(fallback)?
|
||||
} else {
|
||||
return Err(Box::new(ConfigError::ConfigNotFound(
|
||||
path.to_string_lossy().to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
let content = resolve_env_placeholders(&content);
|
||||
let config: Config = serde_json::from_str(&content)?;
|
||||
Ok(config)
|
||||
}
|
||||
@ -82,6 +160,7 @@ impl Config {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
ConfigNotFound(String),
|
||||
AgentNotFound(String),
|
||||
ProviderNotFound(String),
|
||||
ModelNotFound(String),
|
||||
@ -90,6 +169,7 @@ pub enum ConfigError {
|
||||
impl std::fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path),
|
||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||
@ -99,6 +179,37 @@ impl std::fmt::Display for ConfigError {
|
||||
|
||||
impl std::error::Error for ConfigError {}
|
||||
|
||||
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let env_path = Path::new(".env");
|
||||
if env_path.exists() {
|
||||
let content = fs::read_to_string(env_path)?;
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
let key = key.trim();
|
||||
let value = value.trim().trim_matches('"').trim_matches('\'');
|
||||
if !value.is_empty() {
|
||||
// SAFETY: Setting environment variables for the current process
|
||||
// is safe as we're only modifying our own process state
|
||||
unsafe { env::set_var(key, value) };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_env_placeholders(content: &str) -> String {
|
||||
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
|
||||
re.replace_all(content, |caps: ®ex::Captures| {
|
||||
let var_name = &caps[1];
|
||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
||||
}).to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -129,4 +240,11 @@ mod tests {
|
||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||
assert_eq!(provider_config.temperature, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_gateway_config() {
|
||||
let config = Config::load("config.json").unwrap();
|
||||
assert_eq!(config.gateway.host, "0.0.0.0");
|
||||
assert_eq!(config.gateway.port, 19876);
|
||||
}
|
||||
}
|
||||
|
||||
15
src/gateway/http.rs
Normal file
15
src/gateway/http.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use axum::Json;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct HealthResponse {
|
||||
status: String,
|
||||
version: String,
|
||||
}
|
||||
|
||||
pub async fn health() -> Json<HealthResponse> {
|
||||
Json(HealthResponse {
|
||||
status: "ok".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
})
|
||||
}
|
||||
44
src/gateway/mod.rs
Normal file
44
src/gateway/mod.rs
Normal file
@ -0,0 +1,44 @@
|
||||
pub mod http;
|
||||
pub mod session;
|
||||
pub mod ws;
|
||||
|
||||
use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::config::Config;
|
||||
use session::SessionManager;
|
||||
|
||||
pub struct GatewayState {
|
||||
pub config: Config,
|
||||
pub session_manager: SessionManager,
|
||||
}
|
||||
|
||||
impl GatewayState {
|
||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let config = Config::load_default()?;
|
||||
Ok(Self {
|
||||
config,
|
||||
session_manager: SessionManager::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let state = Arc::new(GatewayState::new()?);
|
||||
|
||||
// CLI args override config file values
|
||||
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
||||
let bind_port = port.unwrap_or(state.config.gateway.port);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", routing::get(http::health))
|
||||
.route("/ws", routing::get(ws::ws_handler))
|
||||
.with_state(state);
|
||||
|
||||
let addr = format!("{}:{}", bind_host, bind_port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
println!("Gateway listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
Ok(())
|
||||
}
|
||||
67
src/gateway/session.rs
Normal file
67
src/gateway/session.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::protocol::WsOutbound;
|
||||
|
||||
pub struct Session {
|
||||
pub id: Uuid,
|
||||
pub agent_loop: Arc<Mutex<AgentLoop>>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn new(
|
||||
provider_config: LLMProviderConfig,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
) -> Result<Self, crate::agent::AgentError> {
|
||||
let agent_loop = AgentLoop::new(provider_config)?;
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
agent_loop: Arc::new(Mutex::new(agent_loop)),
|
||||
user_tx,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: WsOutbound) {
|
||||
let _ = self.user_tx.send(msg).await;
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
pub struct SessionManager {
|
||||
sessions: RwLock<HashMap<Uuid, Arc<Session>>>,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&self, session: Arc<Session>) {
|
||||
self.sessions.write().unwrap().insert(session.id, session);
|
||||
}
|
||||
|
||||
pub fn remove(&self, id: &Uuid) {
|
||||
self.sessions.write().unwrap().remove(id);
|
||||
}
|
||||
|
||||
pub fn get(&self, id: &Uuid) -> Option<Arc<Session>> {
|
||||
self.sessions.read().unwrap().get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.sessions.read().unwrap().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
116
src/gateway/ws.rs
Normal file
116
src/gateway/ws.rs
Normal file
@ -0,0 +1,116 @@
|
||||
use std::sync::Arc;
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||
use axum::extract::State;
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::sync::mpsc;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::Session};
|
||||
|
||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||
ws.on_upgrade(|socket| async {
|
||||
handle_socket(socket, state).await;
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
||||
|
||||
let provider_config = match state.config.get_provider_config("default") {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to get provider config: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let session = match Session::new(provider_config, sender).await {
|
||||
Ok(s) => Arc::new(s),
|
||||
Err(e) => {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let session_id = session.id;
|
||||
state.session_manager.add(session.clone());
|
||||
|
||||
let _ = session.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.to_string(),
|
||||
}).await;
|
||||
|
||||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||||
|
||||
let mut receiver = receiver;
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
if let Ok(text) = serialize_outbound(&msg) {
|
||||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
while let Some(msg) = ws_receiver.next().await {
|
||||
match msg {
|
||||
Ok(WsMessage::Text(text)) => {
|
||||
let text = text.to_string();
|
||||
match parse_inbound(&text) {
|
||||
Ok(inbound) => {
|
||||
handle_inbound(&session, inbound).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = session.send(WsOutbound::Error {
|
||||
code: "PARSE_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(WsMessage::Close(_)) | Err(_) => {
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
state.session_manager.remove(&session_id);
|
||||
}
|
||||
|
||||
async fn handle_inbound(session: &Arc<Session>, inbound: WsInbound) {
|
||||
match inbound {
|
||||
WsInbound::UserInput { content } => {
|
||||
let user_msg = ChatMessage::user(content);
|
||||
let mut agent = session.agent_loop.lock().await;
|
||||
match agent.process(user_msg).await {
|
||||
Ok(response) => {
|
||||
let _ = session.send(WsOutbound::AssistantResponse {
|
||||
id: response.id,
|
||||
content: response.content,
|
||||
role: response.role,
|
||||
}).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = session.send(WsOutbound::Error {
|
||||
code: "LLM_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
WsInbound::ClearHistory => {
|
||||
let mut agent = session.agent_loop.lock().await;
|
||||
agent.clear_history();
|
||||
let _ = session.send(WsOutbound::AssistantResponse {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
content: "History cleared.".to_string(),
|
||||
role: "system".to_string(),
|
||||
}).await;
|
||||
}
|
||||
WsInbound::Ping => {
|
||||
let _ = session.send(WsOutbound::Pong).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,2 +1,8 @@
|
||||
pub mod config;
|
||||
pub mod providers;
|
||||
pub mod bus;
|
||||
pub mod cli;
|
||||
pub mod agent;
|
||||
pub mod gateway;
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
|
||||
79
src/main.rs
79
src/main.rs
@ -1,38 +1,57 @@
|
||||
mod config;
|
||||
mod providers;
|
||||
use clap::{Parser, CommandFactory};
|
||||
|
||||
use config::Config;
|
||||
use providers::{create_provider, ChatCompletionRequest, Message};
|
||||
#[derive(Parser)]
|
||||
#[command(name = "picobot")]
|
||||
#[command(about = "A CLI chatbot", long_about = None)]
|
||||
enum Command {
|
||||
/// Connect to gateway
|
||||
Agent {
|
||||
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
|
||||
#[arg(long)]
|
||||
gateway_url: Option<String>,
|
||||
},
|
||||
/// Start gateway server
|
||||
Gateway {
|
||||
/// Host to bind to
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
/// Port to listen on
|
||||
#[arg(long)]
|
||||
port: Option<u16>,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Load config
|
||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut cmd = Command::command();
|
||||
|
||||
// Get provider config for "default" agent
|
||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||
// If no arguments, print help
|
||||
if std::env::args().len() <= 1 {
|
||||
cmd.print_help()?;
|
||||
println!();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create provider
|
||||
let provider = create_provider(provider_config).expect("Failed to create provider");
|
||||
|
||||
println!("Provider type: {}", provider.ptype());
|
||||
println!("Provider name: {}", provider.name());
|
||||
println!("Model ID: {}", provider.model_id());
|
||||
|
||||
// Create request (no model ID needed - it's baked into the provider)
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hello!".to_string(),
|
||||
}],
|
||||
temperature: None, // Will use config default if not provided
|
||||
max_tokens: None, // Will use config default if not provided
|
||||
tools: None,
|
||||
// Load config for defaults
|
||||
let config = match picobot::config::Config::load_default() {
|
||||
Ok(cfg) => Some(cfg),
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Could not load config from ~/.config/picobot/config.json: {}", e);
|
||||
eprintln!("Using built-in defaults. Run `picobot gateway` or `picobot agent` with --help for options.\n");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Example usage:
|
||||
// match provider.chat(request).await {
|
||||
// Ok(resp) => println!("Response: {}", resp.content),
|
||||
// Err(e) => eprintln!("Error: {}", e),
|
||||
// }
|
||||
match Command::parse() {
|
||||
Command::Agent { gateway_url } => {
|
||||
let url = gateway_url
|
||||
.or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone()))
|
||||
.unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string());
|
||||
picobot::client::run(&url).await?;
|
||||
}
|
||||
Command::Gateway { host, port } => {
|
||||
picobot::gateway::run(host, port).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
37
src/protocol.rs
Normal file
37
src/protocol.rs
Normal file
@ -0,0 +1,37 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsInbound {
|
||||
#[serde(rename = "user_input")]
|
||||
UserInput { content: String },
|
||||
#[serde(rename = "clear_history")]
|
||||
ClearHistory,
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsOutbound {
|
||||
#[serde(rename = "assistant_response")]
|
||||
AssistantResponse { id: String, content: String, role: String },
|
||||
#[serde(rename = "error")]
|
||||
Error { code: String, message: String },
|
||||
#[serde(rename = "session_established")]
|
||||
SessionEstablished { session_id: String },
|
||||
#[serde(rename = "pong")]
|
||||
Pong,
|
||||
}
|
||||
|
||||
pub fn parse_inbound(raw: &str) -> Result<WsInbound, serde_json::Error> {
|
||||
serde_json::from_str(raw)
|
||||
}
|
||||
|
||||
pub fn serialize_inbound(msg: &WsInbound) -> Result<String, serde_json::Error> {
|
||||
serde_json::to_string(msg)
|
||||
}
|
||||
|
||||
pub fn serialize_outbound(msg: &WsOutbound) -> Result<String, serde_json::Error> {
|
||||
serde_json::to_string(msg)
|
||||
}
|
||||
@ -134,7 +134,15 @@ impl LLMProvider for OpenAIProvider {
|
||||
|
||||
let resp = req_builder.json(&body).send().await?;
|
||||
|
||||
let openai_resp: OpenAIResponse = resp.json().await?;
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(format!("API error {}: {}", status, text).into());
|
||||
}
|
||||
|
||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
||||
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
||||
|
||||
let content = openai_resp.choices[0]
|
||||
.message
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user