Compare commits

...

2 Commits

15 changed files with 572 additions and 196 deletions

1
.gitignore vendored
View File

@ -10,3 +10,4 @@ Cargo.lock
PicoBot.code-workspace
.picobot
.claude
output

View File

@ -1,4 +1,5 @@
use crate::agent::AgentRuntimeConfig;
use crate::agent::{SystemPromptContext, SystemPromptProvider};
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::domain::messages::{ContentBlock, ToolCall};
@ -508,7 +509,10 @@ pub struct AgentLoop {
runtime_config: AgentRuntimeConfig,
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
/// 系统提示词提供者(统一注入 Agent 和 Skill 提示词)
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
/// Skill 提供者(用于匹配错误提示)
skills: Option<Arc<dyn SkillProvider>>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
@ -535,6 +539,7 @@ pub trait SkillProvider: Send + Sync + 'static {
}
#[derive(Default)]
#[allow(dead_code)]
struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider {
@ -554,7 +559,8 @@ impl AgentLoop {
runtime_config,
provider,
tools: Arc::new(ToolRegistry::new()),
skills: Arc::new(EmptySkillProvider),
system_prompt_provider: None,
skills: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
@ -575,7 +581,8 @@ impl AgentLoop {
runtime_config,
provider,
tools,
skills: Arc::new(EmptySkillProvider),
system_prompt_provider: None,
skills: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
@ -597,6 +604,34 @@ impl AgentLoop {
runtime_config,
provider,
tools,
system_prompt_provider: None,
skills: Some(skills),
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
/// 使用系统提示词提供者创建 AgentLoop
///
/// 这是新的推荐方式,支持统一注入 Agent 和 Skill 提示词。
pub fn with_tools_and_system_prompt_provider(
config: impl Into<AgentRuntimeConfig>,
tools: Arc<ToolRegistry>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
skills: Option<Arc<dyn SkillProvider>>,
) -> Result<Self, AgentError> {
let runtime_config = config.into();
let max_iterations = runtime_config.max_tool_iterations;
let provider = create_provider(runtime_config.provider.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
runtime_config,
provider,
tools,
system_prompt_provider: Some(system_prompt_provider),
skills,
tool_context: ToolContext::default(),
observer: None,
@ -632,9 +667,14 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
///
/// # 参数
/// - `messages`: 会话历史消息
/// - `system_prompt_context`: 系统提示词上下文(用于动态注入,可选)
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
system_prompt_context: Option<&SystemPromptContext>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(
@ -660,9 +700,17 @@ impl AgentLoop {
};
let image_count = count_supported_image_media_refs(&messages);
// 构建系统提示词(统一注入 Agent 和 Skill 提示词)
let system_prompt = system_prompt_context.and_then(|ctx| {
self.system_prompt_provider
.as_ref()
.and_then(|provider| provider.build(ctx))
});
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 2);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
text_only_messages.push(Message::system(skill_prompt.clone()));
if let Some(ref prompt) = system_prompt {
text_only_messages.push(Message::system(prompt.content.clone()));
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
@ -673,8 +721,9 @@ impl AgentLoop {
);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 2);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
// 使用相同的系统提示词(已构建)
if let Some(ref prompt) = system_prompt {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
messages_for_llm.extend(
messages
@ -863,19 +912,27 @@ impl AgentLoop {
);
messages.push(summary_request);
// Convert messages to LLM format
// Convert messages to LLM format (使用系统提示词提供者)
let image_count = count_supported_image_media_refs(&messages);
let mut text_only_messages: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
text_only_messages.push(Message::system(skill_prompt));
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
text_only_messages.push(Message::system(prompt.content.clone()));
}
}
}
text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message));
let image_tokens =
image_token_budget_for_request(&self.runtime_config, &text_only_messages, None);
let mut image_budget = ImageInlineBudget::new(image_tokens, image_count);
let mut messages_for_llm: Vec<Message> = Vec::with_capacity(messages.len() + 1);
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
if let Some(ref provider) = self.system_prompt_provider {
if let Some(ctx) = system_prompt_context {
if let Some(prompt) = provider.build(ctx) {
messages_for_llm.push(Message::system(prompt.content.clone()));
}
}
}
messages_for_llm.extend(
messages
@ -1025,7 +1082,9 @@ impl AgentLoop {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
let skill_hint = self.skills.matching_skill_summary(&tool_call.name);
let skill_hint = self.skills
.as_ref()
.and_then(|s| s.matching_skill_summary(&tool_call.name));
let error = match skill_hint {
Some(summary) => format!(
"Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.",

View File

@ -1,9 +1,13 @@
pub mod agent_loop;
pub mod context_compressor;
pub mod runtime_config;
pub mod system_prompt;
pub use agent_loop::{
AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler, SkillProvider,
};
pub use context_compressor::ContextCompressor;
pub use runtime_config::AgentRuntimeConfig;
pub use system_prompt::{
CompositeSystemPromptProvider, SystemPrompt, SystemPromptContext, SystemPromptProvider,
};

193
src/agent/system_prompt.rs Normal file
View File

@ -0,0 +1,193 @@
/// 系统提示词提供者的上下文
#[derive(Debug, Clone)]
pub struct SystemPromptContext {
/// 会话 ID
pub session_id: Option<String>,
/// 聊天 ID
pub chat_id: String,
/// 用户消息计数(用于判断是否重新注入)
pub user_message_count: usize,
}
/// 系统提示词结果
#[derive(Debug, Clone)]
pub struct SystemPrompt {
/// 提示词内容
pub content: String,
/// 上下文标记(如 "agent_prompt", "skill_index" 等)
pub context: Option<String>,
}
/// 系统提示词提供者 trait
///
/// 实现此 trait 可以为 AgentLoop 提供系统提示词内容。
/// 每次用户请求时动态构建,不持久化。
pub trait SystemPromptProvider: Send + Sync + 'static {
/// 构建系统提示词
///
/// 返回 `None` 表示此提供者没有内容要注入。
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt>;
}
/// 组合多个提供者的系统提示词
///
/// 按顺序调用所有提供者,合并非空内容为完整的系统提示词。
pub struct CompositeSystemPromptProvider {
providers: Vec<Box<dyn SystemPromptProvider>>,
}
impl CompositeSystemPromptProvider {
/// 创建新的组合提供者
pub fn new(providers: Vec<Box<dyn SystemPromptProvider>>) -> Self {
Self { providers }
}
/// 构建组合后的系统提示词
///
/// 按顺序收集所有非空提供者的内容,用 `\n\n` 连接。
pub fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
let fragments: Vec<String> = self
.providers
.iter()
.filter_map(|p| p.build(context))
.map(|p| p.content)
.collect();
if fragments.is_empty() {
None
} else {
Some(SystemPrompt {
content: fragments.join("\n\n"),
context: Some("combined_system_prompt".to_string()),
})
}
}
}
impl SystemPromptProvider for CompositeSystemPromptProvider {
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
self.build(context)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestProvider {
content: &'static str,
}
impl SystemPromptProvider for TestProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
if self.content.is_empty() {
None
} else {
Some(SystemPrompt {
content: self.content.to_string(),
context: Some("test".to_string()),
})
}
}
}
struct ConditionalProvider {
user_message_count_threshold: usize,
content: &'static str,
}
impl SystemPromptProvider for ConditionalProvider {
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
if context.user_message_count >= self.user_message_count_threshold {
Some(SystemPrompt {
content: self.content.to_string(),
context: Some("conditional".to_string()),
})
} else {
None
}
}
}
fn test_context(user_count: usize) -> SystemPromptContext {
SystemPromptContext {
session_id: Some("test-session".to_string()),
chat_id: "test-chat".to_string(),
user_message_count: user_count,
}
}
#[test]
fn test_composite_provider_joins_multiple_sources() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider {
content: "First part",
}),
Box::new(TestProvider {
content: "Second part",
}),
]);
let result = composite.build(&test_context(0)).unwrap();
assert_eq!(result.content, "First part\n\nSecond part");
}
#[test]
fn test_composite_provider_skips_empty_sources() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider {
content: "First part",
}),
Box::new(TestProvider { content: "" }),
Box::new(TestProvider {
content: "Third part",
}),
]);
let result = composite.build(&test_context(0)).unwrap();
assert_eq!(result.content, "First part\n\nThird part");
}
#[test]
fn test_composite_provider_returns_none_for_all_empty() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider { content: "" }),
Box::new(TestProvider { content: "" }),
]);
assert!(composite.build(&test_context(0)).is_none());
}
#[test]
fn test_composite_provider_preserves_order() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider { content: "A" }),
Box::new(TestProvider { content: "B" }),
Box::new(TestProvider { content: "C" }),
]);
let result = composite.build(&test_context(0)).unwrap();
assert_eq!(result.content, "A\n\nB\n\nC");
}
#[test]
fn test_conditional_provider_respects_context() {
let composite = CompositeSystemPromptProvider::new(vec![
Box::new(TestProvider {
content: "Always present",
}),
Box::new(ConditionalProvider {
user_message_count_threshold: 5,
content: "Conditional content",
}),
]);
// User message count < 5, conditional provider returns None
let result1 = composite.build(&test_context(3)).unwrap();
assert_eq!(result1.content, "Always present");
// User message count >= 5, conditional provider returns Some
let result2 = composite.build(&test_context(5)).unwrap();
assert_eq!(result2.content, "Always present\n\nConditional content");
}
}

