feat: 添加 CLI 会话服务和会话池,重构 SessionManager 以优化会话管理逻辑

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-28 11:55:55 +08:00
parent 14476bb101
commit 65bcf34b75
6 changed files with 249 additions and 156 deletions

View File

@ -30,14 +30,14 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个
1. Channel 接收外部消息 1. Channel 接收外部消息
2. MessageBus 将消息送入统一的 inbound 队列 2. MessageBus 将消息送入统一的 inbound 队列
3. Gateway 启动的 inbound processor 调用 SessionManager 处理消息 3. Gateway 启动的 InboundProcessor 调用 SessionManager 定位目标 Session
4. SessionManager 加载持久化历史、注入系统提示、运行 AgentLoop、执行工具调用 4. AgentExecutionService 准备上下文、运行 AgentLoop、执行工具调用并收集结果
5. 生成的 assistant / tool / system 消息写入 SQLite 5. 生成的 user / assistant / tool / system 消息按真实顺序写入 SQLite
6. OutboundDispatcher 将结果投递到目标通道 6. OutboundDispatcher 将结果投递到目标通道
主要模块如下: 主要模块如下:
- src/gateway网关入口、HTTP 健康检查、WebSocket 服务、Session 管理 - src/gateway网关入口、HTTP 健康检查、WebSocket 服务、Session 池、CLI 会话服务与 Agent 执行编排
- src/bus消息总线与消息结构定义 - src/bus消息总线与消息结构定义
- src/agentAgentLoop 与上下文压缩器 - src/agentAgentLoop 与上下文压缩器
- src/providers不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic - src/providers不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic

View File

@ -0,0 +1,57 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::{SessionRecord, SessionStore};
#[derive(Clone)]
pub(crate) struct CliSessionService {
store: Arc<SessionStore>,
}
impl CliSessionService {
pub(crate) fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
pub(crate) fn create(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub(crate) fn list(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub(crate) fn rename(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub(crate) fn archive(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub(crate) fn delete(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub(crate) fn clear_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
}

View File

@ -1,9 +1,11 @@
pub mod cli_session;
pub mod execution; pub mod execution;
pub mod http; pub mod http;
pub mod memory_maintenance; pub mod memory_maintenance;
pub mod processor; pub mod processor;
pub mod prompt; pub mod prompt;
pub mod session; pub mod session;
pub mod session_pool;
pub mod ws; pub mod ws;
use axum::{Router, routing}; use axum::{Router, routing};

View File

@ -14,10 +14,10 @@ use crate::tools::{
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc}; use tokio::sync::{Mutex, mpsc};
use uuid::Uuid; use uuid::Uuid;
use super::cli_session::CliSessionService;
use super::execution::{ use super::execution::{
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest, AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
select_provider_config, should_display_message_to_user, select_provider_config, should_display_message_to_user,
@ -32,6 +32,7 @@ use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService, MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
}; };
use super::prompt::load_agent_prompt; use super::prompt::load_agent_prompt;
use super::session_pool::SessionPool;
fn preview_text(content: &str, max_chars: usize) -> String { fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>(); let mut preview = content.chars().take(max_chars).collect::<String>();
@ -488,20 +489,14 @@ impl Session {
/// SessionManager 管理所有 Session按 channel_name 路由 /// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)] #[derive(Clone)]
pub struct SessionManager { pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>, provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
store: Arc<SessionStore>, store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
show_tool_results: bool, show_tool_results: bool,
} session_pool: SessionPool,
cli_sessions: CliSessionService,
struct SessionManagerInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
} }
fn default_tools( fn default_tools(
@ -677,24 +672,31 @@ impl SessionManager {
tracing::warn!(error = %err, "Failed to record skill discovery event"); tracing::warn!(error = %err, "Failed to record skill discovery event");
} }
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
known_agents,
default_timezone,
));
let session_pool = SessionPool::new(
session_ttl_hours,
agent_prompt_reinject_every,
provider_config.clone(),
tools.clone(),
skills.clone(),
store.clone(),
);
let cli_sessions = CliSessionService::new(store.clone());
Ok(Self { Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config, provider_config,
provider_configs: Arc::new(provider_configs), provider_configs: Arc::new(provider_configs),
tools: Arc::new(default_tools( tools,
skills.clone(),
store.clone(),
known_agents,
default_timezone,
)),
skills, skills,
store, store,
agent_prompt_reinject_every,
show_tool_results, show_tool_results,
session_pool,
cli_sessions,
}) })
} }
@ -714,6 +716,10 @@ impl SessionManager {
self.skills.clone() self.skills.clone()
} }
pub(crate) fn cli_sessions(&self) -> CliSessionService {
self.cli_sessions.clone()
}
#[cfg_attr(not(test), allow(dead_code))] #[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn summarize_memory_maintenance_for_scope( pub(crate) async fn summarize_memory_maintenance_for_scope(
&self, &self,
@ -747,141 +753,23 @@ impl SessionManager {
select_provider_config(&self.provider_config, &self.provider_configs, agent_name) select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
} }
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub fn get_session_record(
&self,
session_id: &str,
) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub fn list_cli_sessions(
&self,
include_archived: bool,
) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
self.store
.load_messages(session_id)
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
}
/// 确保 session 存在且未超时,超时则重建 /// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await; self.session_pool.ensure_session(channel_name).await
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
// 移除旧 session
inner.sessions.remove(channel_name);
// 创建新 session使用临时 user_tx因为 Feishu 不通过 WS
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
channel_name.to_string(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.skills.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(channel_name.to_string(), arc.clone());
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
} }
/// 获取 session不检查超时 /// 获取 session不检查超时
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> { pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
let inner = self.inner.lock().await; self.session_pool.get(channel_name).await
inner.sessions.get(channel_name).cloned()
} }
/// 更新最后活跃时间 /// 更新最后活跃时间
pub async fn touch(&self, channel_name: &str) { pub async fn touch(&self, channel_name: &str) {
let mut inner = self.inner.lock().await; self.session_pool.touch(channel_name).await;
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
} }
pub async fn cleanup_expired_sessions(&self) -> usize { pub async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await; self.session_pool.cleanup_expired_sessions().await
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
} }
/// 处理消息:路由到对应 session 的 agent /// 处理消息:路由到对应 session 的 agent

