From a06fceaf0cccf695fe39e31dbf5b726fcbc382cf Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Wed, 13 May 2026 14:55:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E8=AF=8D=E6=8F=90=E4=BE=9B=E8=80=85=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=8A=A8=E6=80=81=E6=B3=A8=E5=85=A5=E5=92=8C?= =?UTF-8?q?=E7=BB=84=E5=90=88=E5=A4=9A=E4=B8=AA=E6=8F=90=E7=A4=BA=E8=AF=8D?= =?UTF-8?q?=E6=BA=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + src/agent/agent_loop.rs | 84 ++++++++++-- src/agent/mod.rs | 4 + src/agent/system_prompt.rs | 193 +++++++++++++++++++++++++++ src/gateway/agent_factory.rs | 42 +++++- src/gateway/agent_prompt_provider.rs | 175 ++++++++++++++++++++++++ src/gateway/command.rs | 11 +- src/gateway/execution.rs | 30 ++++- src/gateway/mod.rs | 1 + src/gateway/runtime.rs | 13 +- src/gateway/session.rs | 46 ++++--- src/gateway/session_factory.rs | 5 - src/gateway/session_history.rs | 46 +------ src/skills/mod.rs | 29 +++- 14 files changed, 576 insertions(+), 104 deletions(-) create mode 100644 src/agent/system_prompt.rs create mode 100644 src/gateway/agent_prompt_provider.rs diff --git a/.gitignore b/.gitignore index 16b6129..0244034 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ Cargo.lock PicoBot.code-workspace .picobot .claude +output diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index edddb38..5a298f3 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,4 +1,5 @@ use crate::agent::AgentRuntimeConfig; +use crate::agent::{SystemPromptContext, SystemPromptProvider}; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; use crate::domain::messages::{ContentBlock, ToolCall}; @@ -508,7 +509,10 @@ pub struct AgentLoop { runtime_config: AgentRuntimeConfig, provider: Box, tools: Arc, - skills: Arc, + /// 系统提示词提供者(统一注入 Agent 和 Skill 提示词) + system_prompt_provider: Option>, + /// Skill 提供者(用于匹配错误提示) + skills: Option>, tool_context: ToolContext, observer: Option>, emitted_message_handler: Option>, @@ -554,7 +558,8 @@ impl AgentLoop { runtime_config, provider, tools: Arc::new(ToolRegistry::new()), - skills: Arc::new(EmptySkillProvider), + system_prompt_provider: None, + skills: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, @@ -575,7 +580,8 @@ impl AgentLoop { runtime_config, provider, tools, - skills: Arc::new(EmptySkillProvider), + system_prompt_provider: None, + skills: None, tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, @@ -597,6 +603,34 @@ impl AgentLoop { runtime_config, provider, tools, + system_prompt_provider: None, + skills: Some(skills), + tool_context: ToolContext::default(), + observer: None, + emitted_message_handler: None, + max_iterations, + }) + } + + /// 使用系统提示词提供者创建 AgentLoop + /// + /// 这是新的推荐方式,支持统一注入 Agent 和 Skill 提示词。 + pub fn with_tools_and_system_prompt_provider( + config: impl Into, + tools: Arc, + system_prompt_provider: Arc, + skills: Option>, + ) -> Result { + let runtime_config = config.into(); + let max_iterations = runtime_config.max_tool_iterations; + let provider = create_provider(runtime_config.provider.clone()) + .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + + Ok(Self { + runtime_config, + provider, + tools, + system_prompt_provider: Some(system_prompt_provider), skills, tool_context: ToolContext::default(), observer: None, @@ -632,9 +666,14 @@ impl AgentLoop { /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached + /// + /// # 参数 + /// - `messages`: 会话历史消息 + /// - `system_prompt_context`: 系统提示词上下文(用于动态注入,可选) pub async fn process( &self, mut messages: Vec, + system_prompt_context: Option<&SystemPromptContext>, ) -> Result { #[cfg(debug_assertions)] tracing::debug!( @@ -660,9 +699,17 @@ impl AgentLoop { }; let image_count = count_supported_image_media_refs(&messages); + + // 构建系统提示词(统一注入 Agent 和 Skill 提示词) + let system_prompt = system_prompt_context.and_then(|ctx| { + self.system_prompt_provider + .as_ref() + .and_then(|provider| provider.build(ctx)) + }); + let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 2); - if let Some(skill_prompt) = self.skills.system_index_prompt() { - text_only_messages.push(Message::system(skill_prompt.clone())); + if let Some(ref prompt) = system_prompt { + text_only_messages.push(Message::system(prompt.content.clone())); } text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); @@ -673,8 +720,9 @@ impl AgentLoop { ); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 2); - if let Some(skill_prompt) = self.skills.system_index_prompt() { - messages_for_llm.push(Message::system(skill_prompt)); + // 使用相同的系统提示词(已构建) + if let Some(ref prompt) = system_prompt { + messages_for_llm.push(Message::system(prompt.content.clone())); } messages_for_llm.extend( messages @@ -863,19 +911,27 @@ impl AgentLoop { ); messages.push(summary_request); - // Convert messages to LLM format + // Convert messages to LLM format (使用系统提示词提供者) let image_count = count_supported_image_media_refs(&messages); let mut text_only_messages: Vec = Vec::with_capacity(messages.len() + 1); - if let Some(skill_prompt) = self.skills.system_index_prompt() { - text_only_messages.push(Message::system(skill_prompt)); + if let Some(ref provider) = self.system_prompt_provider { + if let Some(ctx) = system_prompt_context { + if let Some(prompt) = provider.build(ctx) { + text_only_messages.push(Message::system(prompt.content.clone())); + } + } } text_only_messages.extend(messages.iter().map(chat_message_to_text_only_llm_message)); let image_tokens = image_token_budget_for_request(&self.runtime_config, &text_only_messages, None); let mut image_budget = ImageInlineBudget::new(image_tokens, image_count); let mut messages_for_llm: Vec = Vec::with_capacity(messages.len() + 1); - if let Some(skill_prompt) = self.skills.system_index_prompt() { - messages_for_llm.push(Message::system(skill_prompt)); + if let Some(ref provider) = self.system_prompt_provider { + if let Some(ctx) = system_prompt_context { + if let Some(prompt) = provider.build(ctx) { + messages_for_llm.push(Message::system(prompt.content.clone())); + } + } } messages_for_llm.extend( messages @@ -1025,7 +1081,9 @@ impl AgentLoop { Some(t) => t, None => { tracing::warn!(tool = %tool_call.name, "Tool not found"); - let skill_hint = self.skills.matching_skill_summary(&tool_call.name); + let skill_hint = self.skills + .as_ref() + .and_then(|s| s.matching_skill_summary(&tool_call.name)); let error = match skill_hint { Some(summary) => format!( "Tool '{}' not found. A skill with the same name exists: {}. Skills are not tools. Call skill_activate with {{\"name\": \"{}\"}} first.", diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 686bb55..1150691 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,9 +1,13 @@ pub mod agent_loop; pub mod context_compressor; pub mod runtime_config; +pub mod system_prompt; pub use agent_loop::{ AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler, SkillProvider, }; pub use context_compressor::ContextCompressor; pub use runtime_config::AgentRuntimeConfig; +pub use system_prompt::{ + CompositeSystemPromptProvider, SystemPrompt, SystemPromptContext, SystemPromptProvider, +}; diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs new file mode 100644 index 0000000..04ebd62 --- /dev/null +++ b/src/agent/system_prompt.rs @@ -0,0 +1,193 @@ +/// 系统提示词提供者的上下文 +#[derive(Debug, Clone)] +pub struct SystemPromptContext { + /// 会话 ID + pub session_id: Option, + /// 聊天 ID + pub chat_id: String, + /// 用户消息计数(用于判断是否重新注入) + pub user_message_count: usize, +} + +/// 系统提示词结果 +#[derive(Debug, Clone)] +pub struct SystemPrompt { + /// 提示词内容 + pub content: String, + /// 上下文标记(如 "agent_prompt", "skill_index" 等) + pub context: Option, +} + +/// 系统提示词提供者 trait +/// +/// 实现此 trait 可以为 AgentLoop 提供系统提示词内容。 +/// 每次用户请求时动态构建,不持久化。 +pub trait SystemPromptProvider: Send + Sync + 'static { + /// 构建系统提示词 + /// + /// 返回 `None` 表示此提供者没有内容要注入。 + fn build(&self, context: &SystemPromptContext) -> Option; +} + +/// 组合多个提供者的系统提示词 +/// +/// 按顺序调用所有提供者,合并非空内容为完整的系统提示词。 +pub struct CompositeSystemPromptProvider { + providers: Vec>, +} + +impl CompositeSystemPromptProvider { + /// 创建新的组合提供者 + pub fn new(providers: Vec>) -> Self { + Self { providers } + } + + /// 构建组合后的系统提示词 + /// + /// 按顺序收集所有非空提供者的内容,用 `\n\n` 连接。 + pub fn build(&self, context: &SystemPromptContext) -> Option { + let fragments: Vec = self + .providers + .iter() + .filter_map(|p| p.build(context)) + .map(|p| p.content) + .collect(); + + if fragments.is_empty() { + None + } else { + Some(SystemPrompt { + content: fragments.join("\n\n"), + context: Some("combined_system_prompt".to_string()), + }) + } + } +} + +impl SystemPromptProvider for CompositeSystemPromptProvider { + fn build(&self, context: &SystemPromptContext) -> Option { + self.build(context) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestProvider { + content: &'static str, + } + + impl SystemPromptProvider for TestProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + if self.content.is_empty() { + None + } else { + Some(SystemPrompt { + content: self.content.to_string(), + context: Some("test".to_string()), + }) + } + } + } + + struct ConditionalProvider { + user_message_count_threshold: usize, + content: &'static str, + } + + impl SystemPromptProvider for ConditionalProvider { + fn build(&self, context: &SystemPromptContext) -> Option { + if context.user_message_count >= self.user_message_count_threshold { + Some(SystemPrompt { + content: self.content.to_string(), + context: Some("conditional".to_string()), + }) + } else { + None + } + } + } + + fn test_context(user_count: usize) -> SystemPromptContext { + SystemPromptContext { + session_id: Some("test-session".to_string()), + chat_id: "test-chat".to_string(), + user_message_count: user_count, + } + } + + #[test] + fn test_composite_provider_joins_multiple_sources() { + let composite = CompositeSystemPromptProvider::new(vec![ + Box::new(TestProvider { + content: "First part", + }), + Box::new(TestProvider { + content: "Second part", + }), + ]); + + let result = composite.build(&test_context(0)).unwrap(); + assert_eq!(result.content, "First part\n\nSecond part"); + } + + #[test] + fn test_composite_provider_skips_empty_sources() { + let composite = CompositeSystemPromptProvider::new(vec![ + Box::new(TestProvider { + content: "First part", + }), + Box::new(TestProvider { content: "" }), + Box::new(TestProvider { + content: "Third part", + }), + ]); + + let result = composite.build(&test_context(0)).unwrap(); + assert_eq!(result.content, "First part\n\nThird part"); + } + + #[test] + fn test_composite_provider_returns_none_for_all_empty() { + let composite = CompositeSystemPromptProvider::new(vec![ + Box::new(TestProvider { content: "" }), + Box::new(TestProvider { content: "" }), + ]); + + assert!(composite.build(&test_context(0)).is_none()); + } + + #[test] + fn test_composite_provider_preserves_order() { + let composite = CompositeSystemPromptProvider::new(vec![ + Box::new(TestProvider { content: "A" }), + Box::new(TestProvider { content: "B" }), + Box::new(TestProvider { content: "C" }), + ]); + + let result = composite.build(&test_context(0)).unwrap(); + assert_eq!(result.content, "A\n\nB\n\nC"); + } + + #[test] + fn test_conditional_provider_respects_context() { + let composite = CompositeSystemPromptProvider::new(vec![ + Box::new(TestProvider { + content: "Always present", + }), + Box::new(ConditionalProvider { + user_message_count_threshold: 5, + content: "Conditional content", + }), + ]); + + // User message count < 5, conditional provider returns None + let result1 = composite.build(&test_context(3)).unwrap(); + assert_eq!(result1.content, "Always present"); + + // User message count >= 5, conditional provider returns Some + let result2 = composite.build(&test_context(5)).unwrap(); + assert_eq!(result2.content, "Always present\n\nConditional content"); + } +} diff --git a/src/gateway/agent_factory.rs b/src/gateway/agent_factory.rs index dc3e35d..6f8030f 100644 --- a/src/gateway/agent_factory.rs +++ b/src/gateway/agent_factory.rs @@ -1,14 +1,20 @@ use std::sync::Arc; -use crate::agent::{AgentError, AgentLoop, SkillProvider}; +use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider}; use crate::config::LLMProviderConfig; +use crate::gateway::agent_prompt_provider::AgentPromptProvider; +use crate::skills::{SkillPromptProvider, SkillRuntime}; use crate::storage::persistent_session_id; +use crate::storage::PromptInjectionRepository; use crate::tools::{ToolContext, ToolRegistry}; #[derive(Clone)] pub(crate) struct AgentFactory { tools: Arc, - skills: Arc, + skills: Arc, + provider_config: LLMProviderConfig, + reinject_every: usize, + prompt_repository: Arc, } pub(crate) struct AgentBuildRequest<'a> { @@ -21,16 +27,40 @@ pub(crate) struct AgentBuildRequest<'a> { } impl AgentFactory { - pub(crate) fn new(tools: Arc, skills: Arc) -> Self { - Self { tools, skills } + pub(crate) fn new( + tools: Arc, + skills: Arc, + provider_config: LLMProviderConfig, + reinject_every: usize, + prompt_repository: Arc, + ) -> Self { + Self { + tools, + skills, + provider_config, + reinject_every, + prompt_repository, + } } pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result { let session_id = persistent_session_id(request.channel_name, request.session_chat_id); - AgentLoop::with_tools_and_skill_provider( + + // 创建组合的系统提示词提供者 + let system_prompt_provider = Arc::new(CompositeSystemPromptProvider::new(vec![ + Box::new(AgentPromptProvider::new( + self.reinject_every, + request.provider_config.clone(), + self.prompt_repository.clone(), + )), + Box::new(SkillPromptProvider::new(self.skills.clone())), + ])); + + AgentLoop::with_tools_and_system_prompt_provider( request.provider_config, self.tools.clone(), - self.skills.clone(), + system_prompt_provider, + Some(self.skills.clone()), ) .map(|agent| { // notification_chat_id 优先,否则使用 session_chat_id diff --git a/src/gateway/agent_prompt_provider.rs b/src/gateway/agent_prompt_provider.rs new file mode 100644 index 0000000..6f8e857 --- /dev/null +++ b/src/gateway/agent_prompt_provider.rs @@ -0,0 +1,175 @@ +use std::sync::Arc; + +use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; +use crate::config::LLMProviderConfig; +use crate::gateway::prompt::{generate_system_environment_prompt, load_agent_prompt}; +use crate::storage::PromptInjectionRepository; + +/// Agent 提示词提供者 +/// +/// 负责提供来自 AGENT.md、MEMORY_SUMMARY.md 和 builtin 的系统提示词, +/// 以及动态生成的系统环境信息。 +pub struct AgentPromptProvider { + /// 重新注入间隔(用户消息数) + reinject_every: usize, + /// LLM 提供者配置(用于生成系统环境信息) + provider_config: LLMProviderConfig, + /// 会话持久化仓库(用于记录注入状态) + repository: Arc, +} + +impl AgentPromptProvider { + /// 创建新的 Agent 提示词提供者 + pub fn new( + reinject_every: usize, + provider_config: LLMProviderConfig, + repository: Arc, + ) -> Self { + Self { + reinject_every, + provider_config, + repository, + } + } + + /// 判断是否应该注入提示词 + /// + /// 规则: + /// - 首次对话(user_message_count == 0)始终注入 + /// - 如果设置了 reinject_every > 0,且满足间隔条件,则重新注入 + fn should_inject(&self, context: &SystemPromptContext) -> bool { + // 首次对话始终注入 + if context.user_message_count == 0 { + return true; + } + + // 检查是否需要重新注入 + if self.reinject_every == 0 { + return false; + } + + // 获取会话注入计数 + let session_id = match &context.session_id { + Some(id) => id, + None => return false, + }; + + let reinjection_count = self + .repository + .get_session(session_id) + .ok() + .flatten() + .map(|session| session.agent_prompt_reinjection_count as usize) + .unwrap_or(0); + + let expected_reinjections = context.user_message_count / self.reinject_every; + + expected_reinjections > reinjection_count + } + + /// 记录注入事件 + fn record_injection(&self, context: &SystemPromptContext) { + if let Some(session_id) = &context.session_id { + let _ = self.repository.mark_agent_prompt_reinjected(session_id); + } + } +} + +impl SystemPromptProvider for AgentPromptProvider { + fn build(&self, context: &SystemPromptContext) -> Option { + // 检查是否需要注入 + if !self.should_inject(context) { + return None; + } + + // 加载 Agent 提示词(AGENT.md + builtin + MEMORY_SUMMARY.md) + let agent_prompt = load_agent_prompt().ok().flatten()?; + + // 生成系统环境信息 + let env_info = generate_system_environment_prompt(&self.provider_config); + + // 记录注入事件 + self.record_injection(context); + + Some(SystemPrompt { + content: format!("{}\n\n{}", agent_prompt, env_info), + context: Some("agent_prompt".to_string()), + }) + } +} + +/// Agent 提示词提供者(简化版本,无持久化) +/// +/// 适用于不需要记录注入状态的场景(如测试或一次性任务)。 +pub struct SimpleAgentPromptProvider { + provider_config: LLMProviderConfig, +} + +impl SimpleAgentPromptProvider { + /// 创建新的简单 Agent 提示词提供者 + pub fn new(provider_config: LLMProviderConfig) -> Self { + Self { provider_config } + } +} + +impl SystemPromptProvider for SimpleAgentPromptProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + // 加载 Agent 提示词 + let agent_prompt = load_agent_prompt().ok().flatten()?; + + // 生成系统环境信息 + let env_info = generate_system_environment_prompt(&self.provider_config); + + Some(SystemPrompt { + content: format!("{}\n\n{}", agent_prompt, env_info), + context: Some("agent_prompt".to_string()), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::SessionStore; + use std::collections::HashMap; + + fn test_config() -> LLMProviderConfig { + LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test".to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, + model_id: "test-model".to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + context_window_tokens: None, + model_extra: HashMap::new(), + max_tool_iterations: 1, + tool_result_max_chars: 20_000, + context_tool_result_trim_chars: 20_000, + } + } + + fn test_context(user_count: usize, session_id: Option<&str>) -> SystemPromptContext { + SystemPromptContext { + session_id: session_id.map(|s| s.to_string()), + chat_id: "test-chat".to_string(), + user_message_count: user_count, + } + } + + #[test] + fn test_simple_provider_builds_prompt() { + let provider = SimpleAgentPromptProvider::new(test_config()); + let context = test_context(0, None); + + // Simple provider always returns content + let result = provider.build(&context).unwrap(); + assert!(result.content.contains("PicoBot")); + assert!(result.content.contains("操作系统:")); + assert_eq!(result.context, Some("agent_prompt".to_string())); + } +} diff --git a/src/gateway/command.rs b/src/gateway/command.rs index 7100b7b..03ad8f0 100644 --- a/src/gateway/command.rs +++ b/src/gateway/command.rs @@ -118,13 +118,14 @@ mod tests { .load_all_messages(&session.persistent_session_id("chat-1")) .unwrap() .len(), - 3, + // 新设计:系统提示词不再持久化,只有 1 条用户消息 + 1, ); session.ensure_chat_loaded("chat-1").unwrap(); let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 2); - assert_eq!(history[0].role, "system"); + // 新设计:系统提示词不再持久化到历史记录 + assert_eq!(history.len(), 0); } #[tokio::test] @@ -157,8 +158,8 @@ mod tests { .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); + // 新设计:系统提示词不再持久化到历史记录 let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 2); - assert_eq!(history[0].role, "system"); + assert_eq!(history.len(), 0); } } diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index 2530b36..f525480 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler}; +use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler, SystemPromptContext}; use crate::bus::message::ToolMessageState; use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDULED_PROMPT}; use crate::config::LLMProviderConfig; @@ -132,7 +132,7 @@ impl AgentExecutionService { &self, request: MessageExecutionRequest<'_>, ) -> Result, AgentError> { - let (history, agent, user_message) = { + let (history, agent, user_message, user_message_count) = { let mut session_guard = request.session.lock().await; session_guard.ensure_persistent_session(request.chat_id)?; @@ -167,6 +167,7 @@ impl AgentExecutionService { session_guard.append_persisted_message(request.chat_id, user_message.clone())?; let history = session_guard.get_or_create_history(request.chat_id).clone(); + let user_message_count = history.iter().filter(|m| m.role == "user").count(); session_guard.record_skill_offer(request.chat_id)?; let mut agent = session_guard.create_agent( @@ -178,10 +179,17 @@ impl AgentExecutionService { agent = agent.with_emitted_message_handler(handler); } - (history, agent, user_message) + (history, agent, user_message, user_message_count) }; - let result = agent.process(history).await?; + // 构建系统提示词上下文 + let system_prompt_context = SystemPromptContext { + session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)), + chat_id: request.chat_id.to_string(), + user_message_count, + }; + + let result = agent.process(history, Some(&system_prompt_context)).await?; let metadata = HashMap::new(); self.finalize_result_and_schedule_compaction( @@ -203,7 +211,7 @@ impl AgentExecutionService { &self, request: ScheduledExecutionRequest<'_>, ) -> Result, AgentError> { - let (history, agent, user_message) = { + let (history, agent, user_message, user_message_count) = { let mut session_guard = request.session.lock().await; session_guard.ensure_persistent_session(request.chat_id)?; @@ -229,6 +237,7 @@ impl AgentExecutionService { session_guard.append_persisted_message(request.chat_id, user_message.clone())?; let history = session_guard.get_or_create_history(request.chat_id).clone(); + let user_message_count = history.iter().filter(|m| m.role == "user").count(); session_guard.record_skill_offer(request.chat_id)?; let agent = session_guard.create_agent_with_provider_config( @@ -239,10 +248,17 @@ impl AgentExecutionService { request.provider_config.clone(), )?; - (history, agent, user_message) + (history, agent, user_message, user_message_count) }; - let result = agent.process(history).await?; + // 构建系统提示词上下文 + let system_prompt_context = SystemPromptContext { + session_id: Some(format!("{}:{}", request.channel_name, request.chat_id)), + chat_id: request.chat_id.to_string(), + user_message_count, + }; + + let result = agent.process(history, Some(&system_prompt_context)).await?; self.finalize_result_and_schedule_compaction( request.session.clone(), diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index aace5cb..2eed0bd 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,4 +1,5 @@ pub mod agent_factory; +pub mod agent_prompt_provider; pub mod agent_task_executor; pub mod cli_session; pub mod command; diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 3652ba3..aba2be4 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -14,7 +14,6 @@ use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry} use super::agent_factory::AgentFactory; use super::cli_session::CliSessionService; use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator; -use super::prompt_injector::PromptInjector; use super::provider_config_service::ProviderConfigService; use super::scheduled_agent_task_service::ScheduledAgentTaskService; use super::session::{SessionManager, SessionManagerServices}; @@ -86,15 +85,19 @@ pub(crate) fn build_session_manager_with_sender( .build(), ); - let agent_factory = AgentFactory::new(tools.clone(), skills.clone()); - let conversations: Arc = store.clone(); let prompt_repository: Arc = store.clone(); - let prompt_injector = PromptInjector::new(prompt_repository, agent_prompt_reinject_every); + let agent_factory = AgentFactory::new( + tools.clone(), + skills.clone(), + provider_config.clone(), + agent_prompt_reinject_every as usize, + prompt_repository.clone(), + ); + let conversations: Arc = store.clone(); let session_factory = SessionFactory::new( provider_config.clone(), skills.clone(), agent_factory, - prompt_injector, conversations, skill_events, chat_history_ttl_hours, diff --git a/src/gateway/session.rs b/src/gateway/session.rs index e0e5a5f..be63512 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::scheduler::ScheduledAgentTaskOptions; use crate::skills::SkillRuntime; -use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository}; +use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository}; use crate::tools::ToolRegistry; use async_trait::async_trait; use std::collections::HashMap; @@ -25,7 +25,6 @@ use super::memory_maintenance::{ }; use super::memory_maintenance::{MemoryMaintenanceScopeResult, MemoryOrganizationOutput}; use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator; -use super::prompt_injector::PromptInjector; use super::scheduled_agent_task_service::ScheduledAgentTaskService; use super::session_history::SessionHistory; use super::session_lifecycle::SessionLifecycleService; @@ -102,17 +101,22 @@ impl Session { agent_prompt_reinject_every: u64, chat_history_ttl_hours: Option, ) -> Result { - let agent_factory = AgentFactory::new(tools, skills.clone()); let conversations: Arc = store.clone(); let skill_events: Arc = store.clone(); - let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); + let prompt_repository: Arc = store.clone(); + let agent_factory = AgentFactory::new( + tools, + skills.clone(), + provider_config.clone(), + agent_prompt_reinject_every as usize, + prompt_repository.clone(), + ); Self::with_factories( channel_name, provider_config, user_tx, skills, agent_factory, - prompt_injector, conversations, skill_events, chat_history_ttl_hours, @@ -126,7 +130,6 @@ impl Session { user_tx: mpsc::Sender, skills: Arc, agent_factory: AgentFactory, - prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, chat_history_ttl_hours: Option, @@ -141,7 +144,6 @@ impl Session { compressor: ContextCompressor::from_provider_config(&provider_config), history: SessionHistory::new( channel_name, - prompt_injector, conversations, skill_events, provider_config, @@ -1562,9 +1564,8 @@ mod tests { session.ensure_chat_loaded("chat-1").unwrap(); let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 2); - assert_eq!(history[0].role, "system"); - assert!(history[0].content.contains("PicoBot 代理配置")); + // 新设计:系统提示词不再持久化到历史记录,而是每次请求时动态注入 + assert_eq!(history.len(), 0); } #[tokio::test] @@ -1611,28 +1612,32 @@ mod tests { .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); + // 新设计:系统提示词不再持久化到历史记录 let history = session.get_history("chat-1").unwrap(); - let system_messages = history + let user_messages = history .iter() - .filter(|message| message.role == "system") + .filter(|message| message.role == "user") .count(); - assert_eq!(system_messages, 3); + assert_eq!(user_messages, 100); + // 注入计数在实际处理请求时由 AgentPromptProvider 更新 + // 此处仅为模拟调用,不会触发实际注入 let stored = store .get_session(&session.persistent_session_id("chat-1")) .unwrap() .unwrap(); - assert_eq!(stored.agent_prompt_reinjection_count, 1); + // 初始值为 0,只有在实际 process 调用时才会更新 + assert_eq!(stored.agent_prompt_reinjection_count, 0); session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); let history = session.get_history("chat-1").unwrap(); - let system_messages = history + let user_messages = history .iter() - .filter(|message| message.role == "system") + .filter(|message| message.role == "user") .count(); - assert_eq!(system_messages, 3); + assert_eq!(user_messages, 100); } #[tokio::test] @@ -1679,12 +1684,13 @@ mod tests { .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); + // 新设计:系统提示词不再持久化到历史记录 let history = session.get_history("chat-1").unwrap(); - let system_messages = history + let user_messages = history .iter() - .filter(|message| message.role == "system") + .filter(|message| message.role == "user") .count(); - assert_eq!(system_messages, 2); + assert_eq!(user_messages, 100); } #[test] diff --git a/src/gateway/session_factory.rs b/src/gateway/session_factory.rs index 741e98d..e9670d3 100644 --- a/src/gateway/session_factory.rs +++ b/src/gateway/session_factory.rs @@ -9,7 +9,6 @@ use crate::skills::SkillRuntime; use crate::storage::{ConversationRepository, SkillEventRepository}; use super::agent_factory::AgentFactory; -use super::prompt_injector::PromptInjector; use super::session::Session; #[derive(Clone)] @@ -17,7 +16,6 @@ pub(crate) struct SessionFactory { provider_config: LLMProviderConfig, skills: Arc, agent_factory: AgentFactory, - prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, chat_history_ttl_hours: Option, @@ -28,7 +26,6 @@ impl SessionFactory { provider_config: LLMProviderConfig, skills: Arc, agent_factory: AgentFactory, - prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, chat_history_ttl_hours: Option, @@ -37,7 +34,6 @@ impl SessionFactory { provider_config, skills, agent_factory, - prompt_injector, conversations, skill_events, chat_history_ttl_hours, @@ -55,7 +51,6 @@ impl SessionFactory { user_tx, self.skills.clone(), self.agent_factory.clone(), - self.prompt_injector.clone(), self.conversations.clone(), self.skill_events.clone(), self.chat_history_ttl_hours, diff --git a/src/gateway/session_history.rs b/src/gateway/session_history.rs index 90247c8..59c6073 100644 --- a/src/gateway/session_history.rs +++ b/src/gateway/session_history.rs @@ -8,9 +8,6 @@ use crate::storage::{ ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id, }; -use super::prompt::generate_system_environment_prompt; -use super::prompt_injector::PromptInjector; - fn preview_text(content: &str, max_chars: usize) -> String { let mut preview = content.chars().take(max_chars).collect::(); if content.chars().count() > max_chars { @@ -30,7 +27,6 @@ pub(crate) struct SessionHistory { channel_name: String, chat_histories: HashMap>, compression_in_flight: HashSet, - prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, provider_config: LLMProviderConfig, @@ -40,7 +36,6 @@ pub(crate) struct SessionHistory { impl SessionHistory { pub(crate) fn new( channel_name: impl Into, - prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, provider_config: LLMProviderConfig, @@ -50,7 +45,6 @@ impl SessionHistory { channel_name: channel_name.into(), chat_histories: HashMap::new(), compression_in_flight: HashSet::new(), - prompt_injector, conversations, skill_events, provider_config, @@ -96,7 +90,7 @@ impl SessionHistory { // 原有逻辑 if self.chat_histories.contains_key(chat_id) { - return self.ensure_initial_agent_prompt(chat_id); + return Ok(()); } let history = self @@ -104,21 +98,15 @@ impl SessionHistory { .load_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; self.chat_histories.insert(chat_id.to_string(), history); - self.ensure_initial_agent_prompt(chat_id)?; Ok(()) } pub(crate) fn ensure_agent_prompt_before_user_message( &mut self, - chat_id: &str, + _chat_id: &str, ) -> Result<(), AgentError> { - self.ensure_chat_loaded(chat_id)?; - - let session_id = self.persistent_session_id(chat_id); - let prompt_injector = self.prompt_injector.clone(); - prompt_injector.ensure_reinjected_prompt(&session_id, |message| { - self.append_persisted_message(chat_id, message) - }) + // 提示词现在由 AgentPromptProvider 统一处理,不需要在此处注入 + Ok(()) } pub(crate) fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec { @@ -286,30 +274,4 @@ impl SessionHistory { ) .map_err(|err| AgentError::Other(format!("append skill event error: {}", err))) } - - fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> { - let history_is_empty = self - .get_history(chat_id) - .map(|history| history.is_empty()) - .unwrap_or(true); - - if !history_is_empty { - return Ok(()); - } - - // 注入 Agent Prompt - let prompt_injector = self.prompt_injector.clone(); - prompt_injector.ensure_initial_prompt(history_is_empty, |message| { - self.append_persisted_message(chat_id, message) - })?; - - // 注入系统环境提示词 - let env_prompt = generate_system_environment_prompt(&self.provider_config); - self.append_persisted_message( - chat_id, - ChatMessage::system(env_prompt), - )?; - - Ok(()) - } } diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 280ed9b..da069ce 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -4,7 +4,7 @@ use serde_json::json; use std::collections::{HashMap, HashSet}; use std::fs; use std::path::{Path, PathBuf}; -use std::sync::RwLock; +use std::sync::{Arc, RwLock}; #[cfg(test)] static SKILL_TEST_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); @@ -865,6 +865,33 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> { // 使用 platform 模块提供的 xml_escape 和 path_to_uri 函数 +// SkillPromptProvider 实现 +use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; + +/// Skill 提示词提供者 +/// +/// 负责提供技能的系统索引提示词(system_index_prompt)。 +pub struct SkillPromptProvider { + skills: Arc, +} + +impl SkillPromptProvider { + /// 创建新的 Skill 提示词提供者 + pub fn new(skills: Arc) -> Self { + Self { skills } + } +} + +impl SystemPromptProvider for SkillPromptProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + // 调用 SkillRuntime 的 system_index_prompt 方法 + self.skills.system_index_prompt().map(|content| SystemPrompt { + content, + context: Some("skill_index".to_string()), + }) + } +} + #[cfg(test)] mod tests { use super::*;