feat: 添加 AgentFactory 和 PromptInjector,重构工具注册逻辑以优化会话管理

This commit is contained in:
ooodc 2026-04-28 13:06:00 +08:00
parent 008aba91ac
commit c65921b5e8
6 changed files with 324 additions and 147 deletions

View File

@ -0,0 +1,57 @@
use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop};
use crate::config::LLMProviderConfig;
use crate::skills::SkillRuntime;
use crate::storage::{SessionStore, persistent_session_id};
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Clone)]
pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
}
pub(crate) struct AgentBuildRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) sender_id: Option<&'a str>,
pub(crate) message_id: Option<&'a str>,
pub(crate) provider_config: LLMProviderConfig,
}
impl AgentFactory {
pub(crate) fn new(
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
) -> Self {
Self {
tools,
skills,
store,
}
}
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
let session_id = persistent_session_id(request.channel_name, request.chat_id);
AgentLoop::with_tools_and_skills(
request.provider_config,
self.tools.clone(),
self.skills.clone(),
)
.map(|agent| {
agent
.with_skill_event_store(self.store.clone(), session_id.clone())
.with_tool_context(ToolContext {
channel_name: Some(request.channel_name.to_string()),
sender_id: request.sender_id.map(str::to_string),
chat_id: Some(request.chat_id.to_string()),
session_id: Some(session_id),
message_id: request.message_id.map(str::to_string),
message_seq: None,
})
})
}
}

View File

@ -1,3 +1,4 @@
pub mod agent_factory;
pub mod agent_task_executor;
pub mod cli_session;
pub mod command;
@ -8,9 +9,11 @@ pub mod memory_maintenance;
pub mod message_prepare;
pub mod processor;
pub mod prompt;
pub mod prompt_injector;
pub mod session;
pub mod session_factory;
pub mod session_pool;
pub mod tool_registry_factory;
pub mod ws;
pub mod ws_adapter;

View File

@ -0,0 +1,86 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::SessionStore;
use super::prompt::load_agent_prompt;
#[derive(Clone)]
pub(crate) struct PromptInjector {
store: Arc<SessionStore>,
reinject_every: i64,
}
impl PromptInjector {
pub(crate) fn new(store: Arc<SessionStore>, reinject_every: u64) -> Self {
Self {
store,
reinject_every: reinject_every as i64,
}
}
pub(crate) fn ensure_initial_prompt<F>(
&self,
history_is_empty: bool,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
}
Ok(())
}
pub(crate) fn ensure_reinjected_prompt<F>(
&self,
session_id: &str,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
let session_record = self
.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns =
self.store
.count_active_user_messages(session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
})?;
if self.reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
self.store
.mark_agent_prompt_reinjected(session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
})?;
}
}
Ok(())
}
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
)
}
}

View File

