feat: 重构存储逻辑,使用 ConversationRepository 和 PromptInjectionRepository 替代 SessionStore,优化会话和提示注入管理

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-28 15:55:27 +08:00
parent f48b132bb9
commit 891830779f
6 changed files with 159 additions and 40 deletions

View File

@ -2,20 +2,20 @@ use std::sync::Arc;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT}; use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::SessionStore; use crate::storage::PromptInjectionRepository;
use super::prompt::load_agent_prompt; use super::prompt::load_agent_prompt;
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct PromptInjector { pub(crate) struct PromptInjector {
store: Arc<SessionStore>, repository: Arc<dyn PromptInjectionRepository>,
reinject_every: i64, reinject_every: i64,
} }
impl PromptInjector { impl PromptInjector {
pub(crate) fn new(store: Arc<SessionStore>, reinject_every: u64) -> Self { pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
Self { Self {
store, repository,
reinject_every: reinject_every as i64, reinject_every: reinject_every as i64,
} }
} }
@ -48,16 +48,16 @@ impl PromptInjector {
F: FnMut(ChatMessage) -> Result<(), AgentError>, F: FnMut(ChatMessage) -> Result<(), AgentError>,
{ {
let session_record = self let session_record = self
.store .repository
.get_session(session_id) .get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))? .map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?; .ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = let active_user_turns = self
self.store .repository
.count_active_user_messages(session_id) .count_active_user_messages(session_id)
.map_err(|err| { .map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err)) AgentError::Other(format!("count active user messages error: {}", err))
})?; })?;
if self.reinject_every > 0 if self.reinject_every > 0
&& active_user_turns > 0 && active_user_turns > 0
@ -66,7 +66,7 @@ impl PromptInjector {
{ {
if let Some(agent_prompt) = load_agent_prompt()? { if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?; append_message(Self::agent_prompt_message(agent_prompt))?;
self.store self.repository
.mark_agent_prompt_reinjected(session_id) .mark_agent_prompt_reinjected(session_id)
.map_err(|err| { .map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err)) AgentError::Other(format!("mark agent prompt reinjection error: {}", err))

View File

@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::scheduler::ScheduledAgentTaskOptions; use crate::scheduler::ScheduledAgentTaskOptions;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore}; use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
@ -105,6 +105,8 @@ impl Session {
agent_prompt_reinject_every: u64, agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let agent_factory = AgentFactory::new(tools, skills.clone()); let agent_factory = AgentFactory::new(tools, skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
Self::with_factories( Self::with_factories(
channel_name, channel_name,
@ -113,7 +115,8 @@ impl Session {
skills, skills,
agent_factory, agent_factory,
prompt_injector, prompt_injector,
store, conversations,
skill_events,
) )
.await .await
} }
@ -125,7 +128,8 @@ impl Session {
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
agent_factory: AgentFactory, agent_factory: AgentFactory,
prompt_injector: PromptInjector, prompt_injector: PromptInjector,
store: Arc<SessionStore>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
Ok(Self { Ok(Self {
id: Uuid::new_v4(), id: Uuid::new_v4(),
@ -135,7 +139,12 @@ impl Session {
skills, skills,
agent_factory, agent_factory,
compressor: ContextCompressor::from_provider_config(&provider_config), compressor: ContextCompressor::from_provider_config(&provider_config),
history: SessionHistory::new(channel_name, prompt_injector, store), history: SessionHistory::new(
channel_name,
prompt_injector,
conversations,
skill_events,
),
}) })
} }
@ -274,8 +283,8 @@ impl Session {
self.history.reload_chat_history(chat_id) self.history.reload_chat_history(chat_id)
} }
pub(crate) fn store(&self) -> Arc<SessionStore> { pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
self.history.store() self.history.conversations()
} }
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> { pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
@ -371,13 +380,16 @@ impl SessionManager {
.build(), .build(),
); );
let agent_factory = AgentFactory::new(tools.clone(), skills.clone()); let agent_factory = AgentFactory::new(tools.clone(), skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
let session_factory = SessionFactory::new( let session_factory = SessionFactory::new(
provider_config.clone(), provider_config.clone(),
skills.clone(), skills.clone(),
agent_factory, agent_factory,
prompt_injector, prompt_injector,
store.clone(), conversations,
skill_events,
); );
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory); let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
let cli_sessions = CliSessionService::new(store.clone()); let cli_sessions = CliSessionService::new(store.clone());

View File

@ -6,7 +6,7 @@ use crate::agent::AgentError;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::SessionStore; use crate::storage::{ConversationRepository, SkillEventRepository};
use super::agent_factory::AgentFactory; use super::agent_factory::AgentFactory;
use super::prompt_injector::PromptInjector; use super::prompt_injector::PromptInjector;
@ -18,7 +18,8 @@ pub(crate) struct SessionFactory {
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
agent_factory: AgentFactory, agent_factory: AgentFactory,
prompt_injector: PromptInjector, prompt_injector: PromptInjector,
store: Arc<SessionStore>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
} }
impl SessionFactory { impl SessionFactory {
@ -27,14 +28,16 @@ impl SessionFactory {
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
agent_factory: AgentFactory, agent_factory: AgentFactory,
prompt_injector: PromptInjector, prompt_injector: PromptInjector,
store: Arc<SessionStore>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self { ) -> Self {
Self { Self {
provider_config, provider_config,
skills, skills,
agent_factory, agent_factory,
prompt_injector, prompt_injector,
store, conversations,
skill_events,
} }
} }
@ -50,7 +53,8 @@ impl SessionFactory {
self.skills.clone(), self.skills.clone(),
self.agent_factory.clone(), self.agent_factory.clone(),
self.prompt_injector.clone(), self.prompt_injector.clone(),
self.store.clone(), self.conversations.clone(),
self.skill_events.clone(),
) )
.await .await
} }

