From c65921b5e8d1f4022d6eb8c456daac0a3b6f4204 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 28 Apr 2026 13:06:00 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20AgentFactory=20?= =?UTF-8?q?=E5=92=8C=20PromptInjector=EF=BC=8C=E9=87=8D=E6=9E=84=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E6=B3=A8=E5=86=8C=E9=80=BB=E8=BE=91=E4=BB=A5=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=BC=9A=E8=AF=9D=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/gateway/agent_factory.rs | 57 +++++++ src/gateway/mod.rs | 3 + src/gateway/prompt_injector.rs | 86 ++++++++++ src/gateway/session.rs | 245 ++++++++++++--------------- src/gateway/session_factory.rs | 21 +-- src/gateway/tool_registry_factory.rs | 59 +++++++ 6 files changed, 324 insertions(+), 147 deletions(-) create mode 100644 src/gateway/agent_factory.rs create mode 100644 src/gateway/prompt_injector.rs create mode 100644 src/gateway/tool_registry_factory.rs diff --git a/src/gateway/agent_factory.rs b/src/gateway/agent_factory.rs new file mode 100644 index 0000000..c82ac41 --- /dev/null +++ b/src/gateway/agent_factory.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use crate::agent::{AgentError, AgentLoop}; +use crate::config::LLMProviderConfig; +use crate::skills::SkillRuntime; +use crate::storage::{SessionStore, persistent_session_id}; +use crate::tools::{ToolContext, ToolRegistry}; + +#[derive(Clone)] +pub(crate) struct AgentFactory { + tools: Arc, + skills: Arc, + store: Arc, +} + +pub(crate) struct AgentBuildRequest<'a> { + pub(crate) channel_name: &'a str, + pub(crate) chat_id: &'a str, + pub(crate) sender_id: Option<&'a str>, + pub(crate) message_id: Option<&'a str>, + pub(crate) provider_config: LLMProviderConfig, +} + +impl AgentFactory { + pub(crate) fn new( + tools: Arc, + skills: Arc, + store: Arc, + ) -> Self { + Self { + tools, + skills, + store, + } + } + + pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result { + let session_id = persistent_session_id(request.channel_name, request.chat_id); + AgentLoop::with_tools_and_skills( + request.provider_config, + self.tools.clone(), + self.skills.clone(), + ) + .map(|agent| { + agent + .with_skill_event_store(self.store.clone(), session_id.clone()) + .with_tool_context(ToolContext { + channel_name: Some(request.channel_name.to_string()), + sender_id: request.sender_id.map(str::to_string), + chat_id: Some(request.chat_id.to_string()), + session_id: Some(session_id), + message_id: request.message_id.map(str::to_string), + message_seq: None, + }) + }) + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e578c14..ff50c3c 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,3 +1,4 @@ +pub mod agent_factory; pub mod agent_task_executor; pub mod cli_session; pub mod command; @@ -8,9 +9,11 @@ pub mod memory_maintenance; pub mod message_prepare; pub mod processor; pub mod prompt; +pub mod prompt_injector; pub mod session; pub mod session_factory; pub mod session_pool; +pub mod tool_registry_factory; pub mod ws; pub mod ws_adapter; diff --git a/src/gateway/prompt_injector.rs b/src/gateway/prompt_injector.rs new file mode 100644 index 0000000..f4f6902 --- /dev/null +++ b/src/gateway/prompt_injector.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use crate::agent::AgentError; +use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT}; +use crate::storage::SessionStore; + +use super::prompt::load_agent_prompt; + +#[derive(Clone)] +pub(crate) struct PromptInjector { + store: Arc, + reinject_every: i64, +} + +impl PromptInjector { + pub(crate) fn new(store: Arc, reinject_every: u64) -> Self { + Self { + store, + reinject_every: reinject_every as i64, + } + } + + pub(crate) fn ensure_initial_prompt( + &self, + history_is_empty: bool, + mut append_message: F, + ) -> Result<(), AgentError> + where + F: FnMut(ChatMessage) -> Result<(), AgentError>, + { + if !history_is_empty { + return Ok(()); + } + + if let Some(agent_prompt) = load_agent_prompt()? { + append_message(Self::agent_prompt_message(agent_prompt))?; + } + + Ok(()) + } + + pub(crate) fn ensure_reinjected_prompt( + &self, + session_id: &str, + mut append_message: F, + ) -> Result<(), AgentError> + where + F: FnMut(ChatMessage) -> Result<(), AgentError>, + { + let session_record = self + .store + .get_session(session_id) + .map_err(|err| AgentError::Other(format!("get session error: {}", err)))? + .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; + let active_user_turns = + self.store + .count_active_user_messages(session_id) + .map_err(|err| { + AgentError::Other(format!("count active user messages error: {}", err)) + })?; + + if self.reinject_every > 0 + && active_user_turns > 0 + && active_user_turns / self.reinject_every + > session_record.agent_prompt_reinjection_count + { + if let Some(agent_prompt) = load_agent_prompt()? { + append_message(Self::agent_prompt_message(agent_prompt))?; + self.store + .mark_agent_prompt_reinjected(session_id) + .map_err(|err| { + AgentError::Other(format!("mark agent prompt reinjection error: {}", err)) + })?; + } + } + + Ok(()) + } + + fn agent_prompt_message(agent_prompt: String) -> ChatMessage { + ChatMessage::system_with_context( + agent_prompt, + Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()), + ) + } +} diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 2e9a54f..612119a 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,22 +1,19 @@ use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler}; #[cfg(test)] use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT; -use crate::bus::{ChatMessage, MessageBus, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT}; +use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::skills::SkillRuntime; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; -use crate::tools::{ - BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, - MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool, - TimeTool, ToolContext, ToolRegistry, WebFetchTool, -}; +use crate::tools::ToolRegistry; use async_trait::async_trait; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; +use super::agent_factory::{AgentBuildRequest, AgentFactory}; use super::cli_session::CliSessionService; use super::execution::{ AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest, @@ -31,9 +28,10 @@ use super::memory_maintenance::{ use super::memory_maintenance::{ MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService, }; -use super::prompt::load_agent_prompt; +use super::prompt_injector::PromptInjector; use super::session_factory::SessionFactory; use super::session_pool::SessionPool; +use super::tool_registry_factory::ToolRegistryFactory; fn preview_text(content: &str, max_chars: usize) -> String { let mut preview = content.chars().take(max_chars).collect::(); @@ -53,11 +51,11 @@ pub struct Session { compression_in_flight: HashSet, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, - tools: Arc, skills: Arc, + agent_factory: AgentFactory, + prompt_injector: PromptInjector, compressor: ContextCompressor, store: Arc, - agent_prompt_reinject_every: i64, } pub struct BusToolCallEmitter { @@ -125,6 +123,29 @@ impl Session { skills: Arc, store: Arc, agent_prompt_reinject_every: u64, + ) -> Result { + let agent_factory = AgentFactory::new(tools, skills.clone(), store.clone()); + let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); + Self::with_factories( + channel_name, + provider_config, + user_tx, + skills, + agent_factory, + prompt_injector, + store, + ) + .await + } + + pub(crate) async fn with_factories( + channel_name: String, + provider_config: LLMProviderConfig, + user_tx: mpsc::Sender, + skills: Arc, + agent_factory: AgentFactory, + prompt_injector: PromptInjector, + store: Arc, ) -> Result { Ok(Self { id: Uuid::new_v4(), @@ -133,11 +154,11 @@ impl Session { compression_in_flight: HashSet::new(), user_tx, provider_config: provider_config.clone(), - tools, skills, + agent_factory, + prompt_injector, compressor: ContextCompressor::from_provider_config(&provider_config), store, - agent_prompt_reinject_every: agent_prompt_reinject_every as i64, }) } @@ -172,40 +193,10 @@ impl Session { self.ensure_chat_loaded(chat_id)?; let session_id = self.persistent_session_id(chat_id); - let session_record = self - .store - .get_session(&session_id) - .map_err(|err| AgentError::Other(format!("get session error: {}", err)))? - .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; - let active_user_turns = - self.store - .count_active_user_messages(&session_id) - .map_err(|err| { - AgentError::Other(format!("count active user messages error: {}", err)) - })?; - - if self.agent_prompt_reinject_every > 0 - && active_user_turns > 0 - && active_user_turns / self.agent_prompt_reinject_every - > session_record.agent_prompt_reinjection_count - { - if let Some(agent_prompt) = load_agent_prompt()? { - self.append_persisted_message( - chat_id, - ChatMessage::system_with_context( - agent_prompt, - Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()), - ), - )?; - self.store - .mark_agent_prompt_reinjected(&session_id) - .map_err(|err| { - AgentError::Other(format!("mark agent prompt reinjection error: {}", err)) - })?; - } - } - - Ok(()) + let prompt_injector = self.prompt_injector.clone(); + prompt_injector.ensure_reinjected_prompt(&session_id, |message| { + self.append_persisted_message(chat_id, message) + }) } /// 获取或创建指定 chat_id 的会话历史 @@ -437,20 +428,13 @@ impl Session { message_id: Option<&str>, provider_config: LLMProviderConfig, ) -> Result { - let session_id = self.persistent_session_id(chat_id); - AgentLoop::with_tools_and_skills(provider_config, self.tools.clone(), self.skills.clone()) - .map(|agent| { - agent - .with_skill_event_store(self.store.clone(), session_id.clone()) - .with_tool_context(ToolContext { - channel_name: Some(self.channel_name.clone()), - sender_id: sender_id.map(str::to_string), - chat_id: Some(chat_id.to_string()), - session_id: Some(session_id), - message_id: message_id.map(str::to_string), - message_seq: None, - }) - }) + self.agent_factory.create(AgentBuildRequest { + channel_name: &self.channel_name, + chat_id, + sender_id, + message_id, + provider_config, + }) } fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> { @@ -463,17 +447,10 @@ impl Session { return Ok(()); } - if let Some(agent_prompt) = load_agent_prompt()? { - self.append_persisted_message( - chat_id, - ChatMessage::system_with_context( - agent_prompt, - Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()), - ), - )?; - } - - Ok(()) + let prompt_injector = self.prompt_injector.clone(); + prompt_injector.ensure_initial_prompt(history_is_empty, |message| { + self.append_persisted_message(chat_id, message) + }) } } @@ -490,34 +467,6 @@ pub struct SessionManager { cli_sessions: CliSessionService, } -fn default_tools( - skills: Arc, - store: Arc, - known_agents: HashSet, - default_timezone: String, -) -> ToolRegistry { - let mut registry = ToolRegistry::new(); - registry.register(CalculatorTool::new()); - registry.register(TimeTool::new(default_timezone)); - registry.register(FileReadTool::new()); - registry.register(FileWriteTool::new()); - registry.register(FileEditTool::new()); - registry.register(MemorySearchTool::new(store.clone())); - registry.register(MemoryManageTool::new(store.clone())); - registry.register(SchedulerManageTool::new(store, known_agents)); - registry.register(SkillListTool::new(skills.clone())); - registry.register(SkillManageTool::new(skills)); - registry.register(BashTool::new()); - registry.register(HttpRequestTool::new( - vec!["*".to_string()], // 允许所有域名,实际使用时建议限制 - 1_000_000, // max_response_size - 30, // timeout_secs - false, // allow_private_hosts - )); - registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs - registry -} - impl SessionManager { pub fn new( session_ttl_hours: u64, @@ -540,18 +489,23 @@ impl SessionManager { tracing::warn!(error = %err, "Failed to record skill discovery event"); } - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - known_agents, - default_timezone, - )); + let tools = Arc::new( + ToolRegistryFactory::new( + skills.clone(), + store.clone(), + known_agents, + default_timezone, + ) + .build(), + ); + let agent_factory = AgentFactory::new(tools.clone(), skills.clone(), store.clone()); + let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); let session_factory = SessionFactory::new( provider_config.clone(), - tools.clone(), skills.clone(), + agent_factory, + prompt_injector, store.clone(), - agent_prompt_reinject_every, ); let session_pool = SessionPool::new(session_ttl_hours, session_factory); let cli_sessions = CliSessionService::new(store.clone()); @@ -785,12 +739,15 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); + let tools = Arc::new( + ToolRegistryFactory::new( + skills.clone(), + store.clone(), + HashSet::new(), + "Asia/Shanghai".to_string(), + ) + .build(), + ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -824,12 +781,15 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); + let tools = Arc::new( + ToolRegistryFactory::new( + skills.clone(), + store.clone(), + HashSet::new(), + "Asia/Shanghai".to_string(), + ) + .build(), + ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -1603,12 +1563,15 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); + let tools = Arc::new( + ToolRegistryFactory::new( + skills.clone(), + store.clone(), + HashSet::new(), + "Asia/Shanghai".to_string(), + ) + .build(), + ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -1635,12 +1598,15 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); + let tools = Arc::new( + ToolRegistryFactory::new( + skills.clone(), + store.clone(), + HashSet::new(), + "Asia/Shanghai".to_string(), + ) + .build(), + ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -1695,12 +1661,15 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); + let tools = Arc::new( + ToolRegistryFactory::new( + skills.clone(), + store.clone(), + HashSet::new(), + "Asia/Shanghai".to_string(), + ) + .build(), + ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -1738,7 +1707,9 @@ mod tests { fn test_default_tools_registers_get_time() { let skills = Arc::new(SkillRuntime::default()); let store = Arc::new(SessionStore::in_memory().unwrap()); - let tools = default_tools(skills, store, HashSet::new(), "Asia/Shanghai".to_string()); + let tools = + ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string()) + .build(); assert!(tools.tool_names().iter().any(|name| name == "get_time")); } diff --git a/src/gateway/session_factory.rs b/src/gateway/session_factory.rs index 40a53ca..5838475 100644 --- a/src/gateway/session_factory.rs +++ b/src/gateway/session_factory.rs @@ -7,33 +7,34 @@ use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::skills::SkillRuntime; use crate::storage::SessionStore; -use crate::tools::ToolRegistry; +use super::agent_factory::AgentFactory; +use super::prompt_injector::PromptInjector; use super::session::Session; #[derive(Clone)] pub(crate) struct SessionFactory { provider_config: LLMProviderConfig, - tools: Arc, skills: Arc, + agent_factory: AgentFactory, + prompt_injector: PromptInjector, store: Arc, - agent_prompt_reinject_every: u64, } impl SessionFactory { pub(crate) fn new( provider_config: LLMProviderConfig, - tools: Arc, skills: Arc, + agent_factory: AgentFactory, + prompt_injector: PromptInjector, store: Arc, - agent_prompt_reinject_every: u64, ) -> Self { Self { provider_config, - tools, skills, + agent_factory, + prompt_injector, store, - agent_prompt_reinject_every, } } @@ -42,14 +43,14 @@ impl SessionFactory { channel_name: impl Into, user_tx: mpsc::Sender, ) -> Result { - Session::new( + Session::with_factories( channel_name.into(), self.provider_config.clone(), user_tx, - self.tools.clone(), self.skills.clone(), + self.agent_factory.clone(), + self.prompt_injector.clone(), self.store.clone(), - self.agent_prompt_reinject_every, ) .await } diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs new file mode 100644 index 0000000..fd585af --- /dev/null +++ b/src/gateway/tool_registry_factory.rs @@ -0,0 +1,59 @@ +use std::collections::HashSet; +use std::sync::Arc; + +use crate::skills::SkillRuntime; +use crate::storage::SessionStore; +use crate::tools::{ + BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, + MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool, + TimeTool, ToolRegistry, WebFetchTool, +}; + +pub(crate) struct ToolRegistryFactory { + skills: Arc, + store: Arc, + known_agents: HashSet, + default_timezone: String, +} + +impl ToolRegistryFactory { + pub(crate) fn new( + skills: Arc, + store: Arc, + known_agents: HashSet, + default_timezone: String, + ) -> Self { + Self { + skills, + store, + known_agents, + default_timezone, + } + } + + pub(crate) fn build(&self) -> ToolRegistry { + let mut registry = ToolRegistry::new(); + registry.register(CalculatorTool::new()); + registry.register(TimeTool::new(self.default_timezone.clone())); + registry.register(FileReadTool::new()); + registry.register(FileWriteTool::new()); + registry.register(FileEditTool::new()); + registry.register(MemorySearchTool::new(self.store.clone())); + registry.register(MemoryManageTool::new(self.store.clone())); + registry.register(SchedulerManageTool::new( + self.store.clone(), + self.known_agents.clone(), + )); + registry.register(SkillListTool::new(self.skills.clone())); + registry.register(SkillManageTool::new(self.skills.clone())); + registry.register(BashTool::new()); + registry.register(HttpRequestTool::new( + vec!["*".to_string()], + 1_000_000, + 30, + false, + )); + registry.register(WebFetchTool::new(50_000, 30)); + registry + } +}