133
src/gateway/session_pool.rs Normal file
View File

@ -0,0 +1,133 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use super::session::Session;
#[derive(Clone)]
pub(crate) struct SessionPool {
inner: Arc<Mutex<SessionPoolInner>>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
}
struct SessionPoolInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
impl SessionPool {
pub(crate) fn new(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionPoolInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools,
skills,
store,
agent_prompt_reinject_every,
}
}
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
inner.sessions.remove(channel_name);
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
channel_name.to_string(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.skills.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await?;
inner
.sessions
.insert(channel_name.to_string(), Arc::new(Mutex::new(session)));
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
}
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.inner.lock().await.sessions.get(channel_name).cloned()
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.inner
.lock()
.await
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
}
}

View File

@ -52,7 +52,8 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
}; };
let initial_record = match state.session_manager.create_cli_session(None) { let cli_sessions = state.session_manager.cli_sessions();
let initial_record = match cli_sessions.create(None) {
Ok(record) => record, Ok(record) => record,
Err(e) => { Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session"); tracing::error!(error = %e, "Failed to create initial CLI session");
@ -346,7 +347,10 @@ async fn handle_inbound(
let target = session_id let target = session_id
.or(chat_id) .or(chat_id)
.unwrap_or_else(|| current_session_id.clone()); .unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?; state
.session_manager
.cli_sessions()
.clear_messages(&target)?;
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
session_guard.remove_history(&target); session_guard.remove_history(&target);
@ -356,7 +360,10 @@ async fn handle_inbound(
Ok(()) Ok(())
} }
WsInbound::CreateSession { title } => { WsInbound::CreateSession { title } => {
let record = state.session_manager.create_cli_session(title.as_deref())?; let record = state
.session_manager
.cli_sessions()
.create(title.as_deref())?;
*current_session_id = record.id.clone(); *current_session_id = record.id.clone();
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
@ -370,7 +377,10 @@ async fn handle_inbound(
Ok(()) Ok(())
} }
WsInbound::ListSessions { include_archived } => { WsInbound::ListSessions { include_archived } => {
let records = state.session_manager.list_cli_sessions(include_archived)?; let records = state
.session_manager
.cli_sessions()
.list(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect(); let summaries = records.into_iter().map(to_session_summary).collect();
let session_guard = session.lock().await; let session_guard = session.lock().await;
@ -383,7 +393,7 @@ async fn handle_inbound(
Ok(()) Ok(())
} }
WsInbound::LoadSession { session_id } => { WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.get_session_record(&session_id)? else { let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
let session_guard = session.lock().await; let session_guard = session.lock().await;
let _ = session_guard let _ = session_guard
.send(WsOutbound::Error { .send(WsOutbound::Error {
@ -408,7 +418,10 @@ async fn handle_inbound(
} }
WsInbound::RenameSession { session_id, title } => { WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone()); let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.rename_session(&target, &title)?; state
.session_manager
.cli_sessions()
.rename(&target, &title)?;
let session_guard = session.lock().await; let session_guard = session.lock().await;
let _ = session_guard let _ = session_guard
.send(WsOutbound::SessionRenamed { .send(WsOutbound::SessionRenamed {
@ -420,7 +433,7 @@ async fn handle_inbound(
} }
WsInbound::ArchiveSession { session_id } => { WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone()); let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.archive_session(&target)?; state.session_manager.cli_sessions().archive(&target)?;
let session_guard = session.lock().await; let session_guard = session.lock().await;
let _ = session_guard let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target }) .send(WsOutbound::SessionArchived { session_id: target })
@ -429,10 +442,10 @@ async fn handle_inbound(
} }
WsInbound::DeleteSession { session_id } => { WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone()); let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.delete_session(&target)?; state.session_manager.cli_sessions().delete(&target)?;
let replacement = if target == *current_session_id { let replacement = if target == *current_session_id {
Some(state.session_manager.create_cli_session(None)?) Some(state.session_manager.cli_sessions().create(None)?)
} else { } else {
None None
}; };