From aa7f1d61602d0900422483af1564fd1d5a5ed67b Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 5 May 2026 18:27:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E6=8F=90=E7=A4=BA?= =?UTF-8?q?=E6=BA=90=E5=8A=A0=E8=BD=BD=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E4=BB=8E=E5=A4=9A=E4=B8=AA=E6=BA=90=E8=81=9A=E5=90=88?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E5=86=85=E5=AE=B9=EF=BC=8C=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=86=85=E5=AD=98=E6=91=98=E8=A6=81=E6=8C=81=E4=B9=85?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/gateway/prompt.rs | 258 +++++++++++++++++++++++++++--------------- 1 file changed, 164 insertions(+), 94 deletions(-) diff --git a/src/gateway/prompt.rs b/src/gateway/prompt.rs index 068a517..f99052e 100644 --- a/src/gateway/prompt.rs +++ b/src/gateway/prompt.rs @@ -4,23 +4,63 @@ use std::path::{Path, PathBuf}; use crate::agent::AgentError; pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md"); -pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = ""; -pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = ""; -pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要"; + +#[derive(Clone)] +struct PromptSource { + path: PathBuf, + default_content: Option<&'static str>, +} pub(crate) fn load_agent_prompt() -> Result, AgentError> { - let path = agent_prompt_path()?; - if let Some(parent) = path.parent() { - fs::create_dir_all(parent) - .map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?; + load_prompt_from_sources(&prompt_sources()?) +} + +pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> { + persist_memory_summary(&memory_summary_path()?, markdown_body) +} + +fn prompt_sources() -> Result, AgentError> { + Ok(vec![ + PromptSource { + path: agent_prompt_path()?, + default_content: Some(DEFAULT_AGENT_PROMPT), + }, + PromptSource { + path: memory_summary_path()?, + default_content: None, + }, + ]) +} + +fn load_prompt_from_sources(sources: &[PromptSource]) -> Result, AgentError> { + let mut fragments = Vec::with_capacity(sources.len()); + + for source in sources { + if let Some(fragment) = read_prompt_fragment(source)? { + fragments.push(fragment); + } } - if !path.exists() { - write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?; + if fragments.is_empty() { + return Ok(None); } - let content = fs::read_to_string(&path) - .map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?; + Ok(Some(fragments.join("\n\n"))) +} + +fn read_prompt_fragment(source: &PromptSource) -> Result, AgentError> { + ensure_parent_dir(&source.path)?; + + if !source.path.exists() { + if let Some(default_content) = source.default_content { + write_prompt_file(&source.path, default_content)?; + } else { + return Ok(None); + } + } + + let content = fs::read_to_string(&source.path) + .map_err(|err| AgentError::Other(format!("read prompt file error: {}", err)))?; let trimmed = content.trim(); if trimmed.is_empty() { return Ok(None); @@ -29,121 +69,151 @@ pub(crate) fn load_agent_prompt() -> Result, AgentError> { Ok(Some(trimmed.to_string())) } -pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> { - let path = agent_prompt_path()?; - let existing = if path.exists() { - fs::read_to_string(&path) - .map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))? - } else { - DEFAULT_AGENT_PROMPT.to_string() - }; - let updated = upsert_managed_agent_memory_block(&existing, markdown_body); - write_agent_prompt(&path, &updated) -} - -pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String { - let managed_block = render_managed_agent_memory_block(markdown_body); - - if let (Some(start), Some(end)) = ( - existing.find(MANAGED_AGENT_MEMORY_BLOCK_START), - existing.find(MANAGED_AGENT_MEMORY_BLOCK_END), - ) { - let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len(); - let mut updated = String::new(); - updated.push_str(existing[..start].trim_end()); - updated.push_str("\n\n"); - updated.push_str(&managed_block); - updated.push_str("\n\n"); - updated.push_str(existing[end..].trim_start()); - return updated.trim().to_string() + "\n"; +fn persist_memory_summary(path: &Path, markdown_body: &str) -> Result<(), AgentError> { + let trimmed = markdown_body.trim(); + if trimmed.is_empty() { + return Ok(()); } - if let Some(reply_rules_index) = existing.find("## 回复规则") { - let mut updated = String::new(); - updated.push_str(existing[..reply_rules_index].trim_end()); - updated.push_str("\n\n"); - updated.push_str(&managed_block); - updated.push_str("\n\n"); - updated.push_str(existing[reply_rules_index..].trim_start()); - return updated.trim().to_string() + "\n"; - } - - let mut updated = existing.trim_end().to_string(); - if !updated.is_empty() { - updated.push_str("\n\n"); - } - updated.push_str(&managed_block); - updated.push('\n'); - updated + write_prompt_file(path, trimmed) } -fn render_managed_agent_memory_block(markdown_body: &str) -> String { - format!( - "{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}", - markdown_body.trim() - ) +fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> { + ensure_parent_dir(path)?; + + let normalized = content.trim_end().to_string() + "\n"; + let temp_path = path.with_extension("md.tmp"); + fs::write(&temp_path, normalized) + .map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?; + fs::rename(&temp_path, path) + .map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?; + Ok(()) } -fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> { +fn ensure_parent_dir(path: &Path) -> Result<(), AgentError> { if let Some(parent) = path.parent() { fs::create_dir_all(parent) .map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?; } - - let temp_path = path.with_extension("md.tmp"); - fs::write(&temp_path, content) - .map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?; - fs::rename(&temp_path, path) - .map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?; Ok(()) } fn agent_prompt_path() -> Result { + Ok(agent_dir_path()?.join("AGENT.md")) +} + +fn memory_summary_path() -> Result { + Ok(agent_dir_path()?.join("MEMORY_SUMMARY.md")) +} + +fn agent_dir_path() -> Result { let home = dirs::home_dir() .ok_or_else(|| AgentError::Other("home directory not found".to_string()))?; - Ok(home.join(".picobot").join("agent").join("AGENT.md")) + Ok(home.join(".picobot").join("agent")) } #[cfg(test)] mod tests { use super::*; + use tempfile::tempdir; #[test] - fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() { - let original = - "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n"; - let updated = upsert_managed_agent_memory_block( - original, - "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达", - ); + fn test_load_prompt_from_sources_aggregates_multiple_fragments_in_order() { + let temp = tempdir().unwrap(); + let agent_path = temp.path().join("AGENT.md"); + let memory_path = temp.path().join("MEMORY_SUMMARY.md"); - let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap(); - let reply_rules_pos = updated.find("## 回复规则").unwrap(); - assert!(managed_pos < reply_rules_pos); - assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE)); - assert!(updated.contains("用户在做AI产品")); - assert!(updated.contains("偏好简洁表达")); + write_prompt_file(&agent_path, "# Agent\n静态规则").unwrap(); + write_prompt_file(&memory_path, "## 用户记忆摘要\n- 偏好简洁").unwrap(); + + let prompt = load_prompt_from_sources(&[ + PromptSource { + path: agent_path.clone(), + default_content: Some(DEFAULT_AGENT_PROMPT), + }, + PromptSource { + path: memory_path.clone(), + default_content: None, + }, + ]) + .unwrap() + .unwrap(); + + assert_eq!(prompt, "# Agent\n静态规则\n\n## 用户记忆摘要\n- 偏好简洁"); } #[test] - fn test_upsert_managed_agent_memory_block_replaces_existing_block() { - let original = format!( - "# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n" - ); + fn test_load_prompt_from_sources_ignores_missing_optional_source() { + let temp = tempdir().unwrap(); + let agent_path = temp.path().join("AGENT.md"); + let memory_path = temp.path().join("MEMORY_SUMMARY.md"); - let updated = upsert_managed_agent_memory_block(&original, "new"); + write_prompt_file(&agent_path, "# Agent\n静态规则").unwrap(); - assert!(updated.contains("new")); - assert!(!updated.contains("old")); - assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1); - assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1); + let prompt = load_prompt_from_sources(&[ + PromptSource { + path: agent_path.clone(), + default_content: Some(DEFAULT_AGENT_PROMPT), + }, + PromptSource { + path: memory_path.clone(), + default_content: None, + }, + ]) + .unwrap() + .unwrap(); + + assert_eq!(prompt, "# Agent\n静态规则"); + assert!(!memory_path.exists()); } #[test] - fn test_upsert_managed_agent_memory_block_trims_summary_body() { - let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n"); + fn test_load_prompt_from_sources_creates_default_agent_prompt() { + let temp = tempdir().unwrap(); + let agent_path = temp.path().join("AGENT.md"); - assert!(updated.contains("\n\nsummary\n")); - assert!(!updated.contains("\n\nsummary\n\n\n")); + let prompt = load_prompt_from_sources(&[PromptSource { + path: agent_path.clone(), + default_content: Some("# Default Agent\n规则"), + }]) + .unwrap() + .unwrap(); + + assert_eq!(prompt, "# Default Agent\n规则"); + assert_eq!(fs::read_to_string(&agent_path).unwrap(), "# Default Agent\n规则\n"); + } + + #[test] + fn test_load_prompt_from_sources_returns_none_when_all_sources_empty() { + let temp = tempdir().unwrap(); + let agent_path = temp.path().join("AGENT.md"); + let memory_path = temp.path().join("MEMORY_SUMMARY.md"); + + write_prompt_file(&agent_path, " ").unwrap(); + write_prompt_file(&memory_path, "\n\n").unwrap(); + + let prompt = load_prompt_from_sources(&[ + PromptSource { + path: agent_path.clone(), + default_content: None, + }, + PromptSource { + path: memory_path.clone(), + default_content: None, + }, + ]) + .unwrap(); + + assert!(prompt.is_none()); + } + + #[test] + fn test_upsert_managed_agent_memory_summary_writes_trimmed_markdown() { + let temp = tempdir().unwrap(); + let memory_path = temp.path().join("MEMORY_SUMMARY.md"); + + persist_memory_summary(&memory_path, "\n## 用户记忆摘要\n- 偏好简洁\n\n").unwrap(); + + assert_eq!(fs::read_to_string(&memory_path).unwrap(), "## 用户记忆摘要\n- 偏好简洁\n"); } }