View File

@ -1,14 +1,19 @@
use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop, SkillProvider};
use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider};
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::skills::{SkillPromptProvider, SkillRuntime};
use crate::storage::persistent_session_id;
use crate::storage::PromptInjectionRepository;
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Clone)]
pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
skills: Arc<SkillRuntime>,
reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>,
}
pub(crate) struct AgentBuildRequest<'a> {
@ -21,16 +26,38 @@ pub(crate) struct AgentBuildRequest<'a> {
}
impl AgentFactory {
pub(crate) fn new(tools: Arc<ToolRegistry>, skills: Arc<dyn SkillProvider>) -> Self {
Self { tools, skills }
pub(crate) fn new(
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>,
) -> Self {
Self {
tools,
skills,
reinject_every,
prompt_repository,
}
}
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
let session_id = persistent_session_id(request.channel_name, request.session_chat_id);
AgentLoop::with_tools_and_skill_provider(
// 创建组合的系统提示词提供者
let system_prompt_provider = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
self.reinject_every,
request.provider_config.clone(),
self.prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(self.skills.clone())),
]));
AgentLoop::with_tools_and_system_prompt_provider(
request.provider_config,
self.tools.clone(),
self.skills.clone(),
system_prompt_provider,
Some(self.skills.clone()),
)
.map(|agent| {
// notification_chat_id 优先,否则使用 session_chat_id

View File

@ -0,0 +1,175 @@
use std::sync::Arc;
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::config::LLMProviderConfig;
use crate::gateway::prompt::{generate_system_environment_prompt, load_agent_prompt};
use crate::storage::PromptInjectionRepository;
/// Agent 提示词提供者
///
/// 负责提供来自 AGENT.md、MEMORY_SUMMARY.md 和 builtin 的系统提示词,
/// 以及动态生成的系统环境信息。
pub struct AgentPromptProvider {
/// 重新注入间隔(用户消息数)
reinject_every: usize,
/// LLM 提供者配置(用于生成系统环境信息)
provider_config: LLMProviderConfig,
/// 会话持久化仓库(用于记录注入状态)
repository: Arc<dyn PromptInjectionRepository>,
}
impl AgentPromptProvider {
/// 创建新的 Agent 提示词提供者
pub fn new(
reinject_every: usize,
provider_config: LLMProviderConfig,
repository: Arc<dyn PromptInjectionRepository>,
) -> Self {
Self {
reinject_every,
provider_config,
repository,
}
}
/// 判断是否应该注入提示词
///
/// 规则:
/// - 首次对话user_message_count == 0始终注入
/// - 如果设置了 reinject_every > 0且满足间隔条件则重新注入
fn should_inject(&self, context: &SystemPromptContext) -> bool {
// 首次对话始终注入
if context.user_message_count == 0 {
return true;
}
// 检查是否需要重新注入
if self.reinject_every == 0 {
return false;
}
// 获取会话注入计数
let session_id = match &context.session_id {
Some(id) => id,
None => return false,
};
let reinjection_count = self
.repository
.get_session(session_id)
.ok()
.flatten()
.map(|session| session.agent_prompt_reinjection_count as usize)
.unwrap_or(0);
let expected_reinjections = context.user_message_count / self.reinject_every;
expected_reinjections > reinjection_count
}
/// 记录注入事件
fn record_injection(&self, context: &SystemPromptContext) {
if let Some(session_id) = &context.session_id {
let _ = self.repository.mark_agent_prompt_reinjected(session_id);
}
}
}
impl SystemPromptProvider for AgentPromptProvider {
fn build(&self, context: &SystemPromptContext) -> Option<SystemPrompt> {
// 检查是否需要注入
if !self.should_inject(context) {
return None;
}
// 加载 Agent 提示词AGENT.md + builtin + MEMORY_SUMMARY.md
let agent_prompt = load_agent_prompt().ok().flatten()?;
// 生成系统环境信息
let env_info = generate_system_environment_prompt(&self.provider_config);
// 记录注入事件
self.record_injection(context);
Some(SystemPrompt {
content: format!("{}\n\n{}", agent_prompt, env_info),
context: Some("agent_prompt".to_string()),
})
}
}
/// Agent 提示词提供者(简化版本,无持久化)
///
/// 适用于不需要记录注入状态的场景(如测试或一次性任务)。
pub struct SimpleAgentPromptProvider {
provider_config: LLMProviderConfig,
}
impl SimpleAgentPromptProvider {
/// 创建新的简单 Agent 提示词提供者
pub fn new(provider_config: LLMProviderConfig) -> Self {
Self { provider_config }
}
}
impl SystemPromptProvider for SimpleAgentPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
// 加载 Agent 提示词
let agent_prompt = load_agent_prompt().ok().flatten()?;
// 生成系统环境信息
let env_info = generate_system_environment_prompt(&self.provider_config);
Some(SystemPrompt {
content: format!("{}\n\n{}", agent_prompt, env_info),
context: Some("agent_prompt".to_string()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
use std::collections::HashMap;
fn test_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
fn test_context(user_count: usize, session_id: Option<&str>) -> SystemPromptContext {
SystemPromptContext {
session_id: session_id.map(|s| s.to_string()),
chat_id: "test-chat".to_string(),
user_message_count: user_count,
}
}
#[test]
fn test_simple_provider_builds_prompt() {
let provider = SimpleAgentPromptProvider::new(test_config());
let context = test_context(0, None);
// Simple provider always returns content
let result = provider.build(&context).unwrap();
assert!(result.content.contains("PicoBot"));
assert!(result.content.contains("操作系统:"));
assert_eq!(result.context, Some("agent_prompt".to_string()));
}
}

View File

@ -118,13 +118,14 @@ mod tests {
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
3,
// 新设计:系统提示词不再持久化,只有 1 条用户消息
1,
);
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "system");
// 新设计:系统提示词不再持久化到历史记录
assert_eq!(history.len(), 0);
}
#[tokio::test]
@ -157,8 +158,8 @@ mod tests {
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
// 新设计:系统提示词不再持久化到历史记录
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "system");
assert_eq!(history.len(), 0);
}
}

