添加 Feishu 通道支持,重构配置以包含通道设置,更新依赖项,增强错误处理
This commit is contained in:
parent
35d201f206
commit
04736f9f46
@ -4,17 +4,18 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls"] }
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
thiserror = "1.0"
|
thiserror = "2.0.18"
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
uuid = { version = "1.0", features = ["v4"] }
|
uuid = { version = "1.0", features = ["v4"] }
|
||||||
axum = { version = "0.8", features = ["ws"] }
|
axum = { version = "0.8", features = ["ws"] }
|
||||||
tokio-tungstenite = "0.26"
|
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
dirs = "5"
|
dirs = "6.0.0"
|
||||||
|
prost = "0.14"
|
||||||
|
|||||||
50
src/channels/base.rs
Normal file
50
src/channels/base.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct InboundMessage {
|
||||||
|
pub channel: String,
|
||||||
|
pub sender_id: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub content: String,
|
||||||
|
pub media: Vec<String>,
|
||||||
|
pub metadata: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct OutboundMessage {
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub content: String,
|
||||||
|
pub media: Vec<String>,
|
||||||
|
pub metadata: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ChannelError {
|
||||||
|
ConfigError(String),
|
||||||
|
ConnectionError(String),
|
||||||
|
SendError(String),
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ChannelError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ChannelError::ConfigError(s) => write!(f, "Config error: {}", s),
|
||||||
|
ChannelError::ConnectionError(s) => write!(f, "Connection error: {}", s),
|
||||||
|
ChannelError::SendError(s) => write!(f, "Send error: {}", s),
|
||||||
|
ChannelError::Other(s) => write!(f, "Error: {}", s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ChannelError {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Channel: Send + Sync + 'static {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
async fn start(&self) -> Result<(), ChannelError>;
|
||||||
|
async fn stop(&self) -> Result<(), ChannelError>;
|
||||||
|
fn is_running(&self) -> bool;
|
||||||
|
}
|
||||||
600
src/channels/feishu.rs
Normal file
600
src/channels/feishu.rs
Normal file
@ -0,0 +1,600 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use tokio::sync::{broadcast, RwLock, Mutex};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use prost::{Message as ProstMessage, bytes::Bytes};
|
||||||
|
|
||||||
|
use crate::agent::AgentLoop;
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
||||||
|
|
||||||
|
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
||||||
|
const FEISHU_WS_BASE: &str = "https://open.feishu.cn";
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Protobuf types for Feishu WebSocket protocol (pbbp2.proto)
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Clone, PartialEq, prost::Message)]
|
||||||
|
struct PbHeader {
|
||||||
|
#[prost(string, tag = "1")]
|
||||||
|
pub key: String,
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub value: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feishu WS frame.
|
||||||
|
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
|
||||||
|
#[derive(Clone, PartialEq, prost::Message)]
|
||||||
|
struct PbFrame {
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub seq_id: u64,
|
||||||
|
#[prost(uint64, tag = "2")]
|
||||||
|
pub log_id: u64,
|
||||||
|
#[prost(int32, tag = "3")]
|
||||||
|
pub service: i32,
|
||||||
|
#[prost(int32, tag = "4")]
|
||||||
|
pub method: i32,
|
||||||
|
#[prost(message, repeated, tag = "5")]
|
||||||
|
pub headers: Vec<PbHeader>,
|
||||||
|
#[prost(bytes = "vec", optional, tag = "8")]
|
||||||
|
pub payload: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /callback/ws/endpoint response
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct WsEndpointResp {
|
||||||
|
code: i32,
|
||||||
|
msg: Option<String>,
|
||||||
|
data: Option<WsEndpoint>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct WsEndpoint {
|
||||||
|
#[serde(rename = "URL")]
|
||||||
|
url: String,
|
||||||
|
#[serde(default)]
|
||||||
|
client_config: Option<WsClientConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Default)]
|
||||||
|
struct WsClientConfig {
|
||||||
|
#[serde(rename = "PingInterval")]
|
||||||
|
ping_interval: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lark event envelope (method=1 / type=event payload)
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct LarkEvent {
|
||||||
|
header: LarkEventHeader,
|
||||||
|
event: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct LarkEventHeader {
|
||||||
|
event_type: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
event_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for LarkEventHeader {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("LarkEventHeader")
|
||||||
|
.field("event_type", &self.event_type)
|
||||||
|
.field("event_id", &self.event_id)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct MsgReceivePayload {
|
||||||
|
sender: LarkSender,
|
||||||
|
message: LarkMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct LarkSender {
|
||||||
|
sender_id: LarkSenderId,
|
||||||
|
#[serde(default)]
|
||||||
|
sender_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Default)]
|
||||||
|
struct LarkSenderId {
|
||||||
|
open_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
struct LarkMessage {
|
||||||
|
message_id: String,
|
||||||
|
chat_id: String,
|
||||||
|
chat_type: String,
|
||||||
|
message_type: String,
|
||||||
|
#[serde(default)]
|
||||||
|
content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
mentions: Vec<serde_json::Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
root_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
parent_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FeishuChannel {
|
||||||
|
config: FeishuChannelConfig,
|
||||||
|
http_client: reqwest::Client,
|
||||||
|
running: Arc<RwLock<bool>>,
|
||||||
|
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
||||||
|
connected: Arc<RwLock<bool>>,
|
||||||
|
/// Dedup: message_id -> timestamp (cleaned after 30 min)
|
||||||
|
seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||||
|
/// Agent for processing messages
|
||||||
|
agent: Arc<Mutex<AgentLoop>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeishuChannel {
|
||||||
|
pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result<Self, ChannelError> {
|
||||||
|
let agent = AgentLoop::new(provider_config)
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
http_client: reqwest::Client::new(),
|
||||||
|
running: Arc::new(RwLock::new(false)),
|
||||||
|
shutdown_tx: Arc::new(RwLock::new(None)),
|
||||||
|
connected: Arc::new(RwLock::new(false)),
|
||||||
|
seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
agent: Arc::new(Mutex::new(agent)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get WebSocket endpoint URL from Feishu API
|
||||||
|
async fn get_ws_endpoint(&self, client: &reqwest::Client) -> Result<(String, WsClientConfig), ChannelError> {
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{}/callback/ws/endpoint", FEISHU_WS_BASE))
|
||||||
|
.header("locale", "zh")
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"AppID": self.config.app_id,
|
||||||
|
"AppSecret": self.config.app_secret,
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
let endpoint_resp: WsEndpointResp = resp
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)))?;
|
||||||
|
|
||||||
|
if endpoint_resp.code != 0 {
|
||||||
|
return Err(ChannelError::ConnectionError(format!(
|
||||||
|
"WS endpoint failed: code={} msg={}",
|
||||||
|
endpoint_resp.code,
|
||||||
|
endpoint_resp.msg.as_deref().unwrap_or("unknown")
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ep = endpoint_resp.data
|
||||||
|
.ok_or_else(|| ChannelError::ConnectionError("Empty endpoint data".to_string()))?;
|
||||||
|
|
||||||
|
let client_config = ep.client_config.unwrap_or_default();
|
||||||
|
Ok((ep.url, client_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get tenant access token
|
||||||
|
async fn get_tenant_token(&self) -> Result<String, ChannelError> {
|
||||||
|
let resp = self.http_client
|
||||||
|
.post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"app_id": self.config.app_id,
|
||||||
|
"app_secret": self.config.app_secret,
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TokenResponse {
|
||||||
|
code: i32,
|
||||||
|
tenant_access_token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_resp: TokenResponse = resp
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to parse token response: {}", e)))?;
|
||||||
|
|
||||||
|
if token_resp.code != 0 {
|
||||||
|
return Err(ChannelError::Other("Auth failed".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
token_resp.tenant_access_token
|
||||||
|
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a text message to Feishu chat
|
||||||
|
async fn send_message(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
||||||
|
let token = self.get_tenant_token().await?;
|
||||||
|
|
||||||
|
// For text message, content should be a JSON string: "{\"text\":\"hello\"}"
|
||||||
|
let text_content = serde_json::json!({ "text": content }).to_string();
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"receive_id": receive_id,
|
||||||
|
"msg_type": "text",
|
||||||
|
"content": text_content
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct SendResp {
|
||||||
|
code: i32,
|
||||||
|
msg: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let send_resp: SendResp = resp
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?;
|
||||||
|
|
||||||
|
if send_resp.code != 0 {
|
||||||
|
return Err(ChannelError::Other(format!("Send message failed: code={} msg={}", send_resp.code, send_resp.msg)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle incoming message - process through agent and send response
|
||||||
|
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
|
||||||
|
println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content);
|
||||||
|
|
||||||
|
// Process through agent
|
||||||
|
let user_msg = ChatMessage::user(content);
|
||||||
|
let mut agent = self.agent.lock().await;
|
||||||
|
let response = agent.process(user_msg).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?;
|
||||||
|
|
||||||
|
// Send response to the chat
|
||||||
|
// Use open_id for p2p chats, chat_id for group chats
|
||||||
|
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
|
||||||
|
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||||
|
|
||||||
|
self.send_message(receive_id, receive_id_type, &response.content).await?;
|
||||||
|
println!("Feishu: sent response to {}", receive_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract service_id from WebSocket URL query params
|
||||||
|
fn extract_service_id(url: &str) -> i32 {
|
||||||
|
url.split('?')
|
||||||
|
.nth(1)
|
||||||
|
.and_then(|qs| {
|
||||||
|
qs.split('&')
|
||||||
|
.find(|kv| kv.starts_with("service_id="))
|
||||||
|
.and_then(|kv| kv.split('=').nth(1))
|
||||||
|
.and_then(|v| v.parse::<i32>().ok())
|
||||||
|
})
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack
|
||||||
|
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<String>, ChannelError> {
|
||||||
|
// method 0 = CONTROL (ping/pong)
|
||||||
|
if frame.method == 0 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// method 1 = DATA (events)
|
||||||
|
if frame.method != 1 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = frame.payload.as_deref()
|
||||||
|
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
|
||||||
|
|
||||||
|
// Parse the event JSON to get event_type from payload header
|
||||||
|
let event: LarkEvent = serde_json::from_slice(payload)
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
|
||||||
|
|
||||||
|
let event_type = event.header.event_type.as_str();
|
||||||
|
if event_type != "im.message.receive_v1" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload_data: MsgReceivePayload = serde_json::from_value(event.event.clone())
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse payload error: {}", e)))?;
|
||||||
|
|
||||||
|
// Skip bot messages
|
||||||
|
if payload_data.sender.sender_type == "bot" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplication check with TTL cleanup
|
||||||
|
let message_id = payload_data.message.message_id.clone();
|
||||||
|
{
|
||||||
|
let mut seen = self.seen_ids.write().await;
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// Clean expired entries (older than 30 min)
|
||||||
|
seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800);
|
||||||
|
|
||||||
|
if seen.contains_key(&message_id) {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
seen.insert(message_id.clone(), now);
|
||||||
|
}
|
||||||
|
|
||||||
|
let open_id = payload_data.sender.sender_id.open_id
|
||||||
|
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
||||||
|
|
||||||
|
let msg = payload_data.message;
|
||||||
|
let chat_id = msg.chat_id.clone();
|
||||||
|
let msg_type = msg.message_type.as_str();
|
||||||
|
let content = parse_message_content(msg_type, &msg.content);
|
||||||
|
|
||||||
|
// Handle the message - process and send response
|
||||||
|
if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await {
|
||||||
|
eprintln!("Error handling message: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return message_id for ack
|
||||||
|
Ok(Some(message_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send acknowledgment for a message
|
||||||
|
async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> {
|
||||||
|
let mut ack = frame.clone();
|
||||||
|
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
|
||||||
|
ack.headers.push(PbHeader {
|
||||||
|
key: "biz_rt".into(),
|
||||||
|
value: "0".into(),
|
||||||
|
});
|
||||||
|
write.send(tokio_tungstenite::tungstenite::Message::Binary(ack.encode_to_vec().into()))
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to send ack: {}", e)))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_ws_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> {
|
||||||
|
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
|
||||||
|
|
||||||
|
let service_id = Self::extract_service_id(&wss_url);
|
||||||
|
println!("Feishu: connecting to {}", wss_url);
|
||||||
|
|
||||||
|
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
|
||||||
|
|
||||||
|
*self.connected.write().await = true;
|
||||||
|
println!("Feishu channel connected");
|
||||||
|
|
||||||
|
let (mut write, mut read) = ws_stream.split();
|
||||||
|
|
||||||
|
// Send initial ping
|
||||||
|
let ping_frame = PbFrame {
|
||||||
|
seq_id: 1,
|
||||||
|
log_id: 0,
|
||||||
|
service: service_id,
|
||||||
|
method: 0,
|
||||||
|
headers: vec![PbHeader {
|
||||||
|
key: "type".into(),
|
||||||
|
value: "ping".into(),
|
||||||
|
}],
|
||||||
|
payload: None,
|
||||||
|
};
|
||||||
|
write.send(tokio_tungstenite::tungstenite::Message::Binary(ping_frame.encode_to_vec().into()))
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)))?;
|
||||||
|
|
||||||
|
let ping_interval = client_config.ping_interval.unwrap_or(120).max(10);
|
||||||
|
let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval));
|
||||||
|
let mut seq: u64 = 1;
|
||||||
|
|
||||||
|
// Consume the immediate tick
|
||||||
|
ping_interval_tok.tick().await;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
msg = read.next() => {
|
||||||
|
match msg {
|
||||||
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
||||||
|
let bytes: Bytes = data;
|
||||||
|
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
||||||
|
// Handle the frame and get message_id for ack if needed
|
||||||
|
match self.handle_frame(&frame).await {
|
||||||
|
Ok(Some(_message_id)) => {
|
||||||
|
// Send ACK immediately (Feishu requires within 3 s)
|
||||||
|
if let Err(e) = Self::send_ack(&frame, &mut write).await {
|
||||||
|
eprintln!("Error sending ack: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Error handling frame: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => {
|
||||||
|
let pong = PbFrame {
|
||||||
|
seq_id: seq.wrapping_add(1),
|
||||||
|
log_id: 0,
|
||||||
|
service: service_id,
|
||||||
|
method: 0,
|
||||||
|
headers: vec![PbHeader {
|
||||||
|
key: "type".into(),
|
||||||
|
value: "pong".into(),
|
||||||
|
}],
|
||||||
|
payload: Some(data.to_vec()),
|
||||||
|
};
|
||||||
|
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
||||||
|
}
|
||||||
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(Err(e)) => {
|
||||||
|
eprintln!("WS error: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = ping_interval_tok.tick() => {
|
||||||
|
seq = seq.wrapping_add(1);
|
||||||
|
let ping = PbFrame {
|
||||||
|
seq_id: seq,
|
||||||
|
log_id: 0,
|
||||||
|
service: service_id,
|
||||||
|
method: 0,
|
||||||
|
headers: vec![PbHeader {
|
||||||
|
key: "type".into(),
|
||||||
|
value: "ping".into(),
|
||||||
|
}],
|
||||||
|
payload: None,
|
||||||
|
};
|
||||||
|
if write.send(tokio_tungstenite::tungstenite::Message::Binary(ping.encode_to_vec().into())).await.is_err() {
|
||||||
|
eprintln!("Feishu: ping failed, reconnecting");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = shutdown_rx.recv() => {
|
||||||
|
println!("Feishu channel shutdown signal received");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.connected.write().await = false;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_message_content(msg_type: &str, content: &str) -> String {
|
||||||
|
match msg_type {
|
||||||
|
"text" => {
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"post" => {
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
let mut texts = vec![];
|
||||||
|
if let Some(post) = parsed.get("post") {
|
||||||
|
if let Some(content_arr) = post.get("content") {
|
||||||
|
if let Some(arr) = content_arr.as_array() {
|
||||||
|
for item in arr {
|
||||||
|
if let Some(arr2) = item.as_array() {
|
||||||
|
for inner in arr2 {
|
||||||
|
if let Some(text) = inner.get("text").and_then(|v| v.as_str()) {
|
||||||
|
texts.push(text.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if texts.is_empty() {
|
||||||
|
content.to_string()
|
||||||
|
} else {
|
||||||
|
texts.join("")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => content.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Channel for FeishuChannel {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"feishu"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start(&self) -> Result<(), ChannelError> {
|
||||||
|
if self.config.app_id.is_empty() || self.config.app_secret.is_empty() {
|
||||||
|
return Err(ChannelError::ConfigError(
|
||||||
|
"Feishu app_id or app_secret is not configured".to_string()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.running.write().await = true;
|
||||||
|
|
||||||
|
let (shutdown_tx, _) = broadcast::channel(1);
|
||||||
|
*self.shutdown_tx.write().await = Some(shutdown_tx.clone());
|
||||||
|
|
||||||
|
let channel = self.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut consecutive_failures = 0;
|
||||||
|
let max_failures = 3;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if !*channel.running.read().await {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let shutdown_rx = shutdown_tx.subscribe();
|
||||||
|
match channel.run_ws_loop(shutdown_rx).await {
|
||||||
|
Ok(_) => {
|
||||||
|
println!("Feishu WebSocket disconnected");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
consecutive_failures += 1;
|
||||||
|
eprintln!("Feishu WebSocket error (attempt {}): {}", consecutive_failures, e);
|
||||||
|
if consecutive_failures >= max_failures {
|
||||||
|
eprintln!("Feishu channel: max failures reached, stopping");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !*channel.running.read().await {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Feishu channel retrying in 5s...");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
*channel.running.write().await = false;
|
||||||
|
println!("Feishu channel stopped");
|
||||||
|
});
|
||||||
|
|
||||||
|
println!("Feishu channel started");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop(&self) -> Result<(), ChannelError> {
|
||||||
|
*self.running.write().await = false;
|
||||||
|
|
||||||
|
if let Some(tx) = self.shutdown_tx.write().await.take() {
|
||||||
|
let _ = tx.send(());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_running(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
78
src/channels/manager.rs
Normal file
78
src/channels/manager.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
use crate::channels::feishu::FeishuChannel;
|
||||||
|
use crate::config::Config;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ChannelManager {
|
||||||
|
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelManager {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(&self, config: &Config) -> Result<(), ChannelError> {
|
||||||
|
// Initialize Feishu channel if enabled
|
||||||
|
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||||
|
if feishu_config.enabled {
|
||||||
|
let agent_name = &feishu_config.agent;
|
||||||
|
let provider_config = config.get_provider_config(agent_name)
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?;
|
||||||
|
|
||||||
|
let channel = FeishuChannel::new(feishu_config.clone(), provider_config)
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||||
|
|
||||||
|
self.channels
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert("feishu".to_string(), Arc::new(channel));
|
||||||
|
println!("Feishu channel registered");
|
||||||
|
} else {
|
||||||
|
println!("Feishu channel disabled in config");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_all(&self) -> Result<(), ChannelError> {
|
||||||
|
let channels = self.channels.read().await;
|
||||||
|
for (name, channel) in channels.iter() {
|
||||||
|
println!("Starting channel: {}", name);
|
||||||
|
if let Err(e) = channel.start().await {
|
||||||
|
eprintln!("Warning: Failed to start channel {}: {}", name, e);
|
||||||
|
// Channel failed to start - it should have logged why
|
||||||
|
// Continue starting other channels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop_all(&self) -> Result<(), ChannelError> {
|
||||||
|
let mut channels = self.channels.write().await;
|
||||||
|
for (name, channel) in channels.iter() {
|
||||||
|
println!("Stopping channel: {}", name);
|
||||||
|
if let Err(e) = channel.stop().await {
|
||||||
|
eprintln!("Error stopping channel {}: {}", name, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
channels.clear();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel>> {
|
||||||
|
self.channels.read().await.get(name).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ChannelManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
7
src/channels/mod.rs
Normal file
7
src/channels/mod.rs
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
pub mod base;
|
||||||
|
pub mod feishu;
|
||||||
|
pub mod manager;
|
||||||
|
|
||||||
|
pub use base::{Channel, ChannelError, InboundMessage, OutboundMessage};
|
||||||
|
pub use manager::ChannelManager;
|
||||||
|
pub use feishu::FeishuChannel;
|
||||||
@ -14,6 +14,24 @@ pub struct Config {
|
|||||||
pub gateway: GatewayConfig,
|
pub gateway: GatewayConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub client: ClientConfig,
|
pub client: ClientConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub channels: HashMap<String, FeishuChannelConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FeishuChannelConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub enabled: bool,
|
||||||
|
pub app_id: String,
|
||||||
|
pub app_secret: String,
|
||||||
|
#[serde(default = "default_allow_from")]
|
||||||
|
pub allow_from: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub agent: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_allow_from() -> Vec<String> {
|
||||||
|
vec!["*".to_string()]
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -117,11 +135,13 @@ impl Config {
|
|||||||
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
load_env_file()?;
|
load_env_file()?;
|
||||||
let content = if path.exists() {
|
let content = if path.exists() {
|
||||||
|
println!("Config loaded from: {}", path.display());
|
||||||
fs::read_to_string(path)?
|
fs::read_to_string(path)?
|
||||||
} else {
|
} else {
|
||||||
// Fallback to current directory
|
// Fallback to current directory
|
||||||
let fallback = Path::new("config.json");
|
let fallback = Path::new("config.json");
|
||||||
if fallback.exists() {
|
if fallback.exists() {
|
||||||
|
println!("Config loaded from: {}", fallback.display());
|
||||||
fs::read_to_string(fallback)?
|
fs::read_to_string(fallback)?
|
||||||
} else {
|
} else {
|
||||||
return Err(Box::new(ConfigError::ConfigNotFound(
|
return Err(Box::new(ConfigError::ConfigNotFound(
|
||||||
|
|||||||
@ -6,20 +6,24 @@ use std::sync::Arc;
|
|||||||
use axum::{routing, Router};
|
use axum::{routing, Router};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
use crate::channels::ChannelManager;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use session::SessionManager;
|
use session::SessionManager;
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
pub session_manager: SessionManager,
|
pub session_manager: SessionManager,
|
||||||
|
pub channel_manager: ChannelManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GatewayState {
|
impl GatewayState {
|
||||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
let config = Config::load_default()?;
|
let config = Config::load_default()?;
|
||||||
|
let channel_manager = ChannelManager::new();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config,
|
config,
|
||||||
session_manager: SessionManager::new(),
|
session_manager: SessionManager::new(),
|
||||||
|
channel_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -27,6 +31,10 @@ impl GatewayState {
|
|||||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let state = Arc::new(GatewayState::new()?);
|
let state = Arc::new(GatewayState::new()?);
|
||||||
|
|
||||||
|
// Initialize and start channels
|
||||||
|
state.channel_manager.init(&state.config).await?;
|
||||||
|
state.channel_manager.start_all().await?;
|
||||||
|
|
||||||
// CLI args override config file values
|
// CLI args override config file values
|
||||||
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
||||||
let bind_port = port.unwrap_or(state.config.gateway.port);
|
let bind_port = port.unwrap_or(state.config.gateway.port);
|
||||||
@ -34,11 +42,30 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
|
|||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/health", routing::get(http::health))
|
.route("/health", routing::get(http::health))
|
||||||
.route("/ws", routing::get(ws::ws_handler))
|
.route("/ws", routing::get(ws::ws_handler))
|
||||||
.with_state(state);
|
.with_state(state.clone());
|
||||||
|
|
||||||
let addr = format!("{}:{}", bind_host, bind_port);
|
let addr = format!("{}:{}", bind_host, bind_port);
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
let listener = TcpListener::bind(&addr).await?;
|
||||||
println!("Gateway listening on {}", addr);
|
println!("Gateway listening on {}", addr);
|
||||||
axum::serve(listener, app).await?;
|
|
||||||
|
// Graceful shutdown using oneshot channel
|
||||||
|
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
||||||
|
let channel_manager = state.channel_manager.clone();
|
||||||
|
|
||||||
|
// Spawn ctrl_c handler
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::signal::ctrl_c().await.ok();
|
||||||
|
println!("Shutting down...");
|
||||||
|
let _ = channel_manager.stop_all().await;
|
||||||
|
let _ = shutdown_tx.send(());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Serve with graceful shutdown
|
||||||
|
axum::serve(listener, app)
|
||||||
|
.with_graceful_shutdown(async {
|
||||||
|
shutdown_rx.await.ok();
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,3 +6,4 @@ pub mod agent;
|
|||||||
pub mod gateway;
|
pub mod gateway;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
|
pub mod channels;
|
||||||
|
|||||||
11
src/main.rs
11
src/main.rs
@ -32,18 +32,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load config for defaults
|
|
||||||
let config = match picobot::config::Config::load_default() {
|
|
||||||
Ok(cfg) => Some(cfg),
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Warning: Could not load config from ~/.config/picobot/config.json: {}", e);
|
|
||||||
eprintln!("Using built-in defaults. Run `picobot gateway` or `picobot agent` with --help for options.\n");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match Command::parse() {
|
match Command::parse() {
|
||||||
Command::Agent { gateway_url } => {
|
Command::Agent { gateway_url } => {
|
||||||
|
let config = picobot::config::Config::load_default().ok();
|
||||||
let url = gateway_url
|
let url = gateway_url
|
||||||
.or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone()))
|
.or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone()))
|
||||||
.unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string());
|
.unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string());
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user