PicoBot/src/gateway/session_pool.rs

165 lines
5.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{Mutex, mpsc};
use crate::agent::AgentError;
use crate::protocol::WsOutbound;
use super::session::Session;
use super::session_factory::SessionFactory;
/// 判断 chat_id 是否是定时任务专用(以 "scheduler/" 开头)
pub(crate) fn is_scheduler_chat_id(chat_id: &str) -> bool {
chat_id.starts_with("scheduler/")
}
#[derive(Clone)]
pub(crate) struct SessionPool {
inner: Arc<Mutex<SessionPoolInner>>,
session_factory: SessionFactory,
session_ttl_hours: Option<u64>,
}
struct SessionPoolInner {
/// 主 Session用于用户消息
sessions: HashMap<String, Arc<Mutex<Session>>>,
/// 定时任务专用 Session独立的实例避免与用户消息竞争锁
scheduler_sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
}
impl SessionPool {
pub(crate) fn new(session_factory: SessionFactory, session_ttl_hours: Option<u64>) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionPoolInner {
sessions: HashMap::new(),
scheduler_sessions: HashMap::new(),
session_timestamps: HashMap::new(),
})),
session_factory,
session_ttl_hours,
}
}
/// 确保主 Session 存在(用于用户消息)
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
self.ensure_session_internal(channel_name, false).await
}
/// 确保定时任务专用 Session 存在
pub(crate) async fn ensure_scheduler_session(&self, channel_name: &str) -> Result<(), AgentError> {
self.ensure_session_internal(channel_name, true).await
}
/// 内部方法:创建 Session根据 is_scheduler 选择存储位置)
async fn ensure_session_internal(&self, channel_name: &str, is_scheduler: bool) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
// 选择对应的存储
let sessions = if is_scheduler {
&mut inner.scheduler_sessions
} else {
&mut inner.sessions
};
// 简化:只检查 session 是否存在,不做超时判断
if sessions.contains_key(channel_name) {
return Ok(());
}
// Session 不存在则创建
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = self
.session_factory
.create(channel_name.to_string(), user_tx)
.await?;
sessions.insert(channel_name.to_string(), Arc::new(Mutex::new(session)));
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
Ok(())
}
/// 获取主 Session用于用户消息
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.inner.lock().await.sessions.get(channel_name).cloned()
}
/// 获取定时任务专用 Session
pub(crate) async fn get_scheduler_session(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.inner.lock().await.scheduler_sessions.get(channel_name).cloned()
}
/// 根据 chat_id 自动选择 Session
/// - scheduler/ 开头:返回定时任务专用 Session
/// - 其他:返回主 Session
pub(crate) async fn get_for_chat_id(&self, channel_name: &str, chat_id: &str) -> Option<Arc<Mutex<Session>>> {
if is_scheduler_chat_id(chat_id) {
self.get_scheduler_session(channel_name).await
} else {
self.get(channel_name).await
}
}
/// 确保 Session 存在(根据 chat_id 自动选择)
pub(crate) async fn ensure_session_for_chat_id(&self, channel_name: &str, chat_id: &str) -> Result<(), AgentError> {
if is_scheduler_chat_id(chat_id) {
self.ensure_scheduler_session(channel_name).await
} else {
self.ensure_session(channel_name).await
}
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.inner
.lock()
.await
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
let ttl_hours = match self.session_ttl_hours {
Some(hours) if hours > 0 => hours,
_ => return 0,
};
let ttl_duration = std::time::Duration::from_secs(ttl_hours * 3600);
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
let elapsed = now.duration_since(*last_active);
if elapsed >= ttl_duration {
tracing::info!(
channel = %channel_name,
elapsed_hours = elapsed.as_secs() / 3600,
ttl_hours = ttl_hours,
"Session expired, removing from memory pool"
);
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
// 清理主 Session
inner.sessions.remove(channel_name);
// 清理定时任务专用 Session
inner.scheduler_sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
}
}