PicoBot/src/agent/system_prompt.rs

194 lines
5.9 KiB
Rust

/// 系统提示词提供者的上下文
#[derive(Debug, Clone)]
pub struct SystemPromptContext {
/// 会话 ID
pub session_id: Option<String>,
/// 聊天 ID
pub chat_id: String,
/// 用户消息计数(用于判断是否重新注入)
pub user_message_count: usize,
}
/// 系统提示词结果
#[derive(Debug, Clone)]
pub struct SystemPrompt {
/// 提示词内容
pub content: String,
/// 上下文标记(如 "agent_prompt", "skill_index" 等)
pub context: Option<String>,
}
/// 系统提示词提供者 trait
///
/// 实现此 trait 可以为 AgentLoop 提供系统提示词内容。
/// 每次用户请求时动态构建,不持久化。
pub trait SystemPromptProvider: Send + Sync + 'static {
/// 构建系统提示词
///
/// 返回 `None` 表示此提供者没有内容要注入。
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt>;
}
/// 组合多个提供者的系统提示词
///
/// 按顺序调用所有提供者,合并非空内容为完整的系统提示词。
pub struct CompositeSystemPromptProvider {
providers: Vec<Box<dyn SystemPromptProvider>>,
}
impl CompositeSystemPromptProvider {
/// 创建新的组合提供者
pub fn new(providers: Vec<Box<dyn SystemPromptProvider>>) -> Self {
Self { providers }
}
/// 构建组合后的系统提示词
///
/// 按顺序收集所有非空提供者的内容,用 `\n\n` 连接。
pub fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
let fragments: Vec<String> = self
.providers
.iter()
.filter_map(|p| p.build(context))
.map(|p| p.content)
.collect();
if fragments.is_empty() {
None
} else {
Some(SystemPrompt {
content: fragments.join("\n\n"),
context: Some("combined_system_prompt".to_string()),
})
}
}
}
impl SystemPromptProvider for CompositeSystemPromptProvider {
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
self.build(context)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestProvider {
content: &'static str,
}
impl SystemPromptProvider for TestProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
if self.content.is_empty() {
None
} else {
Some(SystemPrompt {
content: self.content.to_string(),
context: Some("test".to_string()),
})
}
}
}
struct ConditionalProvider {
user_message_count_threshold: usize,
content: &'static str,
}
impl SystemPromptProvider for ConditionalProvider {
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
if context.user_message_count >= self.user_message_count_threshold {
Some(SystemPrompt {
content: self.content.to_string(),
context: Some("conditional".to_string()),
})
} else {
None
}
}
}
fn test_context(user_count: usize) -> SystemPromptContext {
SystemPromptContext {
session_id: Some("test-session".to_string()),
chat_id: "test-chat".to_string(),
user_message_count: user_count,
}
}
#[test]
fn test_composite_provider_joins_multiple_sources() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider {
content: "First part",
}),
Box::new(TestProvider {
content: "Second part",
}),
]);
let result = composite.build(&test_context(0)).unwrap();
assert_eq!(result.content, "First part\n\nSecond part");
}
#[test]
fn test_composite_provider_skips_empty_sources() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider {
content: "First part",
}),
Box::new(TestProvider { content: "" }),
Box::new(TestProvider {
content: "Third part",
}),
]);
let result = composite.build(&test_context(0)).unwrap();
assert_eq!(result.content, "First part\n\nThird part");
}
#[test]
fn test_composite_provider_returns_none_for_all_empty() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider { content: "" }),
Box::new(TestProvider { content: "" }),
]);
assert!(composite.build(&test_context(0)).is_none());
}
#[test]
fn test_composite_provider_preserves_order() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider { content: "A" }),
Box::new(TestProvider { content: "B" }),
Box::new(TestProvider { content: "C" }),
]);
let result = composite.build(&test_context(0)).unwrap();
assert_eq!(result.content, "A\n\nB\n\nC");
}
#[test]
fn test_conditional_provider_respects_context() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider {
content: "Always present",
}),
Box::new(ConditionalProvider {
user_message_count_threshold: 5,
content: "Conditional content",
}),
]);
// User message count < 5, conditional provider returns None
let result1 = composite.build(&test_context(3)).unwrap();
assert_eq!(result1.content, "Always present");
// User message count >= 5, conditional provider returns Some
let result2 = composite.build(&test_context(5)).unwrap();
assert_eq!(result2.content, "Always present\n\nConditional content");
}
}