From 34ab439067c15291546f7bb1d9b7098cac81fdea Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Mon, 6 Apr 2026 22:38:41 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=B6=88=E6=81=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=B7=BB=E5=8A=A0=20Messag?= =?UTF-8?q?eHandler=20trait=EF=BC=8C=E6=94=AF=E6=8C=81=E5=A4=9A=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E4=BC=9A=E8=AF=9D=EF=BC=8C=E6=9B=B4=E6=96=B0=20Feishu?= =?UTF-8?q?Channel=20=E5=92=8C=20SessionManager=EF=BC=8C=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent_loop.rs | 2 + src/channels/feishu.rs | 91 +++++++++++------------ src/channels/manager.rs | 69 ++++++++++++++---- src/client/mod.rs | 9 ++- src/config/mod.rs | 3 + src/gateway/mod.rs | 21 +++++- src/gateway/session.rs | 158 +++++++++++++++++++++++++++++++++------- src/gateway/ws.rs | 88 ++++++++++++---------- src/protocol.rs | 15 +++- 9 files changed, 323 insertions(+), 133 deletions(-) diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index fbf60b3..ee2134b 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -58,6 +58,7 @@ impl AgentLoop { pub enum AgentError { ProviderCreation(String), LlmError(String), + Other(String), } impl std::fmt::Display for AgentError { @@ -65,6 +66,7 @@ impl std::fmt::Display for AgentError { match self { AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e), AgentError::LlmError(e) => write!(f, "LLM error: {}", e), + AgentError::Other(e) => write!(f, "{}", e), } } } diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index 9f2aac6..e8ec13b 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -1,15 +1,12 @@ -use std::collections::HashMap; use std::sync::Arc; -use std::time::Instant; use async_trait::async_trait; -use tokio::sync::{broadcast, RwLock, Mutex}; +use tokio::sync::{broadcast, RwLock}; use serde::Deserialize; use futures_util::{SinkExt, StreamExt}; use prost::{Message as ProstMessage, bytes::Bytes}; -use crate::agent::AgentLoop; -use crate::bus::ChatMessage; use crate::channels::base::{Channel, ChannelError}; +use crate::channels::manager::MessageHandler; use crate::config::{FeishuChannelConfig, LLMProviderConfig}; const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis"; @@ -134,25 +131,31 @@ pub struct FeishuChannel { running: Arc>, shutdown_tx: Arc>>>, connected: Arc>, - /// Dedup: message_id -> timestamp (cleaned after 30 min) - seen_ids: Arc>>, - /// Agent for processing messages - agent: Arc>, + /// Message handler for routing messages to Gateway + message_handler: Arc, +} + +/// Parsed message data from a Feishu frame +struct ParsedMessage { + message_id: String, + open_id: String, + chat_id: String, + content: String, } impl FeishuChannel { - pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result { - let agent = AgentLoop::new(provider_config) - .map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?; - + pub fn new( + config: FeishuChannelConfig, + message_handler: Arc, + _provider_config: LLMProviderConfig, + ) -> Result { Ok(Self { config, http_client: reqwest::Client::new(), running: Arc::new(RwLock::new(false)), shutdown_tx: Arc::new(RwLock::new(None)), connected: Arc::new(RwLock::new(false)), - seen_ids: Arc::new(RwLock::new(HashMap::new())), - agent: Arc::new(Mutex::new(agent)), + message_handler, }) } @@ -259,22 +262,21 @@ impl FeishuChannel { Ok(()) } - /// Handle incoming message - process through agent and send response + /// Handle incoming message - delegate to message handler and send response async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> { println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content); - // Process through agent - let user_msg = ChatMessage::user(content); - let mut agent = self.agent.lock().await; - let response = agent.process(user_msg).await - .map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?; + // Delegate to message handler (Gateway) + let response = self.message_handler + .handle_message("feishu", open_id, chat_id, content) + .await?; // Send response to the chat // Use open_id for p2p chats, chat_id for group chats let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id }; let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" }; - self.send_message(receive_id, receive_id_type, &response.content).await?; + self.send_message(receive_id, receive_id_type, &response).await?; println!("Feishu: sent response to {}", receive_id); Ok(()) @@ -293,8 +295,8 @@ impl FeishuChannel { .unwrap_or(0) } - /// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack - async fn handle_frame(&self, frame: &PbFrame) -> Result, ChannelError> { + /// Handle incoming binary PbFrame - returns Some(ParsedMessage) if we need to ack + async fn handle_frame(&self, frame: &PbFrame) -> Result, ChannelError> { // method 0 = CONTROL (ping/pong) if frame.method == 0 { return Ok(None); @@ -325,20 +327,7 @@ impl FeishuChannel { return Ok(None); } - // Deduplication check with TTL cleanup let message_id = payload_data.message.message_id.clone(); - { - let mut seen = self.seen_ids.write().await; - let now = Instant::now(); - - // Clean expired entries (older than 30 min) - seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800); - - if seen.contains_key(&message_id) { - return Ok(None); - } - seen.insert(message_id.clone(), now); - } let open_id = payload_data.sender.sender_id.open_id .ok_or_else(|| ChannelError::Other("No open_id".to_string()))?; @@ -348,13 +337,12 @@ impl FeishuChannel { let msg_type = msg.message_type.as_str(); let content = parse_message_content(msg_type, &msg.content); - // Handle the message - process and send response - if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await { - eprintln!("Error handling message: {}", e); - } - - // Return message_id for ack - Ok(Some(message_id)) + Ok(Some(ParsedMessage { + message_id, + open_id, + chat_id, + content, + })) } /// Send acknowledgment for a message @@ -416,13 +404,21 @@ impl FeishuChannel { Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => { let bytes: Bytes = data; if let Ok(frame) = PbFrame::decode(bytes.as_ref()) { - // Handle the frame and get message_id for ack if needed + // Parse the frame first match self.handle_frame(&frame).await { - Ok(Some(_message_id)) => { + Ok(Some(parsed)) => { // Send ACK immediately (Feishu requires within 3 s) if let Err(e) = Self::send_ack(&frame, &mut write).await { eprintln!("Error sending ack: {}", e); } + + // Then process message asynchronously (don't await) + let channel = self.clone(); + tokio::spawn(async move { + if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await { + eprintln!("Error handling message: {}", e); + } + }); } Ok(None) => {} Err(e) => { @@ -595,6 +591,7 @@ impl Channel for FeishuChannel { } fn is_running(&self) -> bool { - false + // Note: blocking read, acceptable for this use case + self.running.try_read().map(|r| *r).unwrap_or(false) } } diff --git a/src/channels/manager.rs b/src/channels/manager.rs index f048c37..2369d21 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -1,32 +1,51 @@ use std::collections::HashMap; use std::sync::Arc; +use async_trait::async_trait; use tokio::sync::RwLock; use crate::channels::base::{Channel, ChannelError}; use crate::channels::feishu::FeishuChannel; -use crate::config::Config; +use crate::config::{Config, FeishuChannelConfig}; +/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦 +#[async_trait] +pub trait MessageHandler: Send + Sync { + async fn handle_message( + &self, + channel_name: &str, + sender_id: &str, + chat_id: &str, + content: &str, + ) -> Result; +} + +/// ChannelManager 管理所有 Channel #[derive(Clone)] pub struct ChannelManager { - channels: Arc>>>, + channels: Arc>>>, + message_handler: Arc, } impl ChannelManager { - pub fn new() -> Self { + pub fn new(message_handler: Arc) -> Self { Self { channels: Arc::new(RwLock::new(HashMap::new())), + message_handler, } } - pub async fn init(&self, config: &Config) -> Result<(), ChannelError> { + /// 获取 MessageHandler 用于让 Channel 调用 + pub fn get_handler(&self) -> Arc { + self.message_handler.clone() + } + + /// 初始化所有 Channel + pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> { // Initialize Feishu channel if enabled if let Some(feishu_config) = config.channels.get("feishu") { if feishu_config.enabled { - let agent_name = &feishu_config.agent; - let provider_config = config.get_provider_config(agent_name) - .map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?; - - let channel = FeishuChannel::new(feishu_config.clone(), provider_config) + let handler = self.get_handler(); + let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config) .map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; self.channels @@ -47,8 +66,6 @@ impl ChannelManager { println!("Starting channel: {}", name); if let Err(e) = channel.start().await { eprintln!("Warning: Failed to start channel {}: {}", name, e); - // Channel failed to start - it should have logged why - // Continue starting other channels } } Ok(()) @@ -66,13 +83,35 @@ impl ChannelManager { Ok(()) } - pub async fn get_channel(&self, name: &str) -> Option> { + pub async fn get_channel(&self, name: &str) -> Option> { self.channels.read().await.get(name).cloned() } } -impl Default for ChannelManager { - fn default() -> Self { - Self::new() +/// Gateway 实现 MessageHandler trait +#[derive(Clone)] +pub struct GatewayMessageHandler { + session_manager: crate::gateway::session::SessionManager, +} + +impl GatewayMessageHandler { + pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self { + Self { session_manager } + } +} + +#[async_trait] +impl MessageHandler for GatewayMessageHandler { + async fn handle_message( + &self, + channel_name: &str, + sender_id: &str, + chat_id: &str, + content: &str, + ) -> Result { + self.session_manager + .handle_message(channel_name, sender_id, chat_id, content) + .await + .map_err(|e| ChannelError::Other(e.to_string())) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index de792f5..64d2b44 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -58,7 +58,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { break; } "__CLEAR__" => { - let inbound = WsInbound::ClearHistory; + let inbound = WsInbound::ClearHistory { chat_id: None }; if let Ok(text) = serialize_inbound(&inbound) { let _ = sender.send(Message::Text(text.into())).await; } @@ -67,7 +67,12 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { _ => {} } - let inbound = WsInbound::UserInput { content: msg.content }; + let inbound = WsInbound::UserInput { + content: msg.content, + channel: None, + chat_id: None, + sender_id: None, + }; if let Ok(text) = serialize_inbound(&inbound) { if sender.send(Message::Text(text.into())).await.is_err() { eprintln!("Failed to send message"); diff --git a/src/config/mod.rs b/src/config/mod.rs index 85155e2..d7310c5 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -67,6 +67,8 @@ pub struct GatewayConfig { pub host: String, #[serde(default = "default_gateway_port")] pub port: u16, + #[serde(default, rename = "session_ttl_hours")] + pub session_ttl_hours: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -92,6 +94,7 @@ impl Default for GatewayConfig { Self { host: default_gateway_host(), port: default_gateway_port(), + session_ttl_hours: None, } } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2afa82a..c3ddded 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use axum::{routing, Router}; use tokio::net::TcpListener; -use crate::channels::ChannelManager; +use crate::channels::{ChannelManager, manager::GatewayMessageHandler}; use crate::config::Config; use session::SessionManager; @@ -19,10 +19,20 @@ pub struct GatewayState { impl GatewayState { pub fn new() -> Result> { let config = Config::load_default()?; - let channel_manager = ChannelManager::new(); + + // Get provider config for SessionManager + let provider_config = config.get_provider_config("default")?; + + // Session TTL from config (default 4 hours) + let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); + + let session_manager = SessionManager::new(session_ttl_hours, provider_config); + let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone())); + let channel_manager = ChannelManager::new(message_handler); + Ok(Self { config, - session_manager: SessionManager::new(), + session_manager, channel_manager, }) } @@ -31,8 +41,11 @@ impl GatewayState { pub async fn run(host: Option, port: Option) -> Result<(), Box> { let state = Arc::new(GatewayState::new()?); + // Get provider config for channels + let provider_config = state.config.get_provider_config("default")?; + // Initialize and start channels - state.channel_manager.init(&state.config).await?; + state.channel_manager.init(&state.config, provider_config).await?; state.channel_manager.start_all().await?; // CLI args override config file values diff --git a/src/gateway/session.rs b/src/gateway/session.rs index bfc48a6..f01be10 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,67 +1,175 @@ +use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; +use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; -use crate::agent::AgentLoop; +use crate::agent::{AgentLoop, AgentError}; use crate::protocol::WsOutbound; +/// Session 按 channel 隔离,每个 channel 一个 Session pub struct Session { pub id: Uuid, - pub agent_loop: Arc>, + pub channel_name: String, + /// 按 chat_id 路由到不同 AgentLoop,支持多用户多会话 + chat_agents: HashMap>>, pub user_tx: mpsc::Sender, + provider_config: LLMProviderConfig, } impl Session { pub async fn new( + channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, - ) -> Result { - let agent_loop = AgentLoop::new(provider_config)?; + ) -> Result { Ok(Self { id: Uuid::new_v4(), - agent_loop: Arc::new(Mutex::new(agent_loop)), + channel_name, + chat_agents: HashMap::new(), user_tx, + provider_config, }) } + /// 获取或创建指定 chat_id 的 AgentLoop + pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result>, AgentError> { + if let Some(agent) = self.chat_agents.get(chat_id) { + return Ok(agent.clone()); + } + let agent = AgentLoop::new(self.provider_config.clone())?; + let arc = Arc::new(Mutex::new(agent)); + self.chat_agents.insert(chat_id.to_string(), arc.clone()); + Ok(arc) + } + + /// 获取指定 chat_id 的 AgentLoop(不创建) + pub fn get_agent(&self, chat_id: &str) -> Option>> { + self.chat_agents.get(chat_id).cloned() + } + + /// 清除指定 chat_id 的历史 + pub async fn clear_chat_history(&mut self, chat_id: &str) { + if let Some(agent) = self.chat_agents.get(chat_id) { + agent.lock().await.clear_history(); + } + } + + /// 清除所有历史 + pub async fn clear_all_history(&mut self) { + for agent in self.chat_agents.values() { + agent.lock().await.clear_history(); + } + } + pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } } -use std::collections::HashMap; -use std::sync::RwLock; - +/// SessionManager 管理所有 Session,按 channel_name 路由 +/// 使用 Arc> 以从 Arc 获取可变访问 +#[derive(Clone)] pub struct SessionManager { - sessions: RwLock>>, + inner: Arc>, + provider_config: LLMProviderConfig, +} + +struct SessionManagerInner { + sessions: HashMap>>, + session_timestamps: HashMap, + session_ttl: Duration, } impl SessionManager { - pub fn new() -> Self { + pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self { Self { - sessions: RwLock::new(HashMap::new()), + inner: Arc::new(Mutex::new(SessionManagerInner { + sessions: HashMap::new(), + session_timestamps: HashMap::new(), + session_ttl: Duration::from_secs(session_ttl_hours * 3600), + })), + provider_config, } } - pub fn add(&self, session: Arc) { - self.sessions.write().unwrap().insert(session.id, session); + /// 确保 session 存在且未超时,超时则重建 + pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { + let mut inner = self.inner.lock().await; + + let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) { + last_active.elapsed() > inner.session_ttl + } else { + true + }; + + if should_recreate { + // 移除旧 session + inner.sessions.remove(channel_name); + + // 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS) + let (user_tx, _rx) = mpsc::channel::(100); + let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx).await?; + let arc = Arc::new(Mutex::new(session)); + + inner.sessions.insert(channel_name.to_string(), arc.clone()); + inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); + } + + Ok(()) } - pub fn remove(&self, id: &Uuid) { - self.sessions.write().unwrap().remove(id); + /// 获取 session(不检查超时) + pub async fn get(&self, channel_name: &str) -> Option>> { + let inner = self.inner.lock().await; + inner.sessions.get(channel_name).cloned() } - pub fn get(&self, id: &Uuid) -> Option> { - self.sessions.read().unwrap().get(id).cloned() + /// 更新最后活跃时间 + pub async fn touch(&self, channel_name: &str) { + let mut inner = self.inner.lock().await; + inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); } - pub fn len(&self) -> usize { - self.sessions.read().unwrap().len() - } -} - -impl Default for SessionManager { - fn default() -> Self { - Self::new() + /// 处理消息:路由到对应 session 的 agent + pub async fn handle_message( + &self, + channel_name: &str, + _sender_id: &str, + chat_id: &str, + content: &str, + ) -> Result { + // 确保 session 存在(可能需要重建) + self.ensure_session(channel_name).await?; + + // 更新活跃时间 + self.touch(channel_name).await; + + // 获取 session + let session = self.get(channel_name).await + .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; + + // 获取或创建 chat_id 对应的 agent + let mut session_guard = session.lock().await; + let agent = session_guard.get_or_create_agent(chat_id).await?; + drop(session_guard); + + let mut agent = agent.lock().await; + + // 处理消息 + let user_msg = ChatMessage::user(content); + let response = agent.process(user_msg).await?; + + Ok(response.content) + } + + /// 清除指定 session 的所有历史 + pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { + if let Some(session) = self.get(channel_name).await { + let mut session_guard = session.lock().await; + session_guard.clear_all_history().await; + } + Ok(()) } } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index f0a80a8..91e9b2d 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -3,7 +3,7 @@ use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use axum::extract::State; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, Mutex}; use crate::bus::ChatMessage; use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound}; use super::{GatewayState, session::Session}; @@ -25,18 +25,21 @@ async fn handle_socket(ws: WebSocket, state: Arc) { } }; - let session = match Session::new(provider_config, sender).await { - Ok(s) => Arc::new(s), + // CLI 使用独立的 session,channel_name = "cli-{uuid}" + let channel_name = format!("cli-{}", uuid::Uuid::new_v4()); + + // 创建 CLI session + let session = match Session::new(channel_name.clone(), provider_config, sender).await { + Ok(s) => Arc::new(Mutex::new(s)), Err(e) => { eprintln!("Failed to create session: {}", e); return; } }; - let session_id = session.id; - state.session_manager.add(session.clone()); + let session_id = session.lock().await.id; - let _ = session.send(WsOutbound::SessionEstablished { + let _ = session.lock().await.send(WsOutbound::SessionEstablished { session_id: session_id.to_string(), }).await; @@ -62,7 +65,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { handle_inbound(&session, inbound).await; } Err(e) => { - let _ = session.send(WsOutbound::Error { + let _ = session.lock().await.send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), }).await; @@ -75,42 +78,51 @@ async fn handle_socket(ws: WebSocket, state: Arc) { _ => {} } } - - state.session_manager.remove(&session_id); } -async fn handle_inbound(session: &Arc, inbound: WsInbound) { - match inbound { - WsInbound::UserInput { content } => { - let user_msg = ChatMessage::user(content); - let mut agent = session.agent_loop.lock().await; - match agent.process(user_msg).await { - Ok(response) => { - let _ = session.send(WsOutbound::AssistantResponse { - id: response.id, - content: response.content, - role: response.role, - }).await; - } - Err(e) => { - let _ = session.send(WsOutbound::Error { - code: "LLM_ERROR".to_string(), - message: e.to_string(), - }).await; - } - } +async fn handle_inbound(session: &Arc>, inbound: WsInbound) { + let inbound_clone = inbound.clone(); + + // 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id) + let (content, chat_id) = match inbound_clone { + WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => { + // CLI 使用 session 中的 channel_name 作为标识 + // chat_id 使用传入的或使用默认 + let chat_id = chat_id.unwrap_or_else(|| "default".to_string()); + (content, chat_id) } - WsInbound::ClearHistory => { - let mut agent = session.agent_loop.lock().await; - agent.clear_history(); - let _ = session.send(WsOutbound::AssistantResponse { - id: uuid::Uuid::new_v4().to_string(), - content: "History cleared.".to_string(), - role: "system".to_string(), + _ => return, + }; + + let user_msg = ChatMessage::user(content); + + let mut session_guard = session.lock().await; + let agent = match session_guard.get_or_create_agent(&chat_id).await { + Ok(a) => a, + Err(e) => { + let _ = session_guard.send(WsOutbound::Error { + code: "AGENT_ERROR".to_string(), + message: e.to_string(), + }).await; + return; + } + }; + drop(session_guard); + + let mut agent = agent.lock().await; + match agent.process(user_msg).await { + Ok(response) => { + let _ = session.lock().await.send(WsOutbound::AssistantResponse { + id: response.id, + content: response.content, + role: response.role, }).await; } - WsInbound::Ping => { - let _ = session.send(WsOutbound::Pong).await; + Err(e) => { + let _ = session.lock().await.send(WsOutbound::Error { + code: "LLM_ERROR".to_string(), + message: e.to_string(), + }).await; } } } diff --git a/src/protocol.rs b/src/protocol.rs index 4e8a22a..8d24404 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -4,9 +4,20 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "type")] pub enum WsInbound { #[serde(rename = "user_input")] - UserInput { content: String }, + UserInput { + content: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + channel: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + chat_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + sender_id: Option, + }, #[serde(rename = "clear_history")] - ClearHistory, + ClearHistory { + #[serde(default, skip_serializing_if = "Option::is_none")] + chat_id: Option, + }, #[serde(rename = "ping")] Ping, }