use std::sync::Arc; use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; use crate::config::LLMProviderConfig; use crate::gateway::prompt::{generate_system_environment_prompt, load_agent_prompt}; use crate::storage::PromptInjectionRepository; /// Agent 提示词提供者 /// /// 负责提供来自 AGENT.md、MEMORY_SUMMARY.md 和 builtin 的系统提示词, /// 以及动态生成的系统环境信息。 pub struct AgentPromptProvider { /// 重新注入间隔(用户消息数) _reinject_every: usize, /// LLM 提供者配置(用于生成系统环境信息) provider_config: LLMProviderConfig, /// 会话持久化仓库(用于记录注入状态) repository: Arc, } impl AgentPromptProvider { /// 创建新的 Agent 提示词提供者 pub fn new( reinject_every: usize, provider_config: LLMProviderConfig, repository: Arc, ) -> Self { Self { _reinject_every: reinject_every, provider_config, repository, } } /// 判断是否应该注入提示词 /// /// 规则:每次处理用户消息都动态注入系统提示词 fn should_inject(&self, _context: &SystemPromptContext) -> bool { // 每次消息都注入系统提示词 true } /// 记录注入事件 fn record_injection(&self, context: &SystemPromptContext) { if let Some(session_id) = &context.session_id { let _ = self.repository.mark_agent_prompt_reinjected(session_id); } } } impl SystemPromptProvider for AgentPromptProvider { fn build(&self, context: &SystemPromptContext) -> Option { // 检查是否需要注入 if !self.should_inject(context) { return None; } // 加载 Agent 提示词(AGENT.md + builtin + MEMORY_SUMMARY.md) let agent_prompt = load_agent_prompt().ok().flatten()?; // 生成系统环境信息 let env_info = generate_system_environment_prompt(&self.provider_config); // 记录注入事件 self.record_injection(context); Some(SystemPrompt { content: format!("{}\n\n{}", agent_prompt, env_info), context: Some("agent_prompt".to_string()), }) } } /// Agent 提示词提供者(简化版本,无持久化) /// /// 适用于不需要记录注入状态的场景(如测试或一次性任务)。 pub struct SimpleAgentPromptProvider { provider_config: LLMProviderConfig, } impl SimpleAgentPromptProvider { /// 创建新的简单 Agent 提示词提供者 pub fn new(provider_config: LLMProviderConfig) -> Self { Self { provider_config } } } impl SystemPromptProvider for SimpleAgentPromptProvider { fn build(&self, _context: &SystemPromptContext) -> Option { // 加载 Agent 提示词 let agent_prompt = load_agent_prompt().ok().flatten()?; // 生成系统环境信息 let env_info = generate_system_environment_prompt(&self.provider_config); Some(SystemPrompt { content: format!("{}\n\n{}", agent_prompt, env_info), context: Some("agent_prompt".to_string()), }) } } #[cfg(test)] mod tests { use super::*; use crate::storage::SessionStore; use std::collections::HashMap; fn test_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), llm_timeout_secs: 120, memory_maintenance_timeout_secs: 600, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 100_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, } } fn test_context(user_count: usize, session_id: Option<&str>) -> SystemPromptContext { SystemPromptContext { session_id: session_id.map(|s| s.to_string()), chat_id: "test-chat".to_string(), user_message_count: user_count, } } #[test] fn test_simple_provider_builds_prompt() { let provider = SimpleAgentPromptProvider::new(test_config()); let context = test_context(0, None); // Simple provider always returns content let result = provider.build(&context).unwrap(); assert!(result.content.contains("PicoBot")); assert!(result.content.contains("操作系统:")); assert_eq!(result.context, Some("agent_prompt".to_string())); } }