@ -1,22 +1,19 @@
use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
#[cfg(test)]
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
use crate::bus::{ChatMessage, MessageBus, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool,
TimeTool, ToolContext, ToolRegistry, WebFetchTool,
};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use super::agent_factory::{AgentBuildRequest, AgentFactory};
use super::cli_session::CliSessionService;
use super::execution::{
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
@ -31,9 +28,10 @@ use super::memory_maintenance::{
use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
};
use super::prompt::load_agent_prompt;
use super::prompt_injector::PromptInjector;
use super::session_factory::SessionFactory;
use super::session_pool::SessionPool;
use super::tool_registry_factory::ToolRegistryFactory;
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
@ -53,11 +51,11 @@ pub struct Session {
compression_in_flight: HashSet<String>,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
compressor: ContextCompressor,
store: Arc<SessionStore>,
agent_prompt_reinject_every: i64,
}
pub struct BusToolCallEmitter {
@ -125,6 +123,29 @@ impl Session {
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> {
let agent_factory = AgentFactory::new(tools, skills.clone(), store.clone());
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
Self::with_factories(
channel_name,
provider_config,
user_tx,
skills,
agent_factory,
prompt_injector,
store,
)
.await
}
pub(crate) async fn with_factories(
channel_name: String,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
) -> Result<Self, AgentError> {
Ok(Self {
id: Uuid::new_v4(),
@ -133,11 +154,11 @@ impl Session {
compression_in_flight: HashSet::new(),
user_tx,
provider_config: provider_config.clone(),
tools,
skills,
agent_factory,
prompt_injector,
compressor: ContextCompressor::from_provider_config(&provider_config),
store,
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
})
}
@ -172,40 +193,10 @@ impl Session {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let session_record = self
.store
.get_session(&session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns =
self.store
.count_active_user_messages(&session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
})?;
if self.agent_prompt_reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.agent_prompt_reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(
chat_id,
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
),
)?;
self.store
.mark_agent_prompt_reinjected(&session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
})?;
}
}
Ok(())
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
self.append_persisted_message(chat_id, message)
})
}
/// 获取或创建指定 chat_id 的会话历史
@ -437,20 +428,13 @@ impl Session {
message_id: Option<&str>,
provider_config: LLMProviderConfig,
) -> Result<AgentLoop, AgentError> {
let session_id = self.persistent_session_id(chat_id);
AgentLoop::with_tools_and_skills(provider_config, self.tools.clone(), self.skills.clone())
.map(|agent| {
agent
.with_skill_event_store(self.store.clone(), session_id.clone())
.with_tool_context(ToolContext {
channel_name: Some(self.channel_name.clone()),
sender_id: sender_id.map(str::to_string),
chat_id: Some(chat_id.to_string()),
session_id: Some(session_id),
message_id: message_id.map(str::to_string),
message_seq: None,
})
})
self.agent_factory.create(AgentBuildRequest {
channel_name: &self.channel_name,
chat_id,
sender_id,
message_id,
provider_config,
})
}
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
@ -463,17 +447,10 @@ impl Session {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(
chat_id,
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
),
)?;
}
Ok(())
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
self.append_persisted_message(chat_id, message)
})
}
}
@ -490,34 +467,6 @@ pub struct SessionManager {
cli_sessions: CliSessionService,
}
fn default_tools(
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(TimeTool::new(default_timezone));
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(MemorySearchTool::new(store.clone()));
registry.register(MemoryManageTool::new(store.clone()));
registry.register(SchedulerManageTool::new(store, known_agents));
registry.register(SkillListTool::new(skills.clone()));
registry.register(SkillManageTool::new(skills));
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
1_000_000, // max_response_size
30, // timeout_secs
false, // allow_private_hosts
));
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
registry
}
impl SessionManager {
pub fn new(
session_ttl_hours: u64,
@ -540,18 +489,23 @@ impl SessionManager {
tracing::warn!(error = %err, "Failed to record skill discovery event");
}
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
known_agents,
default_timezone,
));
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
known_agents,
default_timezone,
)
.build(),
);
let agent_factory = AgentFactory::new(tools.clone(), skills.clone(), store.clone());
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
let session_factory = SessionFactory::new(
provider_config.clone(),
tools.clone(),
skills.clone(),
agent_factory,
prompt_injector,
store.clone(),
agent_prompt_reinject_every,
);
let session_pool = SessionPool::new(session_ttl_hours, session_factory);
let cli_sessions = CliSessionService::new(store.clone());
@ -785,12 +739,15 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
));
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
.build(),
);
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -824,12 +781,15 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
));
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
.build(),
);
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -1603,12 +1563,15 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
));
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
.build(),
);
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -1635,12 +1598,15 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
));
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
.build(),
);
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -1695,12 +1661,15 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
));
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
HashSet::new(),
"Asia/Shanghai".to_string(),
)
.build(),
);
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -1738,7 +1707,9 @@ mod tests {
fn test_default_tools_registers_get_time() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
let tools = default_tools(skills, store, HashSet::new(), "Asia/Shanghai".to_string());
let tools =
ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string())
.build();
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
}

View File

@ -7,33 +7,34 @@ use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use super::agent_factory::AgentFactory;
use super::prompt_injector::PromptInjector;
use super::session::Session;
#[derive(Clone)]
pub(crate) struct SessionFactory {
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
}
impl SessionFactory {
pub(crate) fn new(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
) -> Self {
Self {
provider_config,
tools,
skills,
agent_factory,
prompt_injector,
store,
agent_prompt_reinject_every,
}
}
@ -42,14 +43,14 @@ impl SessionFactory {
channel_name: impl Into<String>,
user_tx: mpsc::Sender<WsOutbound>,
) -> Result<Session, AgentError> {
Session::new(
Session::with_factories(
channel_name.into(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.skills.clone(),
self.agent_factory.clone(),
self.prompt_injector.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await
}

View File

@ -0,0 +1,59 @@
use std::collections::HashSet;
use std::sync::Arc;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool,
TimeTool, ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
) -> Self {
Self {
skills,
store,
known_agents,
default_timezone,
}
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(TimeTool::new(self.default_timezone.clone()));
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(MemorySearchTool::new(self.store.clone()));
registry.register(MemoryManageTool::new(self.store.clone()));
registry.register(SchedulerManageTool::new(
self.store.clone(),
self.known_agents.clone(),
));
registry.register(SkillListTool::new(self.skills.clone()));
registry.register(SkillManageTool::new(self.skills.clone()));
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry
}
}