PicoBot/src/gateway/prompt.rs

361 lines
13 KiB
Rust

use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::platform::atomic_rename;
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
pub(crate) const AGENT_MD_TEMPLATE: &str = include_str!("agent_md_template.md");
/// 提示词来源类型
#[derive(Clone)]
enum PromptSource {
/// 内置内容 - 系统默认,始终注入,不可修改
Builtin(&'static str),
/// 用户自定义 - 自动创建空白文件,可选注入
UserCustom {
path: PathBuf,
template: &'static str,
},
/// 自动生成 - 由系统维护,可选注入
AutoGenerated(PathBuf),
}
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
load_prompt_from_sources(&prompt_sources()?)
}
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
persist_memory_summary(&memory_summary_path()?, markdown_body)
}
fn prompt_sources() -> Result<Vec<PromptSource>, AgentError> {
Ok(vec![
// 1. 系统默认 - 始终注入
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
// 2. 用户自定义 - 自动创建空白模板,可选注入
PromptSource::UserCustom {
path: agent_prompt_path()?,
template: AGENT_MD_TEMPLATE,
},
// 3. 记忆摘要 - 可选
PromptSource::AutoGenerated(memory_summary_path()?),
])
}
fn load_prompt_from_sources(sources: &[PromptSource]) -> Result<Option<String>, AgentError> {
let mut fragments = Vec::new();
for source in sources {
match source {
PromptSource::Builtin(content) => {
fragments.push(content.to_string());
}
PromptSource::UserCustom { path, template } => {
// 确保父目录存在
ensure_parent_dir(path)?;
// 文件不存在时创建空白模板
if !path.exists() {
fs::write(path, template)
.map_err(|err| AgentError::Other(format!("create AGENT.md template error: {}", err)))?;
}
// 读取内容,仅当非空(去除注释后)时注入
let content = fs::read_to_string(path)
.map_err(|err| AgentError::Other(format!("read AGENT.md error: {}", err)))?;
let trimmed = strip_comments_and_whitespace(&content);
if !trimmed.is_empty() {
fragments.push(trimmed.to_string());
}
}
PromptSource::AutoGenerated(path) => {
if path.exists() {
let content = fs::read_to_string(path)
.map_err(|err| AgentError::Other(format!("read MEMORY_SUMMARY.md error: {}", err)))?;
let without_comments = strip_comments_and_whitespace(&content);
if !without_comments.is_empty() {
fragments.push(without_comments);
}
}
}
}
}
if fragments.is_empty() {
Ok(None)
} else {
Ok(Some(fragments.join("\n\n")))
}
}
/// 去除 HTML 注释和空白行,检查是否还有有效内容
fn strip_comments_and_whitespace(content: &str) -> String {
// 去除 HTML 注释 <!-- -->
let re = regex::Regex::new(r"<!--[\s\S]*?-->").unwrap();
let without_comments = re.replace_all(content, "");
// 去除空白行后检查是否还有有效内容
without_comments
.lines()
.filter(|line| !line.trim().is_empty())
.collect::<Vec<_>>()
.join("\n")
}
/// 生成系统环境信息提示词
pub(crate) fn generate_system_environment_prompt(config: &LLMProviderConfig) -> String {
use std::env::consts::{ARCH, OS};
let os_name = match OS {
"windows" => "Windows",
"linux" => "Linux",
"macos" => "macOS",
"freebsd" => "FreeBSD",
_ => OS,
};
let shell = env::var("SHELL").unwrap_or_else(|_| "unknown".to_string());
let cwd = env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| "unknown".to_string());
format!(
"## 系统环境\n- 操作系统: {}\n- 架构: {}\n- Shell: {}\n- 当前工作目录: {}\n- 模型提供商: {}\n- 模型: {}",
os_name, ARCH, shell, cwd, config.name, config.model_id
)
}
fn persist_memory_summary(path: &Path, markdown_body: &str) -> Result<(), AgentError> {
let trimmed = markdown_body.trim();
if trimmed.is_empty() {
return Ok(());
}
write_prompt_file(path, trimmed)
}
fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> {
ensure_parent_dir(path)?;
let normalized = content.trim_end().to_string() + "\n";
let temp_path = path.with_extension("md.tmp");
fs::write(&temp_path, normalized)
.map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
// 使用平台抽象的原子重命名
atomic_rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
Ok(())
}
fn ensure_parent_dir(path: &Path) -> Result<(), AgentError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
Ok(())
}
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
Ok(agent_dir_path()?.join("AGENT.md"))
}
fn memory_summary_path() -> Result<PathBuf, AgentError> {
Ok(agent_dir_path()?.join("MEMORY_SUMMARY.md"))
}
fn agent_dir_path() -> Result<PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent"))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_load_prompt_from_sources_aggregates_multiple_fragments_in_order() {
let temp = tempdir().unwrap();
let agent_path = temp.path().join("AGENT.md");
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
// 创建用户自定义内容
fs::write(&agent_path, "# Agent\n静态规则\n").unwrap();
// 创建记忆摘要
fs::write(&memory_path, "## 用户记忆摘要\n- 偏好简洁\n").unwrap();
let prompt = load_prompt_from_sources(&[
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
PromptSource::UserCustom {
path: agent_path.clone(),
template: AGENT_MD_TEMPLATE,
},
PromptSource::AutoGenerated(memory_path.clone()),
])
.unwrap()
.unwrap();
// 验证三部分都包含:系统默认 + 用户自定义 + 记忆摘要
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
assert!(prompt.contains("# Agent\n静态规则"));
assert!(prompt.contains("## 用户记忆摘要"));
// 验证顺序
let parts: Vec<_> = prompt.split("\n\n").collect();
assert!(parts[0].contains("# PicoBot 代理配置")); // 系统默认开头
}
#[test]
fn test_load_prompt_from_sources_skips_empty_user_custom() {
let temp = tempdir().unwrap();
let agent_path = temp.path().join("AGENT.md");
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
// AGENT.md 为空(只有注释和空白)
fs::write(&agent_path, "<!-- 注释 -->\n\n \n").unwrap();
// 创建记忆摘要
fs::write(&memory_path, "## 用户记忆摘要\n").unwrap();
let prompt = load_prompt_from_sources(&[
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
PromptSource::UserCustom {
path: agent_path.clone(),
template: AGENT_MD_TEMPLATE,
},
PromptSource::AutoGenerated(memory_path.clone()),
])
.unwrap()
.unwrap();
// 只包含系统默认和记忆摘要,不包含空的 AGENT.md
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
assert!(prompt.contains("## 用户记忆摘要"));
assert!(!prompt.contains("<!-- 注释 -->"));
}
#[test]
fn test_load_prompt_from_sources_creates_template_for_missing_agent_md() {
let temp = tempdir().unwrap();
let agent_path = temp.path().join("AGENT.md");
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
// AGENT.md 不存在
assert!(!agent_path.exists());
let prompt = load_prompt_from_sources(&[
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
PromptSource::UserCustom {
path: agent_path.clone(),
template: AGENT_MD_TEMPLATE,
},
PromptSource::AutoGenerated(memory_path.clone()),
])
.unwrap()
.unwrap();
// AGENT.md 应该被创建
assert!(agent_path.exists());
// AGENT.md 应该包含模板内容(在注释中)
let agent_content = fs::read_to_string(&agent_path).unwrap();
assert!(agent_content.contains("# 自定义 Agent 配置"));
// 因为模板全是注释,去除后为空,所以不会注入到最终提示词
assert!(!prompt.contains("自定义 Agent 配置"));
// 但系统默认提示词始终存在
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
// 测试添加实际内容后会注入
fs::write(&agent_path, "# 我的自定义配置\n\n我是 Rust 开发者\n").unwrap();
let prompt_with_custom = load_prompt_from_sources(&[
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
PromptSource::UserCustom {
path: agent_path.clone(),
template: AGENT_MD_TEMPLATE,
},
PromptSource::AutoGenerated(memory_path.clone()),
])
.unwrap()
.unwrap();
// 现在应该包含自定义配置
assert!(prompt_with_custom.contains("# 我的自定义配置"));
assert!(prompt_with_custom.contains("我是 Rust 开发者"));
assert!(prompt_with_custom.contains(DEFAULT_AGENT_PROMPT));
// 测试清空 AGENT.md 后不会注入
fs::write(&agent_path, "<!-- 全部注释 -->\n\n \n").unwrap();
let prompt_empty = load_prompt_from_sources(&[
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
PromptSource::UserCustom {
path: agent_path.clone(),
template: AGENT_MD_TEMPLATE,
},
PromptSource::AutoGenerated(memory_path.clone()),
])
.unwrap()
.unwrap();
// 清空后(只有注释)不会包含自定义配置
assert!(!prompt_empty.contains("我的自定义配置"));
assert!(prompt_empty.contains(DEFAULT_AGENT_PROMPT));
}
#[test]
fn test_load_prompt_from_sources_ignores_missing_optional_source() {
let temp = tempdir().unwrap();
let agent_path = temp.path().join("AGENT.md");
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
// 创建用户自定义内容
fs::write(&agent_path, "# Agent\n静态规则\n").unwrap();
// MEMORY_SUMMARY.md 不存在
let prompt = load_prompt_from_sources(&[
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
PromptSource::UserCustom {
path: agent_path.clone(),
template: AGENT_MD_TEMPLATE,
},
PromptSource::AutoGenerated(memory_path.clone()),
])
.unwrap()
.unwrap();
// 包含系统默认和用户自定义,不包含记忆摘要
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
assert!(prompt.contains("# Agent\n静态规则"));
assert!(!prompt.contains("MEMORY_SUMMARY"));
}
#[test]
fn test_strip_comments_and_whitespace_removes_html_comments() {
let content = "<!-- 注释 -->\n实际内容\n<!-- 另一个注释 -->";
let result = strip_comments_and_whitespace(content);
assert!(!result.contains("<!--"));
assert!(result.contains("实际内容"));
}
#[test]
fn test_strip_comments_and_whitespace_removes_empty_lines() {
let content = "\n\n实际内容\n\n \n\n另一行\n\n";
let result = strip_comments_and_whitespace(content);
assert_eq!(result, "实际内容\n另一行");
}
#[test]
fn test_strip_comments_and_whitespace_returns_empty_for_only_comments() {
let content = "<!-- 注释1 -->\n<!-- 注释2 -->\n\n \n";
let result = strip_comments_and_whitespace(content);
assert!(result.is_empty());
}
#[test]
fn test_upsert_managed_agent_memory_summary_writes_trimmed_markdown() {
let temp = tempdir().unwrap();
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
persist_memory_summary(&memory_path, "\n## 用户记忆摘要\n- 偏好简洁\n\n").unwrap();
assert_eq!(fs::read_to_string(&memory_path).unwrap(), "## 用户记忆摘要\n- 偏好简洁\n");
}
}