feat: 重构存储逻辑,使用 ConversationRepository 和 PromptInjectionRepository 替代 SessionStore,优化会话和提示注入管理
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
f48b132bb9
commit
891830779f
@ -2,20 +2,20 @@ use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
|
||||
use crate::storage::SessionStore;
|
||||
use crate::storage::PromptInjectionRepository;
|
||||
|
||||
use super::prompt::load_agent_prompt;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct PromptInjector {
|
||||
store: Arc<SessionStore>,
|
||||
repository: Arc<dyn PromptInjectionRepository>,
|
||||
reinject_every: i64,
|
||||
}
|
||||
|
||||
impl PromptInjector {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, reinject_every: u64) -> Self {
|
||||
pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
|
||||
Self {
|
||||
store,
|
||||
repository,
|
||||
reinject_every: reinject_every as i64,
|
||||
}
|
||||
}
|
||||
@ -48,12 +48,12 @@ impl PromptInjector {
|
||||
F: FnMut(ChatMessage) -> Result<(), AgentError>,
|
||||
{
|
||||
let session_record = self
|
||||
.store
|
||||
.repository
|
||||
.get_session(session_id)
|
||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
let active_user_turns =
|
||||
self.store
|
||||
let active_user_turns = self
|
||||
.repository
|
||||
.count_active_user_messages(session_id)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("count active user messages error: {}", err))
|
||||
@ -66,7 +66,7 @@ impl PromptInjector {
|
||||
{
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
append_message(Self::agent_prompt_message(agent_prompt))?;
|
||||
self.store
|
||||
self.repository
|
||||
.mark_agent_prompt_reinjected(session_id)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
|
||||
|
||||
@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::scheduler::ScheduledAgentTaskOptions;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionRecord, SessionStore};
|
||||
use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository};
|
||||
use crate::tools::ToolRegistry;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -105,6 +105,8 @@ impl Session {
|
||||
agent_prompt_reinject_every: u64,
|
||||
) -> Result<Self, AgentError> {
|
||||
let agent_factory = AgentFactory::new(tools, skills.clone());
|
||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||||
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||||
Self::with_factories(
|
||||
channel_name,
|
||||
@ -113,7 +115,8 @@ impl Session {
|
||||
skills,
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
store,
|
||||
conversations,
|
||||
skill_events,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@ -125,7 +128,8 @@ impl Session {
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
@ -135,7 +139,12 @@ impl Session {
|
||||
skills,
|
||||
agent_factory,
|
||||
compressor: ContextCompressor::from_provider_config(&provider_config),
|
||||
history: SessionHistory::new(channel_name, prompt_injector, store),
|
||||
history: SessionHistory::new(
|
||||
channel_name,
|
||||
prompt_injector,
|
||||
conversations,
|
||||
skill_events,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
@ -274,8 +283,8 @@ impl Session {
|
||||
self.history.reload_chat_history(chat_id)
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self) -> Arc<SessionStore> {
|
||||
self.history.store()
|
||||
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
|
||||
self.history.conversations()
|
||||
}
|
||||
|
||||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||||
@ -371,13 +380,16 @@ impl SessionManager {
|
||||
.build(),
|
||||
);
|
||||
let agent_factory = AgentFactory::new(tools.clone(), skills.clone());
|
||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||||
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||||
let session_factory = SessionFactory::new(
|
||||
provider_config.clone(),
|
||||
skills.clone(),
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
store.clone(),
|
||||
conversations,
|
||||
skill_events,
|
||||
);
|
||||
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
|
||||
let cli_sessions = CliSessionService::new(store.clone());
|
||||
|
||||
@ -6,7 +6,7 @@ use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::storage::{ConversationRepository, SkillEventRepository};
|
||||
|
||||
use super::agent_factory::AgentFactory;
|
||||
use super::prompt_injector::PromptInjector;
|
||||
@ -18,7 +18,8 @@ pub(crate) struct SessionFactory {
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
}
|
||||
|
||||
impl SessionFactory {
|
||||
@ -27,14 +28,16 @@ impl SessionFactory {
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_config,
|
||||
skills,
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
store,
|
||||
conversations,
|
||||
skill_events,
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +53,8 @@ impl SessionFactory {
|
||||
self.skills.clone(),
|
||||
self.agent_factory.clone(),
|
||||
self.prompt_injector.clone(),
|
||||
self.store.clone(),
|
||||
self.conversations.clone(),
|
||||
self.skill_events.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@ -3,7 +3,9 @@ use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
use crate::storage::{
|
||||
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
|
||||
};
|
||||
|
||||
use super::prompt_injector::PromptInjector;
|
||||
|
||||
@ -20,21 +22,24 @@ pub(crate) struct SessionHistory {
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
compression_in_flight: HashSet<String>,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
}
|
||||
|
||||
impl SessionHistory {
|
||||
pub(crate) fn new(
|
||||
channel_name: impl Into<String>,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel_name: channel_name.into(),
|
||||
chat_histories: HashMap::new(),
|
||||
compression_in_flight: HashSet::new(),
|
||||
prompt_injector,
|
||||
store,
|
||||
conversations,
|
||||
skill_events,
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,7 +51,7 @@ impl SessionHistory {
|
||||
&self,
|
||||
chat_id: &str,
|
||||
) -> Result<SessionRecord, AgentError> {
|
||||
self.store
|
||||
self.conversations
|
||||
.ensure_channel_session(&self.channel_name, chat_id)
|
||||
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
|
||||
}
|
||||
@ -57,7 +62,7 @@ impl SessionHistory {
|
||||
}
|
||||
|
||||
let history = self
|
||||
.store
|
||||
.conversations
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
@ -103,7 +108,7 @@ impl SessionHistory {
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||
}
|
||||
|
||||
self.store
|
||||
self.conversations
|
||||
.clear_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||||
}
|
||||
@ -116,7 +121,7 @@ impl SessionHistory {
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
|
||||
}
|
||||
|
||||
self.store
|
||||
self.conversations
|
||||
.reset_session(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
|
||||
}
|
||||
@ -127,7 +132,7 @@ impl SessionHistory {
|
||||
message: ChatMessage,
|
||||
) -> Result<(), AgentError> {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
self.store
|
||||
self.conversations
|
||||
.append_message(&session_id, &message)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("append message persistence error: {}", err))
|
||||
@ -196,7 +201,7 @@ impl SessionHistory {
|
||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||
|
||||
for chat_id in chat_ids {
|
||||
self.store
|
||||
self.conversations
|
||||
.clear_messages(&self.persistent_session_id(&chat_id))
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("clear history persistence error: {}", err))
|
||||
@ -216,15 +221,15 @@ impl SessionHistory {
|
||||
|
||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history = self
|
||||
.store
|
||||
.conversations
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self) -> Arc<SessionStore> {
|
||||
self.store.clone()
|
||||
pub(crate) fn conversations(&self) -> Arc<dyn ConversationRepository> {
|
||||
self.conversations.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn append_skill_event(
|
||||
@ -234,7 +239,7 @@ impl SessionHistory {
|
||||
skill_name: Option<&str>,
|
||||
payload: &serde_json::Value,
|
||||
) -> Result<(), AgentError> {
|
||||
self.store
|
||||
self.skill_events
|
||||
.append_skill_event(
|
||||
Some(&self.persistent_session_id(chat_id)),
|
||||
event_type,
|
||||
|
||||
@ -11,7 +11,10 @@ pub mod ports;
|
||||
pub mod records;
|
||||
|
||||
pub use error::StorageError;
|
||||
pub use ports::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||
pub use ports::{
|
||||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||||
SkillEventRepository,
|
||||
};
|
||||
pub use records::{
|
||||
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
|
||||
SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||||
|
||||
@ -1,7 +1,42 @@
|
||||
use super::{
|
||||
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
|
||||
SchedulerJobUpsert, SkillEventRecord, StorageError,
|
||||
SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError,
|
||||
};
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
pub trait ConversationRepository: Send + Sync + 'static {
|
||||
fn ensure_channel_session(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
chat_id: &str,
|
||||
) -> Result<SessionRecord, StorageError>;
|
||||
|
||||
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
|
||||
|
||||
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
|
||||
|
||||
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>;
|
||||
|
||||
fn reset_session(&self, session_id: &str) -> Result<(), StorageError>;
|
||||
|
||||
fn compact_active_history(
|
||||
&self,
|
||||
session_id: &str,
|
||||
expected_reset_cutoff_seq: i64,
|
||||
snapshot_end_seq: i64,
|
||||
preserved_system_messages: &[ChatMessage],
|
||||
summary_message: &ChatMessage,
|
||||
preserved_messages: &[ChatMessage],
|
||||
) -> Result<bool, StorageError>;
|
||||
}
|
||||
|
||||
pub trait PromptInjectionRepository: Send + Sync + 'static {
|
||||
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError>;
|
||||
|
||||
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError>;
|
||||
|
||||
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError>;
|
||||
}
|
||||
|
||||
pub trait MemoryRepository: Send + Sync + 'static {
|
||||
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>;
|
||||
@ -86,6 +121,66 @@ pub trait SkillEventRepository: Send + Sync + 'static {
|
||||
) -> Result<Vec<SkillEventRecord>, StorageError>;
|
||||
}
|
||||
|
||||
impl ConversationRepository for super::SessionStore {
|
||||
fn ensure_channel_session(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
chat_id: &str,
|
||||
) -> Result<SessionRecord, StorageError> {
|
||||
super::SessionStore::ensure_channel_session(self, channel_name, chat_id)
|
||||
}
|
||||
|
||||
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
super::SessionStore::load_messages(self, session_id)
|
||||
}
|
||||
|
||||
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
||||
super::SessionStore::append_message(self, session_id, message)
|
||||
}
|
||||
|
||||
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
super::SessionStore::clear_messages(self, session_id)
|
||||
}
|
||||
|
||||
fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
super::SessionStore::reset_session(self, session_id)
|
||||
}
|
||||
|
||||
fn compact_active_history(
|
||||
&self,
|
||||
session_id: &str,
|
||||
expected_reset_cutoff_seq: i64,
|
||||
snapshot_end_seq: i64,
|
||||
preserved_system_messages: &[ChatMessage],
|
||||
summary_message: &ChatMessage,
|
||||
preserved_messages: &[ChatMessage],
|
||||
) -> Result<bool, StorageError> {
|
||||
super::SessionStore::compact_active_history(
|
||||
self,
|
||||
session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
preserved_system_messages,
|
||||
summary_message,
|
||||
preserved_messages,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptInjectionRepository for super::SessionStore {
|
||||
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
|
||||
super::SessionStore::get_session(self, session_id)
|
||||
}
|
||||
|
||||
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||
super::SessionStore::count_active_user_messages(self, session_id)
|
||||
}
|
||||
|
||||
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
super::SessionStore::mark_agent_prompt_reinjected(self, session_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryRepository for super::SessionStore {
|
||||
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
|
||||
super::SessionStore::put_memory(self, input)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user