feat: 添加 CLI 会话服务和会话池,重构 SessionManager 以优化会话管理逻辑

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-28 11:55:55 +08:00
parent 14476bb101
commit 65bcf34b75
6 changed files with 249 additions and 156 deletions

View File

@ -30,14 +30,14 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个
1. Channel 接收外部消息
2. MessageBus 将消息送入统一的 inbound 队列
3. Gateway 启动的 inbound processor 调用 SessionManager 处理消息
4. SessionManager 加载持久化历史、注入系统提示、运行 AgentLoop、执行工具调用
5. 生成的 assistant / tool / system 消息写入 SQLite
3. Gateway 启动的 InboundProcessor 调用 SessionManager 定位目标 Session
4. AgentExecutionService 准备上下文、运行 AgentLoop、执行工具调用并收集结果
5. 生成的 user / assistant / tool / system 消息按真实顺序写入 SQLite
6. OutboundDispatcher 将结果投递到目标通道
主要模块如下:
- src/gateway网关入口、HTTP 健康检查、WebSocket 服务、Session 管理
- src/gateway网关入口、HTTP 健康检查、WebSocket 服务、Session 池、CLI 会话服务与 Agent 执行编排
- src/bus消息总线与消息结构定义
- src/agentAgentLoop 与上下文压缩器
- src/providers不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic

View File

@ -0,0 +1,57 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::{SessionRecord, SessionStore};
#[derive(Clone)]
pub(crate) struct CliSessionService {
store: Arc<SessionStore>,
}
impl CliSessionService {
pub(crate) fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
pub(crate) fn create(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub(crate) fn list(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub(crate) fn rename(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub(crate) fn archive(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub(crate) fn delete(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub(crate) fn clear_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
}

View File

@ -1,9 +1,11 @@
pub mod cli_session;
pub mod execution;
pub mod http;
pub mod memory_maintenance;
pub mod processor;
pub mod prompt;
pub mod session;
pub mod session_pool;
pub mod ws;
use axum::{Router, routing};

View File

@ -14,10 +14,10 @@ use crate::tools::{
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use super::cli_session::CliSessionService;
use super::execution::{
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
select_provider_config, should_display_message_to_user,
@ -32,6 +32,7 @@ use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
};
use super::prompt::load_agent_prompt;
use super::session_pool::SessionPool;
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
@ -488,20 +489,14 @@ impl Session {
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
show_tool_results: bool,
}
struct SessionManagerInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
session_pool: SessionPool,
cli_sessions: CliSessionService,
}
fn default_tools(
@ -677,24 +672,31 @@ impl SessionManager {
tracing::warn!(error = %err, "Failed to record skill discovery event");
}
Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
provider_configs: Arc::new(provider_configs),
tools: Arc::new(default_tools(
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
known_agents,
default_timezone,
)),
));
let session_pool = SessionPool::new(
session_ttl_hours,
agent_prompt_reinject_every,
provider_config.clone(),
tools.clone(),
skills.clone(),
store.clone(),
);
let cli_sessions = CliSessionService::new(store.clone());
Ok(Self {
provider_config,
provider_configs: Arc::new(provider_configs),
tools,
skills,
store,
agent_prompt_reinject_every,
show_tool_results,
session_pool,
cli_sessions,
})
}
@ -714,6 +716,10 @@ impl SessionManager {
self.skills.clone()
}
pub(crate) fn cli_sessions(&self) -> CliSessionService {
self.cli_sessions.clone()
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn summarize_memory_maintenance_for_scope(
&self,
@ -747,141 +753,23 @@ impl SessionManager {
select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub fn get_session_record(
&self,
session_id: &str,
) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub fn list_cli_sessions(
&self,
include_archived: bool,
) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
self.store
.load_messages(session_id)
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
}
/// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
// 移除旧 session
inner.sessions.remove(channel_name);
// 创建新 session使用临时 user_tx因为 Feishu 不通过 WS
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
channel_name.to_string(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.skills.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(channel_name.to_string(), arc.clone());
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
self.session_pool.ensure_session(channel_name).await
}
/// 获取 session不检查超时
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
let inner = self.inner.lock().await;
inner.sessions.get(channel_name).cloned()
self.session_pool.get(channel_name).await
}
/// 更新最后活跃时间
pub async fn touch(&self, channel_name: &str) {
let mut inner = self.inner.lock().await;
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
self.session_pool.touch(channel_name).await;
}
pub async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
self.session_pool.cleanup_expired_sessions().await
}
/// 处理消息:路由到对应 session 的 agent

133
src/gateway/session_pool.rs Normal file
View File

@ -0,0 +1,133 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use super::session::Session;
#[derive(Clone)]
pub(crate) struct SessionPool {
inner: Arc<Mutex<SessionPoolInner>>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
}
struct SessionPoolInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
impl SessionPool {
pub(crate) fn new(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionPoolInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools,
skills,
store,
agent_prompt_reinject_every,
}
}
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
inner.sessions.remove(channel_name);
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
channel_name.to_string(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.skills.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await?;
inner
.sessions
.insert(channel_name.to_string(), Arc::new(Mutex::new(session)));
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
}
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.inner.lock().await.sessions.get(channel_name).cloned()
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.inner
.lock()
.await
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
}
}

View File

@ -52,7 +52,8 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
};
let initial_record = match state.session_manager.create_cli_session(None) {
let cli_sessions = state.session_manager.cli_sessions();
let initial_record = match cli_sessions.create(None) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
@ -346,7 +347,10 @@ async fn handle_inbound(
let target = session_id
.or(chat_id)
.unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?;
state
.session_manager
.cli_sessions()
.clear_messages(&target)?;
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
@ -356,7 +360,10 @@ async fn handle_inbound(
Ok(())
}
WsInbound::CreateSession { title } => {
let record = state.session_manager.create_cli_session(title.as_deref())?;
let record = state
.session_manager
.cli_sessions()
.create(title.as_deref())?;
*current_session_id = record.id.clone();
let mut session_guard = session.lock().await;
@ -370,7 +377,10 @@ async fn handle_inbound(
Ok(())
}
WsInbound::ListSessions { include_archived } => {
let records = state.session_manager.list_cli_sessions(include_archived)?;
let records = state
.session_manager
.cli_sessions()
.list(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect();
let session_guard = session.lock().await;
@ -383,7 +393,7 @@ async fn handle_inbound(
Ok(())
}
WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::Error {
@ -408,7 +418,10 @@ async fn handle_inbound(
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.rename_session(&target, &title)?;
state
.session_manager
.cli_sessions()
.rename(&target, &title)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionRenamed {
@ -420,7 +433,7 @@ async fn handle_inbound(
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.archive_session(&target)?;
state.session_manager.cli_sessions().archive(&target)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target })
@ -429,10 +442,10 @@ async fn handle_inbound(
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.delete_session(&target)?;
state.session_manager.cli_sessions().delete(&target)?;
let replacement = if target == *current_session_id {
Some(state.session_manager.create_cli_session(None)?)
Some(state.session_manager.cli_sessions().create(None)?)
} else {
None
};