From 891830779f9ab6a9e85f394b8864f90cdfd2263a Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 28 Apr 2026 15:55:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF=E7=94=A8=20ConversationRep?= =?UTF-8?q?ository=20=E5=92=8C=20PromptInjectionRepository=20=E6=9B=BF?= =?UTF-8?q?=E4=BB=A3=20SessionStore=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E5=92=8C=E6=8F=90=E7=A4=BA=E6=B3=A8=E5=85=A5=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- src/gateway/prompt_injector.rs | 24 ++++----- src/gateway/session.rs | 26 ++++++--- src/gateway/session_factory.rs | 14 +++-- src/gateway/session_history.rs | 33 +++++++----- src/storage/mod.rs | 5 +- src/storage/ports.rs | 97 +++++++++++++++++++++++++++++++++- 6 files changed, 159 insertions(+), 40 deletions(-) diff --git a/src/gateway/prompt_injector.rs b/src/gateway/prompt_injector.rs index f4f6902..648271d 100644 --- a/src/gateway/prompt_injector.rs +++ b/src/gateway/prompt_injector.rs @@ -2,20 +2,20 @@ use std::sync::Arc; use crate::agent::AgentError; use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT}; -use crate::storage::SessionStore; +use crate::storage::PromptInjectionRepository; use super::prompt::load_agent_prompt; #[derive(Clone)] pub(crate) struct PromptInjector { - store: Arc, + repository: Arc, reinject_every: i64, } impl PromptInjector { - pub(crate) fn new(store: Arc, reinject_every: u64) -> Self { + pub(crate) fn new(repository: Arc, reinject_every: u64) -> Self { Self { - store, + repository, reinject_every: reinject_every as i64, } } @@ -48,16 +48,16 @@ impl PromptInjector { F: FnMut(ChatMessage) -> Result<(), AgentError>, { let session_record = self - .store + .repository .get_session(session_id) .map_err(|err| AgentError::Other(format!("get session error: {}", err)))? .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; - let active_user_turns = - self.store - .count_active_user_messages(session_id) - .map_err(|err| { - AgentError::Other(format!("count active user messages error: {}", err)) - })?; + let active_user_turns = self + .repository + .count_active_user_messages(session_id) + .map_err(|err| { + AgentError::Other(format!("count active user messages error: {}", err)) + })?; if self.reinject_every > 0 && active_user_turns > 0 @@ -66,7 +66,7 @@ impl PromptInjector { { if let Some(agent_prompt) = load_agent_prompt()? { append_message(Self::agent_prompt_message(agent_prompt))?; - self.store + self.repository .mark_agent_prompt_reinjected(session_id) .map_err(|err| { AgentError::Other(format!("mark agent prompt reinjection error: {}", err)) diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 2bfc4e3..6da2d00 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::scheduler::ScheduledAgentTaskOptions; use crate::skills::SkillRuntime; -use crate::storage::{SessionRecord, SessionStore}; +use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository}; use crate::tools::ToolRegistry; use async_trait::async_trait; use std::collections::{HashMap, HashSet}; @@ -105,6 +105,8 @@ impl Session { agent_prompt_reinject_every: u64, ) -> Result { let agent_factory = AgentFactory::new(tools, skills.clone()); + let conversations: Arc = store.clone(); + let skill_events: Arc = store.clone(); let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); Self::with_factories( channel_name, @@ -113,7 +115,8 @@ impl Session { skills, agent_factory, prompt_injector, - store, + conversations, + skill_events, ) .await } @@ -125,7 +128,8 @@ impl Session { skills: Arc, agent_factory: AgentFactory, prompt_injector: PromptInjector, - store: Arc, + conversations: Arc, + skill_events: Arc, ) -> Result { Ok(Self { id: Uuid::new_v4(), @@ -135,7 +139,12 @@ impl Session { skills, agent_factory, compressor: ContextCompressor::from_provider_config(&provider_config), - history: SessionHistory::new(channel_name, prompt_injector, store), + history: SessionHistory::new( + channel_name, + prompt_injector, + conversations, + skill_events, + ), }) } @@ -274,8 +283,8 @@ impl Session { self.history.reload_chat_history(chat_id) } - pub(crate) fn store(&self) -> Arc { - self.history.store() + pub(crate) fn store(&self) -> Arc { + self.history.conversations() } pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> { @@ -371,13 +380,16 @@ impl SessionManager { .build(), ); let agent_factory = AgentFactory::new(tools.clone(), skills.clone()); + let conversations: Arc = store.clone(); + let skill_events: Arc = store.clone(); let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); let session_factory = SessionFactory::new( provider_config.clone(), skills.clone(), agent_factory, prompt_injector, - store.clone(), + conversations, + skill_events, ); let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory); let cli_sessions = CliSessionService::new(store.clone()); diff --git a/src/gateway/session_factory.rs b/src/gateway/session_factory.rs index 5838475..8484c4f 100644 --- a/src/gateway/session_factory.rs +++ b/src/gateway/session_factory.rs @@ -6,7 +6,7 @@ use crate::agent::AgentError; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::skills::SkillRuntime; -use crate::storage::SessionStore; +use crate::storage::{ConversationRepository, SkillEventRepository}; use super::agent_factory::AgentFactory; use super::prompt_injector::PromptInjector; @@ -18,7 +18,8 @@ pub(crate) struct SessionFactory { skills: Arc, agent_factory: AgentFactory, prompt_injector: PromptInjector, - store: Arc, + conversations: Arc, + skill_events: Arc, } impl SessionFactory { @@ -27,14 +28,16 @@ impl SessionFactory { skills: Arc, agent_factory: AgentFactory, prompt_injector: PromptInjector, - store: Arc, + conversations: Arc, + skill_events: Arc, ) -> Self { Self { provider_config, skills, agent_factory, prompt_injector, - store, + conversations, + skill_events, } } @@ -50,7 +53,8 @@ impl SessionFactory { self.skills.clone(), self.agent_factory.clone(), self.prompt_injector.clone(), - self.store.clone(), + self.conversations.clone(), + self.skill_events.clone(), ) .await } diff --git a/src/gateway/session_history.rs b/src/gateway/session_history.rs index 8f38048..c209e75 100644 --- a/src/gateway/session_history.rs +++ b/src/gateway/session_history.rs @@ -3,7 +3,9 @@ use std::sync::Arc; use crate::agent::AgentError; use crate::bus::ChatMessage; -use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; +use crate::storage::{ + ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id, +}; use super::prompt_injector::PromptInjector; @@ -20,21 +22,24 @@ pub(crate) struct SessionHistory { chat_histories: HashMap>, compression_in_flight: HashSet, prompt_injector: PromptInjector, - store: Arc, + conversations: Arc, + skill_events: Arc, } impl SessionHistory { pub(crate) fn new( channel_name: impl Into, prompt_injector: PromptInjector, - store: Arc, + conversations: Arc, + skill_events: Arc, ) -> Self { Self { channel_name: channel_name.into(), chat_histories: HashMap::new(), compression_in_flight: HashSet::new(), prompt_injector, - store, + conversations, + skill_events, } } @@ -46,7 +51,7 @@ impl SessionHistory { &self, chat_id: &str, ) -> Result { - self.store + self.conversations .ensure_channel_session(&self.channel_name, chat_id) .map_err(|err| AgentError::Other(format!("session persistence error: {}", err))) } @@ -57,7 +62,7 @@ impl SessionHistory { } let history = self - .store + .conversations .load_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; self.chat_histories.insert(chat_id.to_string(), history); @@ -103,7 +108,7 @@ impl SessionHistory { tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); } - self.store + self.conversations .clear_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) } @@ -116,7 +121,7 @@ impl SessionHistory { tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory"); } - self.store + self.conversations .reset_session(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err))) } @@ -127,7 +132,7 @@ impl SessionHistory { message: ChatMessage, ) -> Result<(), AgentError> { let session_id = self.persistent_session_id(chat_id); - self.store + self.conversations .append_message(&session_id, &message) .map_err(|err| { AgentError::Other(format!("append message persistence error: {}", err)) @@ -196,7 +201,7 @@ impl SessionHistory { tracing::debug!(previous_total = total, "All chat histories cleared"); for chat_id in chat_ids { - self.store + self.conversations .clear_messages(&self.persistent_session_id(&chat_id)) .map_err(|err| { AgentError::Other(format!("clear history persistence error: {}", err)) @@ -216,15 +221,15 @@ impl SessionHistory { pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { let history = self - .store + .conversations .load_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?; self.chat_histories.insert(chat_id.to_string(), history); Ok(()) } - pub(crate) fn store(&self) -> Arc { - self.store.clone() + pub(crate) fn conversations(&self) -> Arc { + self.conversations.clone() } pub(crate) fn append_skill_event( @@ -234,7 +239,7 @@ impl SessionHistory { skill_name: Option<&str>, payload: &serde_json::Value, ) -> Result<(), AgentError> { - self.store + self.skill_events .append_skill_event( Some(&self.persistent_session_id(chat_id)), event_type, diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 130bdfb..a209c33 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -11,7 +11,10 @@ pub mod ports; pub mod records; pub use error::StorageError; -pub use ports::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; +pub use ports::{ + ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository, + SkillEventRepository, +}; pub use records::{ MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, diff --git a/src/storage/ports.rs b/src/storage/ports.rs index 697e93d..a2edc72 100644 --- a/src/storage/ports.rs +++ b/src/storage/ports.rs @@ -1,7 +1,42 @@ use super::{ MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, - SchedulerJobUpsert, SkillEventRecord, StorageError, + SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError, }; +use crate::bus::ChatMessage; + +pub trait ConversationRepository: Send + Sync + 'static { + fn ensure_channel_session( + &self, + channel_name: &str, + chat_id: &str, + ) -> Result; + + fn load_messages(&self, session_id: &str) -> Result, StorageError>; + + fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>; + + fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>; + + fn reset_session(&self, session_id: &str) -> Result<(), StorageError>; + + fn compact_active_history( + &self, + session_id: &str, + expected_reset_cutoff_seq: i64, + snapshot_end_seq: i64, + preserved_system_messages: &[ChatMessage], + summary_message: &ChatMessage, + preserved_messages: &[ChatMessage], + ) -> Result; +} + +pub trait PromptInjectionRepository: Send + Sync + 'static { + fn get_session(&self, session_id: &str) -> Result, StorageError>; + + fn count_active_user_messages(&self, session_id: &str) -> Result; + + fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError>; +} pub trait MemoryRepository: Send + Sync + 'static { fn put_memory(&self, input: &MemoryUpsert) -> Result; @@ -86,6 +121,66 @@ pub trait SkillEventRepository: Send + Sync + 'static { ) -> Result, StorageError>; } +impl ConversationRepository for super::SessionStore { + fn ensure_channel_session( + &self, + channel_name: &str, + chat_id: &str, + ) -> Result { + super::SessionStore::ensure_channel_session(self, channel_name, chat_id) + } + + fn load_messages(&self, session_id: &str) -> Result, StorageError> { + super::SessionStore::load_messages(self, session_id) + } + + fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> { + super::SessionStore::append_message(self, session_id, message) + } + + fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { + super::SessionStore::clear_messages(self, session_id) + } + + fn reset_session(&self, session_id: &str) -> Result<(), StorageError> { + super::SessionStore::reset_session(self, session_id) + } + + fn compact_active_history( + &self, + session_id: &str, + expected_reset_cutoff_seq: i64, + snapshot_end_seq: i64, + preserved_system_messages: &[ChatMessage], + summary_message: &ChatMessage, + preserved_messages: &[ChatMessage], + ) -> Result { + super::SessionStore::compact_active_history( + self, + session_id, + expected_reset_cutoff_seq, + snapshot_end_seq, + preserved_system_messages, + summary_message, + preserved_messages, + ) + } +} + +impl PromptInjectionRepository for super::SessionStore { + fn get_session(&self, session_id: &str) -> Result, StorageError> { + super::SessionStore::get_session(self, session_id) + } + + fn count_active_user_messages(&self, session_id: &str) -> Result { + super::SessionStore::count_active_user_messages(self, session_id) + } + + fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> { + super::SessionStore::mark_agent_prompt_reinjected(self, session_id) + } +} + impl MemoryRepository for super::SessionStore { fn put_memory(&self, input: &MemoryUpsert) -> Result { super::SessionStore::put_memory(self, input)