From 04736f9f46217e1ece03d5097e70a09cecc55448 Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Mon, 6 Apr 2026 18:43:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20Feishu=20=E9=80=9A?= =?UTF-8?q?=E9=81=93=E6=94=AF=E6=8C=81=EF=BC=8C=E9=87=8D=E6=9E=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E4=BB=A5=E5=8C=85=E5=90=AB=E9=80=9A=E9=81=93=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E6=9B=B4=E6=96=B0=E4=BE=9D=E8=B5=96=E9=A1=B9?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=BC=BA=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 9 +- src/channels/base.rs | 50 ++++ src/channels/feishu.rs | 600 ++++++++++++++++++++++++++++++++++++++++ src/channels/manager.rs | 78 ++++++ src/channels/mod.rs | 7 + src/config/mod.rs | 20 ++ src/gateway/mod.rs | 31 ++- src/lib.rs | 1 + src/main.rs | 11 +- 9 files changed, 791 insertions(+), 16 deletions(-) create mode 100644 src/channels/base.rs create mode 100644 src/channels/feishu.rs create mode 100644 src/channels/manager.rs create mode 100644 src/channels/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 63fd16e..af5c027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,17 +4,18 @@ version = "0.1.0" edition = "2024" [dependencies] -reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls"] } dotenv = "0.15" serde = { version = "1.0", features = ["derive"] } regex = "1.0" serde_json = "1.0" async-trait = "0.1" -thiserror = "1.0" +thiserror = "2.0.18" tokio = { version = "1.0", features = ["full"] } uuid = { version = "1.0", features = ["v4"] } axum = { version = "0.8", features = ["ws"] } -tokio-tungstenite = "0.26" +tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] } futures-util = "0.3" clap = { version = "4", features = ["derive"] } -dirs = "5" +dirs = "6.0.0" +prost = "0.14" diff --git a/src/channels/base.rs b/src/channels/base.rs new file mode 100644 index 0000000..5fec82b --- /dev/null +++ b/src/channels/base.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; +use async_trait::async_trait; + +#[derive(Debug, Clone)] +pub struct InboundMessage { + pub channel: String, + pub sender_id: String, + pub chat_id: String, + pub content: String, + pub media: Vec, + pub metadata: HashMap, +} + +#[derive(Debug, Clone)] +pub struct OutboundMessage { + pub channel: String, + pub chat_id: String, + pub content: String, + pub media: Vec, + pub metadata: HashMap, +} + +#[derive(Debug)] +pub enum ChannelError { + ConfigError(String), + ConnectionError(String), + SendError(String), + Other(String), +} + +impl std::fmt::Display for ChannelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChannelError::ConfigError(s) => write!(f, "Config error: {}", s), + ChannelError::ConnectionError(s) => write!(f, "Connection error: {}", s), + ChannelError::SendError(s) => write!(f, "Send error: {}", s), + ChannelError::Other(s) => write!(f, "Error: {}", s), + } + } +} + +impl std::error::Error for ChannelError {} + +#[async_trait] +pub trait Channel: Send + Sync + 'static { + fn name(&self) -> &str; + async fn start(&self) -> Result<(), ChannelError>; + async fn stop(&self) -> Result<(), ChannelError>; + fn is_running(&self) -> bool; +} diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs new file mode 100644 index 0000000..9f2aac6 --- /dev/null +++ b/src/channels/feishu.rs @@ -0,0 +1,600 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; +use async_trait::async_trait; +use tokio::sync::{broadcast, RwLock, Mutex}; +use serde::Deserialize; +use futures_util::{SinkExt, StreamExt}; +use prost::{Message as ProstMessage, bytes::Bytes}; + +use crate::agent::AgentLoop; +use crate::bus::ChatMessage; +use crate::channels::base::{Channel, ChannelError}; +use crate::config::{FeishuChannelConfig, LLMProviderConfig}; + +const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis"; +const FEISHU_WS_BASE: &str = "https://open.feishu.cn"; + +// ───────────────────────────────────────────────────────────────────────────── +// Protobuf types for Feishu WebSocket protocol (pbbp2.proto) +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Clone, PartialEq, prost::Message)] +struct PbHeader { + #[prost(string, tag = "1")] + pub key: String, + #[prost(string, tag = "2")] + pub value: String, +} + +/// Feishu WS frame. +/// method=0 → CONTROL (ping/pong) method=1 → DATA (events) +#[derive(Clone, PartialEq, prost::Message)] +struct PbFrame { + #[prost(uint64, tag = "1")] + pub seq_id: u64, + #[prost(uint64, tag = "2")] + pub log_id: u64, + #[prost(int32, tag = "3")] + pub service: i32, + #[prost(int32, tag = "4")] + pub method: i32, + #[prost(message, repeated, tag = "5")] + pub headers: Vec, + #[prost(bytes = "vec", optional, tag = "8")] + pub payload: Option>, +} + +/// POST /callback/ws/endpoint response +#[derive(Deserialize)] +struct WsEndpointResp { + code: i32, + msg: Option, + data: Option, +} + +#[derive(Deserialize)] +struct WsEndpoint { + #[serde(rename = "URL")] + url: String, + #[serde(default)] + client_config: Option, +} + +#[derive(Deserialize, Default)] +struct WsClientConfig { + #[serde(rename = "PingInterval")] + ping_interval: Option, +} + +/// Lark event envelope (method=1 / type=event payload) +#[derive(Deserialize)] +struct LarkEvent { + header: LarkEventHeader, + event: serde_json::Value, +} + +#[derive(Deserialize)] +struct LarkEventHeader { + event_type: String, + #[allow(dead_code)] + event_id: String, +} + +impl std::fmt::Debug for LarkEventHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LarkEventHeader") + .field("event_type", &self.event_type) + .field("event_id", &self.event_id) + .finish() + } +} + +#[derive(Deserialize)] +struct MsgReceivePayload { + sender: LarkSender, + message: LarkMessage, +} + +#[derive(Deserialize)] +struct LarkSender { + sender_id: LarkSenderId, + #[serde(default)] + sender_type: String, +} + +#[derive(Deserialize, Default)] +struct LarkSenderId { + open_id: Option, +} + +#[derive(Deserialize)] +#[allow(dead_code)] +struct LarkMessage { + message_id: String, + chat_id: String, + chat_type: String, + message_type: String, + #[serde(default)] + content: String, + #[serde(default)] + mentions: Vec, + #[serde(default)] + root_id: Option, + #[serde(default)] + parent_id: Option, +} + +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Clone)] +pub struct FeishuChannel { + config: FeishuChannelConfig, + http_client: reqwest::Client, + running: Arc>, + shutdown_tx: Arc>>>, + connected: Arc>, + /// Dedup: message_id -> timestamp (cleaned after 30 min) + seen_ids: Arc>>, + /// Agent for processing messages + agent: Arc>, +} + +impl FeishuChannel { + pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result { + let agent = AgentLoop::new(provider_config) + .map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?; + + Ok(Self { + config, + http_client: reqwest::Client::new(), + running: Arc::new(RwLock::new(false)), + shutdown_tx: Arc::new(RwLock::new(None)), + connected: Arc::new(RwLock::new(false)), + seen_ids: Arc::new(RwLock::new(HashMap::new())), + agent: Arc::new(Mutex::new(agent)), + }) + } + + /// Get WebSocket endpoint URL from Feishu API + async fn get_ws_endpoint(&self, client: &reqwest::Client) -> Result<(String, WsClientConfig), ChannelError> { + let resp = client + .post(format!("{}/callback/ws/endpoint", FEISHU_WS_BASE)) + .header("locale", "zh") + .json(&serde_json::json!({ + "AppID": self.config.app_id, + "AppSecret": self.config.app_secret, + })) + .send() + .await + .map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?; + + let endpoint_resp: WsEndpointResp = resp + .json() + .await + .map_err(|e| ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)))?; + + if endpoint_resp.code != 0 { + return Err(ChannelError::ConnectionError(format!( + "WS endpoint failed: code={} msg={}", + endpoint_resp.code, + endpoint_resp.msg.as_deref().unwrap_or("unknown") + ))); + } + + let ep = endpoint_resp.data + .ok_or_else(|| ChannelError::ConnectionError("Empty endpoint data".to_string()))?; + + let client_config = ep.client_config.unwrap_or_default(); + Ok((ep.url, client_config)) + } + + /// Get tenant access token + async fn get_tenant_token(&self) -> Result { + let resp = self.http_client + .post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "app_id": self.config.app_id, + "app_secret": self.config.app_secret, + })) + .send() + .await + .map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?; + + #[derive(Deserialize)] + struct TokenResponse { + code: i32, + tenant_access_token: Option, + } + + let token_resp: TokenResponse = resp + .json() + .await + .map_err(|e| ChannelError::Other(format!("Failed to parse token response: {}", e)))?; + + if token_resp.code != 0 { + return Err(ChannelError::Other("Auth failed".to_string())); + } + + token_resp.tenant_access_token + .ok_or_else(|| ChannelError::Other("No token in response".to_string())) + } + + /// Send a text message to Feishu chat + async fn send_message(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> { + let token = self.get_tenant_token().await?; + + // For text message, content should be a JSON string: "{\"text\":\"hello\"}" + let text_content = serde_json::json!({ "text": content }).to_string(); + + let resp = self.http_client + .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .json(&serde_json::json!({ + "receive_id": receive_id, + "msg_type": "text", + "content": text_content + })) + .send() + .await + .map_err(|e| ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)))?; + + #[derive(Deserialize)] + struct SendResp { + code: i32, + msg: String, + } + + let send_resp: SendResp = resp + .json() + .await + .map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?; + + if send_resp.code != 0 { + return Err(ChannelError::Other(format!("Send message failed: code={} msg={}", send_resp.code, send_resp.msg))); + } + + Ok(()) + } + + /// Handle incoming message - process through agent and send response + async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> { + println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content); + + // Process through agent + let user_msg = ChatMessage::user(content); + let mut agent = self.agent.lock().await; + let response = agent.process(user_msg).await + .map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?; + + // Send response to the chat + // Use open_id for p2p chats, chat_id for group chats + let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id }; + let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" }; + + self.send_message(receive_id, receive_id_type, &response.content).await?; + println!("Feishu: sent response to {}", receive_id); + + Ok(()) + } + + /// Extract service_id from WebSocket URL query params + fn extract_service_id(url: &str) -> i32 { + url.split('?') + .nth(1) + .and_then(|qs| { + qs.split('&') + .find(|kv| kv.starts_with("service_id=")) + .and_then(|kv| kv.split('=').nth(1)) + .and_then(|v| v.parse::().ok()) + }) + .unwrap_or(0) + } + + /// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack + async fn handle_frame(&self, frame: &PbFrame) -> Result, ChannelError> { + // method 0 = CONTROL (ping/pong) + if frame.method == 0 { + return Ok(None); + } + + // method 1 = DATA (events) + if frame.method != 1 { + return Ok(None); + } + + let payload = frame.payload.as_deref() + .ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?; + + // Parse the event JSON to get event_type from payload header + let event: LarkEvent = serde_json::from_slice(payload) + .map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?; + + let event_type = event.header.event_type.as_str(); + if event_type != "im.message.receive_v1" { + return Ok(None); + } + + let payload_data: MsgReceivePayload = serde_json::from_value(event.event.clone()) + .map_err(|e| ChannelError::Other(format!("Parse payload error: {}", e)))?; + + // Skip bot messages + if payload_data.sender.sender_type == "bot" { + return Ok(None); + } + + // Deduplication check with TTL cleanup + let message_id = payload_data.message.message_id.clone(); + { + let mut seen = self.seen_ids.write().await; + let now = Instant::now(); + + // Clean expired entries (older than 30 min) + seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800); + + if seen.contains_key(&message_id) { + return Ok(None); + } + seen.insert(message_id.clone(), now); + } + + let open_id = payload_data.sender.sender_id.open_id + .ok_or_else(|| ChannelError::Other("No open_id".to_string()))?; + + let msg = payload_data.message; + let chat_id = msg.chat_id.clone(); + let msg_type = msg.message_type.as_str(); + let content = parse_message_content(msg_type, &msg.content); + + // Handle the message - process and send response + if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await { + eprintln!("Error handling message: {}", e); + } + + // Return message_id for ack + Ok(Some(message_id)) + } + + /// Send acknowledgment for a message + async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> { + let mut ack = frame.clone(); + ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); + ack.headers.push(PbHeader { + key: "biz_rt".into(), + value: "0".into(), + }); + write.send(tokio_tungstenite::tungstenite::Message::Binary(ack.encode_to_vec().into())) + .await + .map_err(|e| ChannelError::Other(format!("Failed to send ack: {}", e)))?; + Ok(()) + } + + async fn run_ws_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> { + let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?; + + let service_id = Self::extract_service_id(&wss_url); + println!("Feishu: connecting to {}", wss_url); + + let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url) + .await + .map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?; + + *self.connected.write().await = true; + println!("Feishu channel connected"); + + let (mut write, mut read) = ws_stream.split(); + + // Send initial ping + let ping_frame = PbFrame { + seq_id: 1, + log_id: 0, + service: service_id, + method: 0, + headers: vec![PbHeader { + key: "type".into(), + value: "ping".into(), + }], + payload: None, + }; + write.send(tokio_tungstenite::tungstenite::Message::Binary(ping_frame.encode_to_vec().into())) + .await + .map_err(|e| ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)))?; + + let ping_interval = client_config.ping_interval.unwrap_or(120).max(10); + let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval)); + let mut seq: u64 = 1; + + // Consume the immediate tick + ping_interval_tok.tick().await; + + loop { + tokio::select! { + msg = read.next() => { + match msg { + Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => { + let bytes: Bytes = data; + if let Ok(frame) = PbFrame::decode(bytes.as_ref()) { + // Handle the frame and get message_id for ack if needed + match self.handle_frame(&frame).await { + Ok(Some(_message_id)) => { + // Send ACK immediately (Feishu requires within 3 s) + if let Err(e) = Self::send_ack(&frame, &mut write).await { + eprintln!("Error sending ack: {}", e); + } + } + Ok(None) => {} + Err(e) => { + eprintln!("Error handling frame: {}", e); + } + } + } + } + Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => { + let pong = PbFrame { + seq_id: seq.wrapping_add(1), + log_id: 0, + service: service_id, + method: 0, + headers: vec![PbHeader { + key: "type".into(), + value: "pong".into(), + }], + payload: Some(data.to_vec()), + }; + let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await; + } + Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => { + break; + } + Some(Err(e)) => { + eprintln!("WS error: {}", e); + break; + } + _ => {} + } + } + _ = ping_interval_tok.tick() => { + seq = seq.wrapping_add(1); + let ping = PbFrame { + seq_id: seq, + log_id: 0, + service: service_id, + method: 0, + headers: vec![PbHeader { + key: "type".into(), + value: "ping".into(), + }], + payload: None, + }; + if write.send(tokio_tungstenite::tungstenite::Message::Binary(ping.encode_to_vec().into())).await.is_err() { + eprintln!("Feishu: ping failed, reconnecting"); + break; + } + } + _ = shutdown_rx.recv() => { + println!("Feishu channel shutdown signal received"); + break; + } + } + } + + *self.connected.write().await = false; + Ok(()) + } +} + +fn parse_message_content(msg_type: &str, content: &str) -> String { + match msg_type { + "text" => { + if let Ok(parsed) = serde_json::from_str::(content) { + parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string() + } else { + content.to_string() + } + } + "post" => { + if let Ok(parsed) = serde_json::from_str::(content) { + let mut texts = vec![]; + if let Some(post) = parsed.get("post") { + if let Some(content_arr) = post.get("content") { + if let Some(arr) = content_arr.as_array() { + for item in arr { + if let Some(arr2) = item.as_array() { + for inner in arr2 { + if let Some(text) = inner.get("text").and_then(|v| v.as_str()) { + texts.push(text.to_string()); + } + } + } + } + } + } + } + if texts.is_empty() { + content.to_string() + } else { + texts.join("") + } + } else { + content.to_string() + } + } + _ => content.to_string(), + } +} + +#[async_trait] +impl Channel for FeishuChannel { + fn name(&self) -> &str { + "feishu" + } + + async fn start(&self) -> Result<(), ChannelError> { + if self.config.app_id.is_empty() || self.config.app_secret.is_empty() { + return Err(ChannelError::ConfigError( + "Feishu app_id or app_secret is not configured".to_string() + )); + } + + *self.running.write().await = true; + + let (shutdown_tx, _) = broadcast::channel(1); + *self.shutdown_tx.write().await = Some(shutdown_tx.clone()); + + let channel = self.clone(); + tokio::spawn(async move { + let mut consecutive_failures = 0; + let max_failures = 3; + + loop { + if !*channel.running.read().await { + break; + } + + let shutdown_rx = shutdown_tx.subscribe(); + match channel.run_ws_loop(shutdown_rx).await { + Ok(_) => { + println!("Feishu WebSocket disconnected"); + } + Err(e) => { + consecutive_failures += 1; + eprintln!("Feishu WebSocket error (attempt {}): {}", consecutive_failures, e); + if consecutive_failures >= max_failures { + eprintln!("Feishu channel: max failures reached, stopping"); + break; + } + } + } + + if !*channel.running.read().await { + break; + } + + println!("Feishu channel retrying in 5s..."); + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + } + + *channel.running.write().await = false; + println!("Feishu channel stopped"); + }); + + println!("Feishu channel started"); + Ok(()) + } + + async fn stop(&self) -> Result<(), ChannelError> { + *self.running.write().await = false; + + if let Some(tx) = self.shutdown_tx.write().await.take() { + let _ = tx.send(()); + } + + Ok(()) + } + + fn is_running(&self) -> bool { + false + } +} diff --git a/src/channels/manager.rs b/src/channels/manager.rs new file mode 100644 index 0000000..f048c37 --- /dev/null +++ b/src/channels/manager.rs @@ -0,0 +1,78 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::channels::base::{Channel, ChannelError}; +use crate::channels::feishu::FeishuChannel; +use crate::config::Config; + +#[derive(Clone)] +pub struct ChannelManager { + channels: Arc>>>, +} + +impl ChannelManager { + pub fn new() -> Self { + Self { + channels: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn init(&self, config: &Config) -> Result<(), ChannelError> { + // Initialize Feishu channel if enabled + if let Some(feishu_config) = config.channels.get("feishu") { + if feishu_config.enabled { + let agent_name = &feishu_config.agent; + let provider_config = config.get_provider_config(agent_name) + .map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?; + + let channel = FeishuChannel::new(feishu_config.clone(), provider_config) + .map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; + + self.channels + .write() + .await + .insert("feishu".to_string(), Arc::new(channel)); + println!("Feishu channel registered"); + } else { + println!("Feishu channel disabled in config"); + } + } + Ok(()) + } + + pub async fn start_all(&self) -> Result<(), ChannelError> { + let channels = self.channels.read().await; + for (name, channel) in channels.iter() { + println!("Starting channel: {}", name); + if let Err(e) = channel.start().await { + eprintln!("Warning: Failed to start channel {}: {}", name, e); + // Channel failed to start - it should have logged why + // Continue starting other channels + } + } + Ok(()) + } + + pub async fn stop_all(&self) -> Result<(), ChannelError> { + let mut channels = self.channels.write().await; + for (name, channel) in channels.iter() { + println!("Stopping channel: {}", name); + if let Err(e) = channel.stop().await { + eprintln!("Error stopping channel {}: {}", name, e); + } + } + channels.clear(); + Ok(()) + } + + pub async fn get_channel(&self, name: &str) -> Option> { + self.channels.read().await.get(name).cloned() + } +} + +impl Default for ChannelManager { + fn default() -> Self { + Self::new() + } +} diff --git a/src/channels/mod.rs b/src/channels/mod.rs new file mode 100644 index 0000000..41c4d17 --- /dev/null +++ b/src/channels/mod.rs @@ -0,0 +1,7 @@ +pub mod base; +pub mod feishu; +pub mod manager; + +pub use base::{Channel, ChannelError, InboundMessage, OutboundMessage}; +pub use manager::ChannelManager; +pub use feishu::FeishuChannel; diff --git a/src/config/mod.rs b/src/config/mod.rs index 822f558..85155e2 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -14,6 +14,24 @@ pub struct Config { pub gateway: GatewayConfig, #[serde(default)] pub client: ClientConfig, + #[serde(default)] + pub channels: HashMap, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FeishuChannelConfig { + #[serde(default)] + pub enabled: bool, + pub app_id: String, + pub app_secret: String, + #[serde(default = "default_allow_from")] + pub allow_from: Vec, + #[serde(default)] + pub agent: String, +} + +fn default_allow_from() -> Vec { + vec!["*".to_string()] } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -117,11 +135,13 @@ impl Config { fn load_from(path: &Path) -> Result> { load_env_file()?; let content = if path.exists() { + println!("Config loaded from: {}", path.display()); fs::read_to_string(path)? } else { // Fallback to current directory let fallback = Path::new("config.json"); if fallback.exists() { + println!("Config loaded from: {}", fallback.display()); fs::read_to_string(fallback)? } else { return Err(Box::new(ConfigError::ConfigNotFound( diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 15ce939..2afa82a 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -6,20 +6,24 @@ use std::sync::Arc; use axum::{routing, Router}; use tokio::net::TcpListener; +use crate::channels::ChannelManager; use crate::config::Config; use session::SessionManager; pub struct GatewayState { pub config: Config, pub session_manager: SessionManager, + pub channel_manager: ChannelManager, } impl GatewayState { pub fn new() -> Result> { let config = Config::load_default()?; + let channel_manager = ChannelManager::new(); Ok(Self { config, session_manager: SessionManager::new(), + channel_manager, }) } } @@ -27,6 +31,10 @@ impl GatewayState { pub async fn run(host: Option, port: Option) -> Result<(), Box> { let state = Arc::new(GatewayState::new()?); + // Initialize and start channels + state.channel_manager.init(&state.config).await?; + state.channel_manager.start_all().await?; + // CLI args override config file values let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone()); let bind_port = port.unwrap_or(state.config.gateway.port); @@ -34,11 +42,30 @@ pub async fn run(host: Option, port: Option) -> Result<(), Box(); + let channel_manager = state.channel_manager.clone(); + + // Spawn ctrl_c handler + tokio::spawn(async move { + tokio::signal::ctrl_c().await.ok(); + println!("Shutting down..."); + let _ = channel_manager.stop_all().await; + let _ = shutdown_tx.send(()); + }); + + // Serve with graceful shutdown + axum::serve(listener, app) + .with_graceful_shutdown(async { + shutdown_rx.await.ok(); + }) + .await?; + Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 76e66e5..305f795 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,4 @@ pub mod agent; pub mod gateway; pub mod client; pub mod protocol; +pub mod channels; diff --git a/src/main.rs b/src/main.rs index ab2871a..81d73ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,18 +32,9 @@ async fn main() -> Result<(), Box> { return Ok(()); } - // Load config for defaults - let config = match picobot::config::Config::load_default() { - Ok(cfg) => Some(cfg), - Err(e) => { - eprintln!("Warning: Could not load config from ~/.config/picobot/config.json: {}", e); - eprintln!("Using built-in defaults. Run `picobot gateway` or `picobot agent` with --help for options.\n"); - None - } - }; - match Command::parse() { Command::Agent { gateway_url } => { + let config = picobot::config::Config::load_default().ok(); let url = gateway_url .or_else(|| config.as_ref().map(|c| c.client.gateway_url.clone())) .unwrap_or_else(|| "ws://127.0.0.1:19876/ws".to_string());