View File

@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler};
use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler, SystemPromptContext};
use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDULED_PROMPT};
use crate::config::LLMProviderConfig;
@ -132,7 +132,7 @@ impl AgentExecutionService {
&self,
request: MessageExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message) = {
let (history, agent, user_message, user_message_count) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
@ -167,6 +167,7 @@ impl AgentExecutionService {
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
let history = session_guard.get_or_create_history(request.chat_id).clone();
let user_message_count = history.iter().filter(|m| m.role == "user").count();
session_guard.record_skill_offer(request.chat_id)?;
let mut agent = session_guard.create_agent(
@ -178,10 +179,17 @@ impl AgentExecutionService {
agent = agent.with_emitted_message_handler(handler);
}
(history, agent, user_message)
(history, agent, user_message, user_message_count)
};
let result = agent.process(history).await?;
// 构建系统提示词上下文
let system_prompt_context = SystemPromptContext {
session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)),
chat_id: request.chat_id.to_string(),
user_message_count,
};
let result = agent.process(history, Some(&system_prompt_context)).await?;
let metadata = HashMap::new();
self.finalize_result_and_schedule_compaction(
@ -203,7 +211,7 @@ impl AgentExecutionService {
&self,
request: ScheduledExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message) = {
let (history, agent, user_message, user_message_count) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
@ -229,6 +237,7 @@ impl AgentExecutionService {
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
let history = session_guard.get_or_create_history(request.chat_id).clone();
let user_message_count = history.iter().filter(|m| m.role == "user").count();
session_guard.record_skill_offer(request.chat_id)?;
let agent = session_guard.create_agent_with_provider_config(
@ -239,10 +248,17 @@ impl AgentExecutionService {
request.provider_config.clone(),
)?;
(history, agent, user_message)
(history, agent, user_message, user_message_count)
};
let result = agent.process(history).await?;
// 构建系统提示词上下文
let system_prompt_context = SystemPromptContext {
session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)),
chat_id: request.chat_id.to_string(),
user_message_count,
};
let result = agent.process(history, Some(&system_prompt_context)).await?;
self.finalize_result_and_schedule_compaction(
request.session.clone(),

View File

@ -1,4 +1,5 @@
pub mod agent_factory;
pub mod agent_prompt_provider;
pub mod agent_task_executor;
pub mod cli_session;
pub mod command;
@ -11,7 +12,6 @@ pub mod message_prepare;
pub mod outbound_dispatcher;
pub mod processor;
pub mod prompt;
pub mod prompt_injector;
pub mod provider_config_service;
pub mod runtime;
pub mod scheduled_agent_task_service;

View File

@ -1,86 +0,0 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::PromptInjectionRepository;
use super::prompt::load_agent_prompt;
#[derive(Clone)]
pub(crate) struct PromptInjector {
repository: Arc<dyn PromptInjectionRepository>,
reinject_every: i64,
}
impl PromptInjector {
pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
Self {
repository,
reinject_every: reinject_every as i64,
}
}
pub(crate) fn ensure_initial_prompt<F>(
&self,
history_is_empty: bool,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
}
Ok(())
}
pub(crate) fn ensure_reinjected_prompt<F>(
&self,
session_id: &str,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
let session_record = self
.repository
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = self
.repository
.count_active_user_messages(session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
})?;
if self.reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
self.repository
.mark_agent_prompt_reinjected(session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
})?;
}
}
Ok(())
}
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
)
}
}

