feat: 重构存储逻辑,使用 ConversationRepository 和 PromptInjectionRepository 替代 SessionStore,优化会话和提示注入管理

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-28 15:55:27 +08:00
parent f48b132bb9
commit 891830779f
6 changed files with 159 additions and 40 deletions

View File

@ -2,20 +2,20 @@ use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::SessionStore;
use crate::storage::PromptInjectionRepository;
use super::prompt::load_agent_prompt;
#[derive(Clone)]
pub(crate) struct PromptInjector {
store: Arc<SessionStore>,
repository: Arc<dyn PromptInjectionRepository>,
reinject_every: i64,
}
impl PromptInjector {
pub(crate) fn new(store: Arc<SessionStore>, reinject_every: u64) -> Self {
pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
Self {
store,
repository,
reinject_every: reinject_every as i64,
}
}
@ -48,12 +48,12 @@ impl PromptInjector {
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
let session_record = self
.store
.repository
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns =
self.store
let active_user_turns = self
.repository
.count_active_user_messages(session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
@ -66,7 +66,7 @@ impl PromptInjector {
{
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
self.store
self.repository
.mark_agent_prompt_reinjected(session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))

View File

@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::scheduler::ScheduledAgentTaskOptions;
use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore};
use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
@ -105,6 +105,8 @@ impl Session {
agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> {
let agent_factory = AgentFactory::new(tools, skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
Self::with_factories(
channel_name,
@ -113,7 +115,8 @@ impl Session {
skills,
agent_factory,
prompt_injector,
store,
conversations,
skill_events,
)
.await
}
@ -125,7 +128,8 @@ impl Session {
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Result<Self, AgentError> {
Ok(Self {
id: Uuid::new_v4(),
@ -135,7 +139,12 @@ impl Session {
skills,
agent_factory,
compressor: ContextCompressor::from_provider_config(&provider_config),
history: SessionHistory::new(channel_name, prompt_injector, store),
history: SessionHistory::new(
channel_name,
prompt_injector,
conversations,
skill_events,
),
})
}
@ -274,8 +283,8 @@ impl Session {
self.history.reload_chat_history(chat_id)
}
pub(crate) fn store(&self) -> Arc<SessionStore> {
self.history.store()
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
self.history.conversations()
}
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
@ -371,13 +380,16 @@ impl SessionManager {
.build(),
);
let agent_factory = AgentFactory::new(tools.clone(), skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
let session_factory = SessionFactory::new(
provider_config.clone(),
skills.clone(),
agent_factory,
prompt_injector,
store.clone(),
conversations,
skill_events,
);
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
let cli_sessions = CliSessionService::new(store.clone());

View File

@ -6,7 +6,7 @@ use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::storage::{ConversationRepository, SkillEventRepository};
use super::agent_factory::AgentFactory;
use super::prompt_injector::PromptInjector;
@ -18,7 +18,8 @@ pub(crate) struct SessionFactory {
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
}
impl SessionFactory {
@ -27,14 +28,16 @@ impl SessionFactory {
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self {
Self {
provider_config,
skills,
agent_factory,
prompt_injector,
store,
conversations,
skill_events,
}
}
@ -50,7 +53,8 @@ impl SessionFactory {
self.skills.clone(),
self.agent_factory.clone(),
self.prompt_injector.clone(),
self.store.clone(),
self.conversations.clone(),
self.skill_events.clone(),
)
.await
}

View File

@ -3,7 +3,9 @@ use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::ChatMessage;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::storage::{
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
};
use super::prompt_injector::PromptInjector;
@ -20,21 +22,24 @@ pub(crate) struct SessionHistory {
chat_histories: HashMap<String, Vec<ChatMessage>>,
compression_in_flight: HashSet<String>,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
}
impl SessionHistory {
pub(crate) fn new(
channel_name: impl Into<String>,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self {
Self {
channel_name: channel_name.into(),
chat_histories: HashMap::new(),
compression_in_flight: HashSet::new(),
prompt_injector,
store,
conversations,
skill_events,
}
}
@ -46,7 +51,7 @@ impl SessionHistory {
&self,
chat_id: &str,
) -> Result<SessionRecord, AgentError> {
self.store
self.conversations
.ensure_channel_session(&self.channel_name, chat_id)
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
}
@ -57,7 +62,7 @@ impl SessionHistory {
}
let history = self
.store
.conversations
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
@ -103,7 +108,7 @@ impl SessionHistory {
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
}
self.store
self.conversations
.clear_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
}
@ -116,7 +121,7 @@ impl SessionHistory {
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
}
self.store
self.conversations
.reset_session(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
}
@ -127,7 +132,7 @@ impl SessionHistory {
message: ChatMessage,
) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id);
self.store
self.conversations
.append_message(&session_id, &message)
.map_err(|err| {
AgentError::Other(format!("append message persistence error: {}", err))
@ -196,7 +201,7 @@ impl SessionHistory {
tracing::debug!(previous_total = total, "All chat histories cleared");
for chat_id in chat_ids {
self.store
self.conversations
.clear_messages(&self.persistent_session_id(&chat_id))
.map_err(|err| {
AgentError::Other(format!("clear history persistence error: {}", err))
@ -216,15 +221,15 @@ impl SessionHistory {
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history = self
.store
.conversations
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
Ok(())
}
pub(crate) fn store(&self) -> Arc<SessionStore> {
self.store.clone()
pub(crate) fn conversations(&self) -> Arc<dyn ConversationRepository> {
self.conversations.clone()
}
pub(crate) fn append_skill_event(
@ -234,7 +239,7 @@ impl SessionHistory {
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), AgentError> {
self.store
self.skill_events
.append_skill_event(
Some(&self.persistent_session_id(chat_id)),
event_type,

View File

@ -11,7 +11,10 @@ pub mod ports;
pub mod records;
pub use error::StorageError;
pub use ports::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
pub use ports::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SkillEventRepository,
};
pub use records::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SessionRecord, SkillEventRecord,

View File

@ -1,7 +1,42 @@
use super::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SkillEventRecord, StorageError,
SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError,
};
use crate::bus::ChatMessage;
pub trait ConversationRepository: Send + Sync + 'static {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError>;
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>;
fn reset_session(&self, session_id: &str) -> Result<(), StorageError>;
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError>;
}
pub trait PromptInjectionRepository: Send + Sync + 'static {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError>;
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError>;
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError>;
}
pub trait MemoryRepository: Send + Sync + 'static {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>;
@ -86,6 +121,66 @@ pub trait SkillEventRepository: Send + Sync + 'static {
) -> Result<Vec<SkillEventRecord>, StorageError>;
}
impl ConversationRepository for super::SessionStore {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError> {
super::SessionStore::ensure_channel_session(self, channel_name, chat_id)
}
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
super::SessionStore::load_messages(self, session_id)
}
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
super::SessionStore::append_message(self, session_id, message)
}
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::clear_messages(self, session_id)
}
fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::reset_session(self, session_id)
}
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError> {
super::SessionStore::compact_active_history(
self,
session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
preserved_system_messages,
summary_message,
preserved_messages,
)
}
}
impl PromptInjectionRepository for super::SessionStore {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
super::SessionStore::get_session(self, session_id)
}
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
super::SessionStore::count_active_user_messages(self, session_id)
}
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::mark_agent_prompt_reinjected(self, session_id)
}
}
impl MemoryRepository for super::SessionStore {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
super::SessionStore::put_memory(self, input)