Compare commits
No commits in common. "04736f9f46217e1ece03d5097e70a09cecc55448" and "5dc13ea7ceff45001c261808210752f7957ef31c" have entirely different histories.
04736f9f46
...
5dc13ea7ce
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
/target
|
|
||||||
21
Cargo.toml
21
Cargo.toml
@ -1,21 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "picobot"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2024"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls"] }
|
|
||||||
dotenv = "0.15"
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
regex = "1.0"
|
|
||||||
serde_json = "1.0"
|
|
||||||
async-trait = "0.1"
|
|
||||||
thiserror = "2.0.18"
|
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
|
||||||
uuid = { version = "1.0", features = ["v4"] }
|
|
||||||
axum = { version = "0.8", features = ["ws"] }
|
|
||||||
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
|
|
||||||
futures-util = "0.3"
|
|
||||||
clap = { version = "4", features = ["derive"] }
|
|
||||||
dirs = "6.0.0"
|
|
||||||
prost = "0.14"
|
|
||||||
@ -1,72 +0,0 @@
|
|||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
|
||||||
|
|
||||||
pub struct AgentLoop {
|
|
||||||
provider: Box<dyn LLMProvider>,
|
|
||||||
history: Vec<ChatMessage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AgentLoop {
|
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
|
||||||
let provider = create_provider(provider_config)
|
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
provider,
|
|
||||||
history: Vec::new(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
|
|
||||||
self.history.push(user_message.clone());
|
|
||||||
|
|
||||||
let messages: Vec<Message> = self.history
|
|
||||||
.iter()
|
|
||||||
.map(|m| Message {
|
|
||||||
role: m.role.clone(),
|
|
||||||
content: m.content.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages,
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
tools: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = (*self.provider).chat(request).await
|
|
||||||
.map_err(|e| AgentError::LlmError(e.to_string()))?;
|
|
||||||
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
|
||||||
self.history.push(assistant_message.clone());
|
|
||||||
|
|
||||||
Ok(assistant_message)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clear_history(&mut self) {
|
|
||||||
self.history.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn history(&self) -> &[ChatMessage] {
|
|
||||||
&self.history
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum AgentError {
|
|
||||||
ProviderCreation(String),
|
|
||||||
LlmError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for AgentError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
|
||||||
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for AgentError {}
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
pub mod agent_loop;
|
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError};
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatMessage {
|
|
||||||
pub id: String,
|
|
||||||
pub role: String,
|
|
||||||
pub content: String,
|
|
||||||
pub timestamp: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatMessage {
|
|
||||||
pub fn user(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: content.into(),
|
|
||||||
timestamp: current_timestamp(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn assistant(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: content.into(),
|
|
||||||
timestamp: current_timestamp(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn system(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: content.into(),
|
|
||||||
timestamp: current_timestamp(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
|
||||||
std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap()
|
|
||||||
.as_millis() as i64
|
|
||||||
}
|
|
||||||
@ -1,44 +0,0 @@
|
|||||||
pub mod message;
|
|
||||||
|
|
||||||
pub use message::ChatMessage;
|
|
||||||
|
|
||||||
use tokio::sync::{mpsc, broadcast};
|
|
||||||
|
|
||||||
pub struct MessageBus {
|
|
||||||
user_tx: mpsc::Sender<ChatMessage>,
|
|
||||||
llm_tx: broadcast::Sender<ChatMessage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MessageBus {
|
|
||||||
pub fn new(buffer_size: usize) -> Self {
|
|
||||||
let (user_tx, _) = mpsc::channel(buffer_size);
|
|
||||||
let (llm_tx, _) = broadcast::channel(buffer_size);
|
|
||||||
|
|
||||||
Self { user_tx, llm_tx }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> {
|
|
||||||
self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_llm_output(&self, msg: ChatMessage) -> Result<usize, BusError> {
|
|
||||||
self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum BusError {
|
|
||||||
ChannelClosed,
|
|
||||||
SendError(usize),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for BusError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
BusError::ChannelClosed => write!(f, "Channel closed"),
|
|
||||||
BusError::SendError(n) => write!(f, "Send error, {} receivers", n),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for BusError {}
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct InboundMessage {
|
|
||||||
pub channel: String,
|
|
||||||
pub sender_id: String,
|
|
||||||
pub chat_id: String,
|
|
||||||
pub content: String,
|
|
||||||
pub media: Vec<String>,
|
|
||||||
pub metadata: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct OutboundMessage {
|
|
||||||
pub channel: String,
|
|
||||||
pub chat_id: String,
|
|
||||||
pub content: String,
|
|
||||||
pub media: Vec<String>,
|
|
||||||
pub metadata: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ChannelError {
|
|
||||||
ConfigError(String),
|
|
||||||
ConnectionError(String),
|
|
||||||
SendError(String),
|
|
||||||
Other(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for ChannelError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
ChannelError::ConfigError(s) => write!(f, "Config error: {}", s),
|
|
||||||
ChannelError::ConnectionError(s) => write!(f, "Connection error: {}", s),
|
|
||||||
ChannelError::SendError(s) => write!(f, "Send error: {}", s),
|
|
||||||
ChannelError::Other(s) => write!(f, "Error: {}", s),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for ChannelError {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Channel: Send + Sync + 'static {
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
async fn start(&self) -> Result<(), ChannelError>;
|
|
||||||
async fn stop(&self) -> Result<(), ChannelError>;
|
|
||||||
fn is_running(&self) -> bool;
|
|
||||||
}
|
|
||||||
@ -1,600 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Instant;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use tokio::sync::{broadcast, RwLock, Mutex};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use futures_util::{SinkExt, StreamExt};
|
|
||||||
use prost::{Message as ProstMessage, bytes::Bytes};
|
|
||||||
|
|
||||||
use crate::agent::AgentLoop;
|
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
|
||||||
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
|
||||||
|
|
||||||
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
|
||||||
const FEISHU_WS_BASE: &str = "https://open.feishu.cn";
|
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
// Protobuf types for Feishu WebSocket protocol (pbbp2.proto)
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, prost::Message)]
|
|
||||||
struct PbHeader {
|
|
||||||
#[prost(string, tag = "1")]
|
|
||||||
pub key: String,
|
|
||||||
#[prost(string, tag = "2")]
|
|
||||||
pub value: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Feishu WS frame.
|
|
||||||
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
|
|
||||||
#[derive(Clone, PartialEq, prost::Message)]
|
|
||||||
struct PbFrame {
|
|
||||||
#[prost(uint64, tag = "1")]
|
|
||||||
pub seq_id: u64,
|
|
||||||
#[prost(uint64, tag = "2")]
|
|
||||||
pub log_id: u64,
|
|
||||||
#[prost(int32, tag = "3")]
|
|
||||||
pub service: i32,
|
|
||||||
#[prost(int32, tag = "4")]
|
|
||||||
pub method: i32,
|
|
||||||
#[prost(message, repeated, tag = "5")]
|
|
||||||
pub headers: Vec<PbHeader>,
|
|
||||||
#[prost(bytes = "vec", optional, tag = "8")]
|
|
||||||
pub payload: Option<Vec<u8>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// POST /callback/ws/endpoint response
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct WsEndpointResp {
|
|
||||||
code: i32,
|
|
||||||
msg: Option<String>,
|
|
||||||
data: Option<WsEndpoint>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct WsEndpoint {
|
|
||||||
#[serde(rename = "URL")]
|
|
||||||
url: String,
|
|
||||||
#[serde(default)]
|
|
||||||
client_config: Option<WsClientConfig>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Default)]
|
|
||||||
struct WsClientConfig {
|
|
||||||
#[serde(rename = "PingInterval")]
|
|
||||||
ping_interval: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Lark event envelope (method=1 / type=event payload)
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct LarkEvent {
|
|
||||||
header: LarkEventHeader,
|
|
||||||
event: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct LarkEventHeader {
|
|
||||||
event_type: String,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
event_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for LarkEventHeader {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
f.debug_struct("LarkEventHeader")
|
|
||||||
.field("event_type", &self.event_type)
|
|
||||||
.field("event_id", &self.event_id)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct MsgReceivePayload {
|
|
||||||
sender: LarkSender,
|
|
||||||
message: LarkMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct LarkSender {
|
|
||||||
sender_id: LarkSenderId,
|
|
||||||
#[serde(default)]
|
|
||||||
sender_type: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Default)]
|
|
||||||
struct LarkSenderId {
|
|
||||||
open_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[allow(dead_code)]
|
|
||||||
struct LarkMessage {
|
|
||||||
message_id: String,
|
|
||||||
chat_id: String,
|
|
||||||
chat_type: String,
|
|
||||||
message_type: String,
|
|
||||||
#[serde(default)]
|
|
||||||
content: String,
|
|
||||||
#[serde(default)]
|
|
||||||
mentions: Vec<serde_json::Value>,
|
|
||||||
#[serde(default)]
|
|
||||||
root_id: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
parent_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct FeishuChannel {
|
|
||||||
config: FeishuChannelConfig,
|
|
||||||
http_client: reqwest::Client,
|
|
||||||
running: Arc<RwLock<bool>>,
|
|
||||||
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
|
||||||
connected: Arc<RwLock<bool>>,
|
|
||||||
/// Dedup: message_id -> timestamp (cleaned after 30 min)
|
|
||||||
seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
|
||||||
/// Agent for processing messages
|
|
||||||
agent: Arc<Mutex<AgentLoop>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FeishuChannel {
|
|
||||||
pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result<Self, ChannelError> {
|
|
||||||
let agent = AgentLoop::new(provider_config)
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
config,
|
|
||||||
http_client: reqwest::Client::new(),
|
|
||||||
running: Arc::new(RwLock::new(false)),
|
|
||||||
shutdown_tx: Arc::new(RwLock::new(None)),
|
|
||||||
connected: Arc::new(RwLock::new(false)),
|
|
||||||
seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
|
||||||
agent: Arc::new(Mutex::new(agent)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get WebSocket endpoint URL from Feishu API
|
|
||||||
async fn get_ws_endpoint(&self, client: &reqwest::Client) -> Result<(String, WsClientConfig), ChannelError> {
|
|
||||||
let resp = client
|
|
||||||
.post(format!("{}/callback/ws/endpoint", FEISHU_WS_BASE))
|
|
||||||
.header("locale", "zh")
|
|
||||||
.json(&serde_json::json!({
|
|
||||||
"AppID": self.config.app_id,
|
|
||||||
"AppSecret": self.config.app_secret,
|
|
||||||
}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?;
|
|
||||||
|
|
||||||
let endpoint_resp: WsEndpointResp = resp
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)))?;
|
|
||||||
|
|
||||||
if endpoint_resp.code != 0 {
|
|
||||||
return Err(ChannelError::ConnectionError(format!(
|
|
||||||
"WS endpoint failed: code={} msg={}",
|
|
||||||
endpoint_resp.code,
|
|
||||||
endpoint_resp.msg.as_deref().unwrap_or("unknown")
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ep = endpoint_resp.data
|
|
||||||
.ok_or_else(|| ChannelError::ConnectionError("Empty endpoint data".to_string()))?;
|
|
||||||
|
|
||||||
let client_config = ep.client_config.unwrap_or_default();
|
|
||||||
Ok((ep.url, client_config))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get tenant access token
|
|
||||||
async fn get_tenant_token(&self) -> Result<String, ChannelError> {
|
|
||||||
let resp = self.http_client
|
|
||||||
.post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.json(&serde_json::json!({
|
|
||||||
"app_id": self.config.app_id,
|
|
||||||
"app_secret": self.config.app_secret,
|
|
||||||
}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct TokenResponse {
|
|
||||||
code: i32,
|
|
||||||
tenant_access_token: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
let token_resp: TokenResponse = resp
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to parse token response: {}", e)))?;
|
|
||||||
|
|
||||||
if token_resp.code != 0 {
|
|
||||||
return Err(ChannelError::Other("Auth failed".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
token_resp.tenant_access_token
|
|
||||||
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send a text message to Feishu chat
|
|
||||||
async fn send_message(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
|
||||||
let token = self.get_tenant_token().await?;
|
|
||||||
|
|
||||||
// For text message, content should be a JSON string: "{\"text\":\"hello\"}"
|
|
||||||
let text_content = serde_json::json!({ "text": content }).to_string();
|
|
||||||
|
|
||||||
let resp = self.http_client
|
|
||||||
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", token))
|
|
||||||
.json(&serde_json::json!({
|
|
||||||
"receive_id": receive_id,
|
|
||||||
"msg_type": "text",
|
|
||||||
"content": text_content
|
|
||||||
}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)))?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct SendResp {
|
|
||||||
code: i32,
|
|
||||||
msg: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
let send_resp: SendResp = resp
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?;
|
|
||||||
|
|
||||||
if send_resp.code != 0 {
|
|
||||||
return Err(ChannelError::Other(format!("Send message failed: code={} msg={}", send_resp.code, send_resp.msg)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle incoming message - process through agent and send response
|
|
||||||
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
|
|
||||||
println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content);
|
|
||||||
|
|
||||||
// Process through agent
|
|
||||||
let user_msg = ChatMessage::user(content);
|
|
||||||
let mut agent = self.agent.lock().await;
|
|
||||||
let response = agent.process(user_msg).await
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?;
|
|
||||||
|
|
||||||
// Send response to the chat
|
|
||||||
// Use open_id for p2p chats, chat_id for group chats
|
|
||||||
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
|
|
||||||
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
|
||||||
|
|
||||||
self.send_message(receive_id, receive_id_type, &response.content).await?;
|
|
||||||
println!("Feishu: sent response to {}", receive_id);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Extract service_id from WebSocket URL query params
|
|
||||||
fn extract_service_id(url: &str) -> i32 {
|
|
||||||
url.split('?')
|
|
||||||
.nth(1)
|
|
||||||
.and_then(|qs| {
|
|
||||||
qs.split('&')
|
|
||||||
.find(|kv| kv.starts_with("service_id="))
|
|
||||||
.and_then(|kv| kv.split('=').nth(1))
|
|
||||||
.and_then(|v| v.parse::<i32>().ok())
|
|
||||||
})
|
|
||||||
.unwrap_or(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack
|
|
||||||
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<String>, ChannelError> {
|
|
||||||
// method 0 = CONTROL (ping/pong)
|
|
||||||
if frame.method == 0 {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
// method 1 = DATA (events)
|
|
||||||
if frame.method != 1 {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
let payload = frame.payload.as_deref()
|
|
||||||
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
|
|
||||||
|
|
||||||
// Parse the event JSON to get event_type from payload header
|
|
||||||
let event: LarkEvent = serde_json::from_slice(payload)
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
|
|
||||||
|
|
||||||
let event_type = event.header.event_type.as_str();
|
|
||||||
if event_type != "im.message.receive_v1" {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
let payload_data: MsgReceivePayload = serde_json::from_value(event.event.clone())
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Parse payload error: {}", e)))?;
|
|
||||||
|
|
||||||
// Skip bot messages
|
|
||||||
if payload_data.sender.sender_type == "bot" {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deduplication check with TTL cleanup
|
|
||||||
let message_id = payload_data.message.message_id.clone();
|
|
||||||
{
|
|
||||||
let mut seen = self.seen_ids.write().await;
|
|
||||||
let now = Instant::now();
|
|
||||||
|
|
||||||
// Clean expired entries (older than 30 min)
|
|
||||||
seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800);
|
|
||||||
|
|
||||||
if seen.contains_key(&message_id) {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
seen.insert(message_id.clone(), now);
|
|
||||||
}
|
|
||||||
|
|
||||||
let open_id = payload_data.sender.sender_id.open_id
|
|
||||||
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
|
||||||
|
|
||||||
let msg = payload_data.message;
|
|
||||||
let chat_id = msg.chat_id.clone();
|
|
||||||
let msg_type = msg.message_type.as_str();
|
|
||||||
let content = parse_message_content(msg_type, &msg.content);
|
|
||||||
|
|
||||||
// Handle the message - process and send response
|
|
||||||
if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await {
|
|
||||||
eprintln!("Error handling message: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return message_id for ack
|
|
||||||
Ok(Some(message_id))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send acknowledgment for a message
|
|
||||||
async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> {
|
|
||||||
let mut ack = frame.clone();
|
|
||||||
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
|
|
||||||
ack.headers.push(PbHeader {
|
|
||||||
key: "biz_rt".into(),
|
|
||||||
value: "0".into(),
|
|
||||||
});
|
|
||||||
write.send(tokio_tungstenite::tungstenite::Message::Binary(ack.encode_to_vec().into()))
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to send ack: {}", e)))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_ws_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> {
|
|
||||||
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
|
|
||||||
|
|
||||||
let service_id = Self::extract_service_id(&wss_url);
|
|
||||||
println!("Feishu: connecting to {}", wss_url);
|
|
||||||
|
|
||||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
|
|
||||||
|
|
||||||
*self.connected.write().await = true;
|
|
||||||
println!("Feishu channel connected");
|
|
||||||
|
|
||||||
let (mut write, mut read) = ws_stream.split();
|
|
||||||
|
|
||||||
// Send initial ping
|
|
||||||
let ping_frame = PbFrame {
|
|
||||||
seq_id: 1,
|
|
||||||
log_id: 0,
|
|
||||||
service: service_id,
|
|
||||||
method: 0,
|
|
||||||
headers: vec![PbHeader {
|
|
||||||
key: "type".into(),
|
|
||||||
value: "ping".into(),
|
|
||||||
}],
|
|
||||||
payload: None,
|
|
||||||
};
|
|
||||||
write.send(tokio_tungstenite::tungstenite::Message::Binary(ping_frame.encode_to_vec().into()))
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)))?;
|
|
||||||
|
|
||||||
let ping_interval = client_config.ping_interval.unwrap_or(120).max(10);
|
|
||||||
let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval));
|
|
||||||
let mut seq: u64 = 1;
|
|
||||||
|
|
||||||
// Consume the immediate tick
|
|
||||||
ping_interval_tok.tick().await;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
msg = read.next() => {
|
|
||||||
match msg {
|
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
|
||||||
let bytes: Bytes = data;
|
|
||||||
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
|
||||||
// Handle the frame and get message_id for ack if needed
|
|
||||||
match self.handle_frame(&frame).await {
|
|
||||||
Ok(Some(_message_id)) => {
|
|
||||||
// Send ACK immediately (Feishu requires within 3 s)
|
|
||||||
if let Err(e) = Self::send_ack(&frame, &mut write).await {
|
|
||||||
eprintln!("Error sending ack: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(None) => {}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Error handling frame: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => {
|
|
||||||
let pong = PbFrame {
|
|
||||||
seq_id: seq.wrapping_add(1),
|
|
||||||
log_id: 0,
|
|
||||||
service: service_id,
|
|
||||||
method: 0,
|
|
||||||
headers: vec![PbHeader {
|
|
||||||
key: "type".into(),
|
|
||||||
value: "pong".into(),
|
|
||||||
}],
|
|
||||||
payload: Some(data.to_vec()),
|
|
||||||
};
|
|
||||||
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
|
||||||
}
|
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Some(Err(e)) => {
|
|
||||||
eprintln!("WS error: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = ping_interval_tok.tick() => {
|
|
||||||
seq = seq.wrapping_add(1);
|
|
||||||
let ping = PbFrame {
|
|
||||||
seq_id: seq,
|
|
||||||
log_id: 0,
|
|
||||||
service: service_id,
|
|
||||||
method: 0,
|
|
||||||
headers: vec![PbHeader {
|
|
||||||
key: "type".into(),
|
|
||||||
value: "ping".into(),
|
|
||||||
}],
|
|
||||||
payload: None,
|
|
||||||
};
|
|
||||||
if write.send(tokio_tungstenite::tungstenite::Message::Binary(ping.encode_to_vec().into())).await.is_err() {
|
|
||||||
eprintln!("Feishu: ping failed, reconnecting");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = shutdown_rx.recv() => {
|
|
||||||
println!("Feishu channel shutdown signal received");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*self.connected.write().await = false;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_message_content(msg_type: &str, content: &str) -> String {
|
|
||||||
match msg_type {
|
|
||||||
"text" => {
|
|
||||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
|
||||||
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
|
|
||||||
} else {
|
|
||||||
content.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"post" => {
|
|
||||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
|
||||||
let mut texts = vec![];
|
|
||||||
if let Some(post) = parsed.get("post") {
|
|
||||||
if let Some(content_arr) = post.get("content") {
|
|
||||||
if let Some(arr) = content_arr.as_array() {
|
|
||||||
for item in arr {
|
|
||||||
if let Some(arr2) = item.as_array() {
|
|
||||||
for inner in arr2 {
|
|
||||||
if let Some(text) = inner.get("text").and_then(|v| v.as_str()) {
|
|
||||||
texts.push(text.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if texts.is_empty() {
|
|
||||||
content.to_string()
|
|
||||||
} else {
|
|
||||||
texts.join("")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
content.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => content.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Channel for FeishuChannel {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"feishu"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start(&self) -> Result<(), ChannelError> {
|
|
||||||
if self.config.app_id.is_empty() || self.config.app_secret.is_empty() {
|
|
||||||
return Err(ChannelError::ConfigError(
|
|
||||||
"Feishu app_id or app_secret is not configured".to_string()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
*self.running.write().await = true;
|
|
||||||
|
|
||||||
let (shutdown_tx, _) = broadcast::channel(1);
|
|
||||||
*self.shutdown_tx.write().await = Some(shutdown_tx.clone());
|
|
||||||
|
|
||||||
let channel = self.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut consecutive_failures = 0;
|
|
||||||
let max_failures = 3;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
if !*channel.running.read().await {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let shutdown_rx = shutdown_tx.subscribe();
|
|
||||||
match channel.run_ws_loop(shutdown_rx).await {
|
|
||||||
Ok(_) => {
|
|
||||||
println!("Feishu WebSocket disconnected");
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
consecutive_failures += 1;
|
|
||||||
eprintln!("Feishu WebSocket error (attempt {}): {}", consecutive_failures, e);
|
|
||||||
if consecutive_failures >= max_failures {
|
|
||||||
eprintln!("Feishu channel: max failures reached, stopping");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !*channel.running.read().await {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Feishu channel retrying in 5s...");
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
*channel.running.write().await = false;
|
|
||||||
println!("Feishu channel stopped");
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("Feishu channel started");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn stop(&self) -> Result<(), ChannelError> {
|
|
||||||
*self.running.write().await = false;
|
|
||||||
|
|
||||||
if let Some(tx) = self.shutdown_tx.write().await.take() {
|
|
||||||
let _ = tx.send(());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_running(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
|
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
|
||||||
use crate::channels::feishu::FeishuChannel;
|
|
||||||
use crate::config::Config;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ChannelManager {
|
|
||||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChannelManager {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn init(&self, config: &Config) -> Result<(), ChannelError> {
|
|
||||||
// Initialize Feishu channel if enabled
|
|
||||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
|
||||||
if feishu_config.enabled {
|
|
||||||
let agent_name = &feishu_config.agent;
|
|
||||||
let provider_config = config.get_provider_config(agent_name)
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?;
|
|
||||||
|
|
||||||
let channel = FeishuChannel::new(feishu_config.clone(), provider_config)
|
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
|
||||||
|
|
||||||
self.channels
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.insert("feishu".to_string(), Arc::new(channel));
|
|
||||||
println!("Feishu channel registered");
|
|
||||||
} else {
|
|
||||||
println!("Feishu channel disabled in config");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_all(&self) -> Result<(), ChannelError> {
|
|
||||||
let channels = self.channels.read().await;
|
|
||||||
for (name, channel) in channels.iter() {
|
|
||||||
println!("Starting channel: {}", name);
|
|
||||||
if let Err(e) = channel.start().await {
|
|
||||||
eprintln!("Warning: Failed to start channel {}: {}", name, e);
|
|
||||||
// Channel failed to start - it should have logged why
|
|
||||||
// Continue starting other channels
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn stop_all(&self) -> Result<(), ChannelError> {
|
|
||||||
let mut channels = self.channels.write().await;
|
|
||||||
for (name, channel) in channels.iter() {
|
|
||||||
println!("Stopping channel: {}", name);
|
|
||||||
if let Err(e) = channel.stop().await {
|
|
||||||
eprintln!("Error stopping channel {}: {}", name, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
channels.clear();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel>> {
|
|
||||||
self.channels.read().await.get(name).cloned()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ChannelManager {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
pub mod base;
|
|
||||||
pub mod feishu;
|
|
||||||
pub mod manager;
|
|
||||||
|
|
||||||
pub use base::{Channel, ChannelError, InboundMessage, OutboundMessage};
|
|
||||||
pub use manager::ChannelManager;
|
|
||||||
pub use feishu::FeishuChannel;
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt};
|
|
||||||
|
|
||||||
pub struct CliChannel {
|
|
||||||
read: BufReader<tokio::io::Stdin>,
|
|
||||||
write: tokio::io::Stdout,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CliChannel {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
read: BufReader::new(tokio::io::stdin()),
|
|
||||||
write: tokio::io::stdout(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_line(&mut self, prompt: &str) -> Result<Option<String>, std::io::Error> {
|
|
||||||
print!("{}", prompt);
|
|
||||||
self.write.flush().await?;
|
|
||||||
|
|
||||||
let mut line = String::new();
|
|
||||||
let bytes_read = self.read.read_line(&mut line).await?;
|
|
||||||
|
|
||||||
if bytes_read == 0 {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(line.trim_end().to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_line(&mut self, content: &str) -> Result<(), std::io::Error> {
|
|
||||||
self.write.write_all(content.as_bytes()).await?;
|
|
||||||
self.write.write_all(b"\n").await?;
|
|
||||||
self.write.flush().await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_response(&mut self, content: &str) -> Result<(), std::io::Error> {
|
|
||||||
for line in content.lines() {
|
|
||||||
self.write.write_all(b" ").await?;
|
|
||||||
self.write.write_all(line.as_bytes()).await?;
|
|
||||||
self.write.write_all(b"\n").await?;
|
|
||||||
}
|
|
||||||
self.write.flush().await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for CliChannel {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,70 +0,0 @@
|
|||||||
use crate::bus::ChatMessage;
|
|
||||||
|
|
||||||
use super::channel::CliChannel;
|
|
||||||
|
|
||||||
pub struct InputHandler {
|
|
||||||
channel: CliChannel,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InputHandler {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
channel: CliChannel::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_input(&mut self, prompt: &str) -> Result<Option<ChatMessage>, InputError> {
|
|
||||||
match self.channel.read_line(prompt).await {
|
|
||||||
Ok(Some(line)) => {
|
|
||||||
if line.trim().is_empty() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(cmd) = self.handle_special_commands(&line) {
|
|
||||||
return Ok(Some(cmd));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(ChatMessage::user(line)))
|
|
||||||
}
|
|
||||||
Ok(None) => Ok(None),
|
|
||||||
Err(e) => Err(InputError::IoError(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
|
|
||||||
self.channel.write_line(content).await.map_err(InputError::IoError)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
|
|
||||||
self.channel.write_response(content).await.map_err(InputError::IoError)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_special_commands(&self, line: &str) -> Option<ChatMessage> {
|
|
||||||
match line.trim() {
|
|
||||||
"/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")),
|
|
||||||
"/clear" => Some(ChatMessage::system("__CLEAR__")),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for InputHandler {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum InputError {
|
|
||||||
IoError(std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for InputError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
InputError::IoError(e) => write!(f, "IO error: {}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for InputError {}
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
pub mod channel;
|
|
||||||
pub mod input;
|
|
||||||
|
|
||||||
pub use channel::CliChannel;
|
|
||||||
pub use input::InputHandler;
|
|
||||||
@ -1,89 +0,0 @@
|
|||||||
pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_outbound};
|
|
||||||
|
|
||||||
use futures_util::{SinkExt, StreamExt};
|
|
||||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
|
||||||
|
|
||||||
use crate::cli::InputHandler;
|
|
||||||
|
|
||||||
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
|
||||||
serde_json::from_str(raw)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
let (ws_stream, _) = connect_async(gateway_url).await?;
|
|
||||||
println!("Connected to gateway");
|
|
||||||
|
|
||||||
let (mut sender, mut receiver) = ws_stream.split();
|
|
||||||
|
|
||||||
let mut input = InputHandler::new();
|
|
||||||
input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?;
|
|
||||||
|
|
||||||
// Main loop: poll both stdin and WebSocket
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
// Handle WebSocket messages
|
|
||||||
msg = receiver.next() => {
|
|
||||||
match msg {
|
|
||||||
Some(Ok(Message::Text(text))) => {
|
|
||||||
let text = text.to_string();
|
|
||||||
if let Ok(outbound) = parse_message(&text) {
|
|
||||||
match outbound {
|
|
||||||
WsOutbound::AssistantResponse { content, .. } => {
|
|
||||||
input.write_response(&content).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::Error { message, .. } => {
|
|
||||||
input.write_output(&format!("Error: {}", message)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::SessionEstablished { session_id } => {
|
|
||||||
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(Ok(Message::Close(_))) | None => {
|
|
||||||
input.write_output("Gateway disconnected").await?;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Handle stdin input
|
|
||||||
result = input.read_input("> ") => {
|
|
||||||
match result {
|
|
||||||
Ok(Some(msg)) => {
|
|
||||||
match msg.content.as_str() {
|
|
||||||
"__EXIT__" => {
|
|
||||||
input.write_output("Goodbye!").await?;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
"__CLEAR__" => {
|
|
||||||
let inbound = WsInbound::ClearHistory;
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
let inbound = WsInbound::UserInput { content: msg.content };
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
|
||||||
if sender.send(Message::Text(text.into())).await.is_err() {
|
|
||||||
eprintln!("Failed to send message");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(None) => break,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Input error: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@ -1,270 +0,0 @@
|
|||||||
use regex::Regex;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::env;
|
|
||||||
use std::fs;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub providers: HashMap<String, ProviderConfig>,
|
|
||||||
pub models: HashMap<String, ModelConfig>,
|
|
||||||
pub agents: HashMap<String, AgentConfig>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub gateway: GatewayConfig,
|
|
||||||
#[serde(default)]
|
|
||||||
pub client: ClientConfig,
|
|
||||||
#[serde(default)]
|
|
||||||
pub channels: HashMap<String, FeishuChannelConfig>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FeishuChannelConfig {
|
|
||||||
#[serde(default)]
|
|
||||||
pub enabled: bool,
|
|
||||||
pub app_id: String,
|
|
||||||
pub app_secret: String,
|
|
||||||
#[serde(default = "default_allow_from")]
|
|
||||||
pub allow_from: Vec<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub agent: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_allow_from() -> Vec<String> {
|
|
||||||
vec!["*".to_string()]
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ProviderConfig {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub provider_type: String,
|
|
||||||
pub base_url: String,
|
|
||||||
pub api_key: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub extra_headers: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ModelConfig {
|
|
||||||
pub model_id: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub extra: HashMap<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct AgentConfig {
|
|
||||||
pub provider: String,
|
|
||||||
pub model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct GatewayConfig {
|
|
||||||
#[serde(default = "default_gateway_host")]
|
|
||||||
pub host: String,
|
|
||||||
#[serde(default = "default_gateway_port")]
|
|
||||||
pub port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ClientConfig {
|
|
||||||
#[serde(default = "default_gateway_url")]
|
|
||||||
pub gateway_url: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_gateway_host() -> String {
|
|
||||||
"127.0.0.1".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_gateway_port() -> u16 {
|
|
||||||
19876
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_gateway_url() -> String {
|
|
||||||
"ws://127.0.0.1:19876/ws".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for GatewayConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
host: default_gateway_host(),
|
|
||||||
port: default_gateway_port(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ClientConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
gateway_url: default_gateway_url(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct LLMProviderConfig {
|
|
||||||
pub provider_type: String,
|
|
||||||
pub name: String,
|
|
||||||
pub base_url: String,
|
|
||||||
pub api_key: String,
|
|
||||||
pub extra_headers: HashMap<String, String>,
|
|
||||||
pub model_id: String,
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
pub model_extra: HashMap<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_default_config_path() -> PathBuf {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
home.join(".config").join("picobot").join("config.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
|
||||||
Self::load_from(Path::new(path))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
|
|
||||||
let path = get_default_config_path();
|
|
||||||
Self::load_from(&path)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
|
||||||
load_env_file()?;
|
|
||||||
let content = if path.exists() {
|
|
||||||
println!("Config loaded from: {}", path.display());
|
|
||||||
fs::read_to_string(path)?
|
|
||||||
} else {
|
|
||||||
// Fallback to current directory
|
|
||||||
let fallback = Path::new("config.json");
|
|
||||||
if fallback.exists() {
|
|
||||||
println!("Config loaded from: {}", fallback.display());
|
|
||||||
fs::read_to_string(fallback)?
|
|
||||||
} else {
|
|
||||||
return Err(Box::new(ConfigError::ConfigNotFound(
|
|
||||||
path.to_string_lossy().to_string(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let content = resolve_env_placeholders(&content);
|
|
||||||
let config: Config = serde_json::from_str(&content)?;
|
|
||||||
Ok(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
|
||||||
let agent = self.agents.get(agent_name)
|
|
||||||
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
|
||||||
|
|
||||||
let provider = self.providers.get(&agent.provider)
|
|
||||||
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
|
||||||
|
|
||||||
let model = self.models.get(&agent.model)
|
|
||||||
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
|
||||||
|
|
||||||
Ok(LLMProviderConfig {
|
|
||||||
provider_type: provider.provider_type.clone(),
|
|
||||||
name: agent.provider.clone(),
|
|
||||||
base_url: provider.base_url.clone(),
|
|
||||||
api_key: provider.api_key.clone(),
|
|
||||||
extra_headers: provider.extra_headers.clone(),
|
|
||||||
model_id: model.model_id.clone(),
|
|
||||||
temperature: model.temperature,
|
|
||||||
max_tokens: model.max_tokens,
|
|
||||||
model_extra: model.extra.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ConfigError {
|
|
||||||
ConfigNotFound(String),
|
|
||||||
AgentNotFound(String),
|
|
||||||
ProviderNotFound(String),
|
|
||||||
ModelNotFound(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for ConfigError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path),
|
|
||||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
|
||||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
|
||||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for ConfigError {}
|
|
||||||
|
|
||||||
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
let env_path = Path::new(".env");
|
|
||||||
if env_path.exists() {
|
|
||||||
let content = fs::read_to_string(env_path)?;
|
|
||||||
for line in content.lines() {
|
|
||||||
let line = line.trim();
|
|
||||||
if line.is_empty() || line.starts_with('#') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Some((key, value)) = line.split_once('=') {
|
|
||||||
let key = key.trim();
|
|
||||||
let value = value.trim().trim_matches('"').trim_matches('\'');
|
|
||||||
if !value.is_empty() {
|
|
||||||
// SAFETY: Setting environment variables for the current process
|
|
||||||
// is safe as we're only modifying our own process state
|
|
||||||
unsafe { env::set_var(key, value) };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_env_placeholders(content: &str) -> String {
|
|
||||||
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
|
|
||||||
re.replace_all(content, |caps: ®ex::Captures| {
|
|
||||||
let var_name = &caps[1];
|
|
||||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
|
||||||
}).to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_config_load() {
|
|
||||||
let config = Config::load("config.json").unwrap();
|
|
||||||
|
|
||||||
// Check providers
|
|
||||||
assert!(config.providers.contains_key("volcengine"));
|
|
||||||
assert!(config.providers.contains_key("aliyun"));
|
|
||||||
|
|
||||||
// Check models
|
|
||||||
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
|
|
||||||
assert!(config.models.contains_key("qwen-plus"));
|
|
||||||
|
|
||||||
// Check agents
|
|
||||||
assert!(config.agents.contains_key("default"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_get_provider_config() {
|
|
||||||
let config = Config::load("config.json").unwrap();
|
|
||||||
let provider_config = config.get_provider_config("default").unwrap();
|
|
||||||
|
|
||||||
assert_eq!(provider_config.provider_type, "openai");
|
|
||||||
assert_eq!(provider_config.name, "aliyun");
|
|
||||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
|
||||||
assert_eq!(provider_config.temperature, Some(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_default_gateway_config() {
|
|
||||||
let config = Config::load("config.json").unwrap();
|
|
||||||
assert_eq!(config.gateway.host, "0.0.0.0");
|
|
||||||
assert_eq!(config.gateway.port, 19876);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,15 +0,0 @@
|
|||||||
use axum::Json;
|
|
||||||
use serde::Serialize;
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct HealthResponse {
|
|
||||||
status: String,
|
|
||||||
version: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn health() -> Json<HealthResponse> {
|
|
||||||
Json(HealthResponse {
|
|
||||||
status: "ok".to_string(),
|
|
||||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,71 +0,0 @@
|
|||||||
pub mod http;
|
|
||||||
pub mod session;
|
|
||||||
pub mod ws;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use axum::{routing, Router};
|
|
||||||
use tokio::net::TcpListener;
|
|
||||||
|
|
||||||
use crate::channels::ChannelManager;
|
|
||||||
use crate::config::Config;
|
|
||||||
use session::SessionManager;
|
|
||||||
|
|
||||||
pub struct GatewayState {
|
|
||||||
pub config: Config,
|
|
||||||
pub session_manager: SessionManager,
|
|
||||||
pub channel_manager: ChannelManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GatewayState {
|
|
||||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
|
||||||
let config = Config::load_default()?;
|
|
||||||
let channel_manager = ChannelManager::new();
|
|
||||||
Ok(Self {
|
|
||||||
config,
|
|
||||||
session_manager: SessionManager::new(),
|
|
||||||
channel_manager,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
let state = Arc::new(GatewayState::new()?);
|
|
||||||
|
|
||||||
// Initialize and start channels
|
|
||||||
state.channel_manager.init(&state.config).await?;
|
|
||||||
state.channel_manager.start_all().await?;
|
|
||||||
|
|
||||||
// CLI args override config file values
|
|
||||||
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
|
||||||
let bind_port = port.unwrap_or(state.config.gateway.port);
|
|
||||||
|
|
||||||
let app = Router::new()
|
|
||||||
.route("/health", routing::get(http::health))
|
|
||||||
.route("/ws", routing::get(ws::ws_handler))
|
|
||||||
.with_state(state.clone());
|
|
||||||
|
|
||||||
let addr = format!("{}:{}", bind_host, bind_port);
|
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
|
||||||
println!("Gateway listening on {}", addr);
|
|
||||||
|
|
||||||
// Graceful shutdown using oneshot channel
|
|
||||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
|
||||||
let channel_manager = state.channel_manager.clone();
|
|
||||||
|
|
||||||
// Spawn ctrl_c handler
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tokio::signal::ctrl_c().await.ok();
|
|
||||||
println!("Shutting down...");
|
|
||||||
let _ = channel_manager.stop_all().await;
|
|
||||||
let _ = shutdown_tx.send(());
|
|
||||||
});
|
|
||||||
|
|
||||||
// Serve with graceful shutdown
|
|
||||||
axum::serve(listener, app)
|
|
||||||
.with_graceful_shutdown(async {
|
|
||||||
shutdown_rx.await.ok();
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::{Mutex, mpsc};
|
|
||||||
use uuid::Uuid;
|
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
use crate::agent::AgentLoop;
|
|
||||||
use crate::protocol::WsOutbound;
|
|
||||||
|
|
||||||
pub struct Session {
|
|
||||||
pub id: Uuid,
|
|
||||||
pub agent_loop: Arc<Mutex<AgentLoop>>,
|
|
||||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Session {
|
|
||||||
pub async fn new(
|
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
user_tx: mpsc::Sender<WsOutbound>,
|
|
||||||
) -> Result<Self, crate::agent::AgentError> {
|
|
||||||
let agent_loop = AgentLoop::new(provider_config)?;
|
|
||||||
Ok(Self {
|
|
||||||
id: Uuid::new_v4(),
|
|
||||||
agent_loop: Arc::new(Mutex::new(agent_loop)),
|
|
||||||
user_tx,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send(&self, msg: WsOutbound) {
|
|
||||||
let _ = self.user_tx.send(msg).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::RwLock;
|
|
||||||
|
|
||||||
pub struct SessionManager {
|
|
||||||
sessions: RwLock<HashMap<Uuid, Arc<Session>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SessionManager {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
sessions: RwLock::new(HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add(&self, session: Arc<Session>) {
|
|
||||||
self.sessions.write().unwrap().insert(session.id, session);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn remove(&self, id: &Uuid) {
|
|
||||||
self.sessions.write().unwrap().remove(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, id: &Uuid) -> Option<Arc<Session>> {
|
|
||||||
self.sessions.read().unwrap().get(id).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.sessions.read().unwrap().len()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for SessionManager {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,116 +0,0 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
|
||||||
use axum::extract::State;
|
|
||||||
use axum::response::Response;
|
|
||||||
use futures_util::{SinkExt, StreamExt};
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
|
||||||
use super::{GatewayState, session::Session};
|
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
|
||||||
ws.on_upgrade(|socket| async {
|
|
||||||
handle_socket(socket, state).await;
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|
||||||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
|
||||||
|
|
||||||
let provider_config = match state.config.get_provider_config("default") {
|
|
||||||
Ok(cfg) => cfg,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Failed to get provider config: {}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let session = match Session::new(provider_config, sender).await {
|
|
||||||
Ok(s) => Arc::new(s),
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Failed to create session: {}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let session_id = session.id;
|
|
||||||
state.session_manager.add(session.clone());
|
|
||||||
|
|
||||||
let _ = session.send(WsOutbound::SessionEstablished {
|
|
||||||
session_id: session_id.to_string(),
|
|
||||||
}).await;
|
|
||||||
|
|
||||||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
|
||||||
|
|
||||||
let mut receiver = receiver;
|
|
||||||
tokio::spawn(async move {
|
|
||||||
while let Some(msg) = receiver.recv().await {
|
|
||||||
if let Ok(text) = serialize_outbound(&msg) {
|
|
||||||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
while let Some(msg) = ws_receiver.next().await {
|
|
||||||
match msg {
|
|
||||||
Ok(WsMessage::Text(text)) => {
|
|
||||||
let text = text.to_string();
|
|
||||||
match parse_inbound(&text) {
|
|
||||||
Ok(inbound) => {
|
|
||||||
handle_inbound(&session, inbound).await;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let _ = session.send(WsOutbound::Error {
|
|
||||||
code: "PARSE_ERROR".to_string(),
|
|
||||||
message: e.to_string(),
|
|
||||||
}).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(WsMessage::Close(_)) | Err(_) => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
state.session_manager.remove(&session_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_inbound(session: &Arc<Session>, inbound: WsInbound) {
|
|
||||||
match inbound {
|
|
||||||
WsInbound::UserInput { content } => {
|
|
||||||
let user_msg = ChatMessage::user(content);
|
|
||||||
let mut agent = session.agent_loop.lock().await;
|
|
||||||
match agent.process(user_msg).await {
|
|
||||||
Ok(response) => {
|
|
||||||
let _ = session.send(WsOutbound::AssistantResponse {
|
|
||||||
id: response.id,
|
|
||||||
content: response.content,
|
|
||||||
role: response.role,
|
|
||||||
}).await;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let _ = session.send(WsOutbound::Error {
|
|
||||||
code: "LLM_ERROR".to_string(),
|
|
||||||
message: e.to_string(),
|
|
||||||
}).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
WsInbound::ClearHistory => {
|
|
||||||
let mut agent = session.agent_loop.lock().await;
|
|
||||||
agent.clear_history();
|
|
||||||
let _ = session.send(WsOutbound::AssistantResponse {
|
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
content: "History cleared.".to_string(),
|
|
||||||
role: "system".to_string(),
|
|
||||||
}).await;
|
|
||||||
}
|
|
||||||
WsInbound::Ping => {
|
|
||||||
let _ = session.send(WsOutbound::Pong).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
pub mod config;
|
|
||||||
pub mod providers;
|
|
||||||
pub mod bus;
|
|
||||||
pub mod cli;
|
|
||||||
pub mod agent;
|
|
||||||
pub mod gateway;
|
|
||||||
pub mod client;
|
|
||||||
pub mod protocol;
|
|
||||||
pub mod channels;
|
|
||||||
48
src/main.rs
48
src/main.rs
@ -1,48 +0,0 @@
|
|||||||
use clap::{Parser, CommandFactory};
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
#[command(name = "picobot")]
|
|
||||||
#[command(about = "A CLI chatbot", long_about = None)]
|
|
||||||
enum Command {
|
|
||||||
/// Connect to gateway
|
|
||||||
Agent {
|
|
||||||
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
|
|
||||||
#[arg(long)]
|
|
||||||
gateway_url: Option<String>,
|
|
||||||
},
|
|
||||||
/// Start gateway server
|
|
||||||
Gateway {
|
|
||||||
/// Host to bind to
|
|
||||||
#[arg(long)]
|
|
||||||
host: Option<String>,
|
|
||||||
/// Port to listen on
|
|
||||||
#[arg(long)]
|
|
||||||
port: Option<u16>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
let mut cmd = Command::command();
|
|
||||||
|
|
||||||
// If no arguments, print help
|
|
||||||
if std::env::args().len() <= 1 {
|
|
||||||
cmd.print_help()?;
|
|
||||||
println!();
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
match Command::parse() {
|
|
||||||
Command::Agent { gateway_url } => {
|
|
||||||
let config = picobot::config::Config::load_default().ok();
|
|
||||||
let url = gateway_url
|
|
||||||
.or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone()))
|
|
||||||
.unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string());
|
|
||||||
picobot::client::run(&url).await?;
|
|
||||||
}
|
|
||||||
Command::Gateway { host, port } => {
|
|
||||||
picobot::gateway::run(host, port).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum WsInbound {
|
|
||||||
#[serde(rename = "user_input")]
|
|
||||||
UserInput { content: String },
|
|
||||||
#[serde(rename = "clear_history")]
|
|
||||||
ClearHistory,
|
|
||||||
#[serde(rename = "ping")]
|
|
||||||
Ping,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum WsOutbound {
|
|
||||||
#[serde(rename = "assistant_response")]
|
|
||||||
AssistantResponse { id: String, content: String, role: String },
|
|
||||||
#[serde(rename = "error")]
|
|
||||||
Error { code: String, message: String },
|
|
||||||
#[serde(rename = "session_established")]
|
|
||||||
SessionEstablished { session_id: String },
|
|
||||||
#[serde(rename = "pong")]
|
|
||||||
Pong,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn parse_inbound(raw: &str) -> Result<WsInbound, serde_json::Error> {
|
|
||||||
serde_json::from_str(raw)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn serialize_inbound(msg: &WsInbound) -> Result<String, serde_json::Error> {
|
|
||||||
serde_json::to_string(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn serialize_outbound(msg: &WsOutbound) -> Result<String, serde_json::Error> {
|
|
||||||
serde_json::to_string(msg)
|
|
||||||
}
|
|
||||||
@ -1,198 +0,0 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use reqwest::Client;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
|
||||||
use super::traits::Usage;
|
|
||||||
|
|
||||||
pub struct AnthropicProvider {
|
|
||||||
client: Client,
|
|
||||||
name: String,
|
|
||||||
api_key: String,
|
|
||||||
base_url: String,
|
|
||||||
extra_headers: HashMap<String, String>,
|
|
||||||
model_id: String,
|
|
||||||
temperature: Option<f32>,
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AnthropicProvider {
|
|
||||||
pub fn new(
|
|
||||||
name: String,
|
|
||||||
api_key: String,
|
|
||||||
base_url: String,
|
|
||||||
extra_headers: HashMap<String, String>,
|
|
||||||
model_id: String,
|
|
||||||
temperature: Option<f32>,
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
client: Client::new(),
|
|
||||||
name,
|
|
||||||
api_key,
|
|
||||||
base_url,
|
|
||||||
extra_headers,
|
|
||||||
model_id,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
model_extra,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct AnthropicRequest {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<AnthropicMessage>,
|
|
||||||
max_tokens: u32,
|
|
||||||
temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tools: Option<Vec<AnthropicTool>>,
|
|
||||||
#[serde(flatten)]
|
|
||||||
extra: HashMap<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct AnthropicMessage {
|
|
||||||
role: String,
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct AnthropicTool {
|
|
||||||
name: String,
|
|
||||||
description: String,
|
|
||||||
input_schema: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct AnthropicResponse {
|
|
||||||
id: String,
|
|
||||||
model: String,
|
|
||||||
content: Vec<AnthropicContent>,
|
|
||||||
usage: AnthropicUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
enum AnthropicContent {
|
|
||||||
Text { text: String },
|
|
||||||
Thinking { thinking: String },
|
|
||||||
#[serde(rename = "tool_use")]
|
|
||||||
ToolUse {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
input: serde_json::Value,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct AnthropicUsage {
|
|
||||||
input_tokens: u32,
|
|
||||||
output_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for AnthropicProvider {
|
|
||||||
async fn chat(
|
|
||||||
&self,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let url = format!("{}/v1/messages", self.base_url);
|
|
||||||
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
|
||||||
|
|
||||||
let tools = request.tools.map(|tools| {
|
|
||||||
tools
|
|
||||||
.iter()
|
|
||||||
.map(|t: &Tool| AnthropicTool {
|
|
||||||
name: t.function.name.clone(),
|
|
||||||
description: t.function.description.clone(),
|
|
||||||
input_schema: t.function.parameters.clone(),
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
});
|
|
||||||
|
|
||||||
let body = AnthropicRequest {
|
|
||||||
model: self.model_id.clone(),
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|m| AnthropicMessage {
|
|
||||||
role: m.role.clone(),
|
|
||||||
content: m.content.clone(),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
max_tokens,
|
|
||||||
temperature: request.temperature.or(self.temperature),
|
|
||||||
tools,
|
|
||||||
extra: self.model_extra.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut req_builder = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.header("x-api-key", &self.api_key)
|
|
||||||
.header("anthropic-version", "2023-06-01")
|
|
||||||
.header("Content-Type", "application/json");
|
|
||||||
|
|
||||||
for (key, value) in &self.extra_headers {
|
|
||||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await?;
|
|
||||||
|
|
||||||
let anthropic_resp: AnthropicResponse = resp.json().await?;
|
|
||||||
|
|
||||||
let mut content = String::new();
|
|
||||||
let mut tool_calls = Vec::new();
|
|
||||||
|
|
||||||
for c in &anthropic_resp.content {
|
|
||||||
match c {
|
|
||||||
AnthropicContent::Text { text } => {
|
|
||||||
if !text.is_empty() {
|
|
||||||
if !content.is_empty() {
|
|
||||||
content.push('\n');
|
|
||||||
}
|
|
||||||
content.push_str(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AnthropicContent::Thinking { .. } => {}
|
|
||||||
AnthropicContent::ToolUse { id, name, input } => {
|
|
||||||
tool_calls.push(ToolCall {
|
|
||||||
id: id.clone(),
|
|
||||||
name: name.clone(),
|
|
||||||
arguments: input.clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ChatCompletionResponse {
|
|
||||||
id: anthropic_resp.id,
|
|
||||||
model: anthropic_resp.model,
|
|
||||||
content,
|
|
||||||
tool_calls,
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens: anthropic_resp.usage.input_tokens,
|
|
||||||
completion_tokens: anthropic_resp.usage.output_tokens,
|
|
||||||
total_tokens: anthropic_resp.usage.input_tokens
|
|
||||||
+ anthropic_resp.usage.output_tokens,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptype(&self) -> &str {
|
|
||||||
"anthropic"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model_id(&self) -> &str {
|
|
||||||
&self.model_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
pub mod traits;
|
|
||||||
pub mod openai;
|
|
||||||
pub mod anthropic;
|
|
||||||
|
|
||||||
pub use self::openai::OpenAIProvider;
|
|
||||||
pub use self::anthropic::AnthropicProvider;
|
|
||||||
|
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
|
||||||
|
|
||||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
|
||||||
match config.provider_type.as_str() {
|
|
||||||
"openai" => Ok(Box::new(OpenAIProvider::new(
|
|
||||||
config.name,
|
|
||||||
config.api_key,
|
|
||||||
config.base_url,
|
|
||||||
config.extra_headers,
|
|
||||||
config.model_id,
|
|
||||||
config.temperature,
|
|
||||||
config.max_tokens,
|
|
||||||
config.model_extra,
|
|
||||||
))),
|
|
||||||
"anthropic" => Ok(Box::new(AnthropicProvider::new(
|
|
||||||
config.name,
|
|
||||||
config.api_key,
|
|
||||||
config.base_url,
|
|
||||||
config.extra_headers,
|
|
||||||
config.model_id,
|
|
||||||
config.temperature,
|
|
||||||
config.max_tokens,
|
|
||||||
config.model_extra,
|
|
||||||
))),
|
|
||||||
_ => Err(ProviderError::UnknownProviderType(config.provider_type)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ProviderError {
|
|
||||||
UnknownProviderType(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for ProviderError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
ProviderError::UnknownProviderType(t) => write!(f, "Unknown provider type: {}", t),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for ProviderError {}
|
|
||||||
@ -1,189 +0,0 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use reqwest::Client;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
||||||
use super::traits::Usage;
|
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
|
||||||
client: Client,
|
|
||||||
name: String,
|
|
||||||
api_key: String,
|
|
||||||
base_url: String,
|
|
||||||
extra_headers: HashMap<String, String>,
|
|
||||||
model_id: String,
|
|
||||||
temperature: Option<f32>,
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIProvider {
|
|
||||||
pub fn new(
|
|
||||||
name: String,
|
|
||||||
api_key: String,
|
|
||||||
base_url: String,
|
|
||||||
extra_headers: HashMap<String, String>,
|
|
||||||
model_id: String,
|
|
||||||
temperature: Option<f32>,
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
client: Client::new(),
|
|
||||||
name,
|
|
||||||
api_key,
|
|
||||||
base_url,
|
|
||||||
extra_headers,
|
|
||||||
model_id,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
model_extra,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIResponse {
|
|
||||||
id: String,
|
|
||||||
model: String,
|
|
||||||
choices: Vec<OpenAIChoice>,
|
|
||||||
#[serde(default)]
|
|
||||||
usage: OpenAIUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIChoice {
|
|
||||||
message: OpenAIMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIMessage {
|
|
||||||
#[serde(default)]
|
|
||||||
content: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
name: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
tool_calls: Vec<OpenAIToolCall>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIToolCall {
|
|
||||||
id: String,
|
|
||||||
#[serde(rename = "function")]
|
|
||||||
function: OAIFunction,
|
|
||||||
#[serde(default)]
|
|
||||||
index: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OAIFunction {
|
|
||||||
name: String,
|
|
||||||
arguments: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Default)]
|
|
||||||
struct OpenAIUsage {
|
|
||||||
#[serde(default)]
|
|
||||||
prompt_tokens: u32,
|
|
||||||
#[serde(default)]
|
|
||||||
completion_tokens: u32,
|
|
||||||
#[serde(default)]
|
|
||||||
total_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for OpenAIProvider {
|
|
||||||
async fn chat(
|
|
||||||
&self,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
|
||||||
|
|
||||||
let mut body = json!({
|
|
||||||
"model": self.model_id,
|
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
|
||||||
"max_tokens": request.max_tokens.or(self.max_tokens),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add model extra fields
|
|
||||||
for (key, value) in &self.model_extra {
|
|
||||||
body[key] = value.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(tools) = &request.tools {
|
|
||||||
body["tools"] = json!(tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut req_builder = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.header("Content-Type", "application/json");
|
|
||||||
|
|
||||||
for (key, value) in &self.extra_headers {
|
|
||||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await?;
|
|
||||||
|
|
||||||
let status = resp.status();
|
|
||||||
let text = resp.text().await?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Err(format!("API error {}: {}", status, text).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
|
||||||
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
|
||||||
|
|
||||||
let content = openai_resp.choices[0]
|
|
||||||
.message
|
|
||||||
.content
|
|
||||||
.as_ref()
|
|
||||||
.unwrap_or(&String::new())
|
|
||||||
.clone();
|
|
||||||
|
|
||||||
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
|
||||||
.message
|
|
||||||
.tool_calls
|
|
||||||
.iter()
|
|
||||||
.map(|tc| ToolCall {
|
|
||||||
id: tc.id.clone(),
|
|
||||||
name: tc.function.name.clone(),
|
|
||||||
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(ChatCompletionResponse {
|
|
||||||
id: openai_resp.id,
|
|
||||||
model: openai_resp.model,
|
|
||||||
content,
|
|
||||||
tool_calls,
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens: openai_resp.usage.prompt_tokens,
|
|
||||||
completion_tokens: openai_resp.usage.completion_tokens,
|
|
||||||
total_tokens: openai_resp.usage.total_tokens,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ptype(&self) -> &str {
|
|
||||||
"openai"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model_id(&self) -> &str {
|
|
||||||
&self.model_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Message {
|
|
||||||
pub role: String,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Tool {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: String,
|
|
||||||
pub function: ToolFunction,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolFunction {
|
|
||||||
pub name: String,
|
|
||||||
pub description: String,
|
|
||||||
pub parameters: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolCall {
|
|
||||||
pub id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
pub messages: Vec<Message>,
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub model: String,
|
|
||||||
pub content: String,
|
|
||||||
pub tool_calls: Vec<ToolCall>,
|
|
||||||
pub usage: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait LLMProvider: Send + Sync {
|
|
||||||
async fn chat(
|
|
||||||
&self,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
fn ptype(&self) -> &str;
|
|
||||||
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
fn model_id(&self) -> &str;
|
|
||||||
}
|
|
||||||
@ -1,12 +0,0 @@
|
|||||||
# Copy this file to test.env and fill in your API keys
|
|
||||||
# cp tests/test.env.example tests/test.env
|
|
||||||
|
|
||||||
# Anthropic Configuration
|
|
||||||
ANTHROPIIC_BASE_URL=https://api.anthropic.com/v1
|
|
||||||
ANTHROPIIC_API_KEY=your_anthropic_api_key_here
|
|
||||||
ANTHROPIIC_MODEL_NAME=claude-3-5-sonnet-20241022
|
|
||||||
|
|
||||||
# OpenAI Configuration
|
|
||||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
|
||||||
OPENAI_MODEL_NAME=gpt-4
|
|
||||||
@ -1,94 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
|
||||||
use PicoBot::config::{Config, LLMProviderConfig};
|
|
||||||
|
|
||||||
fn load_config() -> Option<LLMProviderConfig> {
|
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
|
||||||
|
|
||||||
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
|
|
||||||
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
|
||||||
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
|
|
||||||
|
|
||||||
if openai_api_key.contains("your_") {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(LLMProviderConfig {
|
|
||||||
provider_type: "openai".to_string(),
|
|
||||||
name: "test_openai".to_string(),
|
|
||||||
base_url: openai_base_url,
|
|
||||||
api_key: openai_api_key,
|
|
||||||
extra_headers: HashMap::new(),
|
|
||||||
model_id: openai_model,
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(100),
|
|
||||||
model_extra: HashMap::new(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_request(content: &str) -> ChatCompletionRequest {
|
|
||||||
ChatCompletionRequest {
|
|
||||||
messages: vec![Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: content.to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(100),
|
|
||||||
tools: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore]
|
|
||||||
async fn test_openai_simple_completion() {
|
|
||||||
let config = load_config()
|
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
|
||||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
|
||||||
|
|
||||||
assert!(!response.id.is_empty());
|
|
||||||
assert!(!response.content.is_empty());
|
|
||||||
assert!(response.usage.total_tokens > 0);
|
|
||||||
assert!(response.content.to_lowercase().contains("ok"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore]
|
|
||||||
async fn test_openai_conversation() {
|
|
||||||
let config = load_config()
|
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: vec![
|
|
||||||
Message { role: "user".to_string(), content: "My name is Alice".to_string() },
|
|
||||||
Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() },
|
|
||||||
Message { role: "user".to_string(), content: "What is my name?".to_string() },
|
|
||||||
],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(50),
|
|
||||||
tools: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = provider.chat(request).await.unwrap();
|
|
||||||
assert!(response.content.to_lowercase().contains("alice"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore]
|
|
||||||
async fn test_config_load() {
|
|
||||||
// Test that config.json can be loaded and provider config created
|
|
||||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
|
||||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
|
||||||
|
|
||||||
assert_eq!(provider_config.provider_type, "openai");
|
|
||||||
assert_eq!(provider_config.name, "aliyun");
|
|
||||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
|
||||||
|
|
||||||
let provider = create_provider(provider_config).expect("Failed to create provider");
|
|
||||||
assert_eq!(provider.ptype(), "openai");
|
|
||||||
assert_eq!(provider.name(), "aliyun");
|
|
||||||
assert_eq!(provider.model_id(), "qwen-plus");
|
|
||||||
}
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
use PicoBot::providers::{ChatCompletionRequest, Message};
|
|
||||||
|
|
||||||
/// Test that message with special characters is properly escaped
|
|
||||||
#[test]
|
|
||||||
fn test_message_special_characters() {
|
|
||||||
let msg = Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Hello \"world\"\nNew line\tTab".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&msg).unwrap();
|
|
||||||
let deserialized: Message = serde_json::from_str(&json).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(deserialized.content, "Hello \"world\"\nNew line\tTab");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test that multi-line system prompt is preserved
|
|
||||||
#[test]
|
|
||||||
fn test_multiline_system_prompt() {
|
|
||||||
let messages = vec![
|
|
||||||
Message {
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Hi".to_string(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&messages[0]).unwrap();
|
|
||||||
assert!(json.contains("helpful assistant"));
|
|
||||||
assert!(json.contains("rules"));
|
|
||||||
assert!(json.contains("1. Be kind"));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test ChatCompletionRequest serialization (without model field)
|
|
||||||
#[test]
|
|
||||||
fn test_chat_request_serialization() {
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: vec![
|
|
||||||
Message {
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: "You are helpful".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Hello".to_string(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: Some(0.7),
|
|
||||||
max_tokens: Some(100),
|
|
||||||
tools: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&request).unwrap();
|
|
||||||
|
|
||||||
// Verify structure
|
|
||||||
assert!(json.contains(r#""role":"system""#));
|
|
||||||
assert!(json.contains(r#""role":"user""#));
|
|
||||||
assert!(json.contains(r#""content":"You are helpful""#));
|
|
||||||
assert!(json.contains(r#""content":"Hello""#));
|
|
||||||
assert!(json.contains(r#""temperature":0.7"#));
|
|
||||||
assert!(json.contains(r#""max_tokens":100"#));
|
|
||||||
}
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
|
||||||
use PicoBot::config::LLMProviderConfig;
|
|
||||||
|
|
||||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
|
||||||
|
|
||||||
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
|
|
||||||
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
|
||||||
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
|
|
||||||
|
|
||||||
if openai_api_key.contains("your_") {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(LLMProviderConfig {
|
|
||||||
provider_type: "openai".to_string(),
|
|
||||||
name: "test_openai".to_string(),
|
|
||||||
base_url: openai_base_url,
|
|
||||||
api_key: openai_api_key,
|
|
||||||
extra_headers: HashMap::new(),
|
|
||||||
model_id: openai_model,
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(100),
|
|
||||||
model_extra: HashMap::new(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn make_weather_tool() -> Tool {
|
|
||||||
Tool {
|
|
||||||
tool_type: "function".to_string(),
|
|
||||||
function: ToolFunction {
|
|
||||||
name: "get_weather".to_string(),
|
|
||||||
description: "Get current weather for a city".to_string(),
|
|
||||||
parameters: serde_json::json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city name"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore]
|
|
||||||
async fn test_openai_tool_call() {
|
|
||||||
let config = load_openai_config()
|
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: vec![Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "What is the weather in Tokyo?".to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(200),
|
|
||||||
tools: Some(vec![make_weather_tool()]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = provider.chat(request).await.unwrap();
|
|
||||||
|
|
||||||
// Should have tool calls
|
|
||||||
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
|
||||||
|
|
||||||
let tool_call = &response.tool_calls[0];
|
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
|
||||||
assert!(tool_call.arguments.get("city").is_some());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore]
|
|
||||||
async fn test_openai_tool_call_with_manual_execution() {
|
|
||||||
let config = load_openai_config()
|
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
// First request with tool
|
|
||||||
let request1 = ChatCompletionRequest {
|
|
||||||
messages: vec![Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "What is the weather in Tokyo?".to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(200),
|
|
||||||
tools: Some(vec![make_weather_tool()]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let response1 = provider.chat(request1).await.unwrap();
|
|
||||||
let tool_call = response1.tool_calls.first()
|
|
||||||
.expect("Expected tool call");
|
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
|
||||||
|
|
||||||
// Second request with tool result
|
|
||||||
let request2 = ChatCompletionRequest {
|
|
||||||
messages: vec![
|
|
||||||
Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "What is the weather in Tokyo?".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: r#"I'll check the weather for you using the get_weather tool."#.to_string(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(200),
|
|
||||||
tools: Some(vec![make_weather_tool()]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let response2 = provider.chat(request2).await.unwrap();
|
|
||||||
|
|
||||||
// Should have a response
|
|
||||||
assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore]
|
|
||||||
async fn test_openai_no_tool_when_not_provided() {
|
|
||||||
let config = load_openai_config()
|
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: vec![Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Say hello in one word.".to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(10),
|
|
||||||
tools: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = provider.chat(request).await.unwrap();
|
|
||||||
|
|
||||||
// Should NOT have tool calls
|
|
||||||
assert!(response.tool_calls.is_empty());
|
|
||||||
assert!(!response.content.is_empty());
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user