View File

@ -14,7 +14,6 @@ use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry}
use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService;
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
use super::prompt_injector::PromptInjector;
use super::provider_config_service::ProviderConfigService;
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
use super::session::{SessionManager, SessionManagerServices};
@ -86,15 +85,18 @@ pub(crate) fn build_session_manager_with_sender(
.build(),
);
let agent_factory = AgentFactory::new(tools.clone(), skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
let prompt_injector = PromptInjector::new(prompt_repository, agent_prompt_reinject_every);
let agent_factory = AgentFactory::new(
tools.clone(),
skills.clone(),
agent_prompt_reinject_every as usize,
prompt_repository.clone(),
);
let conversations: Arc<dyn ConversationRepository> = store.clone();
let session_factory = SessionFactory::new(
provider_config.clone(),
skills.clone(),
agent_factory,
prompt_injector,
conversations,
skill_events,
chat_history_ttl_hours,

View File

@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::scheduler::ScheduledAgentTaskOptions;
use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use std::collections::HashMap;
@ -25,7 +25,6 @@ use super::memory_maintenance::{
};
use super::memory_maintenance::{MemoryMaintenanceScopeResult, MemoryOrganizationOutput};
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
use super::prompt_injector::PromptInjector;
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
use super::session_history::SessionHistory;
use super::session_lifecycle::SessionLifecycleService;
@ -102,17 +101,21 @@ impl Session {
agent_prompt_reinject_every: u64,
chat_history_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> {
let agent_factory = AgentFactory::new(tools, skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
let agent_factory = AgentFactory::new(
tools,
skills.clone(),
agent_prompt_reinject_every as usize,
prompt_repository.clone(),
);
Self::with_factories(
channel_name,
provider_config,
user_tx,
skills,
agent_factory,
prompt_injector,
conversations,
skill_events,
chat_history_ttl_hours,
@ -126,7 +129,6 @@ impl Session {
user_tx: mpsc::Sender<WsOutbound>,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>,
@ -141,10 +143,8 @@ impl Session {
compressor: ContextCompressor::from_provider_config(&provider_config),
history: SessionHistory::new(
channel_name,
prompt_injector,
conversations,
skill_events,
provider_config,
chat_history_ttl_hours,
),
})
@ -1562,9 +1562,8 @@ mod tests {
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "system");
assert!(history[0].content.contains("PicoBot 代理配置"));
// 新设计:系统提示词不再持久化到历史记录,而是每次请求时动态注入
assert_eq!(history.len(), 0);
}
#[tokio::test]
@ -1611,28 +1610,32 @@ mod tests {
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
// 新设计:系统提示词不再持久化到历史记录
let history = session.get_history("chat-1").unwrap();
let system_messages = history
let user_messages = history
.iter()
.filter(|message| message.role == "system")
.filter(|message| message.role == "user")
.count();
assert_eq!(system_messages, 3);
assert_eq!(user_messages, 100);
// 注入计数在实际处理请求时由 AgentPromptProvider 更新
// 此处仅为模拟调用,不会触发实际注入
let stored = store
.get_session(&session.persistent_session_id("chat-1"))
.unwrap()
.unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 1);
// 初始值为 0只有在实际 process 调用时才会更新
assert_eq!(stored.agent_prompt_reinjection_count, 0);
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history
let user_messages = history
.iter()
.filter(|message| message.role == "system")
.filter(|message| message.role == "user")
.count();
assert_eq!(system_messages, 3);
assert_eq!(user_messages, 100);
}
#[tokio::test]
@ -1679,12 +1682,13 @@ mod tests {
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
// 新设计:系统提示词不再持久化到历史记录
let history = session.get_history("chat-1").unwrap();
let system_messages = history
let user_messages = history
.iter()
.filter(|message| message.role == "system")
.filter(|message| message.role == "user")
.count();
assert_eq!(system_messages, 2);
assert_eq!(user_messages, 100);
}
#[test]

View File

@ -9,7 +9,6 @@ use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, SkillEventRepository};
use super::agent_factory::AgentFactory;
use super::prompt_injector::PromptInjector;
use super::session::Session;
#[derive(Clone)]
@ -17,7 +16,6 @@ pub(crate) struct SessionFactory {
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>,
@ -28,7 +26,6 @@ impl SessionFactory {
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>,
@ -37,7 +34,6 @@ impl SessionFactory {
provider_config,
skills,
agent_factory,
prompt_injector,
conversations,
skill_events,
chat_history_ttl_hours,
@ -55,7 +51,6 @@ impl SessionFactory {
user_tx,
self.skills.clone(),
self.agent_factory.clone(),
self.prompt_injector.clone(),
self.conversations.clone(),
self.skill_events.clone(),
self.chat_history_ttl_hours,

View File

@ -3,14 +3,10 @@ use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::storage::{
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
};
use super::prompt::generate_system_environment_prompt;
use super::prompt_injector::PromptInjector;
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
@ -30,30 +26,24 @@ pub(crate) struct SessionHistory {
channel_name: String,
chat_histories: HashMap<String, Vec<ChatMessage>>,
compression_in_flight: HashSet<String>,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
provider_config: LLMProviderConfig,
chat_history_ttl_hours: Option<u64>,
}
impl SessionHistory {
pub(crate) fn new(
channel_name: impl Into<String>,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
provider_config: LLMProviderConfig,
chat_history_ttl_hours: Option<u64>,
) -> Self {
Self {
channel_name: channel_name.into(),
chat_histories: HashMap::new(),
compression_in_flight: HashSet::new(),
prompt_injector,
conversations,
skill_events,
provider_config,
chat_history_ttl_hours,
}
}
@ -96,7 +86,7 @@ impl SessionHistory {
// 原有逻辑
if self.chat_histories.contains_key(chat_id) {
return self.ensure_initial_agent_prompt(chat_id);
return Ok(());
}
let history = self
@ -104,21 +94,15 @@ impl SessionHistory {
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
self.ensure_initial_agent_prompt(chat_id)?;
Ok(())
}
pub(crate) fn ensure_agent_prompt_before_user_message(
&mut self,
chat_id: &str,
_chat_id: &str,
) -> Result<(), AgentError> {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
self.append_persisted_message(chat_id, message)
})
// 提示词现在由 AgentPromptProvider 统一处理,不需要在此处注入
Ok(())
}
pub(crate) fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
@ -286,30 +270,4 @@ impl SessionHistory {
)
.map_err(|err| AgentError::Other(format!("append skill event error: {}", err)))
}
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history_is_empty = self
.get_history(chat_id)
.map(|history| history.is_empty())
.unwrap_or(true);
if !history_is_empty {
return Ok(());
}
// 注入 Agent Prompt
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
self.append_persisted_message(chat_id, message)
})?;
// 注入系统环境提示词
let env_prompt = generate_system_environment_prompt(&self.provider_config);
self.append_persisted_message(
chat_id,
ChatMessage::system(env_prompt),
)?;
Ok(())
}
}

View File

@ -4,7 +4,7 @@ use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::RwLock;
use std::sync::{Arc, RwLock};
#[cfg(test)]
static SKILL_TEST_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
@ -865,6 +865,33 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
// 使用 platform 模块提供的 xml_escape 和 path_to_uri 函数
// SkillPromptProvider 实现
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
/// Skill 提示词提供者
///
/// 负责提供技能的系统索引提示词system_index_prompt
pub struct SkillPromptProvider {
skills: Arc<SkillRuntime>,
}
impl SkillPromptProvider {
/// 创建新的 Skill 提示词提供者
pub fn new(skills: Arc<SkillRuntime>) -> Self {
Self { skills }
}
}
impl SystemPromptProvider for SkillPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
// 调用 SkillRuntime 的 system_index_prompt 方法
self.skills.system_index_prompt().map(|content| SystemPrompt {
content,
context: Some("skill_index".to_string()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;