Refactor AgentLoop to manage history externally via SessionManager

- Removed internal history management from AgentLoop.
- Updated process method to accept conversation history as a parameter.
- Adjusted continue_with_tool_results to work with external history.
- Added OutboundDispatcher for handling outbound messages from MessageBus.
- Introduced InboundMessage and OutboundMessage structs for message handling.
- Updated Channel trait to include message handling and publishing to MessageBus.
- Refactored Session to manage chat histories instead of AgentLoop instances.
- Enhanced GatewayState to start message processing loops for inbound and outbound messages.
This commit is contained in:
xiaoski 2026-04-07 21:53:37 +08:00
parent 9834bd75cf
commit a051f83050
11 changed files with 534 additions and 231 deletions

View File

@ -4,9 +4,9 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess
use crate::tools::ToolRegistry;
use std::sync::Arc;
/// Stateless AgentLoop - history is managed externally by SessionManager
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
history: Vec<ChatMessage>,
tools: Arc<ToolRegistry>,
}
@ -17,7 +17,6 @@ impl AgentLoop {
Ok(Self {
provider,
history: Vec::new(),
tools: Arc::new(ToolRegistry::new()),
})
}
@ -28,7 +27,6 @@ impl AgentLoop {
Ok(Self {
provider,
history: Vec::new(),
tools,
})
}
@ -37,10 +35,10 @@ impl AgentLoop {
&self.tools
}
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
self.history.push(user_message.clone());
let messages: Vec<Message> = self.history
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
let messages_for_llm: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
@ -50,7 +48,7 @@ impl AgentLoop {
})
.collect();
tracing::debug!(history_len = self.history.len(), "Sending request to LLM");
tracing::debug!(history_len = messages.len(), "Sending request to LLM");
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
@ -59,7 +57,7 @@ impl AgentLoop {
};
let request = ChatCompletionRequest {
messages,
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
@ -71,12 +69,18 @@ impl AgentLoop {
AgentError::LlmError(e.to_string())
})?;
tracing::debug!(response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received");
tracing::debug!(
response_len = response.content.len(),
tool_calls_len = response.tool_calls.len(),
"LLM response received"
);
if !response.tool_calls.is_empty() {
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
let mut updated_messages = messages.clone();
let assistant_message = ChatMessage::assistant(response.content.clone());
self.history.push(assistant_message.clone());
updated_messages.push(assistant_message.clone());
let tool_results = self.execute_tools(&response.tool_calls).await;
@ -86,20 +90,18 @@ impl AgentLoop {
tool_call.name.clone(),
result.clone(),
);
self.history.push(tool_message);
updated_messages.push(tool_message);
}
return self.continue_with_tool_results(response.content).await;
return self.continue_with_tool_results(updated_messages).await;
}
let assistant_message = ChatMessage::assistant(response.content);
self.history.push(assistant_message.clone());
Ok(assistant_message)
}
async fn continue_with_tool_results(&mut self, _original_content: String) -> Result<ChatMessage, AgentError> {
let messages: Vec<Message> = self.history
async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
let messages_for_llm: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
@ -116,7 +118,7 @@ impl AgentLoop {
};
let request = ChatCompletionRequest {
messages,
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
@ -129,8 +131,6 @@ impl AgentLoop {
})?;
let assistant_message = ChatMessage::assistant(response.content);
self.history.push(assistant_message.clone());
Ok(assistant_message)
}
@ -168,16 +168,6 @@ impl AgentLoop {
}
}
}
pub fn clear_history(&mut self) {
let len = self.history.len();
self.history.clear();
tracing::debug!(previous_len = len, "Chat history cleared");
}
pub fn history(&self) -> &[ChatMessage] {
&self.history
}
}
#[derive(Debug)]

82
src/bus/dispatcher.rs Normal file
View File

