feat: 移除冗余的 provider_config 字段,优化 AgentFactory 和 Session 的构造逻辑

This commit is contained in:
oudecheng 2026-05-13 15:41:52 +08:00
parent a06fceaf0c
commit 1c1efcabf4
7 changed files with 1 additions and 97 deletions

View File

@ -539,6 +539,7 @@ pub trait SkillProvider: Send + Sync + 'static {
} }
#[derive(Default)] #[derive(Default)]
#[allow(dead_code)]
struct EmptySkillProvider; struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider { impl SkillProvider for EmptySkillProvider {

View File

@ -12,7 +12,6 @@ use crate::tools::{ToolContext, ToolRegistry};
pub(crate) struct AgentFactory { pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
provider_config: LLMProviderConfig,
reinject_every: usize, reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>, prompt_repository: Arc<dyn PromptInjectionRepository>,
} }
@ -30,14 +29,12 @@ impl AgentFactory {
pub(crate) fn new( pub(crate) fn new(
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
provider_config: LLMProviderConfig,
reinject_every: usize, reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>, prompt_repository: Arc<dyn PromptInjectionRepository>,
) -> Self { ) -> Self {
Self { Self {
tools, tools,
skills, skills,
provider_config,
reinject_every, reinject_every,
prompt_repository, prompt_repository,
} }

View File

@ -12,7 +12,6 @@ pub mod message_prepare;
pub mod outbound_dispatcher; pub mod outbound_dispatcher;
pub mod processor; pub mod processor;
pub mod prompt; pub mod prompt;
pub mod prompt_injector;
pub mod provider_config_service; pub mod provider_config_service;
pub mod runtime; pub mod runtime;
pub mod scheduled_agent_task_service; pub mod scheduled_agent_task_service;

View File

@ -1,86 +0,0 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::PromptInjectionRepository;
use super::prompt::load_agent_prompt;
#[derive(Clone)]
pub(crate) struct PromptInjector {
repository: Arc<dyn PromptInjectionRepository>,
reinject_every: i64,
}
impl PromptInjector {
pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
Self {
repository,
reinject_every: reinject_every as i64,
}
}
pub(crate) fn ensure_initial_prompt<F>(
&self,
history_is_empty: bool,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
}
Ok(())
}
pub(crate) fn ensure_reinjected_prompt<F>(
&self,
session_id: &str,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
let session_record = self
.repository
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = self
.repository
.count_active_user_messages(session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
})?;
if self.reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
self.repository
.mark_agent_prompt_reinjected(session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
})?;
}
}
Ok(())
}
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
)
}
}

View File

@ -89,7 +89,6 @@ pub(crate) fn build_session_manager_with_sender(
let agent_factory = AgentFactory::new( let agent_factory = AgentFactory::new(
tools.clone(), tools.clone(),
skills.clone(), skills.clone(),
provider_config.clone(),
agent_prompt_reinject_every as usize, agent_prompt_reinject_every as usize,
prompt_repository.clone(), prompt_repository.clone(),
); );

View File

@ -107,7 +107,6 @@ impl Session {
let agent_factory = AgentFactory::new( let agent_factory = AgentFactory::new(
tools, tools,
skills.clone(), skills.clone(),
provider_config.clone(),
agent_prompt_reinject_every as usize, agent_prompt_reinject_every as usize,
prompt_repository.clone(), prompt_repository.clone(),
); );
@ -146,7 +145,6 @@ impl Session {
channel_name, channel_name,
conversations, conversations,
skill_events, skill_events,
provider_config,
chat_history_ttl_hours, chat_history_ttl_hours,
), ),
}) })

View File

@ -3,7 +3,6 @@ use std::sync::Arc;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::storage::{ use crate::storage::{
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id, ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
}; };
@ -29,7 +28,6 @@ pub(crate) struct SessionHistory {
compression_in_flight: HashSet<String>, compression_in_flight: HashSet<String>,
conversations: Arc<dyn ConversationRepository>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
provider_config: LLMProviderConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
} }
@ -38,7 +36,6 @@ impl SessionHistory {
channel_name: impl Into<String>, channel_name: impl Into<String>,
conversations: Arc<dyn ConversationRepository>, conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>, skill_events: Arc<dyn SkillEventRepository>,
provider_config: LLMProviderConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
) -> Self { ) -> Self {
Self { Self {
@ -47,7 +44,6 @@ impl SessionHistory {
compression_in_flight: HashSet::new(), compression_in_flight: HashSet::new(),
conversations, conversations,
skill_events, skill_events,
provider_config,
chat_history_ttl_hours, chat_history_ttl_hours,
} }
} }