165 lines
5.6 KiB
Rust
165 lines
5.6 KiB
Rust
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use std::time::Instant;
|
||
|
||
use tokio::sync::{Mutex, mpsc};
|
||
|
||
use crate::agent::AgentError;
|
||
use crate::protocol::WsOutbound;
|
||
|
||
use super::session::Session;
|
||
use super::session_factory::SessionFactory;
|
||
|
||
/// 判断 chat_id 是否是定时任务专用(以 "scheduler/" 开头)
|
||
pub(crate) fn is_scheduler_chat_id(chat_id: &str) -> bool {
|
||
chat_id.starts_with("scheduler/")
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub(crate) struct SessionPool {
|
||
inner: Arc<Mutex<SessionPoolInner>>,
|
||
session_factory: SessionFactory,
|
||
session_ttl_hours: Option<u64>,
|
||
}
|
||
|
||
struct SessionPoolInner {
|
||
/// 主 Session:用于用户消息
|
||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||
/// 定时任务专用 Session:独立的实例,避免与用户消息竞争锁
|
||
scheduler_sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||
session_timestamps: HashMap<String, Instant>,
|
||
}
|
||
|
||
impl SessionPool {
|
||
pub(crate) fn new(session_factory: SessionFactory, session_ttl_hours: Option<u64>) -> Self {
|
||
Self {
|
||
inner: Arc::new(Mutex::new(SessionPoolInner {
|
||
sessions: HashMap::new(),
|
||
scheduler_sessions: HashMap::new(),
|
||
session_timestamps: HashMap::new(),
|
||
})),
|
||
session_factory,
|
||
session_ttl_hours,
|
||
}
|
||
}
|
||
|
||
/// 确保主 Session 存在(用于用户消息)
|
||
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
self.ensure_session_internal(channel_name, false).await
|
||
}
|
||
|
||
/// 确保定时任务专用 Session 存在
|
||
pub(crate) async fn ensure_scheduler_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
self.ensure_session_internal(channel_name, true).await
|
||
}
|
||
|
||
/// 内部方法:创建 Session(根据 is_scheduler 选择存储位置)
|
||
async fn ensure_session_internal(&self, channel_name: &str, is_scheduler: bool) -> Result<(), AgentError> {
|
||
let mut inner = self.inner.lock().await;
|
||
|
||
// 选择对应的存储
|
||
let sessions = if is_scheduler {
|
||
&mut inner.scheduler_sessions
|
||
} else {
|
||
&mut inner.sessions
|
||
};
|
||
|
||
// 简化:只检查 session 是否存在,不做超时判断
|
||
if sessions.contains_key(channel_name) {
|
||
return Ok(());
|
||
}
|
||
|
||
// Session 不存在则创建
|
||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||
let session = self
|
||
.session_factory
|
||
.create(channel_name.to_string(), user_tx)
|
||
.await?;
|
||
|
||
sessions.insert(channel_name.to_string(), Arc::new(Mutex::new(session)));
|
||
inner
|
||
.session_timestamps
|
||
.insert(channel_name.to_string(), Instant::now());
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取主 Session(用于用户消息)
|
||
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
||
self.inner.lock().await.sessions.get(channel_name).cloned()
|
||
}
|
||
|
||
/// 获取定时任务专用 Session
|
||
pub(crate) async fn get_scheduler_session(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
||
self.inner.lock().await.scheduler_sessions.get(channel_name).cloned()
|
||
}
|
||
|
||
/// 根据 chat_id 自动选择 Session
|
||
/// - scheduler/ 开头:返回定时任务专用 Session
|
||
/// - 其他:返回主 Session
|
||
pub(crate) async fn get_for_chat_id(&self, channel_name: &str, chat_id: &str) -> Option<Arc<Mutex<Session>>> {
|
||
if is_scheduler_chat_id(chat_id) {
|
||
self.get_scheduler_session(channel_name).await
|
||
} else {
|
||
self.get(channel_name).await
|
||
}
|
||
}
|
||
|
||
/// 确保 Session 存在(根据 chat_id 自动选择)
|
||
pub(crate) async fn ensure_session_for_chat_id(&self, channel_name: &str, chat_id: &str) -> Result<(), AgentError> {
|
||
if is_scheduler_chat_id(chat_id) {
|
||
self.ensure_scheduler_session(channel_name).await
|
||
} else {
|
||
self.ensure_session(channel_name).await
|
||
}
|
||
}
|
||
|
||
pub(crate) async fn touch(&self, channel_name: &str) {
|
||
self.inner
|
||
.lock()
|
||
.await
|
||
.session_timestamps
|
||
.insert(channel_name.to_string(), Instant::now());
|
||
}
|
||
|
||
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
|
||
let ttl_hours = match self.session_ttl_hours {
|
||
Some(hours) if hours > 0 => hours,
|
||
_ => return 0,
|
||
};
|
||
|
||
let ttl_duration = std::time::Duration::from_secs(ttl_hours * 3600);
|
||
let mut inner = self.inner.lock().await;
|
||
let now = Instant::now();
|
||
|
||
let expired_channels: Vec<String> = inner
|
||
.session_timestamps
|
||
.iter()
|
||
.filter_map(|(channel_name, last_active)| {
|
||
let elapsed = now.duration_since(*last_active);
|
||
if elapsed >= ttl_duration {
|
||
tracing::info!(
|
||
channel = %channel_name,
|
||
elapsed_hours = elapsed.as_secs() / 3600,
|
||
ttl_hours = ttl_hours,
|
||
"Session expired, removing from memory pool"
|
||
);
|
||
Some(channel_name.clone())
|
||
} else {
|
||
None
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
for channel_name in &expired_channels {
|
||
// 清理主 Session
|
||
inner.sessions.remove(channel_name);
|
||
// 清理定时任务专用 Session
|
||
inner.scheduler_sessions.remove(channel_name);
|
||
inner.session_timestamps.remove(channel_name);
|
||
}
|
||
|
||
expired_channels.len()
|
||
}
|
||
}
|