@ -0,0 +1,82 @@
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
/// OutboundDispatcher consumes outbound messages from the MessageBus
/// and dispatches them to the appropriate Channel
pub struct OutboundDispatcher {
bus: Arc<MessageBus>,
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
}
impl OutboundDispatcher {
pub fn new(bus: Arc<MessageBus>) -> Self {
Self {
bus,
channels: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Register a channel with the dispatcher
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels.write().await.insert(name.to_string(), channel);
}
/// Run the dispatcher loop - consumes from bus and dispatches to channels
pub async fn run(&self) {
tracing::info!("OutboundDispatcher started");
loop {
let msg = self.bus.consume_outbound().await;
tracing::debug!(
channel = %msg.channel,
chat_id = %msg.chat_id,
content_len = msg.content.len(),
"OutboundDispatcher received message"
);
let channel_name = msg.channel.clone();
let channel = self.channels.read().await.get(&channel_name).cloned();
match channel {
Some(ch) => {
if let Err(e) = self.send_with_retry(&*ch, msg).await {
tracing::error!(channel = %channel_name, error = %e, "Failed to send message after retries");
}
}
None => {
tracing::warn!(channel = %channel_name, "No channel found for message");
}
}
}
}
/// Send a message with exponential retry
async fn send_with_retry(
&self,
channel: &dyn Channel,
msg: OutboundMessage,
) -> Result<(), ChannelError> {
const DELAYS: [u64; 3] = [1, 2, 4];
for (i, delay) in DELAYS.iter().enumerate() {
match channel.send(msg.clone()).await {
Ok(()) => return Ok(()),
Err(e) if i < DELAYS.len() - 1 => {
tracing::warn!(
attempt = i + 1,
delay = delay,
error = %e,
"Send failed, retrying"
);
tokio::time::sleep(tokio::time::Duration::from_secs(*delay)).await;
}
Err(e) => return Err(e),
}
}
unreachable!()
}
}

View File

@ -1,5 +1,10 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
// ============================================================================
// ChatMessage - Legacy type used by AgentLoop for LLM conversation history
// ============================================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub id: String,
@ -58,6 +63,51 @@ impl ChatMessage {
}
}
// ============================================================================
// InboundMessage - Message from Channel to Bus (user input)
// ============================================================================
#[derive(Debug, Clone)]
pub struct InboundMessage {
pub channel: String,
pub sender_id: String,
pub chat_id: String,
pub content: String,
pub timestamp: i64,
pub media: Vec<String>,
pub metadata: HashMap<String, String>,
}
impl InboundMessage {
pub fn session_key(&self) -> String {
format!("{}:{}", self.channel, self.chat_id)
}
}
// ============================================================================
// OutboundMessage - Message from Agent to Channel (bot response)
// ============================================================================
#[derive(Debug, Clone)]
pub struct OutboundMessage {
pub channel: String,
pub chat_id: String,
pub content: String,
pub reply_to: Option<String>,
pub media: Vec<String>,
pub metadata: HashMap<String, String>,
}
impl OutboundMessage {
pub fn is_stream_delta(&self) -> bool {
self.metadata.get("_stream_delta").is_some()
}
}
// ============================================================================
// Helpers
// ============================================================================
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)

View File

@ -1,42 +1,86 @@
pub mod dispatcher;
pub mod message;
pub use message::ChatMessage;
pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, InboundMessage, OutboundMessage};
use tokio::sync::{mpsc, broadcast};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
// ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication
// ============================================================================
pub struct MessageBus {
user_tx: mpsc::Sender<ChatMessage>,
llm_tx: broadcast::Sender<ChatMessage>,
inbound_tx: mpsc::Sender<InboundMessage>,
outbound_tx: mpsc::Sender<OutboundMessage>,
inbound_rx: Mutex<mpsc::Receiver<InboundMessage>>,
outbound_rx: Mutex<mpsc::Receiver<OutboundMessage>>,
}
impl MessageBus {
pub fn new(buffer_size: usize) -> Self {
let (user_tx, _) = mpsc::channel(buffer_size);
let (llm_tx, _) = broadcast::channel(buffer_size);
Self { user_tx, llm_tx }
/// Create a new MessageBus with the given channel capacity
pub fn new(capacity: usize) -> Arc<Self> {
let (inbound_tx, inbound_rx) = mpsc::channel(capacity);
let (outbound_tx, outbound_rx) = mpsc::channel(capacity);
Arc::new(Self {
inbound_tx,
outbound_tx,
inbound_rx: Mutex::new(inbound_rx),
outbound_rx: Mutex::new(outbound_rx),
})
}
pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> {
self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed)
/// Publish an inbound message (Channel -> Bus)
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
self.inbound_tx
.send(msg)
.await
.map_err(|_| BusError::Closed)
}
pub fn send_llm_output(&self, msg: ChatMessage) -> Result<usize, BusError> {
self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed)
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
self.inbound_rx
.lock()
.await
.recv()
.await
.expect("bus inbound closed")
}
/// Publish an outbound message (Agent -> Bus)
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
self.outbound_tx
.send(msg)
.await
.map_err(|_| BusError::Closed)
}
/// Consume an outbound message (Dispatcher -> Bus)
pub async fn consume_outbound(&self) -> OutboundMessage {
self.outbound_rx
.lock()
.await
.recv()
.await
.expect("bus outbound closed")
}
}
// ============================================================================
// BusError
// ============================================================================
#[derive(Debug)]
pub enum BusError {
ChannelClosed,
SendError(usize),
Closed,
}
impl std::fmt::Display for BusError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BusError::ChannelClosed => write!(f, "Channel closed"),
BusError::SendError(n) => write!(f, "Send error, {} receivers", n),
BusError::Closed => write!(f, "Bus channel closed"),
}
}
}