View File

@ -3,7 +3,9 @@ use std::sync::Arc;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::storage::{
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
};
use super::prompt_injector::PromptInjector; use super::prompt_injector::PromptInjector;
@ -20,21 +22,24 @@ pub(crate) struct SessionHistory {
chat_histories: HashMap<String, Vec<ChatMessage>>, chat_histories: HashMap<String, Vec<ChatMessage>>,
compression_in_flight: HashSet<String>, compression_in_flight: HashSet<String>,
prompt_injector: PromptInjector, prompt_injector: PromptInjector,
store: Arc<SessionStore>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
} }
impl SessionHistory { impl SessionHistory {
pub(crate) fn new( pub(crate) fn new(
channel_name: impl Into<String>, channel_name: impl Into<String>,
prompt_injector: PromptInjector, prompt_injector: PromptInjector,
store: Arc<SessionStore>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self { ) -> Self {
Self { Self {
channel_name: channel_name.into(), channel_name: channel_name.into(),
chat_histories: HashMap::new(), chat_histories: HashMap::new(),
compression_in_flight: HashSet::new(), compression_in_flight: HashSet::new(),
prompt_injector, prompt_injector,
store, conversations,
skill_events,
} }
} }
@ -46,7 +51,7 @@ impl SessionHistory {
&self, &self,
chat_id: &str, chat_id: &str,
) -> Result<SessionRecord, AgentError> { ) -> Result<SessionRecord, AgentError> {
self.store self.conversations
.ensure_channel_session(&self.channel_name, chat_id) .ensure_channel_session(&self.channel_name, chat_id)
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err))) .map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
} }
@ -57,7 +62,7 @@ impl SessionHistory {
} }
let history = self let history = self
.store .conversations
.load_messages(&self.persistent_session_id(chat_id)) .load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history); self.chat_histories.insert(chat_id.to_string(), history);
@ -103,7 +108,7 @@ impl SessionHistory {
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
} }
self.store self.conversations
.clear_messages(&self.persistent_session_id(chat_id)) .clear_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
} }
@ -116,7 +121,7 @@ impl SessionHistory {
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory"); tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
} }
self.store self.conversations
.reset_session(&self.persistent_session_id(chat_id)) .reset_session(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err))) .map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
} }
@ -127,7 +132,7 @@ impl SessionHistory {
message: ChatMessage, message: ChatMessage,
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id); let session_id = self.persistent_session_id(chat_id);
self.store self.conversations
.append_message(&session_id, &message) .append_message(&session_id, &message)
.map_err(|err| { .map_err(|err| {
AgentError::Other(format!("append message persistence error: {}", err)) AgentError::Other(format!("append message persistence error: {}", err))
@ -196,7 +201,7 @@ impl SessionHistory {
tracing::debug!(previous_total = total, "All chat histories cleared"); tracing::debug!(previous_total = total, "All chat histories cleared");
for chat_id in chat_ids { for chat_id in chat_ids {
self.store self.conversations
.clear_messages(&self.persistent_session_id(&chat_id)) .clear_messages(&self.persistent_session_id(&chat_id))
.map_err(|err| { .map_err(|err| {
AgentError::Other(format!("clear history persistence error: {}", err)) AgentError::Other(format!("clear history persistence error: {}", err))
@ -216,15 +221,15 @@ impl SessionHistory {
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history = self let history = self
.store .conversations
.load_messages(&self.persistent_session_id(chat_id)) .load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?; .map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history); self.chat_histories.insert(chat_id.to_string(), history);
Ok(()) Ok(())
} }
pub(crate) fn store(&self) -> Arc<SessionStore> { pub(crate) fn conversations(&self) -> Arc<dyn ConversationRepository> {
self.store.clone() self.conversations.clone()
} }
pub(crate) fn append_skill_event( pub(crate) fn append_skill_event(
@ -234,7 +239,7 @@ impl SessionHistory {
skill_name: Option<&str>, skill_name: Option<&str>,
payload: &serde_json::Value, payload: &serde_json::Value,
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
self.store self.skill_events
.append_skill_event( .append_skill_event(
Some(&self.persistent_session_id(chat_id)), Some(&self.persistent_session_id(chat_id)),
event_type, event_type,

View File

@ -11,7 +11,10 @@ pub mod ports;
pub mod records; pub mod records;
pub use error::StorageError; pub use error::StorageError;
pub use ports::{MemoryRepository, SchedulerJobRepository, SkillEventRepository}; pub use ports::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SkillEventRepository,
};
pub use records::{ pub use records::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SessionRecord, SkillEventRecord, SchedulerJobUpsert, SessionRecord, SkillEventRecord,

View File

@ -1,7 +1,42 @@
use super::{ use super::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SkillEventRecord, StorageError, SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError,
}; };
use crate::bus::ChatMessage;
pub trait ConversationRepository: Send + Sync + 'static {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError>;
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>;
fn reset_session(&self, session_id: &str) -> Result<(), StorageError>;
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError>;
}
pub trait PromptInjectionRepository: Send + Sync + 'static {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError>;
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError>;
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError>;
}
pub trait MemoryRepository: Send + Sync + 'static { pub trait MemoryRepository: Send + Sync + 'static {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>; fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>;
@ -86,6 +121,66 @@ pub trait SkillEventRepository: Send + Sync + 'static {
) -> Result<Vec<SkillEventRecord>, StorageError>; ) -> Result<Vec<SkillEventRecord>, StorageError>;
} }
impl ConversationRepository for super::SessionStore {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError> {
super::SessionStore::ensure_channel_session(self, channel_name, chat_id)
}
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
super::SessionStore::load_messages(self, session_id)
}
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
super::SessionStore::append_message(self, session_id, message)
}
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::clear_messages(self, session_id)
}
fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::reset_session(self, session_id)
}
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError> {
super::SessionStore::compact_active_history(
self,
session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
preserved_system_messages,
summary_message,
preserved_messages,
)
}
}
impl PromptInjectionRepository for super::SessionStore {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
super::SessionStore::get_session(self, session_id)
}
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
super::SessionStore::count_active_user_messages(self, session_id)
}
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::mark_agent_prompt_reinjected(self, session_id)
}
}
impl MemoryRepository for super::SessionStore { impl MemoryRepository for super::SessionStore {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> { fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
super::SessionStore::put_memory(self, input) super::SessionStore::put_memory(self, input)