重构消息处理逻辑,添加 MessageHandler trait,支持多用户会话,更新 FeishuChannel 和 SessionManager,增强错误处理
This commit is contained in:
parent
04736f9f46
commit
34ab439067
@ -58,6 +58,7 @@ impl AgentLoop {
|
||||
pub enum AgentError {
|
||||
ProviderCreation(String),
|
||||
LlmError(String),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AgentError {
|
||||
@ -65,6 +66,7 @@ impl std::fmt::Display for AgentError {
|
||||
match self {
|
||||
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
||||
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
||||
AgentError::Other(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::{broadcast, RwLock, Mutex};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use serde::Deserialize;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::{Message as ProstMessage, bytes::Bytes};
|
||||
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::manager::MessageHandler;
|
||||
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
||||
|
||||
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
||||
@ -134,25 +131,31 @@ pub struct FeishuChannel {
|
||||
running: Arc<RwLock<bool>>,
|
||||
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
||||
connected: Arc<RwLock<bool>>,
|
||||
/// Dedup: message_id -> timestamp (cleaned after 30 min)
|
||||
seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
/// Agent for processing messages
|
||||
agent: Arc<Mutex<AgentLoop>>,
|
||||
/// Message handler for routing messages to Gateway
|
||||
message_handler: Arc<dyn MessageHandler>,
|
||||
}
|
||||
|
||||
/// Parsed message data from a Feishu frame
|
||||
struct ParsedMessage {
|
||||
message_id: String,
|
||||
open_id: String,
|
||||
chat_id: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl FeishuChannel {
|
||||
pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result<Self, ChannelError> {
|
||||
let agent = AgentLoop::new(provider_config)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?;
|
||||
|
||||
pub fn new(
|
||||
config: FeishuChannelConfig,
|
||||
message_handler: Arc<dyn MessageHandler>,
|
||||
_provider_config: LLMProviderConfig,
|
||||
) -> Result<Self, ChannelError> {
|
||||
Ok(Self {
|
||||
config,
|
||||
http_client: reqwest::Client::new(),
|
||||
running: Arc::new(RwLock::new(false)),
|
||||
shutdown_tx: Arc::new(RwLock::new(None)),
|
||||
connected: Arc::new(RwLock::new(false)),
|
||||
seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
agent: Arc::new(Mutex::new(agent)),
|
||||
message_handler,
|
||||
})
|
||||
}
|
||||
|
||||
@ -259,22 +262,21 @@ impl FeishuChannel {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle incoming message - process through agent and send response
|
||||
/// Handle incoming message - delegate to message handler and send response
|
||||
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
|
||||
println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content);
|
||||
|
||||
// Process through agent
|
||||
let user_msg = ChatMessage::user(content);
|
||||
let mut agent = self.agent.lock().await;
|
||||
let response = agent.process(user_msg).await
|
||||
.map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?;
|
||||
// Delegate to message handler (Gateway)
|
||||
let response = self.message_handler
|
||||
.handle_message("feishu", open_id, chat_id, content)
|
||||
.await?;
|
||||
|
||||
// Send response to the chat
|
||||
// Use open_id for p2p chats, chat_id for group chats
|
||||
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
|
||||
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||
|
||||
self.send_message(receive_id, receive_id_type, &response.content).await?;
|
||||
self.send_message(receive_id, receive_id_type, &response).await?;
|
||||
println!("Feishu: sent response to {}", receive_id);
|
||||
|
||||
Ok(())
|
||||
@ -293,8 +295,8 @@ impl FeishuChannel {
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack
|
||||
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<String>, ChannelError> {
|
||||
/// Handle incoming binary PbFrame - returns Some(ParsedMessage) if we need to ack
|
||||
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<ParsedMessage>, ChannelError> {
|
||||
// method 0 = CONTROL (ping/pong)
|
||||
if frame.method == 0 {
|
||||
return Ok(None);
|
||||
@ -325,20 +327,7 @@ impl FeishuChannel {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Deduplication check with TTL cleanup
|
||||
let message_id = payload_data.message.message_id.clone();
|
||||
{
|
||||
let mut seen = self.seen_ids.write().await;
|
||||
let now = Instant::now();
|
||||
|
||||
// Clean expired entries (older than 30 min)
|
||||
seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800);
|
||||
|
||||
if seen.contains_key(&message_id) {
|
||||
return Ok(None);
|
||||
}
|
||||
seen.insert(message_id.clone(), now);
|
||||
}
|
||||
|
||||
let open_id = payload_data.sender.sender_id.open_id
|
||||
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
||||
@ -348,13 +337,12 @@ impl FeishuChannel {
|
||||
let msg_type = msg.message_type.as_str();
|
||||
let content = parse_message_content(msg_type, &msg.content);
|
||||
|
||||
// Handle the message - process and send response
|
||||
if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await {
|
||||
eprintln!("Error handling message: {}", e);
|
||||
}
|
||||
|
||||
// Return message_id for ack
|
||||
Ok(Some(message_id))
|
||||
Ok(Some(ParsedMessage {
|
||||
message_id,
|
||||
open_id,
|
||||
chat_id,
|
||||
content,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Send acknowledgment for a message
|
||||
@ -416,13 +404,21 @@ impl FeishuChannel {
|
||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
||||
let bytes: Bytes = data;
|
||||
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
||||
// Handle the frame and get message_id for ack if needed
|
||||
// Parse the frame first
|
||||
match self.handle_frame(&frame).await {
|
||||
Ok(Some(_message_id)) => {
|
||||
Ok(Some(parsed)) => {
|
||||
// Send ACK immediately (Feishu requires within 3 s)
|
||||
if let Err(e) = Self::send_ack(&frame, &mut write).await {
|
||||
eprintln!("Error sending ack: {}", e);
|
||||
}
|
||||
|
||||
// Then process message asynchronously (don't await)
|
||||
let channel = self.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await {
|
||||
eprintln!("Error handling message: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
@ -595,6 +591,7 @@ impl Channel for FeishuChannel {
|
||||
}
|
||||
|
||||
fn is_running(&self) -> bool {
|
||||
false
|
||||
// Note: blocking read, acceptable for this use case
|
||||
self.running.try_read().map(|r| *r).unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,32 +1,51 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::feishu::FeishuChannel;
|
||||
use crate::config::Config;
|
||||
use crate::config::{Config, FeishuChannelConfig};
|
||||
|
||||
/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦
|
||||
#[async_trait]
|
||||
pub trait MessageHandler: Send + Sync {
|
||||
async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<String, ChannelError>;
|
||||
}
|
||||
|
||||
/// ChannelManager 管理所有 Channel
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelManager {
|
||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>,
|
||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
||||
message_handler: Arc<dyn MessageHandler>,
|
||||
}
|
||||
|
||||
impl ChannelManager {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(message_handler: Arc<dyn MessageHandler>) -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
message_handler,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init(&self, config: &Config) -> Result<(), ChannelError> {
|
||||
/// 获取 MessageHandler 用于让 Channel 调用
|
||||
pub fn get_handler(&self) -> Arc<dyn MessageHandler> {
|
||||
self.message_handler.clone()
|
||||
}
|
||||
|
||||
/// 初始化所有 Channel
|
||||
pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
|
||||
// Initialize Feishu channel if enabled
|
||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||
if feishu_config.enabled {
|
||||
let agent_name = &feishu_config.agent;
|
||||
let provider_config = config.get_provider_config(agent_name)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?;
|
||||
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), provider_config)
|
||||
let handler = self.get_handler();
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||
|
||||
self.channels
|
||||
@ -47,8 +66,6 @@ impl ChannelManager {
|
||||
println!("Starting channel: {}", name);
|
||||
if let Err(e) = channel.start().await {
|
||||
eprintln!("Warning: Failed to start channel {}: {}", name, e);
|
||||
// Channel failed to start - it should have logged why
|
||||
// Continue starting other channels
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@ -66,13 +83,35 @@ impl ChannelManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel>> {
|
||||
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel + Send + Sync>> {
|
||||
self.channels.read().await.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChannelManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
/// Gateway 实现 MessageHandler trait
|
||||
#[derive(Clone)]
|
||||
pub struct GatewayMessageHandler {
|
||||
session_manager: crate::gateway::session::SessionManager,
|
||||
}
|
||||
|
||||
impl GatewayMessageHandler {
|
||||
pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self {
|
||||
Self { session_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageHandler for GatewayMessageHandler {
|
||||
async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<String, ChannelError> {
|
||||
self.session_manager
|
||||
.handle_message(channel_name, sender_id, chat_id, content)
|
||||
.await
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
break;
|
||||
}
|
||||
"__CLEAR__" => {
|
||||
let inbound = WsInbound::ClearHistory;
|
||||
let inbound = WsInbound::ClearHistory { chat_id: None };
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
@ -67,7 +67,12 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let inbound = WsInbound::UserInput { content: msg.content };
|
||||
let inbound = WsInbound::UserInput {
|
||||
content: msg.content,
|
||||
channel: None,
|
||||
chat_id: None,
|
||||
sender_id: None,
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
if sender.send(Message::Text(text.into())).await.is_err() {
|
||||
eprintln!("Failed to send message");
|
||||
|
||||
@ -67,6 +67,8 @@ pub struct GatewayConfig {
|
||||
pub host: String,
|
||||
#[serde(default = "default_gateway_port")]
|
||||
pub port: u16,
|
||||
#[serde(default, rename = "session_ttl_hours")]
|
||||
pub session_ttl_hours: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -92,6 +94,7 @@ impl Default for GatewayConfig {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
session_ttl_hours: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::channels::ChannelManager;
|
||||
use crate::channels::{ChannelManager, manager::GatewayMessageHandler};
|
||||
use crate::config::Config;
|
||||
use session::SessionManager;
|
||||
|
||||
@ -19,10 +19,20 @@ pub struct GatewayState {
|
||||
impl GatewayState {
|
||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let config = Config::load_default()?;
|
||||
let channel_manager = ChannelManager::new();
|
||||
|
||||
// Get provider config for SessionManager
|
||||
let provider_config = config.get_provider_config("default")?;
|
||||
|
||||
// Session TTL from config (default 4 hours)
|
||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||
|
||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config);
|
||||
let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone()));
|
||||
let channel_manager = ChannelManager::new(message_handler);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
session_manager: SessionManager::new(),
|
||||
session_manager,
|
||||
channel_manager,
|
||||
})
|
||||
}
|
||||
@ -31,8 +41,11 @@ impl GatewayState {
|
||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let state = Arc::new(GatewayState::new()?);
|
||||
|
||||
// Get provider config for channels
|
||||
let provider_config = state.config.get_provider_config("default")?;
|
||||
|
||||
// Initialize and start channels
|
||||
state.channel_manager.init(&state.config).await?;
|
||||
state.channel_manager.init(&state.config, provider_config).await?;
|
||||
state.channel_manager.start_all().await?;
|
||||
|
||||
// CLI args override config file values
|
||||
|
||||
@ -1,67 +1,175 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::agent::{AgentLoop, AgentError};
|
||||
use crate::protocol::WsOutbound;
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
pub struct Session {
|
||||
pub id: Uuid,
|
||||
pub agent_loop: Arc<Mutex<AgentLoop>>,
|
||||
pub channel_name: String,
|
||||
/// 按 chat_id 路由到不同 AgentLoop,支持多用户多会话
|
||||
chat_agents: HashMap<String, Arc<Mutex<AgentLoop>>>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn new(
|
||||
channel_name: String,
|
||||
provider_config: LLMProviderConfig,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
) -> Result<Self, crate::agent::AgentError> {
|
||||
let agent_loop = AgentLoop::new(provider_config)?;
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
agent_loop: Arc::new(Mutex::new(agent_loop)),
|
||||
channel_name,
|
||||
chat_agents: HashMap::new(),
|
||||
user_tx,
|
||||
provider_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取或创建指定 chat_id 的 AgentLoop
|
||||
pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result<Arc<Mutex<AgentLoop>>, AgentError> {
|
||||
if let Some(agent) = self.chat_agents.get(chat_id) {
|
||||
return Ok(agent.clone());
|
||||
}
|
||||
let agent = AgentLoop::new(self.provider_config.clone())?;
|
||||
let arc = Arc::new(Mutex::new(agent));
|
||||
self.chat_agents.insert(chat_id.to_string(), arc.clone());
|
||||
Ok(arc)
|
||||
}
|
||||
|
||||
/// 获取指定 chat_id 的 AgentLoop(不创建)
|
||||
pub fn get_agent(&self, chat_id: &str) -> Option<Arc<Mutex<AgentLoop>>> {
|
||||
self.chat_agents.get(chat_id).cloned()
|
||||
}
|
||||
|
||||
/// 清除指定 chat_id 的历史
|
||||
pub async fn clear_chat_history(&mut self, chat_id: &str) {
|
||||
if let Some(agent) = self.chat_agents.get(chat_id) {
|
||||
agent.lock().await.clear_history();
|
||||
}
|
||||
}
|
||||
|
||||
/// 清除所有历史
|
||||
pub async fn clear_all_history(&mut self) {
|
||||
for agent in self.chat_agents.values() {
|
||||
agent.lock().await.clear_history();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: WsOutbound) {
|
||||
let _ = self.user_tx.send(msg).await;
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||
/// 使用 Arc<Mutex<SessionManager>> 以从 Arc 获取可变访问
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: RwLock<HashMap<Uuid, Arc<Session>>>,
|
||||
inner: Arc<Mutex<SessionManagerInner>>,
|
||||
provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
struct SessionManagerInner {
|
||||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||||
session_timestamps: HashMap<String, Instant>,
|
||||
session_ttl: Duration,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self {
|
||||
Self {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||||
sessions: HashMap::new(),
|
||||
session_timestamps: HashMap::new(),
|
||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||||
})),
|
||||
provider_config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&self, session: Arc<Session>) {
|
||||
self.sessions.write().unwrap().insert(session.id, session);
|
||||
/// 确保 session 存在且未超时,超时则重建
|
||||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) {
|
||||
last_active.elapsed() > inner.session_ttl
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if should_recreate {
|
||||
// 移除旧 session
|
||||
inner.sessions.remove(channel_name);
|
||||
|
||||
// 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS)
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx).await?;
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
|
||||
inner.sessions.insert(channel_name.to_string(), arc.clone());
|
||||
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
|
||||
}
|
||||
|
||||
pub fn remove(&self, id: &Uuid) {
|
||||
self.sessions.write().unwrap().remove(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get(&self, id: &Uuid) -> Option<Arc<Session>> {
|
||||
self.sessions.read().unwrap().get(id).cloned()
|
||||
/// 获取 session(不检查超时)
|
||||
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
||||
let inner = self.inner.lock().await;
|
||||
inner.sessions.get(channel_name).cloned()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.sessions.read().unwrap().len()
|
||||
}
|
||||
/// 更新最后活跃时间
|
||||
pub async fn touch(&self, channel_name: &str) {
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
|
||||
}
|
||||
|
||||
impl Default for SessionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
/// 处理消息:路由到对应 session 的 agent
|
||||
pub async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
_sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<String, AgentError> {
|
||||
// 确保 session 存在(可能需要重建)
|
||||
self.ensure_session(channel_name).await?;
|
||||
|
||||
// 更新活跃时间
|
||||
self.touch(channel_name).await;
|
||||
|
||||
// 获取 session
|
||||
let session = self.get(channel_name).await
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
|
||||
// 获取或创建 chat_id 对应的 agent
|
||||
let mut session_guard = session.lock().await;
|
||||
let agent = session_guard.get_or_create_agent(chat_id).await?;
|
||||
drop(session_guard);
|
||||
|
||||
let mut agent = agent.lock().await;
|
||||
|
||||
// 处理消息
|
||||
let user_msg = ChatMessage::user(content);
|
||||
let response = agent.process(user_msg).await?;
|
||||
|
||||
Ok(response.content)
|
||||
}
|
||||
|
||||
/// 清除指定 session 的所有历史
|
||||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||
if let Some(session) = self.get(channel_name).await {
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.clear_all_history().await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||
use axum::extract::State;
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::Session};
|
||||
@ -25,18 +25,21 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
}
|
||||
};
|
||||
|
||||
let session = match Session::new(provider_config, sender).await {
|
||||
Ok(s) => Arc::new(s),
|
||||
// CLI 使用独立的 session,channel_name = "cli-{uuid}"
|
||||
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
|
||||
|
||||
// 创建 CLI session
|
||||
let session = match Session::new(channel_name.clone(), provider_config, sender).await {
|
||||
Ok(s) => Arc::new(Mutex::new(s)),
|
||||
Err(e) => {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let session_id = session.id;
|
||||
state.session_manager.add(session.clone());
|
||||
let session_id = session.lock().await.id;
|
||||
|
||||
let _ = session.send(WsOutbound::SessionEstablished {
|
||||
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.to_string(),
|
||||
}).await;
|
||||
|
||||
@ -62,7 +65,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
handle_inbound(&session, inbound).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = session.send(WsOutbound::Error {
|
||||
let _ = session.lock().await.send(WsOutbound::Error {
|
||||
code: "PARSE_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
@ -75,42 +78,51 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
state.session_manager.remove(&session_id);
|
||||
}
|
||||
|
||||
async fn handle_inbound(session: &Arc<Session>, inbound: WsInbound) {
|
||||
match inbound {
|
||||
WsInbound::UserInput { content } => {
|
||||
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
||||
let inbound_clone = inbound.clone();
|
||||
|
||||
// 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id)
|
||||
let (content, chat_id) = match inbound_clone {
|
||||
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
|
||||
// CLI 使用 session 中的 channel_name 作为标识
|
||||
// chat_id 使用传入的或使用默认
|
||||
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
|
||||
(content, chat_id)
|
||||
}
|
||||
_ => return,
|
||||
};
|
||||
|
||||
let user_msg = ChatMessage::user(content);
|
||||
let mut agent = session.agent_loop.lock().await;
|
||||
|
||||
let mut session_guard = session.lock().await;
|
||||
let agent = match session_guard.get_or_create_agent(&chat_id).await {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
let _ = session_guard.send(WsOutbound::Error {
|
||||
code: "AGENT_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
drop(session_guard);
|
||||
|
||||
let mut agent = agent.lock().await;
|
||||
match agent.process(user_msg).await {
|
||||
Ok(response) => {
|
||||
let _ = session.send(WsOutbound::AssistantResponse {
|
||||
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
|
||||
id: response.id,
|
||||
content: response.content,
|
||||
role: response.role,
|
||||
}).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = session.send(WsOutbound::Error {
|
||||
let _ = session.lock().await.send(WsOutbound::Error {
|
||||
code: "LLM_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
WsInbound::ClearHistory => {
|
||||
let mut agent = session.agent_loop.lock().await;
|
||||
agent.clear_history();
|
||||
let _ = session.send(WsOutbound::AssistantResponse {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
content: "History cleared.".to_string(),
|
||||
role: "system".to_string(),
|
||||
}).await;
|
||||
}
|
||||
WsInbound::Ping => {
|
||||
let _ = session.send(WsOutbound::Pong).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,9 +4,20 @@ use serde::{Deserialize, Serialize};
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsInbound {
|
||||
#[serde(rename = "user_input")]
|
||||
UserInput { content: String },
|
||||
UserInput {
|
||||
content: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
channel: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
chat_id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
sender_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "clear_history")]
|
||||
ClearHistory,
|
||||
ClearHistory {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
chat_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user