PicoBot/src/gateway/agent_prompt_provider.rs

151 lines
4.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::config::LLMProviderConfig;
use crate::gateway::prompt::{generate_system_environment_prompt, load_agent_prompt};
use crate::storage::PromptInjectionRepository;
/// Agent 提示词提供者
///
/// 负责提供来自 AGENT.md、MEMORY_SUMMARY.md 和 builtin 的系统提示词,
/// 以及动态生成的系统环境信息。
pub struct AgentPromptProvider {
/// 重新注入间隔(用户消息数)
_reinject_every: usize,
/// LLM 提供者配置(用于生成系统环境信息)
provider_config: LLMProviderConfig,
/// 会话持久化仓库(用于记录注入状态)
repository: Arc<dyn PromptInjectionRepository>,
}
impl AgentPromptProvider {
/// 创建新的 Agent 提示词提供者
pub fn new(
reinject_every: usize,
provider_config: LLMProviderConfig,
repository: Arc<dyn PromptInjectionRepository>,
) -> Self {
Self {
_reinject_every: reinject_every,
provider_config,
repository,
}
}
/// 判断是否应该注入提示词
///
/// 规则:每次处理用户消息都动态注入系统提示词
fn should_inject(&self, _context: &SystemPromptContext) -> bool {
// 每次消息都注入系统提示词
true
}
/// 记录注入事件
fn record_injection(&self, context: &SystemPromptContext) {
if let Some(session_id) = &context.session_id {
let _ = self.repository.mark_agent_prompt_reinjected(session_id);
}
}
}
impl SystemPromptProvider for AgentPromptProvider {
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
// 检查是否需要注入
if !self.should_inject(context) {
return None;
}
// 加载 Agent 提示词AGENT.md + builtin + MEMORY_SUMMARY.md
let agent_prompt = load_agent_prompt().ok().flatten()?;
// 生成系统环境信息
let env_info = generate_system_environment_prompt(&self.provider_config);
// 记录注入事件
self.record_injection(context);
Some(SystemPrompt {
content: format!("{}\n\n{}", agent_prompt, env_info),
context: Some("agent_prompt".to_string()),
})
}
}
/// Agent 提示词提供者(简化版本,无持久化)
///
/// 适用于不需要记录注入状态的场景(如测试或一次性任务)。
pub struct SimpleAgentPromptProvider {
provider_config: LLMProviderConfig,
}
impl SimpleAgentPromptProvider {
/// 创建新的简单 Agent 提示词提供者
pub fn new(provider_config: LLMProviderConfig) -> Self {
Self { provider_config }
}
}
impl SystemPromptProvider for SimpleAgentPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
// 加载 Agent 提示词
let agent_prompt = load_agent_prompt().ok().flatten()?;
// 生成系统环境信息
let env_info = generate_system_environment_prompt(&self.provider_config);
Some(SystemPrompt {
content: format!("{}\n\n{}", agent_prompt, env_info),
context: Some("agent_prompt".to_string()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
use std::collections::HashMap;
fn test_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 100_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
fn test_context(user_count: usize, session_id: Option<&str>) -> SystemPromptContext {
SystemPromptContext {
session_id: session_id.map(|s| s.to_string()),
chat_id: "test-chat".to_string(),
user_message_count: user_count,
}
}
#[test]
fn test_simple_provider_builds_prompt() {
let provider = SimpleAgentPromptProvider::new(test_config());
let context = test_context(0, None);
// Simple provider always returns content
let result = provider.build(&context).unwrap();
assert!(result.content.contains("PicoBot"));
assert!(result.content.contains("操作系统:"));
assert_eq!(result.context, Some("agent_prompt".to_string()));
}
}