use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; use tokio::sync::{Mutex, mpsc}; use crate::agent::AgentError; use crate::protocol::WsOutbound; use super::session::Session; use super::session_factory::SessionFactory; /// 判断 chat_id 是否是定时任务专用(以 "scheduler/" 开头) pub(crate) fn is_scheduler_chat_id(chat_id: &str) -> bool { chat_id.starts_with("scheduler/") } #[derive(Clone)] pub(crate) struct SessionPool { inner: Arc>, session_factory: SessionFactory, session_ttl_hours: Option, } struct SessionPoolInner { /// 主 Session:用于用户消息 sessions: HashMap>>, /// 定时任务专用 Session:独立的实例,避免与用户消息竞争锁 scheduler_sessions: HashMap>>, session_timestamps: HashMap, } impl SessionPool { pub(crate) fn new(session_factory: SessionFactory, session_ttl_hours: Option) -> Self { Self { inner: Arc::new(Mutex::new(SessionPoolInner { sessions: HashMap::new(), scheduler_sessions: HashMap::new(), session_timestamps: HashMap::new(), })), session_factory, session_ttl_hours, } } /// 确保主 Session 存在(用于用户消息) pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { self.ensure_session_internal(channel_name, false).await } /// 确保定时任务专用 Session 存在 pub(crate) async fn ensure_scheduler_session(&self, channel_name: &str) -> Result<(), AgentError> { self.ensure_session_internal(channel_name, true).await } /// 内部方法:创建 Session(根据 is_scheduler 选择存储位置) async fn ensure_session_internal(&self, channel_name: &str, is_scheduler: bool) -> Result<(), AgentError> { let mut inner = self.inner.lock().await; // 选择对应的存储 let sessions = if is_scheduler { &mut inner.scheduler_sessions } else { &mut inner.sessions }; // 简化:只检查 session 是否存在,不做超时判断 if sessions.contains_key(channel_name) { return Ok(()); } // Session 不存在则创建 let (user_tx, _rx) = mpsc::channel::(100); let session = self .session_factory .create(channel_name.to_string(), user_tx) .await?; sessions.insert(channel_name.to_string(), Arc::new(Mutex::new(session))); inner .session_timestamps .insert(channel_name.to_string(), Instant::now()); Ok(()) } /// 获取主 Session(用于用户消息) pub(crate) async fn get(&self, channel_name: &str) -> Option>> { self.inner.lock().await.sessions.get(channel_name).cloned() } /// 获取定时任务专用 Session pub(crate) async fn get_scheduler_session(&self, channel_name: &str) -> Option>> { self.inner.lock().await.scheduler_sessions.get(channel_name).cloned() } /// 根据 chat_id 自动选择 Session /// - scheduler/ 开头:返回定时任务专用 Session /// - 其他:返回主 Session pub(crate) async fn get_for_chat_id(&self, channel_name: &str, chat_id: &str) -> Option>> { if is_scheduler_chat_id(chat_id) { self.get_scheduler_session(channel_name).await } else { self.get(channel_name).await } } /// 确保 Session 存在(根据 chat_id 自动选择) pub(crate) async fn ensure_session_for_chat_id(&self, channel_name: &str, chat_id: &str) -> Result<(), AgentError> { if is_scheduler_chat_id(chat_id) { self.ensure_scheduler_session(channel_name).await } else { self.ensure_session(channel_name).await } } pub(crate) async fn touch(&self, channel_name: &str) { self.inner .lock() .await .session_timestamps .insert(channel_name.to_string(), Instant::now()); } pub(crate) async fn cleanup_expired_sessions(&self) -> usize { let ttl_hours = match self.session_ttl_hours { Some(hours) if hours > 0 => hours, _ => return 0, }; let ttl_duration = std::time::Duration::from_secs(ttl_hours * 3600); let mut inner = self.inner.lock().await; let now = Instant::now(); let expired_channels: Vec = inner .session_timestamps .iter() .filter_map(|(channel_name, last_active)| { let elapsed = now.duration_since(*last_active); if elapsed >= ttl_duration { tracing::info!( channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, ttl_hours = ttl_hours, "Session expired, removing from memory pool" ); Some(channel_name.clone()) } else { None } }) .collect(); for channel_name in &expired_channels { // 清理主 Session inner.sessions.remove(channel_name); // 清理定时任务专用 Session inner.scheduler_sessions.remove(channel_name); inner.session_timestamps.remove(channel_name); } expired_channels.len() } }