151 lines
4.9 KiB
Rust
151 lines
4.9 KiB
Rust
use std::sync::Arc;
|
||
|
||
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::gateway::prompt::{generate_system_environment_prompt, load_agent_prompt};
|
||
use crate::storage::PromptInjectionRepository;
|
||
|
||
/// Agent 提示词提供者
|
||
///
|
||
/// 负责提供来自 AGENT.md、MEMORY_SUMMARY.md 和 builtin 的系统提示词,
|
||
/// 以及动态生成的系统环境信息。
|
||
pub struct AgentPromptProvider {
|
||
/// 重新注入间隔(用户消息数)
|
||
_reinject_every: usize,
|
||
/// LLM 提供者配置(用于生成系统环境信息)
|
||
provider_config: LLMProviderConfig,
|
||
/// 会话持久化仓库(用于记录注入状态)
|
||
repository: Arc<dyn PromptInjectionRepository>,
|
||
}
|
||
|
||
impl AgentPromptProvider {
|
||
/// 创建新的 Agent 提示词提供者
|
||
pub fn new(
|
||
reinject_every: usize,
|
||
provider_config: LLMProviderConfig,
|
||
repository: Arc<dyn PromptInjectionRepository>,
|
||
) -> Self {
|
||
Self {
|
||
_reinject_every: reinject_every,
|
||
provider_config,
|
||
repository,
|
||
}
|
||
}
|
||
|
||
/// 判断是否应该注入提示词
|
||
///
|
||
/// 规则:每次处理用户消息都动态注入系统提示词
|
||
fn should_inject(&self, _context: &SystemPromptContext) -> bool {
|
||
// 每次消息都注入系统提示词
|
||
true
|
||
}
|
||
|
||
/// 记录注入事件
|
||
fn record_injection(&self, context: &SystemPromptContext) {
|
||
if let Some(session_id) = &context.session_id {
|
||
let _ = self.repository.mark_agent_prompt_reinjected(session_id);
|
||
}
|
||
}
|
||
}
|
||
|
||
impl SystemPromptProvider for AgentPromptProvider {
|
||
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||
// 检查是否需要注入
|
||
if !self.should_inject(context) {
|
||
return None;
|
||
}
|
||
|
||
// 加载 Agent 提示词(AGENT.md + builtin + MEMORY_SUMMARY.md)
|
||
let agent_prompt = load_agent_prompt().ok().flatten()?;
|
||
|
||
// 生成系统环境信息
|
||
let env_info = generate_system_environment_prompt(&self.provider_config);
|
||
|
||
// 记录注入事件
|
||
self.record_injection(context);
|
||
|
||
Some(SystemPrompt {
|
||
content: format!("{}\n\n{}", agent_prompt, env_info),
|
||
context: Some("agent_prompt".to_string()),
|
||
})
|
||
}
|
||
}
|
||
|
||
/// Agent 提示词提供者(简化版本,无持久化)
|
||
///
|
||
/// 适用于不需要记录注入状态的场景(如测试或一次性任务)。
|
||
pub struct SimpleAgentPromptProvider {
|
||
provider_config: LLMProviderConfig,
|
||
}
|
||
|
||
impl SimpleAgentPromptProvider {
|
||
/// 创建新的简单 Agent 提示词提供者
|
||
pub fn new(provider_config: LLMProviderConfig) -> Self {
|
||
Self { provider_config }
|
||
}
|
||
}
|
||
|
||
impl SystemPromptProvider for SimpleAgentPromptProvider {
|
||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||
// 加载 Agent 提示词
|
||
let agent_prompt = load_agent_prompt().ok().flatten()?;
|
||
|
||
// 生成系统环境信息
|
||
let env_info = generate_system_environment_prompt(&self.provider_config);
|
||
|
||
Some(SystemPrompt {
|
||
content: format!("{}\n\n{}", agent_prompt, env_info),
|
||
context: Some("agent_prompt".to_string()),
|
||
})
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::storage::SessionStore;
|
||
use std::collections::HashMap;
|
||
|
||
fn test_config() -> LLMProviderConfig {
|
||
LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "test".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
llm_timeout_secs: 120,
|
||
memory_maintenance_timeout_secs: 600,
|
||
model_id: "test-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
}
|
||
}
|
||
|
||
fn test_context(user_count: usize, session_id: Option<&str>) -> SystemPromptContext {
|
||
SystemPromptContext {
|
||
session_id: session_id.map(|s| s.to_string()),
|
||
chat_id: "test-chat".to_string(),
|
||
user_message_count: user_count,
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_simple_provider_builds_prompt() {
|
||
let provider = SimpleAgentPromptProvider::new(test_config());
|
||
let context = test_context(0, None);
|
||
|
||
// Simple provider always returns content
|
||
let result = provider.build(&context).unwrap();
|
||
assert!(result.content.contains("PicoBot"));
|
||
assert!(result.content.contains("操作系统:"));
|
||
assert_eq!(result.context, Some("agent_prompt".to_string()));
|
||
}
|
||
}
|