PicoBot/src/gateway/agent_factory.rs

81 lines
2.8 KiB
Rust

use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider};
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::skills::{SkillPromptProvider, SkillRuntime};
use crate::storage::persistent_session_id;
use crate::storage::PromptInjectionRepository;
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Clone)]
pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
provider_config: LLMProviderConfig,
reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>,
}
pub(crate) struct AgentBuildRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) session_chat_id: &'a str,
pub(crate) notification_chat_id: Option<&'a str>,
pub(crate) sender_id: Option<&'a str>,
pub(crate) message_id: Option<&'a str>,
pub(crate) provider_config: LLMProviderConfig,
}
impl AgentFactory {
pub(crate) fn new(
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
provider_config: LLMProviderConfig,
reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>,
) -> Self {
Self {
tools,
skills,
provider_config,
reinject_every,
prompt_repository,
}
}
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
let session_id = persistent_session_id(request.channel_name, request.session_chat_id);
// 创建组合的系统提示词提供者
let system_prompt_provider = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
self.reinject_every,
request.provider_config.clone(),
self.prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(self.skills.clone())),
]));
AgentLoop::with_tools_and_system_prompt_provider(
request.provider_config,
self.tools.clone(),
system_prompt_provider,
Some(self.skills.clone()),
)
.map(|agent| {
// notification_chat_id 优先,否则使用 session_chat_id
let tool_chat_id = request
.notification_chat_id
.unwrap_or(request.session_chat_id);
agent.with_tool_context(ToolContext {
channel_name: Some(request.channel_name.to_string()),
sender_id: request.sender_id.map(str::to_string),
chat_id: Some(tool_chat_id.to_string()),
session_id: Some(session_id),
message_id: request.message_id.map(str::to_string),
message_seq: None,
})
})
}
}