From 65bcf34b750776342052a0bea7c445312e1dedee Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 28 Apr 2026 11:55:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20CLI=20=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E6=9C=8D=E5=8A=A1=E5=92=8C=E4=BC=9A=E8=AF=9D=E6=B1=A0?= =?UTF-8?q?=EF=BC=8C=E9=87=8D=E6=9E=84=20SessionManager=20=E4=BB=A5?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BC=9A=E8=AF=9D=E7=AE=A1=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- README.md | 8 +- src/gateway/cli_session.rs | 57 ++++++++++++ src/gateway/mod.rs | 2 + src/gateway/session.rs | 174 +++++++----------------------------- src/gateway/session_pool.rs | 133 +++++++++++++++++++++++++++ src/gateway/ws.rs | 31 +++++-- 6 files changed, 249 insertions(+), 156 deletions(-) create mode 100644 src/gateway/cli_session.rs create mode 100644 src/gateway/session_pool.rs diff --git a/README.md b/README.md index e699793..a2114e6 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,14 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个 1. Channel 接收外部消息 2. MessageBus 将消息送入统一的 inbound 队列 -3. Gateway 启动的 inbound processor 调用 SessionManager 处理消息 -4. SessionManager 加载持久化历史、注入系统提示、运行 AgentLoop、执行工具调用 -5. 生成的 assistant / tool / system 消息写入 SQLite +3. Gateway 启动的 InboundProcessor 调用 SessionManager 定位目标 Session +4. AgentExecutionService 准备上下文、运行 AgentLoop、执行工具调用并收集结果 +5. 生成的 user / assistant / tool / system 消息按真实顺序写入 SQLite 6. OutboundDispatcher 将结果投递到目标通道 主要模块如下: -- src/gateway:网关入口、HTTP 健康检查、WebSocket 服务、Session 管理 +- src/gateway:网关入口、HTTP 健康检查、WebSocket 服务、Session 池、CLI 会话服务与 Agent 执行编排 - src/bus:消息总线与消息结构定义 - src/agent:AgentLoop 与上下文压缩器 - src/providers:不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic diff --git a/src/gateway/cli_session.rs b/src/gateway/cli_session.rs new file mode 100644 index 0000000..c07d8e1 --- /dev/null +++ b/src/gateway/cli_session.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use crate::agent::AgentError; +use crate::storage::{SessionRecord, SessionStore}; + +#[derive(Clone)] +pub(crate) struct CliSessionService { + store: Arc, +} + +impl CliSessionService { + pub(crate) fn new(store: Arc) -> Self { + Self { store } + } + + pub(crate) fn create(&self, title: Option<&str>) -> Result { + self.store + .create_cli_session(title) + .map_err(|err| AgentError::Other(format!("create session error: {}", err))) + } + + pub(crate) fn get(&self, session_id: &str) -> Result, AgentError> { + self.store + .get_session(session_id) + .map_err(|err| AgentError::Other(format!("get session error: {}", err))) + } + + pub(crate) fn list(&self, include_archived: bool) -> Result, AgentError> { + self.store + .list_sessions("cli", include_archived) + .map_err(|err| AgentError::Other(format!("list sessions error: {}", err))) + } + + pub(crate) fn rename(&self, session_id: &str, title: &str) -> Result<(), AgentError> { + self.store + .rename_session(session_id, title) + .map_err(|err| AgentError::Other(format!("rename session error: {}", err))) + } + + pub(crate) fn archive(&self, session_id: &str) -> Result<(), AgentError> { + self.store + .archive_session(session_id) + .map_err(|err| AgentError::Other(format!("archive session error: {}", err))) + } + + pub(crate) fn delete(&self, session_id: &str) -> Result<(), AgentError> { + self.store + .delete_session(session_id) + .map_err(|err| AgentError::Other(format!("delete session error: {}", err))) + } + + pub(crate) fn clear_messages(&self, session_id: &str) -> Result<(), AgentError> { + self.store + .clear_messages(session_id) + .map_err(|err| AgentError::Other(format!("clear session error: {}", err))) + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 28a29af..1f98e5f 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,9 +1,11 @@ +pub mod cli_session; pub mod execution; pub mod http; pub mod memory_maintenance; pub mod processor; pub mod prompt; pub mod session; +pub mod session_pool; pub mod ws; use axum::{Router, routing}; diff --git a/src/gateway/session.rs b/src/gateway/session.rs index dd2a8c4..4a86d91 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -14,10 +14,10 @@ use crate::tools::{ use async_trait::async_trait; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use std::time::{Duration, Instant}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; +use super::cli_session::CliSessionService; use super::execution::{ AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest, select_provider_config, should_display_message_to_user, @@ -32,6 +32,7 @@ use super::memory_maintenance::{ MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService, }; use super::prompt::load_agent_prompt; +use super::session_pool::SessionPool; fn preview_text(content: &str, max_chars: usize) -> String { let mut preview = content.chars().take(max_chars).collect::(); @@ -488,20 +489,14 @@ impl Session { /// SessionManager 管理所有 Session,按 channel_name 路由 #[derive(Clone)] pub struct SessionManager { - inner: Arc>, provider_config: LLMProviderConfig, provider_configs: Arc>, tools: Arc, skills: Arc, store: Arc, - agent_prompt_reinject_every: u64, show_tool_results: bool, -} - -struct SessionManagerInner { - sessions: HashMap>>, - session_timestamps: HashMap, - session_ttl: Duration, + session_pool: SessionPool, + cli_sessions: CliSessionService, } fn default_tools( @@ -677,24 +672,31 @@ impl SessionManager { tracing::warn!(error = %err, "Failed to record skill discovery event"); } + let tools = Arc::new(default_tools( + skills.clone(), + store.clone(), + known_agents, + default_timezone, + )); + let session_pool = SessionPool::new( + session_ttl_hours, + agent_prompt_reinject_every, + provider_config.clone(), + tools.clone(), + skills.clone(), + store.clone(), + ); + let cli_sessions = CliSessionService::new(store.clone()); + Ok(Self { - inner: Arc::new(Mutex::new(SessionManagerInner { - sessions: HashMap::new(), - session_timestamps: HashMap::new(), - session_ttl: Duration::from_secs(session_ttl_hours * 3600), - })), provider_config, provider_configs: Arc::new(provider_configs), - tools: Arc::new(default_tools( - skills.clone(), - store.clone(), - known_agents, - default_timezone, - )), + tools, skills, store, - agent_prompt_reinject_every, show_tool_results, + session_pool, + cli_sessions, }) } @@ -714,6 +716,10 @@ impl SessionManager { self.skills.clone() } + pub(crate) fn cli_sessions(&self) -> CliSessionService { + self.cli_sessions.clone() + } + #[cfg_attr(not(test), allow(dead_code))] pub(crate) async fn summarize_memory_maintenance_for_scope( &self, @@ -747,141 +753,23 @@ impl SessionManager { select_provider_config(&self.provider_config, &self.provider_configs, agent_name) } - pub fn create_cli_session(&self, title: Option<&str>) -> Result { - self.store - .create_cli_session(title) - .map_err(|err| AgentError::Other(format!("create session error: {}", err))) - } - - pub fn get_session_record( - &self, - session_id: &str, - ) -> Result, AgentError> { - self.store - .get_session(session_id) - .map_err(|err| AgentError::Other(format!("get session error: {}", err))) - } - - pub fn list_cli_sessions( - &self, - include_archived: bool, - ) -> Result, AgentError> { - self.store - .list_sessions("cli", include_archived) - .map_err(|err| AgentError::Other(format!("list sessions error: {}", err))) - } - - pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> { - self.store - .rename_session(session_id, title) - .map_err(|err| AgentError::Other(format!("rename session error: {}", err))) - } - - pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> { - self.store - .archive_session(session_id) - .map_err(|err| AgentError::Other(format!("archive session error: {}", err))) - } - - pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> { - self.store - .delete_session(session_id) - .map_err(|err| AgentError::Other(format!("delete session error: {}", err))) - } - - pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> { - self.store - .clear_messages(session_id) - .map_err(|err| AgentError::Other(format!("clear session error: {}", err))) - } - - pub fn load_session_messages(&self, session_id: &str) -> Result, AgentError> { - self.store - .load_messages(session_id) - .map_err(|err| AgentError::Other(format!("load messages error: {}", err))) - } - /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { - let mut inner = self.inner.lock().await; - - let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) - { - let elapsed = last_active.elapsed(); - if elapsed > inner.session_ttl { - tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating"); - true - } else { - false - } - } else { - #[cfg(debug_assertions)] - tracing::debug!(channel = %channel_name, "Creating new session"); - true - }; - - if should_recreate { - // 移除旧 session - inner.sessions.remove(channel_name); - - // 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS) - let (user_tx, _rx) = mpsc::channel::(100); - let session = Session::new( - channel_name.to_string(), - self.provider_config.clone(), - user_tx, - self.tools.clone(), - self.skills.clone(), - self.store.clone(), - self.agent_prompt_reinject_every, - ) - .await?; - let arc = Arc::new(Mutex::new(session)); - - inner.sessions.insert(channel_name.to_string(), arc.clone()); - inner - .session_timestamps - .insert(channel_name.to_string(), Instant::now()); - } - - Ok(()) + self.session_pool.ensure_session(channel_name).await } /// 获取 session(不检查超时) pub async fn get(&self, channel_name: &str) -> Option>> { - let inner = self.inner.lock().await; - inner.sessions.get(channel_name).cloned() + self.session_pool.get(channel_name).await } /// 更新最后活跃时间 pub async fn touch(&self, channel_name: &str) { - let mut inner = self.inner.lock().await; - inner - .session_timestamps - .insert(channel_name.to_string(), Instant::now()); + self.session_pool.touch(channel_name).await; } pub async fn cleanup_expired_sessions(&self) -> usize { - let mut inner = self.inner.lock().await; - let now = Instant::now(); - let expired_channels: Vec = inner - .session_timestamps - .iter() - .filter_map(|(channel_name, last_active)| { - if now.duration_since(*last_active) > inner.session_ttl { - Some(channel_name.clone()) - } else { - None - } - }) - .collect(); - - for channel_name in &expired_channels { - inner.sessions.remove(channel_name); - inner.session_timestamps.remove(channel_name); - } - - expired_channels.len() + self.session_pool.cleanup_expired_sessions().await } /// 处理消息:路由到对应 session 的 agent diff --git a/src/gateway/session_pool.rs b/src/gateway/session_pool.rs new file mode 100644 index 0000000..e0f0d1a --- /dev/null +++ b/src/gateway/session_pool.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::{Mutex, mpsc}; + +use crate::agent::AgentError; +use crate::config::LLMProviderConfig; +use crate::protocol::WsOutbound; +use crate::skills::SkillRuntime; +use crate::storage::SessionStore; +use crate::tools::ToolRegistry; + +use super::session::Session; + +#[derive(Clone)] +pub(crate) struct SessionPool { + inner: Arc>, + provider_config: LLMProviderConfig, + tools: Arc, + skills: Arc, + store: Arc, + agent_prompt_reinject_every: u64, +} + +struct SessionPoolInner { + sessions: HashMap>>, + session_timestamps: HashMap, + session_ttl: Duration, +} + +impl SessionPool { + pub(crate) fn new( + session_ttl_hours: u64, + agent_prompt_reinject_every: u64, + provider_config: LLMProviderConfig, + tools: Arc, + skills: Arc, + store: Arc, + ) -> Self { + Self { + inner: Arc::new(Mutex::new(SessionPoolInner { + sessions: HashMap::new(), + session_timestamps: HashMap::new(), + session_ttl: Duration::from_secs(session_ttl_hours * 3600), + })), + provider_config, + tools, + skills, + store, + agent_prompt_reinject_every, + } + } + + pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { + let mut inner = self.inner.lock().await; + + let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) + { + let elapsed = last_active.elapsed(); + if elapsed > inner.session_ttl { + tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating"); + true + } else { + false + } + } else { + #[cfg(debug_assertions)] + tracing::debug!(channel = %channel_name, "Creating new session"); + true + }; + + if should_recreate { + inner.sessions.remove(channel_name); + + let (user_tx, _rx) = mpsc::channel::(100); + let session = Session::new( + channel_name.to_string(), + self.provider_config.clone(), + user_tx, + self.tools.clone(), + self.skills.clone(), + self.store.clone(), + self.agent_prompt_reinject_every, + ) + .await?; + + inner + .sessions + .insert(channel_name.to_string(), Arc::new(Mutex::new(session))); + inner + .session_timestamps + .insert(channel_name.to_string(), Instant::now()); + } + + Ok(()) + } + + pub(crate) async fn get(&self, channel_name: &str) -> Option>> { + self.inner.lock().await.sessions.get(channel_name).cloned() + } + + pub(crate) async fn touch(&self, channel_name: &str) { + self.inner + .lock() + .await + .session_timestamps + .insert(channel_name.to_string(), Instant::now()); + } + + pub(crate) async fn cleanup_expired_sessions(&self) -> usize { + let mut inner = self.inner.lock().await; + let now = Instant::now(); + let expired_channels: Vec = inner + .session_timestamps + .iter() + .filter_map(|(channel_name, last_active)| { + if now.duration_since(*last_active) > inner.session_ttl { + Some(channel_name.clone()) + } else { + None + } + }) + .collect(); + + for channel_name in &expired_channels { + inner.sessions.remove(channel_name); + inner.session_timestamps.remove(channel_name); + } + + expired_channels.len() + } +} diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 816e8f3..87ce66f 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -52,7 +52,8 @@ async fn handle_socket(ws: WebSocket, state: Arc) { } }; - let initial_record = match state.session_manager.create_cli_session(None) { + let cli_sessions = state.session_manager.cli_sessions(); + let initial_record = match cli_sessions.create(None) { Ok(record) => record, Err(e) => { tracing::error!(error = %e, "Failed to create initial CLI session"); @@ -346,7 +347,10 @@ async fn handle_inbound( let target = session_id .or(chat_id) .unwrap_or_else(|| current_session_id.clone()); - state.session_manager.clear_session_messages(&target)?; + state + .session_manager + .cli_sessions() + .clear_messages(&target)?; let mut session_guard = session.lock().await; session_guard.remove_history(&target); @@ -356,7 +360,10 @@ async fn handle_inbound( Ok(()) } WsInbound::CreateSession { title } => { - let record = state.session_manager.create_cli_session(title.as_deref())?; + let record = state + .session_manager + .cli_sessions() + .create(title.as_deref())?; *current_session_id = record.id.clone(); let mut session_guard = session.lock().await; @@ -370,7 +377,10 @@ async fn handle_inbound( Ok(()) } WsInbound::ListSessions { include_archived } => { - let records = state.session_manager.list_cli_sessions(include_archived)?; + let records = state + .session_manager + .cli_sessions() + .list(include_archived)?; let summaries = records.into_iter().map(to_session_summary).collect(); let session_guard = session.lock().await; @@ -383,7 +393,7 @@ async fn handle_inbound( Ok(()) } WsInbound::LoadSession { session_id } => { - let Some(record) = state.session_manager.get_session_record(&session_id)? else { + let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else { let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::Error { @@ -408,7 +418,10 @@ async fn handle_inbound( } WsInbound::RenameSession { session_id, title } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.rename_session(&target, &title)?; + state + .session_manager + .cli_sessions() + .rename(&target, &title)?; let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::SessionRenamed { @@ -420,7 +433,7 @@ async fn handle_inbound( } WsInbound::ArchiveSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.archive_session(&target)?; + state.session_manager.cli_sessions().archive(&target)?; let session_guard = session.lock().await; let _ = session_guard .send(WsOutbound::SessionArchived { session_id: target }) @@ -429,10 +442,10 @@ async fn handle_inbound( } WsInbound::DeleteSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); - state.session_manager.delete_session(&target)?; + state.session_manager.cli_sessions().delete(&target)?; let replacement = if target == *current_session_id { - Some(state.session_manager.create_cli_session(None)?) + Some(state.session_manager.cli_sessions().create(None)?) } else { None };