194 lines
5.9 KiB
Rust
194 lines
5.9 KiB
Rust
/// 系统提示词提供者的上下文
|
|
#[derive(Debug, Clone)]
|
|
pub struct SystemPromptContext {
|
|
/// 会话 ID
|
|
pub session_id: Option<String>,
|
|
/// 聊天 ID
|
|
pub chat_id: String,
|
|
/// 用户消息计数(用于判断是否重新注入)
|
|
pub user_message_count: usize,
|
|
}
|
|
|
|
/// 系统提示词结果
|
|
#[derive(Debug, Clone)]
|
|
pub struct SystemPrompt {
|
|
/// 提示词内容
|
|
pub content: String,
|
|
/// 上下文标记(如 "agent_prompt", "skill_index" 等)
|
|
pub context: Option<String>,
|
|
}
|
|
|
|
/// 系统提示词提供者 trait
|
|
///
|
|
/// 实现此 trait 可以为 AgentLoop 提供系统提示词内容。
|
|
/// 每次用户请求时动态构建,不持久化。
|
|
pub trait SystemPromptProvider: Send + Sync + 'static {
|
|
/// 构建系统提示词
|
|
///
|
|
/// 返回 `None` 表示此提供者没有内容要注入。
|
|
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt>;
|
|
}
|
|
|
|
/// 组合多个提供者的系统提示词
|
|
///
|
|
/// 按顺序调用所有提供者,合并非空内容为完整的系统提示词。
|
|
pub struct CompositeSystemPromptProvider {
|
|
providers: Vec<Box<dyn SystemPromptProvider>>,
|
|
}
|
|
|
|
impl CompositeSystemPromptProvider {
|
|
/// 创建新的组合提供者
|
|
pub fn new(providers: Vec<Box<dyn SystemPromptProvider>>) -> Self {
|
|
Self { providers }
|
|
}
|
|
|
|
/// 构建组合后的系统提示词
|
|
///
|
|
/// 按顺序收集所有非空提供者的内容,用 `\n\n` 连接。
|
|
pub fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
|
|
let fragments: Vec<String> = self
|
|
.providers
|
|
.iter()
|
|
.filter_map(|p| p.build(context))
|
|
.map(|p| p.content)
|
|
.collect();
|
|
|
|
if fragments.is_empty() {
|
|
None
|
|
} else {
|
|
Some(SystemPrompt {
|
|
content: fragments.join("\n\n"),
|
|
context: Some("combined_system_prompt".to_string()),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
impl SystemPromptProvider for CompositeSystemPromptProvider {
|
|
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
|
|
self.build(context)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
struct TestProvider {
|
|
content: &'static str,
|
|
}
|
|
|
|
impl SystemPromptProvider for TestProvider {
|
|
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
|
if self.content.is_empty() {
|
|
None
|
|
} else {
|
|
Some(SystemPrompt {
|
|
content: self.content.to_string(),
|
|
context: Some("test".to_string()),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
struct ConditionalProvider {
|
|
user_message_count_threshold: usize,
|
|
content: &'static str,
|
|
}
|
|
|
|
impl SystemPromptProvider for ConditionalProvider {
|
|
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
|
|
if context.user_message_count >= self.user_message_count_threshold {
|
|
Some(SystemPrompt {
|
|
content: self.content.to_string(),
|
|
context: Some("conditional".to_string()),
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
fn test_context(user_count: usize) -> SystemPromptContext {
|
|
SystemPromptContext {
|
|
session_id: Some("test-session".to_string()),
|
|
chat_id: "test-chat".to_string(),
|
|
user_message_count: user_count,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_composite_provider_joins_multiple_sources() {
|
|
let composite = CompositeSystemPromptProvider::new(vec![
|
|
Box::new(TestProvider {
|
|
content: "First part",
|
|
}),
|
|
Box::new(TestProvider {
|
|
content: "Second part",
|
|
}),
|
|
]);
|
|
|
|
let result = composite.build(&test_context(0)).unwrap();
|
|
assert_eq!(result.content, "First part\n\nSecond part");
|
|
}
|
|
|
|
#[test]
|
|
fn test_composite_provider_skips_empty_sources() {
|
|
let composite = CompositeSystemPromptProvider::new(vec![
|
|
Box::new(TestProvider {
|
|
content: "First part",
|
|
}),
|
|
Box::new(TestProvider { content: "" }),
|
|
Box::new(TestProvider {
|
|
content: "Third part",
|
|
}),
|
|
]);
|
|
|
|
let result = composite.build(&test_context(0)).unwrap();
|
|
assert_eq!(result.content, "First part\n\nThird part");
|
|
}
|
|
|
|
#[test]
|
|
fn test_composite_provider_returns_none_for_all_empty() {
|
|
let composite = CompositeSystemPromptProvider::new(vec![
|
|
Box::new(TestProvider { content: "" }),
|
|
Box::new(TestProvider { content: "" }),
|
|
]);
|
|
|
|
assert!(composite.build(&test_context(0)).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_composite_provider_preserves_order() {
|
|
let composite = CompositeSystemPromptProvider::new(vec![
|
|
Box::new(TestProvider { content: "A" }),
|
|
Box::new(TestProvider { content: "B" }),
|
|
Box::new(TestProvider { content: "C" }),
|
|
]);
|
|
|
|
let result = composite.build(&test_context(0)).unwrap();
|
|
assert_eq!(result.content, "A\n\nB\n\nC");
|
|
}
|
|
|
|
#[test]
|
|
fn test_conditional_provider_respects_context() {
|
|
let composite = CompositeSystemPromptProvider::new(vec![
|
|
Box::new(TestProvider {
|
|
content: "Always present",
|
|
}),
|
|
Box::new(ConditionalProvider {
|
|
user_message_count_threshold: 5,
|
|
content: "Conditional content",
|
|
}),
|
|
]);
|
|
|
|
// User message count < 5, conditional provider returns None
|
|
let result1 = composite.build(&test_context(3)).unwrap();
|
|
assert_eq!(result1.content, "Always present");
|
|
|
|
// User message count >= 5, conditional provider returns Some
|
|
let result2 = composite.build(&test_context(5)).unwrap();
|
|
assert_eq!(result2.content, "Always present\n\nConditional content");
|
|
}
|
|
}
|