diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index d6166fc..11d5422 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -4,9 +4,9 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess use crate::tools::ToolRegistry; use std::sync::Arc; +/// Stateless AgentLoop - history is managed externally by SessionManager pub struct AgentLoop { provider: Box, - history: Vec, tools: Arc, } @@ -17,7 +17,6 @@ impl AgentLoop { Ok(Self { provider, - history: Vec::new(), tools: Arc::new(ToolRegistry::new()), }) } @@ -28,7 +27,6 @@ impl AgentLoop { Ok(Self { provider, - history: Vec::new(), tools, }) } @@ -37,10 +35,10 @@ impl AgentLoop { &self.tools } - pub async fn process(&mut self, user_message: ChatMessage) -> Result { - self.history.push(user_message.clone()); - - let messages: Vec = self.history + /// Process a message using the provided conversation history. + /// History management is handled externally by SessionManager. + pub async fn process(&self, messages: Vec) -> Result { + let messages_for_llm: Vec = messages .iter() .map(|m| Message { role: m.role.clone(), @@ -50,7 +48,7 @@ impl AgentLoop { }) .collect(); - tracing::debug!(history_len = self.history.len(), "Sending request to LLM"); + tracing::debug!(history_len = messages.len(), "Sending request to LLM"); let tools = if self.tools.has_tools() { Some(self.tools.get_definitions()) @@ -59,7 +57,7 @@ impl AgentLoop { }; let request = ChatCompletionRequest { - messages, + messages: messages_for_llm, temperature: None, max_tokens: None, tools, @@ -71,12 +69,18 @@ impl AgentLoop { AgentError::LlmError(e.to_string()) })?; - tracing::debug!(response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received"); + tracing::debug!( + response_len = response.content.len(), + tool_calls_len = response.tool_calls.len(), + "LLM response received" + ); if !response.tool_calls.is_empty() { tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools"); + + let mut updated_messages = messages.clone(); let assistant_message = ChatMessage::assistant(response.content.clone()); - self.history.push(assistant_message.clone()); + updated_messages.push(assistant_message.clone()); let tool_results = self.execute_tools(&response.tool_calls).await; @@ -86,20 +90,18 @@ impl AgentLoop { tool_call.name.clone(), result.clone(), ); - self.history.push(tool_message); + updated_messages.push(tool_message); } - return self.continue_with_tool_results(response.content).await; + return self.continue_with_tool_results(updated_messages).await; } let assistant_message = ChatMessage::assistant(response.content); - self.history.push(assistant_message.clone()); - Ok(assistant_message) } - async fn continue_with_tool_results(&mut self, _original_content: String) -> Result { - let messages: Vec = self.history + async fn continue_with_tool_results(&self, messages: Vec) -> Result { + let messages_for_llm: Vec = messages .iter() .map(|m| Message { role: m.role.clone(), @@ -116,7 +118,7 @@ impl AgentLoop { }; let request = ChatCompletionRequest { - messages, + messages: messages_for_llm, temperature: None, max_tokens: None, tools, @@ -129,8 +131,6 @@ impl AgentLoop { })?; let assistant_message = ChatMessage::assistant(response.content); - self.history.push(assistant_message.clone()); - Ok(assistant_message) } @@ -168,16 +168,6 @@ impl AgentLoop { } } } - - pub fn clear_history(&mut self) { - let len = self.history.len(); - self.history.clear(); - tracing::debug!(previous_len = len, "Chat history cleared"); - } - - pub fn history(&self) -> &[ChatMessage] { - &self.history - } } #[derive(Debug)] diff --git a/src/bus/dispatcher.rs b/src/bus/dispatcher.rs new file mode 100644 index 0000000..2621c50 --- /dev/null +++ b/src/bus/dispatcher.rs @@ -0,0 +1,82 @@ +use std::sync::Arc; +use tokio::sync::RwLock; +use std::collections::HashMap; + +use crate::bus::{MessageBus, OutboundMessage}; +use crate::channels::base::{Channel, ChannelError}; + +/// OutboundDispatcher consumes outbound messages from the MessageBus +/// and dispatches them to the appropriate Channel +pub struct OutboundDispatcher { + bus: Arc, + channels: Arc>>>, +} + +impl OutboundDispatcher { + pub fn new(bus: Arc) -> Self { + Self { + bus, + channels: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Register a channel with the dispatcher + pub async fn register_channel(&self, name: &str, channel: Arc) { + self.channels.write().await.insert(name.to_string(), channel); + } + + /// Run the dispatcher loop - consumes from bus and dispatches to channels + pub async fn run(&self) { + tracing::info!("OutboundDispatcher started"); + + loop { + let msg = self.bus.consume_outbound().await; + tracing::debug!( + channel = %msg.channel, + chat_id = %msg.chat_id, + content_len = msg.content.len(), + "OutboundDispatcher received message" + ); + + let channel_name = msg.channel.clone(); + let channel = self.channels.read().await.get(&channel_name).cloned(); + + match channel { + Some(ch) => { + if let Err(e) = self.send_with_retry(&*ch, msg).await { + tracing::error!(channel = %channel_name, error = %e, "Failed to send message after retries"); + } + } + None => { + tracing::warn!(channel = %channel_name, "No channel found for message"); + } + } + } + } + + /// Send a message with exponential retry + async fn send_with_retry( + &self, + channel: &dyn Channel, + msg: OutboundMessage, + ) -> Result<(), ChannelError> { + const DELAYS: [u64; 3] = [1, 2, 4]; + + for (i, delay) in DELAYS.iter().enumerate() { + match channel.send(msg.clone()).await { + Ok(()) => return Ok(()), + Err(e) if i < DELAYS.len() - 1 => { + tracing::warn!( + attempt = i + 1, + delay = delay, + error = %e, + "Send failed, retrying" + ); + tokio::time::sleep(tokio::time::Duration::from_secs(*delay)).await; + } + Err(e) => return Err(e), + } + } + unreachable!() + } +} diff --git a/src/bus/message.rs b/src/bus/message.rs index 71f3294..f003af1 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -1,5 +1,10 @@ +use std::collections::HashMap; use serde::{Deserialize, Serialize}; +// ============================================================================ +// ChatMessage - Legacy type used by AgentLoop for LLM conversation history +// ============================================================================ + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub id: String, @@ -58,6 +63,51 @@ impl ChatMessage { } } +// ============================================================================ +// InboundMessage - Message from Channel to Bus (user input) +// ============================================================================ + +#[derive(Debug, Clone)] +pub struct InboundMessage { + pub channel: String, + pub sender_id: String, + pub chat_id: String, + pub content: String, + pub timestamp: i64, + pub media: Vec, + pub metadata: HashMap, +} + +impl InboundMessage { + pub fn session_key(&self) -> String { + format!("{}:{}", self.channel, self.chat_id) + } +} + +// ============================================================================ +// OutboundMessage - Message from Agent to Channel (bot response) +// ============================================================================ + +#[derive(Debug, Clone)] +pub struct OutboundMessage { + pub channel: String, + pub chat_id: String, + pub content: String, + pub reply_to: Option, + pub media: Vec, + pub metadata: HashMap, +} + +impl OutboundMessage { + pub fn is_stream_delta(&self) -> bool { + self.metadata.get("_stream_delta").is_some() + } +} + +// ============================================================================ +// Helpers +// ============================================================================ + fn current_timestamp() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 91eb63d..08a724b 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -1,42 +1,86 @@ +pub mod dispatcher; pub mod message; -pub use message::ChatMessage; +pub use dispatcher::OutboundDispatcher; +pub use message::{ChatMessage, InboundMessage, OutboundMessage}; -use tokio::sync::{mpsc, broadcast}; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; + +// ============================================================================ +// MessageBus - Async message queue for Channel <-> Agent communication +// ============================================================================ pub struct MessageBus { - user_tx: mpsc::Sender, - llm_tx: broadcast::Sender, + inbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, + inbound_rx: Mutex>, + outbound_rx: Mutex>, } impl MessageBus { - pub fn new(buffer_size: usize) -> Self { - let (user_tx, _) = mpsc::channel(buffer_size); - let (llm_tx, _) = broadcast::channel(buffer_size); - - Self { user_tx, llm_tx } + /// Create a new MessageBus with the given channel capacity + pub fn new(capacity: usize) -> Arc { + let (inbound_tx, inbound_rx) = mpsc::channel(capacity); + let (outbound_tx, outbound_rx) = mpsc::channel(capacity); + Arc::new(Self { + inbound_tx, + outbound_tx, + inbound_rx: Mutex::new(inbound_rx), + outbound_rx: Mutex::new(outbound_rx), + }) } - pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> { - self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed) + /// Publish an inbound message (Channel -> Bus) + pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> { + self.inbound_tx + .send(msg) + .await + .map_err(|_| BusError::Closed) } - pub fn send_llm_output(&self, msg: ChatMessage) -> Result { - self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed) + /// Consume an inbound message (Agent -> Bus) + pub async fn consume_inbound(&self) -> InboundMessage { + self.inbound_rx + .lock() + .await + .recv() + .await + .expect("bus inbound closed") + } + + /// Publish an outbound message (Agent -> Bus) + pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> { + self.outbound_tx + .send(msg) + .await + .map_err(|_| BusError::Closed) + } + + /// Consume an outbound message (Dispatcher -> Bus) + pub async fn consume_outbound(&self) -> OutboundMessage { + self.outbound_rx + .lock() + .await + .recv() + .await + .expect("bus outbound closed") } } +// ============================================================================ +// BusError +// ============================================================================ + #[derive(Debug)] pub enum BusError { - ChannelClosed, - SendError(usize), + Closed, } impl std::fmt::Display for BusError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - BusError::ChannelClosed => write!(f, "Channel closed"), - BusError::SendError(n) => write!(f, "Send error, {} receivers", n), + BusError::Closed => write!(f, "Bus channel closed"), } } } diff --git a/src/channels/base.rs b/src/channels/base.rs index 5fec82b..33f7872 100644 --- a/src/channels/base.rs +++ b/src/channels/base.rs @@ -1,30 +1,14 @@ -use std::collections::HashMap; use async_trait::async_trait; +use std::sync::Arc; -#[derive(Debug, Clone)] -pub struct InboundMessage { - pub channel: String, - pub sender_id: String, - pub chat_id: String, - pub content: String, - pub media: Vec, - pub metadata: HashMap, -} - -#[derive(Debug, Clone)] -pub struct OutboundMessage { - pub channel: String, - pub chat_id: String, - pub content: String, - pub media: Vec, - pub metadata: HashMap, -} +use crate::bus::{BusError, InboundMessage, MessageBus, OutboundMessage}; #[derive(Debug)] pub enum ChannelError { ConfigError(String), ConnectionError(String), SendError(String), + BusError(String), Other(String), } @@ -34,6 +18,7 @@ impl std::fmt::Display for ChannelError { ChannelError::ConfigError(s) => write!(f, "Config error: {}", s), ChannelError::ConnectionError(s) => write!(f, "Connection error: {}", s), ChannelError::SendError(s) => write!(f, "Send error: {}", s), + ChannelError::BusError(s) => write!(f, "Bus error: {}", s), ChannelError::Other(s) => write!(f, "Error: {}", s), } } @@ -41,10 +26,73 @@ impl std::fmt::Display for ChannelError { impl std::error::Error for ChannelError {} +impl From for ChannelError { + fn from(e: BusError) -> Self { + ChannelError::BusError(e.to_string()) + } +} + #[async_trait] pub trait Channel: Send + Sync + 'static { fn name(&self) -> &str; - async fn start(&self) -> Result<(), ChannelError>; - async fn stop(&self) -> Result<(), ChannelError>; fn is_running(&self) -> bool; + + /// Start the channel with a reference to the MessageBus + async fn start(&self, bus: Arc) -> Result<(), ChannelError>; + + /// Stop the channel + async fn stop(&self) -> Result<(), ChannelError>; + + /// Send a message to the channel (called by OutboundDispatcher) + async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError>; + + /// Send a streaming delta (optional, for channels that support it) + async fn send_delta(&self, chat_id: &str, delta: &str) -> Result<(), ChannelError> { + let _ = chat_id; + let _ = delta; + Ok(()) + } + + /// Check if a sender is allowed to use this channel + fn is_allowed(&self, _sender_id: &str) -> bool { + true + } + + /// Handle an inbound message: check permissions and publish to bus + async fn handle_and_publish( + &self, + bus: &Arc, + sender_id: &str, + chat_id: &str, + content: &str, + ) -> Result<(), ChannelError> { + if !self.is_allowed(sender_id) { + tracing::warn!( + channel = %self.name(), + sender = %sender_id, + "Access denied" + ); + return Ok(()); + } + + let msg = InboundMessage { + channel: self.name().to_string(), + sender_id: sender_id.to_string(), + chat_id: chat_id.to_string(), + content: content.to_string(), + timestamp: current_timestamp(), + media: vec![], + metadata: std::collections::HashMap::new(), + }; + + bus.publish_inbound(msg).await?; + Ok(()) + } +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as i64 } diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index c982acb..4068b7f 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -5,8 +5,8 @@ use serde::Deserialize; use futures_util::{SinkExt, StreamExt}; use prost::{Message as ProstMessage, bytes::Bytes}; +use crate::bus::{MessageBus, OutboundMessage}; use crate::channels::base::{Channel, ChannelError}; -use crate::channels::manager::MessageHandler; use crate::config::{FeishuChannelConfig, LLMProviderConfig}; const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis"; @@ -131,8 +131,6 @@ pub struct FeishuChannel { running: Arc>, shutdown_tx: Arc>>>, connected: Arc>, - /// Message handler for routing messages to Gateway - message_handler: Arc, } /// Parsed message data from a Feishu frame @@ -146,7 +144,6 @@ struct ParsedMessage { impl FeishuChannel { pub fn new( config: FeishuChannelConfig, - message_handler: Arc, _provider_config: LLMProviderConfig, ) -> Result { Ok(Self { @@ -155,7 +152,6 @@ impl FeishuChannel { running: Arc::new(RwLock::new(false)), shutdown_tx: Arc::new(RwLock::new(None)), connected: Arc::new(RwLock::new(false)), - message_handler, }) } @@ -224,11 +220,10 @@ impl FeishuChannel { .ok_or_else(|| ChannelError::Other("No token in response".to_string())) } - /// Send a text message to Feishu chat - async fn send_message(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> { + /// Send a text message to Feishu chat (implements Channel trait) + async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> { let token = self.get_tenant_token().await?; - // For text message, content should be a JSON string: "{\"text\":\"hello\"}" let text_content = serde_json::json!({ "text": content }).to_string(); let resp = self.http_client @@ -262,26 +257,6 @@ impl FeishuChannel { Ok(()) } - /// Handle incoming message - delegate to message handler and send response - async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> { - tracing::info!(open_id, chat_id, "Processing message from Feishu"); - - // Delegate to message handler (Gateway) - let response = self.message_handler - .handle_message("feishu", open_id, chat_id, content) - .await?; - - // Send response to the chat - // Use open_id for p2p chats, chat_id for group chats - let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id }; - let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" }; - - self.send_message(receive_id, receive_id_type, &response).await?; - tracing::info!(receive_id, "Sent response to Feishu"); - - Ok(()) - } - /// Extract service_id from WebSocket URL query params fn extract_service_id(url: &str) -> i32 { url.split('?') @@ -310,7 +285,6 @@ impl FeishuChannel { let payload = frame.payload.as_deref() .ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?; - // Parse the event JSON to get event_type from payload header let event: LarkEvent = serde_json::from_slice(payload) .map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?; @@ -359,7 +333,7 @@ impl FeishuChannel { Ok(()) } - async fn run_ws_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> { + async fn run_ws_loop(&self, bus: Arc, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> { let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?; let service_id = Self::extract_service_id(&wss_url); @@ -404,7 +378,6 @@ impl FeishuChannel { Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => { let bytes: Bytes = data; if let Ok(frame) = PbFrame::decode(bytes.as_ref()) { - // Parse the frame first match self.handle_frame(&frame).await { Ok(Some(parsed)) => { // Send ACK immediately (Feishu requires within 3 s) @@ -412,11 +385,12 @@ impl FeishuChannel { tracing::error!(error = %e, "Failed to send ACK to Feishu"); } - // Then process message asynchronously (don't await) + // Publish to bus asynchronously let channel = self.clone(); + let bus = bus.clone(); tokio::spawn(async move { - if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await { - tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to handle Feishu message"); + if let Err(e) = channel.handle_and_publish(&bus, &parsed.open_id, &parsed.chat_id, &parsed.content).await { + tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus"); } }); } @@ -528,7 +502,7 @@ impl Channel for FeishuChannel { "feishu" } - async fn start(&self) -> Result<(), ChannelError> { + async fn start(&self, bus: Arc) -> Result<(), ChannelError> { if self.config.app_id.is_empty() || self.config.app_secret.is_empty() { return Err(ChannelError::ConfigError( "Feishu app_id or app_secret is not configured".to_string() @@ -541,6 +515,7 @@ impl Channel for FeishuChannel { *self.shutdown_tx.write().await = Some(shutdown_tx.clone()); let channel = self.clone(); + let bus = bus.clone(); tokio::spawn(async move { let mut consecutive_failures = 0; let max_failures = 3; @@ -551,7 +526,7 @@ impl Channel for FeishuChannel { } let shutdown_rx = shutdown_tx.subscribe(); - match channel.run_ws_loop(shutdown_rx).await { + match channel.run_ws_loop(bus.clone(), shutdown_rx).await { Ok(_) => { tracing::info!("Feishu WebSocket disconnected"); } @@ -592,7 +567,13 @@ impl Channel for FeishuChannel { } fn is_running(&self) -> bool { - // Note: blocking read, acceptable for this use case self.running.try_read().map(|r| *r).unwrap_or(false) } + + async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { + let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) }; + let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" }; + + self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await + } } diff --git a/src/channels/manager.rs b/src/channels/manager.rs index 6b1a0ec..568639e 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -1,51 +1,38 @@ use std::collections::HashMap; use std::sync::Arc; -use async_trait::async_trait; use tokio::sync::RwLock; +use crate::bus::{MessageBus, OutboundMessage}; use crate::channels::base::{Channel, ChannelError}; use crate::channels::feishu::FeishuChannel; -use crate::config::{Config, FeishuChannelConfig}; +use crate::config::Config; -/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦 -#[async_trait] -pub trait MessageHandler: Send + Sync { - async fn handle_message( - &self, - channel_name: &str, - sender_id: &str, - chat_id: &str, - content: &str, - ) -> Result; -} - -/// ChannelManager 管理所有 Channel +/// ChannelManager manages all Channel instances and the MessageBus #[derive(Clone)] pub struct ChannelManager { channels: Arc>>>, - message_handler: Arc, + bus: Arc, } impl ChannelManager { - pub fn new(message_handler: Arc) -> Self { + pub fn new() -> Self { Self { channels: Arc::new(RwLock::new(HashMap::new())), - message_handler, + bus: MessageBus::new(100), } } - /// 获取 MessageHandler 用于让 Channel 调用 - pub fn get_handler(&self) -> Arc { - self.message_handler.clone() + /// Get a reference to the MessageBus + pub fn bus(&self) -> Arc { + self.bus.clone() } - /// 初始化所有 Channel - pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> { + /// Initialize all Channel instances from config + pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> { // Initialize Feishu channel if enabled if let Some(feishu_config) = config.channels.get("feishu") { if feishu_config.enabled { - let handler = self.get_handler(); - let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config) + let channel = FeishuChannel::new(feishu_config.clone(), _provider_config) .map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; self.channels @@ -62,9 +49,10 @@ impl ChannelManager { pub async fn start_all(&self) -> Result<(), ChannelError> { let channels = self.channels.read().await; + let bus = self.bus.clone(); for (name, channel) in channels.iter() { tracing::info!(channel = %name, "Starting channel"); - if let Err(e) = channel.start().await { + if let Err(e) = channel.start(bus.clone()).await { tracing::error!(channel = %name, error = %e, "Failed to start channel"); } } @@ -86,32 +74,14 @@ impl ChannelManager { pub async fn get_channel(&self, name: &str) -> Option> { self.channels.read().await.get(name).cloned() } -} -/// Gateway 实现 MessageHandler trait -#[derive(Clone)] -pub struct GatewayMessageHandler { - session_manager: crate::gateway::session::SessionManager, -} - -impl GatewayMessageHandler { - pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self { - Self { session_manager } - } -} - -#[async_trait] -impl MessageHandler for GatewayMessageHandler { - async fn handle_message( - &self, - channel_name: &str, - sender_id: &str, - chat_id: &str, - content: &str, - ) -> Result { - self.session_manager - .handle_message(channel_name, sender_id, chat_id, content) - .await - .map_err(|e| ChannelError::Other(e.to_string())) + /// Dispatch an outbound message to the appropriate channel + pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> { + let channel_name = &msg.channel; + if let Some(channel) = self.get_channel(channel_name).await { + channel.send(msg).await + } else { + Err(ChannelError::Other(format!("Channel not found: {}", channel_name))) + } } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 41c4d17..d8a5b1d 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -2,6 +2,6 @@ pub mod base; pub mod feishu; pub mod manager; -pub use base::{Channel, ChannelError, InboundMessage, OutboundMessage}; +pub use base::{Channel, ChannelError}; pub use manager::ChannelManager; pub use feishu::FeishuChannel; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 88c9165..1cd5a37 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -6,7 +6,8 @@ use std::sync::Arc; use axum::{routing, Router}; use tokio::net::TcpListener; -use crate::channels::{ChannelManager, manager::GatewayMessageHandler}; +use crate::bus::{MessageBus, OutboundDispatcher}; +use crate::channels::ChannelManager; use crate::config::Config; use crate::logging; use session::SessionManager; @@ -15,6 +16,7 @@ pub struct GatewayState { pub config: Config, pub session_manager: SessionManager, pub channel_manager: ChannelManager, + pub bus: Arc, } impl GatewayState { @@ -28,15 +30,76 @@ impl GatewayState { let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); let session_manager = SessionManager::new(session_ttl_hours, provider_config); - let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone())); - let channel_manager = ChannelManager::new(message_handler); + let channel_manager = ChannelManager::new(); + let bus = channel_manager.bus(); Ok(Self { config, session_manager, channel_manager, + bus, }) } + + /// Start the message processing loops + pub async fn start_message_processing(&self) { + let bus_for_inbound = self.bus.clone(); + let bus_for_outbound = self.bus.clone(); + let session_manager = self.session_manager.clone(); + + // Spawn inbound message processor + // This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound + tokio::spawn(async move { + tracing::info!("Inbound processor started"); + loop { + let inbound = bus_for_inbound.consume_inbound().await; + tracing::debug!( + channel = %inbound.channel, + chat_id = %inbound.chat_id, + "Processing inbound message" + ); + + // Process via session manager + match session_manager.handle_message( + &inbound.channel, + &inbound.sender_id, + &inbound.chat_id, + &inbound.content, + ).await { + Ok(response_content) => { + let outbound = crate::bus::OutboundMessage { + channel: inbound.channel, + chat_id: inbound.chat_id, + content: response_content, + reply_to: None, + media: vec![], + metadata: std::collections::HashMap::new(), + }; + if let Err(e) = bus_for_inbound.publish_outbound(outbound).await { + tracing::error!(error = %e, "Failed to publish outbound"); + } + } + Err(e) => { + tracing::error!(error = %e, "Failed to handle message"); + } + } + } + }); + + // Spawn outbound dispatcher + let dispatcher = OutboundDispatcher::new(bus_for_outbound); + let channel_manager = self.channel_manager.clone(); + + // Register channels with dispatcher + if let Some(channel) = channel_manager.get_channel("feishu").await { + dispatcher.register_channel("feishu", channel).await; + } + + tokio::spawn(async move { + tracing::info!("Outbound dispatcher started"); + dispatcher.run().await; + }); + } } pub async fn run(host: Option, port: Option) -> Result<(), Box> { @@ -50,9 +113,12 @@ pub async fn run(host: Option, port: Option) -> Result<(), Box>>, + /// 按 chat_id 路由到不同会话历史,支持多用户多会话 + chat_histories: HashMap>, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, tools: Arc, @@ -30,52 +31,65 @@ impl Session { Ok(Self { id: Uuid::new_v4(), channel_name, - chat_agents: HashMap::new(), + chat_histories: HashMap::new(), user_tx, provider_config, tools, }) } - /// 获取或创建指定 chat_id 的 AgentLoop - pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result>, AgentError> { - if let Some(agent) = self.chat_agents.get(chat_id) { - tracing::trace!(chat_id = %chat_id, "Reusing existing agent"); - return Ok(agent.clone()); - } - tracing::debug!(chat_id = %chat_id, "Creating new agent for chat"); - let agent = AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())?; - let arc = Arc::new(Mutex::new(agent)); - self.chat_agents.insert(chat_id.to_string(), arc.clone()); - Ok(arc) + /// 获取或创建指定 chat_id 的会话历史 + pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec { + self.chat_histories + .entry(chat_id.to_string()) + .or_insert_with(Vec::new) } - /// 获取指定 chat_id 的 AgentLoop(不创建) - pub fn get_agent(&self, chat_id: &str) -> Option>> { - self.chat_agents.get(chat_id).cloned() + /// 获取指定 chat_id 的会话历史(不创建) + pub fn get_history(&self, chat_id: &str) -> Option<&Vec> { + self.chat_histories.get(chat_id) + } + + /// 添加用户消息到指定 chat_id 的历史 + pub fn add_user_message(&mut self, chat_id: &str, content: &str) { + let history = self.get_or_create_history(chat_id); + history.push(ChatMessage::user(content)); + } + + /// 添加助手响应到指定 chat_id 的历史 + pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) { + if let Some(history) = self.chat_histories.get_mut(chat_id) { + history.push(message); + } } /// 清除指定 chat_id 的历史 - pub async fn clear_chat_history(&mut self, chat_id: &str) { - if let Some(agent) = self.chat_agents.get(chat_id) { - agent.lock().await.clear_history(); + pub fn clear_chat_history(&mut self, chat_id: &str) { + if let Some(history) = self.chat_histories.get_mut(chat_id) { + let len = history.len(); + history.clear(); + tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); } } /// 清除所有历史 - pub async fn clear_all_history(&mut self) { - for agent in self.chat_agents.values() { - agent.lock().await.clear_history(); - } + pub fn clear_all_history(&mut self) { + let total: usize = self.chat_histories.values().map(|h| h.len()).sum(); + self.chat_histories.clear(); + tracing::debug!(previous_total = total, "All chat histories cleared"); } pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } + + /// 创建一个临时的 AgentLoop 实例来处理消息 + pub fn create_agent(&self) -> Result { + AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone()) + } } /// SessionManager 管理所有 Session,按 channel_name 路由 -/// 使用 Arc> 以从 Arc 获取可变访问 #[derive(Clone)] pub struct SessionManager { inner: Arc>, @@ -135,7 +149,13 @@ impl SessionManager { // 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS) let (user_tx, _rx) = mpsc::channel::(100); - let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone()).await?; + let session = Session::new( + channel_name.to_string(), + self.provider_config.clone(), + user_tx, + self.tools.clone(), + ) + .await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(channel_name.to_string(), arc.clone()); @@ -165,7 +185,12 @@ impl SessionManager { chat_id: &str, content: &str, ) -> Result { - tracing::debug!(channel = %channel_name, chat_id = %chat_id, content_len = content.len(), "Routing message to agent"); + tracing::debug!( + channel = %channel_name, + chat_id = %chat_id, + content_len = content.len(), + "Routing message to agent" + ); // 确保 session 存在(可能需要重建) self.ensure_session(channel_name).await?; @@ -174,21 +199,37 @@ impl SessionManager { self.touch(channel_name).await; // 获取 session - let session = self.get(channel_name).await + let session = self + .get(channel_name) + .await .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; - // 获取或创建 chat_id 对应的 agent - let mut session_guard = session.lock().await; - let agent = session_guard.get_or_create_agent(chat_id).await?; - drop(session_guard); - - let mut agent = agent.lock().await; - // 处理消息 - let user_msg = ChatMessage::user(content); - let response = agent.process(user_msg).await?; + let response = { + let mut session_guard = session.lock().await; - tracing::debug!(channel = %channel_name, chat_id = %chat_id, response_len = response.content.len(), "Agent response received"); + // 添加用户消息到历史 + session_guard.add_user_message(chat_id, content); + + // 获取完整历史 + let history = session_guard.get_or_create_history(chat_id).clone(); + + // 创建 agent 并处理 + let agent = session_guard.create_agent()?; + let response = agent.process(history).await?; + + // 添加助手响应到历史 + session_guard.add_assistant_message(chat_id, response.clone()); + + response + }; + + tracing::debug!( + channel = %channel_name, + chat_id = %chat_id, + response_len = response.content.len(), + "Agent response received" + ); Ok(response.content) } @@ -197,7 +238,7 @@ impl SessionManager { pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { if let Some(session) = self.get(channel_name).await { let mut session_guard = session.lock().await; - session_guard.clear_all_history().await; + session_guard.clear_all_history(); } Ok(()) } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index a10d0b6..81f7cd0 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -4,7 +4,6 @@ use axum::extract::State; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; -use crate::bus::ChatMessage; use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound}; use super::{GatewayState, session::Session}; @@ -29,7 +28,14 @@ async fn handle_socket(ws: WebSocket, state: Arc) { let channel_name = format!("cli-{}", uuid::Uuid::new_v4()); // 创建 CLI session - let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await { + let session = match Session::new( + channel_name.clone(), + provider_config, + sender, + state.session_manager.tools(), + ) + .await + { Ok(s) => Arc::new(Mutex::new(s)), Err(e) => { tracing::error!(error = %e, "Failed to create session"); @@ -40,9 +46,13 @@ async fn handle_socket(ws: WebSocket, state: Arc) { let session_id = session.lock().await.id; tracing::info!(session_id = %session_id, "CLI session established"); - let _ = session.lock().await.send(WsOutbound::SessionEstablished { - session_id: session_id.to_string(), - }).await; + let _ = session + .lock() + .await + .send(WsOutbound::SessionEstablished { + session_id: session_id.to_string(), + }) + .await; let (mut ws_sender, mut ws_receiver) = ws.split(); @@ -69,10 +79,14 @@ async fn handle_socket(ws: WebSocket, state: Arc) { } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); - let _ = session.lock().await.send(WsOutbound::Error { - code: "PARSE_ERROR".to_string(), - message: e.to_string(), - }).await; + let _ = session + .lock() + .await + .send(WsOutbound::Error { + code: "PARSE_ERROR".to_string(), + message: e.to_string(), + }) + .await; } } } @@ -92,7 +106,12 @@ async fn handle_inbound(session: &Arc>, inbound: WsInbound) { // 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id) let (content, chat_id) = match inbound_clone { - WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => { + WsInbound::UserInput { + content, + channel: _, + chat_id, + sender_id: _, + } => { // CLI 使用 session 中的 channel_name 作为标识 // chat_id 使用传入的或使用默认 let chat_id = chat_id.unwrap_or_else(|| "default".to_string()); @@ -101,38 +120,50 @@ async fn handle_inbound(session: &Arc>, inbound: WsInbound) { _ => return, }; - let user_msg = ChatMessage::user(content); - let mut session_guard = session.lock().await; - let agent = match session_guard.get_or_create_agent(&chat_id).await { + + // 添加用户消息到历史 + session_guard.add_user_message(&chat_id, &content); + + // 获取完整历史 + let history = session_guard.get_or_create_history(&chat_id).clone(); + + // 创建 agent 并处理 + let agent = match session_guard.create_agent() { Ok(a) => a, Err(e) => { - tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent"); - let _ = session_guard.send(WsOutbound::Error { - code: "AGENT_ERROR".to_string(), - message: e.to_string(), - }).await; + tracing::error!(chat_id = %chat_id, error = %e, "Failed to create agent"); + let _ = session_guard + .send(WsOutbound::Error { + code: "AGENT_ERROR".to_string(), + message: e.to_string(), + }) + .await; return; } }; - drop(session_guard); - let mut agent = agent.lock().await; - match agent.process(user_msg).await { + match agent.process(history).await { Ok(response) => { tracing::debug!(chat_id = %chat_id, "Agent response sent"); - let _ = session.lock().await.send(WsOutbound::AssistantResponse { - id: response.id, - content: response.content, - role: response.role, - }).await; + // 添加助手响应到历史 + session_guard.add_assistant_message(&chat_id, response.clone()); + let _ = session_guard + .send(WsOutbound::AssistantResponse { + id: response.id, + content: response.content, + role: response.role, + }) + .await; } Err(e) => { tracing::error!(chat_id = %chat_id, error = %e, "Agent process error"); - let _ = session.lock().await.send(WsOutbound::Error { - code: "LLM_ERROR".to_string(), - message: e.to_string(), - }).await; + let _ = session_guard + .send(WsOutbound::Error { + code: "LLM_ERROR".to_string(), + message: e.to_string(), + }) + .await; } } }