feat: 重构提示源加载逻辑,支持从多个源聚合提示内容,并优化内存摘要持久化

This commit is contained in:
ooodc 2026-05-05 18:27:58 +08:00
parent 495c8cdc7e
commit aa7f1d6160

View File

@ -4,23 +4,63 @@ use std::path::{Path, PathBuf};
use crate::agent::AgentError; use crate::agent::AgentError;
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md"); pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->"; #[derive(Clone)]
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要"; struct PromptSource {
path: PathBuf,
default_content: Option<&'static str>,
}
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> { pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?; load_prompt_from_sources(&prompt_sources()?)
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
} }
if !path.exists() { pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?; persist_memory_summary(&memory_summary_path()?, markdown_body)
} }
let content = fs::read_to_string(&path) fn prompt_sources() -> Result<Vec<PromptSource>, AgentError> {
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?; Ok(vec![
PromptSource {
path: agent_prompt_path()?,
default_content: Some(DEFAULT_AGENT_PROMPT),
},
PromptSource {
path: memory_summary_path()?,
default_content: None,
},
])
}
fn load_prompt_from_sources(sources: &[PromptSource]) -> Result<Option<String>, AgentError> {
let mut fragments = Vec::with_capacity(sources.len());
for source in sources {
if let Some(fragment) = read_prompt_fragment(source)? {
fragments.push(fragment);
}
}
if fragments.is_empty() {
return Ok(None);
}
Ok(Some(fragments.join("\n\n")))
}
fn read_prompt_fragment(source: &PromptSource) -> Result<Option<String>, AgentError> {
ensure_parent_dir(&source.path)?;
if !source.path.exists() {
if let Some(default_content) = source.default_content {
write_prompt_file(&source.path, default_content)?;
} else {
return Ok(None);
}
}
let content = fs::read_to_string(&source.path)
.map_err(|err| AgentError::Other(format!("read prompt file error: {}", err)))?;
let trimmed = content.trim(); let trimmed = content.trim();
if trimmed.is_empty() { if trimmed.is_empty() {
return Ok(None); return Ok(None);
@ -29,121 +69,151 @@ pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
Ok(Some(trimmed.to_string())) Ok(Some(trimmed.to_string()))
} }
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> { fn persist_memory_summary(path: &Path, markdown_body: &str) -> Result<(), AgentError> {
let path = agent_prompt_path()?; let trimmed = markdown_body.trim();
let existing = if path.exists() { if trimmed.is_empty() {
fs::read_to_string(&path) return Ok(());
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
} else {
DEFAULT_AGENT_PROMPT.to_string()
};
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
write_agent_prompt(&path, &updated)
} }
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String { write_prompt_file(path, trimmed)
let managed_block = render_managed_agent_memory_block(markdown_body);
if let (Some(start), Some(end)) = (
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
) {
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
let mut updated = String::new();
updated.push_str(existing[..start].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[end..].trim_start());
return updated.trim().to_string() + "\n";
} }
if let Some(reply_rules_index) = existing.find("## 回复规则") { fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> {
let mut updated = String::new(); ensure_parent_dir(path)?;
updated.push_str(existing[..reply_rules_index].trim_end());
updated.push_str("\n\n"); let normalized = content.trim_end().to_string() + "\n";
updated.push_str(&managed_block); let temp_path = path.with_extension("md.tmp");
updated.push_str("\n\n"); fs::write(&temp_path, normalized)
updated.push_str(existing[reply_rules_index..].trim_start()); .map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
return updated.trim().to_string() + "\n"; fs::rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
Ok(())
} }
let mut updated = existing.trim_end().to_string(); fn ensure_parent_dir(path: &Path) -> Result<(), AgentError> {
if !updated.is_empty() {
updated.push_str("\n\n");
}
updated.push_str(&managed_block);
updated.push('\n');
updated
}
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
format!(
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
markdown_body.trim()
)
}
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
fs::create_dir_all(parent) fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?; .map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
} }
let temp_path = path.with_extension("md.tmp");
fs::write(&temp_path, content)
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
fs::rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
Ok(()) Ok(())
} }
fn agent_prompt_path() -> Result<PathBuf, AgentError> { fn agent_prompt_path() -> Result<PathBuf, AgentError> {
Ok(agent_dir_path()?.join("AGENT.md"))
}
fn memory_summary_path() -> Result<PathBuf, AgentError> {
Ok(agent_dir_path()?.join("MEMORY_SUMMARY.md"))
}
fn agent_dir_path() -> Result<PathBuf, AgentError> {
let home = dirs::home_dir() let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?; .ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md")) Ok(home.join(".picobot").join("agent"))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::tempdir;
#[test] #[test]
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() { fn test_load_prompt_from_sources_aggregates_multiple_fragments_in_order() {
let original = let temp = tempdir().unwrap();
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n"; let agent_path = temp.path().join("AGENT.md");
let updated = upsert_managed_agent_memory_block( let memory_path = temp.path().join("MEMORY_SUMMARY.md");
original,
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
);
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap(); write_prompt_file(&agent_path, "# Agent\n静态规则").unwrap();
let reply_rules_pos = updated.find("## 回复规则").unwrap(); write_prompt_file(&memory_path, "## 用户记忆摘要\n- 偏好简洁").unwrap();
assert!(managed_pos < reply_rules_pos);
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE)); let prompt = load_prompt_from_sources(&[
assert!(updated.contains("用户在做AI产品")); PromptSource {
assert!(updated.contains("偏好简洁表达")); path: agent_path.clone(),
default_content: Some(DEFAULT_AGENT_PROMPT),
},
PromptSource {
path: memory_path.clone(),
default_content: None,
},
])
.unwrap()
.unwrap();
assert_eq!(prompt, "# Agent\n静态规则\n\n## 用户记忆摘要\n- 偏好简洁");
} }
#[test] #[test]
fn test_upsert_managed_agent_memory_block_replaces_existing_block() { fn test_load_prompt_from_sources_ignores_missing_optional_source() {
let original = format!( let temp = tempdir().unwrap();
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n" let agent_path = temp.path().join("AGENT.md");
); let memory_path = temp.path().join("MEMORY_SUMMARY.md");
let updated = upsert_managed_agent_memory_block(&original, "new"); write_prompt_file(&agent_path, "# Agent\n静态规则").unwrap();
assert!(updated.contains("new")); let prompt = load_prompt_from_sources(&[
assert!(!updated.contains("old")); PromptSource {
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1); path: agent_path.clone(),
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1); default_content: Some(DEFAULT_AGENT_PROMPT),
},
PromptSource {
path: memory_path.clone(),
default_content: None,
},
])
.unwrap()
.unwrap();
assert_eq!(prompt, "# Agent\n静态规则");
assert!(!memory_path.exists());
} }
#[test] #[test]
fn test_upsert_managed_agent_memory_block_trims_summary_body() { fn test_load_prompt_from_sources_creates_default_agent_prompt() {
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n"); let temp = tempdir().unwrap();
let agent_path = temp.path().join("AGENT.md");
assert!(updated.contains("\n\nsummary\n")); let prompt = load_prompt_from_sources(&[PromptSource {
assert!(!updated.contains("\n\nsummary\n\n\n")); path: agent_path.clone(),
default_content: Some("# Default Agent\n规则"),
}])
.unwrap()
.unwrap();
assert_eq!(prompt, "# Default Agent\n规则");
assert_eq!(fs::read_to_string(&agent_path).unwrap(), "# Default Agent\n规则\n");
}
#[test]
fn test_load_prompt_from_sources_returns_none_when_all_sources_empty() {
let temp = tempdir().unwrap();
let agent_path = temp.path().join("AGENT.md");
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
write_prompt_file(&agent_path, " ").unwrap();
write_prompt_file(&memory_path, "\n\n").unwrap();
let prompt = load_prompt_from_sources(&[
PromptSource {
path: agent_path.clone(),
default_content: None,
},
PromptSource {
path: memory_path.clone(),
default_content: None,
},
])
.unwrap();
assert!(prompt.is_none());
}
#[test]
fn test_upsert_managed_agent_memory_summary_writes_trimmed_markdown() {
let temp = tempdir().unwrap();
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
persist_memory_summary(&memory_path, "\n## 用户记忆摘要\n- 偏好简洁\n\n").unwrap();
assert_eq!(fs::read_to_string(&memory_path).unwrap(), "## 用户记忆摘要\n- 偏好简洁\n");
} }
} }