View File

@ -1,30 +1,14 @@
use std::collections::HashMap;
use async_trait::async_trait;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct InboundMessage {
pub channel: String,
pub sender_id: String,
pub chat_id: String,
pub content: String,
pub media: Vec<String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct OutboundMessage {
pub channel: String,
pub chat_id: String,
pub content: String,
pub media: Vec<String>,
pub metadata: HashMap<String, String>,
}
use crate::bus::{BusError, InboundMessage, MessageBus, OutboundMessage};
#[derive(Debug)]
pub enum ChannelError {
ConfigError(String),
ConnectionError(String),
SendError(String),
BusError(String),
Other(String),
}
@ -34,6 +18,7 @@ impl std::fmt::Display for ChannelError {
ChannelError::ConfigError(s) => write!(f, "Config error: {}", s),
ChannelError::ConnectionError(s) => write!(f, "Connection error: {}", s),
ChannelError::SendError(s) => write!(f, "Send error: {}", s),
ChannelError::BusError(s) => write!(f, "Bus error: {}", s),
ChannelError::Other(s) => write!(f, "Error: {}", s),
}
}
@ -41,10 +26,73 @@ impl std::fmt::Display for ChannelError {
impl std::error::Error for ChannelError {}
impl From<BusError> for ChannelError {
fn from(e: BusError) -> Self {
ChannelError::BusError(e.to_string())
}
}
#[async_trait]
pub trait Channel: Send + Sync + 'static {
fn name(&self) -> &str;
async fn start(&self) -> Result<(), ChannelError>;
async fn stop(&self) -> Result<(), ChannelError>;
fn is_running(&self) -> bool;
/// Start the channel with a reference to the MessageBus
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError>;
/// Stop the channel
async fn stop(&self) -> Result<(), ChannelError>;
/// Send a message to the channel (called by OutboundDispatcher)
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError>;
/// Send a streaming delta (optional, for channels that support it)
async fn send_delta(&self, chat_id: &str, delta: &str) -> Result<(), ChannelError> {
let _ = chat_id;
let _ = delta;
Ok(())
}
/// Check if a sender is allowed to use this channel
fn is_allowed(&self, _sender_id: &str) -> bool {
true
}
/// Handle an inbound message: check permissions and publish to bus
async fn handle_and_publish(
&self,
bus: &Arc<MessageBus>,
sender_id: &str,
chat_id: &str,
content: &str,
) -> Result<(), ChannelError> {
if !self.is_allowed(sender_id) {
tracing::warn!(
channel = %self.name(),
sender = %sender_id,
"Access denied"
);
return Ok(());
}
let msg = InboundMessage {
channel: self.name().to_string(),
sender_id: sender_id.to_string(),
chat_id: chat_id.to_string(),
content: content.to_string(),
timestamp: current_timestamp(),
media: vec![],
metadata: std::collections::HashMap::new(),
};
bus.publish_inbound(msg).await?;
Ok(())
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}

View File

@ -5,8 +5,8 @@ use serde::Deserialize;
use futures_util::{SinkExt, StreamExt};
use prost::{Message as ProstMessage, bytes::Bytes};
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::manager::MessageHandler;
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
@ -131,8 +131,6 @@ pub struct FeishuChannel {
running: Arc<RwLock<bool>>,
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
connected: Arc<RwLock<bool>>,
/// Message handler for routing messages to Gateway
message_handler: Arc<dyn MessageHandler>,
}
/// Parsed message data from a Feishu frame
@ -146,7 +144,6 @@ struct ParsedMessage {
impl FeishuChannel {
pub fn new(
config: FeishuChannelConfig,
message_handler: Arc<dyn MessageHandler>,
_provider_config: LLMProviderConfig,
) -> Result<Self, ChannelError> {
Ok(Self {
@ -155,7 +152,6 @@ impl FeishuChannel {
running: Arc::new(RwLock::new(false)),
shutdown_tx: Arc::new(RwLock::new(None)),
connected: Arc::new(RwLock::new(false)),
message_handler,
})
}
@ -224,11 +220,10 @@ impl FeishuChannel {
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
}
/// Send a text message to Feishu chat
async fn send_message(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
/// Send a text message to Feishu chat (implements Channel trait)
async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
let token = self.get_tenant_token().await?;
// For text message, content should be a JSON string: "{\"text\":\"hello\"}"
let text_content = serde_json::json!({ "text": content }).to_string();
let resp = self.http_client
@ -262,26 +257,6 @@ impl FeishuChannel {
Ok(())
}
/// Handle incoming message - delegate to message handler and send response
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
tracing::info!(open_id, chat_id, "Processing message from Feishu");
// Delegate to message handler (Gateway)
let response = self.message_handler
.handle_message("feishu", open_id, chat_id, content)
.await?;
// Send response to the chat
// Use open_id for p2p chats, chat_id for group chats
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
self.send_message(receive_id, receive_id_type, &response).await?;
tracing::info!(receive_id, "Sent response to Feishu");
Ok(())
}
/// Extract service_id from WebSocket URL query params
fn extract_service_id(url: &str) -> i32 {
url.split('?')
@ -310,7 +285,6 @@ impl FeishuChannel {
let payload = frame.payload.as_deref()
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
// Parse the event JSON to get event_type from payload header
let event: LarkEvent = serde_json::from_slice(payload)
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
@ -359,7 +333,7 @@ impl FeishuChannel {
Ok(())
}
async fn run_ws_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> {
async fn run_ws_loop(&self, bus: Arc<MessageBus>, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> {
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
let service_id = Self::extract_service_id(&wss_url);
@ -404,7 +378,6 @@ impl FeishuChannel {
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
let bytes: Bytes = data;
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
// Parse the frame first
match self.handle_frame(&frame).await {
Ok(Some(parsed)) => {
// Send ACK immediately (Feishu requires within 3 s)
@ -412,11 +385,12 @@ impl FeishuChannel {
tracing::error!(error = %e, "Failed to send ACK to Feishu");
}
// Then process message asynchronously (don't await)
// Publish to bus asynchronously
let channel = self.clone();
let bus = bus.clone();
tokio::spawn(async move {
if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await {
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to handle Feishu message");
if let Err(e) = channel.handle_and_publish(&bus, &parsed.open_id, &parsed.chat_id, &parsed.content).await {
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus");
}
});
}
@ -528,7 +502,7 @@ impl Channel for FeishuChannel {
"feishu"
}
async fn start(&self) -> Result<(), ChannelError> {
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
if self.config.app_id.is_empty() || self.config.app_secret.is_empty() {
return Err(ChannelError::ConfigError(
"Feishu app_id or app_secret is not configured".to_string()
@ -541,6 +515,7 @@ impl Channel for FeishuChannel {
*self.shutdown_tx.write().await = Some(shutdown_tx.clone());
let channel = self.clone();
let bus = bus.clone();
tokio::spawn(async move {
let mut consecutive_failures = 0;
let max_failures = 3;
@ -551,7 +526,7 @@ impl Channel for FeishuChannel {
}
let shutdown_rx = shutdown_tx.subscribe();
match channel.run_ws_loop(shutdown_rx).await {
match channel.run_ws_loop(bus.clone(), shutdown_rx).await {
Ok(_) => {
tracing::info!("Feishu WebSocket disconnected");
}
@ -592,7 +567,13 @@ impl Channel for FeishuChannel {
}
fn is_running(&self) -> bool {
// Note: blocking read, acceptable for this use case
self.running.try_read().map(|r| *r).unwrap_or(false)
}
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) };
let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await
}
}

View File

@ -1,51 +1,38 @@
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::feishu::FeishuChannel;
use crate::config::{Config, FeishuChannelConfig};
use crate::config::Config;
/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦
#[async_trait]
pub trait MessageHandler: Send + Sync {
async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
) -> Result<String, ChannelError>;
}
/// ChannelManager 管理所有 Channel
/// ChannelManager manages all Channel instances and the MessageBus
#[derive(Clone)]
pub struct ChannelManager {
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
message_handler: Arc<dyn MessageHandler>,
bus: Arc<MessageBus>,
}
impl ChannelManager {
pub fn new(message_handler: Arc<dyn MessageHandler>) -> Self {
pub fn new() -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
message_handler,
bus: MessageBus::new(100),
}
}
/// 获取 MessageHandler 用于让 Channel 调用
pub fn get_handler(&self) -> Arc<dyn MessageHandler> {
self.message_handler.clone()
/// Get a reference to the MessageBus
pub fn bus(&self) -> Arc<MessageBus> {
self.bus.clone()
}
/// 初始化所有 Channel
pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
/// Initialize all Channel instances from config
pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
// Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled {
let handler = self.get_handler();
let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config)
let channel = FeishuChannel::new(feishu_config.clone(), _provider_config)
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
self.channels
@ -62,9 +49,10 @@ impl ChannelManager {
pub async fn start_all(&self) -> Result<(), ChannelError> {
let channels = self.channels.read().await;
let bus = self.bus.clone();
for (name, channel) in channels.iter() {
tracing::info!(channel = %name, "Starting channel");
if let Err(e) = channel.start().await {
if let Err(e) = channel.start(bus.clone()).await {
tracing::error!(channel = %name, error = %e, "Failed to start channel");
}
}
@ -86,32 +74,14 @@ impl ChannelManager {
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel + Send + Sync>> {
self.channels.read().await.get(name).cloned()
}
}
/// Gateway 实现 MessageHandler trait
#[derive(Clone)]
pub struct GatewayMessageHandler {
session_manager: crate::gateway::session::SessionManager,
}
impl GatewayMessageHandler {
pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self {
Self { session_manager }
}
}
#[async_trait]
impl MessageHandler for GatewayMessageHandler {
async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
) -> Result<String, ChannelError> {
self.session_manager
.handle_message(channel_name, sender_id, chat_id, content)
.await
.map_err(|e| ChannelError::Other(e.to_string()))
/// Dispatch an outbound message to the appropriate channel
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let channel_name = &msg.channel;
if let Some(channel) = self.get_channel(channel_name).await {
channel.send(msg).await
} else {
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
}
}
}

View File

@ -2,6 +2,6 @@ pub mod base;
pub mod feishu;
pub mod manager;
pub use base::{Channel, ChannelError, InboundMessage, OutboundMessage};
pub use base::{Channel, ChannelError};
pub use manager::ChannelManager;
pub use feishu::FeishuChannel;

View File

@ -6,7 +6,8 @@ use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::channels::{ChannelManager, manager::GatewayMessageHandler};
use crate::bus::{MessageBus, OutboundDispatcher};
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::logging;
use session::SessionManager;
@ -15,6 +16,7 @@ pub struct GatewayState {
pub config: Config,
pub session_manager: SessionManager,
pub channel_manager: ChannelManager,
pub bus: Arc<MessageBus>,
}
impl GatewayState {
@ -28,15 +30,76 @@ impl GatewayState {
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let session_manager = SessionManager::new(session_ttl_hours, provider_config);
let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone()));
let channel_manager = ChannelManager::new(message_handler);
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();
Ok(Self {
config,
session_manager,
channel_manager,
bus,
})
}
/// Start the message processing loops
pub async fn start_message_processing(&self) {
let bus_for_inbound = self.bus.clone();
let bus_for_outbound = self.bus.clone();
let session_manager = self.session_manager.clone();
// Spawn inbound message processor
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
tokio::spawn(async move {
tracing::info!("Inbound processor started");
loop {
let inbound = bus_for_inbound.consume_inbound().await;
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
"Processing inbound message"
);
// Process via session manager
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
).await {
Ok(response_content) => {
let outbound = crate::bus::OutboundMessage {
channel: inbound.channel,
chat_id: inbound.chat_id,
content: response_content,
reply_to: None,
media: vec![],
metadata: std::collections::HashMap::new(),
};
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
}
}
Err(e) => {
tracing::error!(error = %e, "Failed to handle message");
}
}
}
});
// Spawn outbound dispatcher
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
let channel_manager = self.channel_manager.clone();
// Register channels with dispatcher
if let Some(channel) = channel_manager.get_channel("feishu").await {
dispatcher.register_channel("feishu", channel).await;
}
tokio::spawn(async move {
tracing::info!("Outbound dispatcher started");
dispatcher.run().await;
});
}
}
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
@ -50,9 +113,12 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
let provider_config = state.config.get_provider_config("default")?;
// Initialize and start channels
state.channel_manager.init(&state.config, provider_config).await?;
state.channel_manager.init(&state.config, provider_config.clone()).await?;
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + outbound dispatcher)
state.start_message_processing().await;
// CLI args override config file values
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
let bind_port = port.unwrap_or(state.config.gateway.port);

View File

@ -10,11 +10,12 @@ use crate::protocol::WsOutbound;
use crate::tools::{CalculatorTool, ToolRegistry};
/// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理
pub struct Session {
pub id: Uuid,
pub channel_name: String,
/// 按 chat_id 路由到不同 AgentLoop,支持多用户多会话
chat_agents: HashMap<String, Arc<Mutex<AgentLoop>>>,
/// 按 chat_id 路由到不同会话历史,支持多用户多会话
chat_histories: HashMap<String, Vec<ChatMessage>>,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
@ -30,52 +31,65 @@ impl Session {
Ok(Self {
id: Uuid::new_v4(),
channel_name,
chat_agents: HashMap::new(),
chat_histories: HashMap::new(),
user_tx,
provider_config,
tools,
})
}
/// 获取或创建指定 chat_id 的 AgentLoop
pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result<Arc<Mutex<AgentLoop>>, AgentError> {
if let Some(agent) = self.chat_agents.get(chat_id) {
tracing::trace!(chat_id = %chat_id, "Reusing existing agent");
return Ok(agent.clone());
}
tracing::debug!(chat_id = %chat_id, "Creating new agent for chat");
let agent = AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())?;
let arc = Arc::new(Mutex::new(agent));
self.chat_agents.insert(chat_id.to_string(), arc.clone());
Ok(arc)
/// 获取或创建指定 chat_id 的会话历史
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
self.chat_histories
.entry(chat_id.to_string())
.or_insert_with(Vec::new)
}
/// 获取指定 chat_id 的 AgentLoop不创建
pub fn get_agent(&self, chat_id: &str) -> Option<Arc<Mutex<AgentLoop>>> {
self.chat_agents.get(chat_id).cloned()
/// 获取指定 chat_id 的会话历史(不创建)
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
self.chat_histories.get(chat_id)
}
/// 添加用户消息到指定 chat_id 的历史
pub fn add_user_message(&mut self, chat_id: &str, content: &str) {
let history = self.get_or_create_history(chat_id);
history.push(ChatMessage::user(content));
}
/// 添加助手响应到指定 chat_id 的历史
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
history.push(message);
}
}
/// 清除指定 chat_id 的历史
pub async fn clear_chat_history(&mut self, chat_id: &str) {
if let Some(agent) = self.chat_agents.get(chat_id) {
agent.lock().await.clear_history();
pub fn clear_chat_history(&mut self, chat_id: &str) {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
}
}
/// 清除所有历史
pub async fn clear_all_history(&mut self) {
for agent in self.chat_agents.values() {
agent.lock().await.clear_history();
}
pub fn clear_all_history(&mut self) {
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear();
tracing::debug!(previous_total = total, "All chat histories cleared");
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
}
}
/// SessionManager 管理所有 Session按 channel_name 路由
/// 使用 Arc<Mutex<SessionManager>> 以从 Arc 获取可变访问
#[derive(Clone)]
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
@ -135,7 +149,13 @@ impl SessionManager {
// 创建新 session使用临时 user_tx因为 Feishu 不通过 WS
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone()).await?;
let session = Session::new(
channel_name.to_string(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
)
.await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(channel_name.to_string(), arc.clone());
@ -165,7 +185,12 @@ impl SessionManager {
chat_id: &str,
content: &str,
) -> Result<String, AgentError> {
tracing::debug!(channel = %channel_name, chat_id = %chat_id, content_len = content.len(), "Routing message to agent");
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
"Routing message to agent"
);
// 确保 session 存在(可能需要重建)
self.ensure_session(channel_name).await?;
@ -174,21 +199,37 @@ impl SessionManager {
self.touch(channel_name).await;
// 获取 session
let session = self.get(channel_name).await
let session = self
.get(channel_name)
.await
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
// 获取或创建 chat_id 对应的 agent
let mut session_guard = session.lock().await;
let agent = session_guard.get_or_create_agent(chat_id).await?;
drop(session_guard);
let mut agent = agent.lock().await;
// 处理消息
let user_msg = ChatMessage::user(content);
let response = agent.process(user_msg).await?;
let response = {
let mut session_guard = session.lock().await;
tracing::debug!(channel = %channel_name, chat_id = %chat_id, response_len = response.content.len(), "Agent response received");
// 添加用户消息到历史
session_guard.add_user_message(chat_id, content);
// 获取完整历史
let history = session_guard.get_or_create_history(chat_id).clone();
// 创建 agent 并处理
let agent = session_guard.create_agent()?;
let response = agent.process(history).await?;
// 添加助手响应到历史
session_guard.add_assistant_message(chat_id, response.clone());
response
};
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
response_len = response.content.len(),
"Agent response received"
);
Ok(response.content)
}
@ -197,7 +238,7 @@ impl SessionManager {
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
if let Some(session) = self.get(channel_name).await {
let mut session_guard = session.lock().await;
session_guard.clear_all_history().await;
session_guard.clear_all_history();
}
Ok(())
}

View File

@ -4,7 +4,6 @@ use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
use super::{GatewayState, session::Session};
@ -29,7 +28,14 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
// 创建 CLI session
let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await {
let session = match Session::new(
channel_name.clone(),
provider_config,
sender,
state.session_manager.tools(),
)
.await
{
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
@ -40,9 +46,13 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let session_id = session.lock().await.id;
tracing::info!(session_id = %session_id, "CLI session established");
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
let _ = session
.lock()
.await
.send(WsOutbound::SessionEstablished {
session_id: session_id.to_string(),
}).await;
})
.await;
let (mut ws_sender, mut ws_receiver) = ws.split();
@ -69,10 +79,14 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = session.lock().await.send(WsOutbound::Error {
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
}).await;
})
.await;
}
}
}
@ -92,7 +106,12 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
// 提取 content 和 chat_idCLI 使用 session id 作为 chat_id
let (content, chat_id) = match inbound_clone {
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
WsInbound::UserInput {
content,
channel: _,
chat_id,
sender_id: _,
} => {
// CLI 使用 session 中的 channel_name 作为标识
// chat_id 使用传入的或使用默认
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
@ -101,38 +120,50 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
_ => return,
};
let user_msg = ChatMessage::user(content);
let mut session_guard = session.lock().await;
let agent = match session_guard.get_or_create_agent(&chat_id).await {
// 添加用户消息到历史
session_guard.add_user_message(&chat_id, &content);
// 获取完整历史
let history = session_guard.get_or_create_history(&chat_id).clone();
// 创建 agent 并处理
let agent = match session_guard.create_agent() {
Ok(a) => a,
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent");
let _ = session_guard.send(WsOutbound::Error {
tracing::error!(chat_id = %chat_id, error = %e, "Failed to create agent");
let _ = session_guard
.send(WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: e.to_string(),
}).await;
})
.await;
return;
}
};
drop(session_guard);
let mut agent = agent.lock().await;
match agent.process(user_msg).await {
match agent.process(history).await {
Ok(response) => {
tracing::debug!(chat_id = %chat_id, "Agent response sent");
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
// 添加助手响应到历史
session_guard.add_assistant_message(&chat_id, response.clone());
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
}).await;
})
.await;
}
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
let _ = session.lock().await.send(WsOutbound::Error {
let _ = session_guard
.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: e.to_string(),
}).await;
})
.await;
}
}
}