feat: 添加 AgentFactory 和 PromptInjector,重构工具注册逻辑以优化会话管理
This commit is contained in:
parent
008aba91ac
commit
c65921b5e8
57
src/gateway/agent_factory.rs
Normal file
57
src/gateway/agent_factory.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::{AgentError, AgentLoop};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionStore, persistent_session_id};
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct AgentFactory {
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
}
|
||||
|
||||
pub(crate) struct AgentBuildRequest<'a> {
|
||||
pub(crate) channel_name: &'a str,
|
||||
pub(crate) chat_id: &'a str,
|
||||
pub(crate) sender_id: Option<&'a str>,
|
||||
pub(crate) message_id: Option<&'a str>,
|
||||
pub(crate) provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
impl AgentFactory {
|
||||
pub(crate) fn new(
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tools,
|
||||
skills,
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
|
||||
let session_id = persistent_session_id(request.channel_name, request.chat_id);
|
||||
AgentLoop::with_tools_and_skills(
|
||||
request.provider_config,
|
||||
self.tools.clone(),
|
||||
self.skills.clone(),
|
||||
)
|
||||
.map(|agent| {
|
||||
agent
|
||||
.with_skill_event_store(self.store.clone(), session_id.clone())
|
||||
.with_tool_context(ToolContext {
|
||||
channel_name: Some(request.channel_name.to_string()),
|
||||
sender_id: request.sender_id.map(str::to_string),
|
||||
chat_id: Some(request.chat_id.to_string()),
|
||||
session_id: Some(session_id),
|
||||
message_id: request.message_id.map(str::to_string),
|
||||
message_seq: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
pub mod agent_factory;
|
||||
pub mod agent_task_executor;
|
||||
pub mod cli_session;
|
||||
pub mod command;
|
||||
@ -8,9 +9,11 @@ pub mod memory_maintenance;
|
||||
pub mod message_prepare;
|
||||
pub mod processor;
|
||||
pub mod prompt;
|
||||
pub mod prompt_injector;
|
||||
pub mod session;
|
||||
pub mod session_factory;
|
||||
pub mod session_pool;
|
||||
pub mod tool_registry_factory;
|
||||
pub mod ws;
|
||||
pub mod ws_adapter;
|
||||
|
||||
|
||||
86
src/gateway/prompt_injector.rs
Normal file
86
src/gateway/prompt_injector.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
|
||||
use crate::storage::SessionStore;
|
||||
|
||||
use super::prompt::load_agent_prompt;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct PromptInjector {
|
||||
store: Arc<SessionStore>,
|
||||
reinject_every: i64,
|
||||
}
|
||||
|
||||
impl PromptInjector {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, reinject_every: u64) -> Self {
|
||||
Self {
|
||||
store,
|
||||
reinject_every: reinject_every as i64,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_initial_prompt<F>(
|
||||
&self,
|
||||
history_is_empty: bool,
|
||||
mut append_message: F,
|
||||
) -> Result<(), AgentError>
|
||||
where
|
||||
F: FnMut(ChatMessage) -> Result<(), AgentError>,
|
||||
{
|
||||
if !history_is_empty {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
append_message(Self::agent_prompt_message(agent_prompt))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_reinjected_prompt<F>(
|
||||
&self,
|
||||
session_id: &str,
|
||||
mut append_message: F,
|
||||
) -> Result<(), AgentError>
|
||||
where
|
||||
F: FnMut(ChatMessage) -> Result<(), AgentError>,
|
||||
{
|
||||
let session_record = self
|
||||
.store
|
||||
.get_session(session_id)
|
||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
let active_user_turns =
|
||||
self.store
|
||||
.count_active_user_messages(session_id)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("count active user messages error: {}", err))
|
||||
})?;
|
||||
|
||||
if self.reinject_every > 0
|
||||
&& active_user_turns > 0
|
||||
&& active_user_turns / self.reinject_every
|
||||
> session_record.agent_prompt_reinjection_count
|
||||
{
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
append_message(Self::agent_prompt_message(agent_prompt))?;
|
||||
self.store
|
||||
.mark_agent_prompt_reinjected(session_id)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
|
||||
ChatMessage::system_with_context(
|
||||
agent_prompt,
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -1,22 +1,19 @@
|
||||
use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
|
||||
#[cfg(test)]
|
||||
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
|
||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
|
||||
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool,
|
||||
TimeTool, ToolContext, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
use crate::tools::ToolRegistry;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::agent_factory::{AgentBuildRequest, AgentFactory};
|
||||
use super::cli_session::CliSessionService;
|
||||
use super::execution::{
|
||||
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
|
||||
@ -31,9 +28,10 @@ use super::memory_maintenance::{
|
||||
use super::memory_maintenance::{
|
||||
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
||||
};
|
||||
use super::prompt::load_agent_prompt;
|
||||
use super::prompt_injector::PromptInjector;
|
||||
use super::session_factory::SessionFactory;
|
||||
use super::session_pool::SessionPool;
|
||||
use super::tool_registry_factory::ToolRegistryFactory;
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||
@ -53,11 +51,11 @@ pub struct Session {
|
||||
compression_in_flight: HashSet<String>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
compressor: ContextCompressor,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: i64,
|
||||
}
|
||||
|
||||
pub struct BusToolCallEmitter {
|
||||
@ -125,6 +123,29 @@ impl Session {
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: u64,
|
||||
) -> Result<Self, AgentError> {
|
||||
let agent_factory = AgentFactory::new(tools, skills.clone(), store.clone());
|
||||
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||||
Self::with_factories(
|
||||
channel_name,
|
||||
provider_config,
|
||||
user_tx,
|
||||
skills,
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
store,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn with_factories(
|
||||
channel_name: String,
|
||||
provider_config: LLMProviderConfig,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
@ -133,11 +154,11 @@ impl Session {
|
||||
compression_in_flight: HashSet::new(),
|
||||
user_tx,
|
||||
provider_config: provider_config.clone(),
|
||||
tools,
|
||||
skills,
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
compressor: ContextCompressor::from_provider_config(&provider_config),
|
||||
store,
|
||||
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
|
||||
})
|
||||
}
|
||||
|
||||
@ -172,40 +193,10 @@ impl Session {
|
||||
self.ensure_chat_loaded(chat_id)?;
|
||||
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
let session_record = self
|
||||
.store
|
||||
.get_session(&session_id)
|
||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
let active_user_turns =
|
||||
self.store
|
||||
.count_active_user_messages(&session_id)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("count active user messages error: {}", err))
|
||||
})?;
|
||||
|
||||
if self.agent_prompt_reinject_every > 0
|
||||
&& active_user_turns > 0
|
||||
&& active_user_turns / self.agent_prompt_reinject_every
|
||||
> session_record.agent_prompt_reinjection_count
|
||||
{
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
self.append_persisted_message(
|
||||
chat_id,
|
||||
ChatMessage::system_with_context(
|
||||
agent_prompt,
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
),
|
||||
)?;
|
||||
self.store
|
||||
.mark_agent_prompt_reinjected(&session_id)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
let prompt_injector = self.prompt_injector.clone();
|
||||
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
|
||||
self.append_persisted_message(chat_id, message)
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取或创建指定 chat_id 的会话历史
|
||||
@ -437,20 +428,13 @@ impl Session {
|
||||
message_id: Option<&str>,
|
||||
provider_config: LLMProviderConfig,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
AgentLoop::with_tools_and_skills(provider_config, self.tools.clone(), self.skills.clone())
|
||||
.map(|agent| {
|
||||
agent
|
||||
.with_skill_event_store(self.store.clone(), session_id.clone())
|
||||
.with_tool_context(ToolContext {
|
||||
channel_name: Some(self.channel_name.clone()),
|
||||
sender_id: sender_id.map(str::to_string),
|
||||
chat_id: Some(chat_id.to_string()),
|
||||
session_id: Some(session_id),
|
||||
message_id: message_id.map(str::to_string),
|
||||
message_seq: None,
|
||||
})
|
||||
})
|
||||
self.agent_factory.create(AgentBuildRequest {
|
||||
channel_name: &self.channel_name,
|
||||
chat_id,
|
||||
sender_id,
|
||||
message_id,
|
||||
provider_config,
|
||||
})
|
||||
}
|
||||
|
||||
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
@ -463,17 +447,10 @@ impl Session {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
self.append_persisted_message(
|
||||
chat_id,
|
||||
ChatMessage::system_with_context(
|
||||
agent_prompt,
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
let prompt_injector = self.prompt_injector.clone();
|
||||
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
|
||||
self.append_persisted_message(chat_id, message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -490,34 +467,6 @@ pub struct SessionManager {
|
||||
cli_sessions: CliSessionService,
|
||||
}
|
||||
|
||||
fn default_tools(
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(TimeTool::new(default_timezone));
|
||||
registry.register(FileReadTool::new());
|
||||
registry.register(FileWriteTool::new());
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(MemorySearchTool::new(store.clone()));
|
||||
registry.register(MemoryManageTool::new(store.clone()));
|
||||
registry.register(SchedulerManageTool::new(store, known_agents));
|
||||
registry.register(SkillListTool::new(skills.clone()));
|
||||
registry.register(SkillManageTool::new(skills));
|
||||
registry.register(BashTool::new());
|
||||
registry.register(HttpRequestTool::new(
|
||||
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
|
||||
1_000_000, // max_response_size
|
||||
30, // timeout_secs
|
||||
false, // allow_private_hosts
|
||||
));
|
||||
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
|
||||
registry
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(
|
||||
session_ttl_hours: u64,
|
||||
@ -540,18 +489,23 @@ impl SessionManager {
|
||||
tracing::warn!(error = %err, "Failed to record skill discovery event");
|
||||
}
|
||||
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
known_agents,
|
||||
default_timezone,
|
||||
));
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
known_agents,
|
||||
default_timezone,
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
let agent_factory = AgentFactory::new(tools.clone(), skills.clone(), store.clone());
|
||||
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||||
let session_factory = SessionFactory::new(
|
||||
provider_config.clone(),
|
||||
tools.clone(),
|
||||
skills.clone(),
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
store.clone(),
|
||||
agent_prompt_reinject_every,
|
||||
);
|
||||
let session_pool = SessionPool::new(session_ttl_hours, session_factory);
|
||||
let cli_sessions = CliSessionService::new(store.clone());
|
||||
@ -785,12 +739,15 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -824,12 +781,15 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -1603,12 +1563,15 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -1635,12 +1598,15 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -1695,12 +1661,15 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -1738,7 +1707,9 @@ mod tests {
|
||||
fn test_default_tools_registers_get_time() {
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tools = default_tools(skills, store, HashSet::new(), "Asia/Shanghai".to_string());
|
||||
let tools =
|
||||
ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string())
|
||||
.build();
|
||||
|
||||
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
||||
}
|
||||
|
||||
@ -7,33 +7,34 @@ use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
use super::agent_factory::AgentFactory;
|
||||
use super::prompt_injector::PromptInjector;
|
||||
use super::session::Session;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionFactory {
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: u64,
|
||||
}
|
||||
|
||||
impl SessionFactory {
|
||||
pub(crate) fn new(
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_config,
|
||||
tools,
|
||||
skills,
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
store,
|
||||
agent_prompt_reinject_every,
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,14 +43,14 @@ impl SessionFactory {
|
||||
channel_name: impl Into<String>,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
) -> Result<Session, AgentError> {
|
||||
Session::new(
|
||||
Session::with_factories(
|
||||
channel_name.into(),
|
||||
self.provider_config.clone(),
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
self.skills.clone(),
|
||||
self.agent_factory.clone(),
|
||||
self.prompt_injector.clone(),
|
||||
self.store.clone(),
|
||||
self.agent_prompt_reinject_every,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
59
src/gateway/tool_registry_factory.rs
Normal file
59
src/gateway/tool_registry_factory.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
|
||||
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool,
|
||||
TimeTool, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
|
||||
pub(crate) struct ToolRegistryFactory {
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
}
|
||||
|
||||
impl ToolRegistryFactory {
|
||||
pub(crate) fn new(
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
skills,
|
||||
store,
|
||||
known_agents,
|
||||
default_timezone,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build(&self) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||||
registry.register(FileReadTool::new());
|
||||
registry.register(FileWriteTool::new());
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(MemorySearchTool::new(self.store.clone()));
|
||||
registry.register(MemoryManageTool::new(self.store.clone()));
|
||||
registry.register(SchedulerManageTool::new(
|
||||
self.store.clone(),
|
||||
self.known_agents.clone(),
|
||||
));
|
||||
registry.register(SkillListTool::new(self.skills.clone()));
|
||||
registry.register(SkillManageTool::new(self.skills.clone()));
|
||||
registry.register(BashTool::new());
|
||||
registry.register(HttpRequestTool::new(
|
||||
vec!["*".to_string()],
|
||||
1_000_000,
|
||||
30,
|
||||
false,
|
||||
));
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
registry
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user