feat: 添加 AgentFactory 和 PromptInjector,重构工具注册逻辑以优化会话管理
This commit is contained in:
parent
008aba91ac
commit
c65921b5e8
57
src/gateway/agent_factory.rs
Normal file
57
src/gateway/agent_factory.rs
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::agent::{AgentError, AgentLoop};
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use crate::skills::SkillRuntime;
|
||||||
|
use crate::storage::{SessionStore, persistent_session_id};
|
||||||
|
use crate::tools::{ToolContext, ToolRegistry};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct AgentFactory {
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
skills: Arc<SkillRuntime>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct AgentBuildRequest<'a> {
|
||||||
|
pub(crate) channel_name: &'a str,
|
||||||
|
pub(crate) chat_id: &'a str,
|
||||||
|
pub(crate) sender_id: Option<&'a str>,
|
||||||
|
pub(crate) message_id: Option<&'a str>,
|
||||||
|
pub(crate) provider_config: LLMProviderConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentFactory {
|
||||||
|
pub(crate) fn new(
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
skills: Arc<SkillRuntime>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
tools,
|
||||||
|
skills,
|
||||||
|
store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
|
||||||
|
let session_id = persistent_session_id(request.channel_name, request.chat_id);
|
||||||
|
AgentLoop::with_tools_and_skills(
|
||||||
|
request.provider_config,
|
||||||
|
self.tools.clone(),
|
||||||
|
self.skills.clone(),
|
||||||
|
)
|
||||||
|
.map(|agent| {
|
||||||
|
agent
|
||||||
|
.with_skill_event_store(self.store.clone(), session_id.clone())
|
||||||
|
.with_tool_context(ToolContext {
|
||||||
|
channel_name: Some(request.channel_name.to_string()),
|
||||||
|
sender_id: request.sender_id.map(str::to_string),
|
||||||
|
chat_id: Some(request.chat_id.to_string()),
|
||||||
|
session_id: Some(session_id),
|
||||||
|
message_id: request.message_id.map(str::to_string),
|
||||||
|
message_seq: None,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
pub mod agent_factory;
|
||||||
pub mod agent_task_executor;
|
pub mod agent_task_executor;
|
||||||
pub mod cli_session;
|
pub mod cli_session;
|
||||||
pub mod command;
|
pub mod command;
|
||||||
@ -8,9 +9,11 @@ pub mod memory_maintenance;
|
|||||||
pub mod message_prepare;
|
pub mod message_prepare;
|
||||||
pub mod processor;
|
pub mod processor;
|
||||||
pub mod prompt;
|
pub mod prompt;
|
||||||
|
pub mod prompt_injector;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod session_factory;
|
pub mod session_factory;
|
||||||
pub mod session_pool;
|
pub mod session_pool;
|
||||||
|
pub mod tool_registry_factory;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
pub mod ws_adapter;
|
pub mod ws_adapter;
|
||||||
|
|
||||||
|
|||||||
86
src/gateway/prompt_injector.rs
Normal file
86
src/gateway/prompt_injector.rs
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::agent::AgentError;
|
||||||
|
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
|
||||||
|
use crate::storage::SessionStore;
|
||||||
|
|
||||||
|
use super::prompt::load_agent_prompt;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct PromptInjector {
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
reinject_every: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptInjector {
|
||||||
|
pub(crate) fn new(store: Arc<SessionStore>, reinject_every: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
reinject_every: reinject_every as i64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ensure_initial_prompt<F>(
|
||||||
|
&self,
|
||||||
|
history_is_empty: bool,
|
||||||
|
mut append_message: F,
|
||||||
|
) -> Result<(), AgentError>
|
||||||
|
where
|
||||||
|
F: FnMut(ChatMessage) -> Result<(), AgentError>,
|
||||||
|
{
|
||||||
|
if !history_is_empty {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||||
|
append_message(Self::agent_prompt_message(agent_prompt))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ensure_reinjected_prompt<F>(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
mut append_message: F,
|
||||||
|
) -> Result<(), AgentError>
|
||||||
|
where
|
||||||
|
F: FnMut(ChatMessage) -> Result<(), AgentError>,
|
||||||
|
{
|
||||||
|
let session_record = self
|
||||||
|
.store
|
||||||
|
.get_session(session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||||
|
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||||
|
let active_user_turns =
|
||||||
|
self.store
|
||||||
|
.count_active_user_messages(session_id)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("count active user messages error: {}", err))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if self.reinject_every > 0
|
||||||
|
&& active_user_turns > 0
|
||||||
|
&& active_user_turns / self.reinject_every
|
||||||
|
> session_record.agent_prompt_reinjection_count
|
||||||
|
{
|
||||||
|
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||||
|
append_message(Self::agent_prompt_message(agent_prompt))?;
|
||||||
|
self.store
|
||||||
|
.mark_agent_prompt_reinjected(session_id)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
|
||||||
|
ChatMessage::system_with_context(
|
||||||
|
agent_prompt,
|
||||||
|
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,22 +1,19 @@
|
|||||||
use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
|
use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
||||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
|
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||||
use crate::tools::{
|
use crate::tools::ToolRegistry;
|
||||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
|
|
||||||
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool,
|
|
||||||
TimeTool, ToolContext, ToolRegistry, WebFetchTool,
|
|
||||||
};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::agent_factory::{AgentBuildRequest, AgentFactory};
|
||||||
use super::cli_session::CliSessionService;
|
use super::cli_session::CliSessionService;
|
||||||
use super::execution::{
|
use super::execution::{
|
||||||
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
|
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
|
||||||
@ -31,9 +28,10 @@ use super::memory_maintenance::{
|
|||||||
use super::memory_maintenance::{
|
use super::memory_maintenance::{
|
||||||
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
||||||
};
|
};
|
||||||
use super::prompt::load_agent_prompt;
|
use super::prompt_injector::PromptInjector;
|
||||||
use super::session_factory::SessionFactory;
|
use super::session_factory::SessionFactory;
|
||||||
use super::session_pool::SessionPool;
|
use super::session_pool::SessionPool;
|
||||||
|
use super::tool_registry_factory::ToolRegistryFactory;
|
||||||
|
|
||||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||||
@ -53,11 +51,11 @@ pub struct Session {
|
|||||||
compression_in_flight: HashSet<String>,
|
compression_in_flight: HashSet<String>,
|
||||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
|
agent_factory: AgentFactory,
|
||||||
|
prompt_injector: PromptInjector,
|
||||||
compressor: ContextCompressor,
|
compressor: ContextCompressor,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: i64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct BusToolCallEmitter {
|
pub struct BusToolCallEmitter {
|
||||||
@ -125,6 +123,29 @@ impl Session {
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
|
) -> Result<Self, AgentError> {
|
||||||
|
let agent_factory = AgentFactory::new(tools, skills.clone(), store.clone());
|
||||||
|
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||||||
|
Self::with_factories(
|
||||||
|
channel_name,
|
||||||
|
provider_config,
|
||||||
|
user_tx,
|
||||||
|
skills,
|
||||||
|
agent_factory,
|
||||||
|
prompt_injector,
|
||||||
|
store,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn with_factories(
|
||||||
|
channel_name: String,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
user_tx: mpsc::Sender<WsOutbound>,
|
||||||
|
skills: Arc<SkillRuntime>,
|
||||||
|
agent_factory: AgentFactory,
|
||||||
|
prompt_injector: PromptInjector,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
@ -133,11 +154,11 @@ impl Session {
|
|||||||
compression_in_flight: HashSet::new(),
|
compression_in_flight: HashSet::new(),
|
||||||
user_tx,
|
user_tx,
|
||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
tools,
|
|
||||||
skills,
|
skills,
|
||||||
|
agent_factory,
|
||||||
|
prompt_injector,
|
||||||
compressor: ContextCompressor::from_provider_config(&provider_config),
|
compressor: ContextCompressor::from_provider_config(&provider_config),
|
||||||
store,
|
store,
|
||||||
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,40 +193,10 @@ impl Session {
|
|||||||
self.ensure_chat_loaded(chat_id)?;
|
self.ensure_chat_loaded(chat_id)?;
|
||||||
|
|
||||||
let session_id = self.persistent_session_id(chat_id);
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
let session_record = self
|
let prompt_injector = self.prompt_injector.clone();
|
||||||
.store
|
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
|
||||||
.get_session(&session_id)
|
self.append_persisted_message(chat_id, message)
|
||||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
})
|
||||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
|
||||||
let active_user_turns =
|
|
||||||
self.store
|
|
||||||
.count_active_user_messages(&session_id)
|
|
||||||
.map_err(|err| {
|
|
||||||
AgentError::Other(format!("count active user messages error: {}", err))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if self.agent_prompt_reinject_every > 0
|
|
||||||
&& active_user_turns > 0
|
|
||||||
&& active_user_turns / self.agent_prompt_reinject_every
|
|
||||||
> session_record.agent_prompt_reinjection_count
|
|
||||||
{
|
|
||||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
|
||||||
self.append_persisted_message(
|
|
||||||
chat_id,
|
|
||||||
ChatMessage::system_with_context(
|
|
||||||
agent_prompt,
|
|
||||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
|
||||||
),
|
|
||||||
)?;
|
|
||||||
self.store
|
|
||||||
.mark_agent_prompt_reinjected(&session_id)
|
|
||||||
.map_err(|err| {
|
|
||||||
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取或创建指定 chat_id 的会话历史
|
/// 获取或创建指定 chat_id 的会话历史
|
||||||
@ -437,20 +428,13 @@ impl Session {
|
|||||||
message_id: Option<&str>,
|
message_id: Option<&str>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
) -> Result<AgentLoop, AgentError> {
|
) -> Result<AgentLoop, AgentError> {
|
||||||
let session_id = self.persistent_session_id(chat_id);
|
self.agent_factory.create(AgentBuildRequest {
|
||||||
AgentLoop::with_tools_and_skills(provider_config, self.tools.clone(), self.skills.clone())
|
channel_name: &self.channel_name,
|
||||||
.map(|agent| {
|
chat_id,
|
||||||
agent
|
sender_id,
|
||||||
.with_skill_event_store(self.store.clone(), session_id.clone())
|
message_id,
|
||||||
.with_tool_context(ToolContext {
|
provider_config,
|
||||||
channel_name: Some(self.channel_name.clone()),
|
})
|
||||||
sender_id: sender_id.map(str::to_string),
|
|
||||||
chat_id: Some(chat_id.to_string()),
|
|
||||||
session_id: Some(session_id),
|
|
||||||
message_id: message_id.map(str::to_string),
|
|
||||||
message_seq: None,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
@ -463,17 +447,10 @@ impl Session {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
let prompt_injector = self.prompt_injector.clone();
|
||||||
self.append_persisted_message(
|
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
|
||||||
chat_id,
|
self.append_persisted_message(chat_id, message)
|
||||||
ChatMessage::system_with_context(
|
})
|
||||||
agent_prompt,
|
|
||||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
|
||||||
),
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,34 +467,6 @@ pub struct SessionManager {
|
|||||||
cli_sessions: CliSessionService,
|
cli_sessions: CliSessionService,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tools(
|
|
||||||
skills: Arc<SkillRuntime>,
|
|
||||||
store: Arc<SessionStore>,
|
|
||||||
known_agents: HashSet<String>,
|
|
||||||
default_timezone: String,
|
|
||||||
) -> ToolRegistry {
|
|
||||||
let mut registry = ToolRegistry::new();
|
|
||||||
registry.register(CalculatorTool::new());
|
|
||||||
registry.register(TimeTool::new(default_timezone));
|
|
||||||
registry.register(FileReadTool::new());
|
|
||||||
registry.register(FileWriteTool::new());
|
|
||||||
registry.register(FileEditTool::new());
|
|
||||||
registry.register(MemorySearchTool::new(store.clone()));
|
|
||||||
registry.register(MemoryManageTool::new(store.clone()));
|
|
||||||
registry.register(SchedulerManageTool::new(store, known_agents));
|
|
||||||
registry.register(SkillListTool::new(skills.clone()));
|
|
||||||
registry.register(SkillManageTool::new(skills));
|
|
||||||
registry.register(BashTool::new());
|
|
||||||
registry.register(HttpRequestTool::new(
|
|
||||||
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
|
|
||||||
1_000_000, // max_response_size
|
|
||||||
30, // timeout_secs
|
|
||||||
false, // allow_private_hosts
|
|
||||||
));
|
|
||||||
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
|
|
||||||
registry
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
session_ttl_hours: u64,
|
session_ttl_hours: u64,
|
||||||
@ -540,18 +489,23 @@ impl SessionManager {
|
|||||||
tracing::warn!(error = %err, "Failed to record skill discovery event");
|
tracing::warn!(error = %err, "Failed to record skill discovery event");
|
||||||
}
|
}
|
||||||
|
|
||||||
let tools = Arc::new(default_tools(
|
let tools = Arc::new(
|
||||||
skills.clone(),
|
ToolRegistryFactory::new(
|
||||||
store.clone(),
|
skills.clone(),
|
||||||
known_agents,
|
store.clone(),
|
||||||
default_timezone,
|
known_agents,
|
||||||
));
|
default_timezone,
|
||||||
|
)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
|
let agent_factory = AgentFactory::new(tools.clone(), skills.clone(), store.clone());
|
||||||
|
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||||||
let session_factory = SessionFactory::new(
|
let session_factory = SessionFactory::new(
|
||||||
provider_config.clone(),
|
provider_config.clone(),
|
||||||
tools.clone(),
|
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
|
agent_factory,
|
||||||
|
prompt_injector,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
agent_prompt_reinject_every,
|
|
||||||
);
|
);
|
||||||
let session_pool = SessionPool::new(session_ttl_hours, session_factory);
|
let session_pool = SessionPool::new(session_ttl_hours, session_factory);
|
||||||
let cli_sessions = CliSessionService::new(store.clone());
|
let cli_sessions = CliSessionService::new(store.clone());
|
||||||
@ -785,12 +739,15 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(
|
let tools = Arc::new(
|
||||||
skills.clone(),
|
ToolRegistryFactory::new(
|
||||||
store.clone(),
|
skills.clone(),
|
||||||
HashSet::new(),
|
store.clone(),
|
||||||
"Asia/Shanghai".to_string(),
|
HashSet::new(),
|
||||||
));
|
"Asia/Shanghai".to_string(),
|
||||||
|
)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -824,12 +781,15 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(
|
let tools = Arc::new(
|
||||||
skills.clone(),
|
ToolRegistryFactory::new(
|
||||||
store.clone(),
|
skills.clone(),
|
||||||
HashSet::new(),
|
store.clone(),
|
||||||
"Asia/Shanghai".to_string(),
|
HashSet::new(),
|
||||||
));
|
"Asia/Shanghai".to_string(),
|
||||||
|
)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -1603,12 +1563,15 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(
|
let tools = Arc::new(
|
||||||
skills.clone(),
|
ToolRegistryFactory::new(
|
||||||
store.clone(),
|
skills.clone(),
|
||||||
HashSet::new(),
|
store.clone(),
|
||||||
"Asia/Shanghai".to_string(),
|
HashSet::new(),
|
||||||
));
|
"Asia/Shanghai".to_string(),
|
||||||
|
)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -1635,12 +1598,15 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(
|
let tools = Arc::new(
|
||||||
skills.clone(),
|
ToolRegistryFactory::new(
|
||||||
store.clone(),
|
skills.clone(),
|
||||||
HashSet::new(),
|
store.clone(),
|
||||||
"Asia/Shanghai".to_string(),
|
HashSet::new(),
|
||||||
));
|
"Asia/Shanghai".to_string(),
|
||||||
|
)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -1695,12 +1661,15 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(
|
let tools = Arc::new(
|
||||||
skills.clone(),
|
ToolRegistryFactory::new(
|
||||||
store.clone(),
|
skills.clone(),
|
||||||
HashSet::new(),
|
store.clone(),
|
||||||
"Asia/Shanghai".to_string(),
|
HashSet::new(),
|
||||||
));
|
"Asia/Shanghai".to_string(),
|
||||||
|
)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -1738,7 +1707,9 @@ mod tests {
|
|||||||
fn test_default_tools_registers_get_time() {
|
fn test_default_tools_registers_get_time() {
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tools = default_tools(skills, store, HashSet::new(), "Asia/Shanghai".to_string());
|
let tools =
|
||||||
|
ToolRegistryFactory::new(skills, store, HashSet::new(), "Asia/Shanghai".to_string())
|
||||||
|
.build();
|
||||||
|
|
||||||
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,33 +7,34 @@ use crate::config::LLMProviderConfig;
|
|||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
use crate::tools::ToolRegistry;
|
|
||||||
|
|
||||||
|
use super::agent_factory::AgentFactory;
|
||||||
|
use super::prompt_injector::PromptInjector;
|
||||||
use super::session::Session;
|
use super::session::Session;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct SessionFactory {
|
pub(crate) struct SessionFactory {
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
|
agent_factory: AgentFactory,
|
||||||
|
prompt_injector: PromptInjector,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: u64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionFactory {
|
impl SessionFactory {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
|
agent_factory: AgentFactory,
|
||||||
|
prompt_injector: PromptInjector,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: u64,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
provider_config,
|
provider_config,
|
||||||
tools,
|
|
||||||
skills,
|
skills,
|
||||||
|
agent_factory,
|
||||||
|
prompt_injector,
|
||||||
store,
|
store,
|
||||||
agent_prompt_reinject_every,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,14 +43,14 @@ impl SessionFactory {
|
|||||||
channel_name: impl Into<String>,
|
channel_name: impl Into<String>,
|
||||||
user_tx: mpsc::Sender<WsOutbound>,
|
user_tx: mpsc::Sender<WsOutbound>,
|
||||||
) -> Result<Session, AgentError> {
|
) -> Result<Session, AgentError> {
|
||||||
Session::new(
|
Session::with_factories(
|
||||||
channel_name.into(),
|
channel_name.into(),
|
||||||
self.provider_config.clone(),
|
self.provider_config.clone(),
|
||||||
user_tx,
|
user_tx,
|
||||||
self.tools.clone(),
|
|
||||||
self.skills.clone(),
|
self.skills.clone(),
|
||||||
|
self.agent_factory.clone(),
|
||||||
|
self.prompt_injector.clone(),
|
||||||
self.store.clone(),
|
self.store.clone(),
|
||||||
self.agent_prompt_reinject_every,
|
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|||||||
59
src/gateway/tool_registry_factory.rs
Normal file
59
src/gateway/tool_registry_factory.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::skills::SkillRuntime;
|
||||||
|
use crate::storage::SessionStore;
|
||||||
|
use crate::tools::{
|
||||||
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
|
||||||
|
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool,
|
||||||
|
TimeTool, ToolRegistry, WebFetchTool,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub(crate) struct ToolRegistryFactory {
|
||||||
|
skills: Arc<SkillRuntime>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
known_agents: HashSet<String>,
|
||||||
|
default_timezone: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistryFactory {
|
||||||
|
pub(crate) fn new(
|
||||||
|
skills: Arc<SkillRuntime>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
known_agents: HashSet<String>,
|
||||||
|
default_timezone: String,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
skills,
|
||||||
|
store,
|
||||||
|
known_agents,
|
||||||
|
default_timezone,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build(&self) -> ToolRegistry {
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(CalculatorTool::new());
|
||||||
|
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||||||
|
registry.register(FileReadTool::new());
|
||||||
|
registry.register(FileWriteTool::new());
|
||||||
|
registry.register(FileEditTool::new());
|
||||||
|
registry.register(MemorySearchTool::new(self.store.clone()));
|
||||||
|
registry.register(MemoryManageTool::new(self.store.clone()));
|
||||||
|
registry.register(SchedulerManageTool::new(
|
||||||
|
self.store.clone(),
|
||||||
|
self.known_agents.clone(),
|
||||||
|
));
|
||||||
|
registry.register(SkillListTool::new(self.skills.clone()));
|
||||||
|
registry.register(SkillManageTool::new(self.skills.clone()));
|
||||||
|
registry.register(BashTool::new());
|
||||||
|
registry.register(HttpRequestTool::new(
|
||||||
|
vec!["*".to_string()],
|
||||||
|
1_000_000,
|
||||||
|
30,
|
||||||
|
false,
|
||||||
|
));
|
||||||
|
registry.register(WebFetchTool::new(50_000, 30));
|
||||||
|
registry
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user