361 lines
13 KiB
Rust
361 lines
13 KiB
Rust
use std::env;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use crate::agent::AgentError;
|
|
use crate::config::LLMProviderConfig;
|
|
use crate::platform::atomic_rename;
|
|
|
|
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
|
pub(crate) const AGENT_MD_TEMPLATE: &str = include_str!("agent_md_template.md");
|
|
|
|
/// 提示词来源类型
|
|
#[derive(Clone)]
|
|
enum PromptSource {
|
|
/// 内置内容 - 系统默认,始终注入,不可修改
|
|
Builtin(&'static str),
|
|
/// 用户自定义 - 自动创建空白文件,可选注入
|
|
UserCustom {
|
|
path: PathBuf,
|
|
template: &'static str,
|
|
},
|
|
/// 自动生成 - 由系统维护,可选注入
|
|
AutoGenerated(PathBuf),
|
|
}
|
|
|
|
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
|
load_prompt_from_sources(&prompt_sources()?)
|
|
}
|
|
|
|
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
|
|
persist_memory_summary(&memory_summary_path()?, markdown_body)
|
|
}
|
|
|
|
fn prompt_sources() -> Result<Vec<PromptSource>, AgentError> {
|
|
Ok(vec![
|
|
// 1. 系统默认 - 始终注入
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
// 2. 用户自定义 - 自动创建空白模板,可选注入
|
|
PromptSource::UserCustom {
|
|
path: agent_prompt_path()?,
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
// 3. 记忆摘要 - 可选
|
|
PromptSource::AutoGenerated(memory_summary_path()?),
|
|
])
|
|
}
|
|
|
|
fn load_prompt_from_sources(sources: &[PromptSource]) -> Result<Option<String>, AgentError> {
|
|
let mut fragments = Vec::new();
|
|
|
|
for source in sources {
|
|
match source {
|
|
PromptSource::Builtin(content) => {
|
|
fragments.push(content.to_string());
|
|
}
|
|
PromptSource::UserCustom { path, template } => {
|
|
// 确保父目录存在
|
|
ensure_parent_dir(path)?;
|
|
// 文件不存在时创建空白模板
|
|
if !path.exists() {
|
|
fs::write(path, template)
|
|
.map_err(|err| AgentError::Other(format!("create AGENT.md template error: {}", err)))?;
|
|
}
|
|
// 读取内容,仅当非空(去除注释后)时注入
|
|
let content = fs::read_to_string(path)
|
|
.map_err(|err| AgentError::Other(format!("read AGENT.md error: {}", err)))?;
|
|
let trimmed = strip_comments_and_whitespace(&content);
|
|
if !trimmed.is_empty() {
|
|
fragments.push(trimmed.to_string());
|
|
}
|
|
}
|
|
PromptSource::AutoGenerated(path) => {
|
|
if path.exists() {
|
|
let content = fs::read_to_string(path)
|
|
.map_err(|err| AgentError::Other(format!("read MEMORY_SUMMARY.md error: {}", err)))?;
|
|
let without_comments = strip_comments_and_whitespace(&content);
|
|
if !without_comments.is_empty() {
|
|
fragments.push(without_comments);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if fragments.is_empty() {
|
|
Ok(None)
|
|
} else {
|
|
Ok(Some(fragments.join("\n\n")))
|
|
}
|
|
}
|
|
|
|
/// 去除 HTML 注释和空白行,检查是否还有有效内容
|
|
fn strip_comments_and_whitespace(content: &str) -> String {
|
|
// 去除 HTML 注释 <!-- -->
|
|
let re = regex::Regex::new(r"<!--[\s\S]*?-->").unwrap();
|
|
let without_comments = re.replace_all(content, "");
|
|
|
|
// 去除空白行后检查是否还有有效内容
|
|
without_comments
|
|
.lines()
|
|
.filter(|line| !line.trim().is_empty())
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|
|
|
|
/// 生成系统环境信息提示词
|
|
pub(crate) fn generate_system_environment_prompt(config: &LLMProviderConfig) -> String {
|
|
use std::env::consts::{ARCH, OS};
|
|
|
|
let os_name = match OS {
|
|
"windows" => "Windows",
|
|
"linux" => "Linux",
|
|
"macos" => "macOS",
|
|
"freebsd" => "FreeBSD",
|
|
_ => OS,
|
|
};
|
|
|
|
let shell = env::var("SHELL").unwrap_or_else(|_| "unknown".to_string());
|
|
let cwd = env::current_dir()
|
|
.map(|p| p.display().to_string())
|
|
.unwrap_or_else(|_| "unknown".to_string());
|
|
|
|
format!(
|
|
"## 系统环境\n- 操作系统: {}\n- 架构: {}\n- Shell: {}\n- 当前工作目录: {}\n- 模型提供商: {}\n- 模型: {}",
|
|
os_name, ARCH, shell, cwd, config.name, config.model_id
|
|
)
|
|
}
|
|
|
|
fn persist_memory_summary(path: &Path, markdown_body: &str) -> Result<(), AgentError> {
|
|
let trimmed = markdown_body.trim();
|
|
if trimmed.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
write_prompt_file(path, trimmed)
|
|
}
|
|
|
|
fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> {
|
|
ensure_parent_dir(path)?;
|
|
|
|
let normalized = content.trim_end().to_string() + "\n";
|
|
let temp_path = path.with_extension("md.tmp");
|
|
fs::write(&temp_path, normalized)
|
|
.map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
|
|
|
|
// 使用平台抽象的原子重命名
|
|
atomic_rename(&temp_path, path)
|
|
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn ensure_parent_dir(path: &Path) -> Result<(), AgentError> {
|
|
if let Some(parent) = path.parent() {
|
|
fs::create_dir_all(parent)
|
|
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
|
|
Ok(agent_dir_path()?.join("AGENT.md"))
|
|
}
|
|
|
|
fn memory_summary_path() -> Result<PathBuf, AgentError> {
|
|
Ok(agent_dir_path()?.join("MEMORY_SUMMARY.md"))
|
|
}
|
|
|
|
fn agent_dir_path() -> Result<PathBuf, AgentError> {
|
|
let home = dirs::home_dir()
|
|
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
|
Ok(home.join(".picobot").join("agent"))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use tempfile::tempdir;
|
|
|
|
#[test]
|
|
fn test_load_prompt_from_sources_aggregates_multiple_fragments_in_order() {
|
|
let temp = tempdir().unwrap();
|
|
let agent_path = temp.path().join("AGENT.md");
|
|
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
|
|
|
|
// 创建用户自定义内容
|
|
fs::write(&agent_path, "# Agent\n静态规则\n").unwrap();
|
|
// 创建记忆摘要
|
|
fs::write(&memory_path, "## 用户记忆摘要\n- 偏好简洁\n").unwrap();
|
|
|
|
let prompt = load_prompt_from_sources(&[
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
PromptSource::UserCustom {
|
|
path: agent_path.clone(),
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
PromptSource::AutoGenerated(memory_path.clone()),
|
|
])
|
|
.unwrap()
|
|
.unwrap();
|
|
|
|
// 验证三部分都包含:系统默认 + 用户自定义 + 记忆摘要
|
|
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
|
|
assert!(prompt.contains("# Agent\n静态规则"));
|
|
assert!(prompt.contains("## 用户记忆摘要"));
|
|
// 验证顺序
|
|
let parts: Vec<_> = prompt.split("\n\n").collect();
|
|
assert!(parts[0].contains("# PicoBot 代理配置")); // 系统默认开头
|
|
}
|
|
|
|
#[test]
|
|
fn test_load_prompt_from_sources_skips_empty_user_custom() {
|
|
let temp = tempdir().unwrap();
|
|
let agent_path = temp.path().join("AGENT.md");
|
|
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
|
|
|
|
// AGENT.md 为空(只有注释和空白)
|
|
fs::write(&agent_path, "<!-- 注释 -->\n\n \n").unwrap();
|
|
// 创建记忆摘要
|
|
fs::write(&memory_path, "## 用户记忆摘要\n").unwrap();
|
|
|
|
let prompt = load_prompt_from_sources(&[
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
PromptSource::UserCustom {
|
|
path: agent_path.clone(),
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
PromptSource::AutoGenerated(memory_path.clone()),
|
|
])
|
|
.unwrap()
|
|
.unwrap();
|
|
|
|
// 只包含系统默认和记忆摘要,不包含空的 AGENT.md
|
|
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
|
|
assert!(prompt.contains("## 用户记忆摘要"));
|
|
assert!(!prompt.contains("<!-- 注释 -->"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_load_prompt_from_sources_creates_template_for_missing_agent_md() {
|
|
let temp = tempdir().unwrap();
|
|
let agent_path = temp.path().join("AGENT.md");
|
|
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
|
|
|
|
// AGENT.md 不存在
|
|
assert!(!agent_path.exists());
|
|
|
|
let prompt = load_prompt_from_sources(&[
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
PromptSource::UserCustom {
|
|
path: agent_path.clone(),
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
PromptSource::AutoGenerated(memory_path.clone()),
|
|
])
|
|
.unwrap()
|
|
.unwrap();
|
|
|
|
// AGENT.md 应该被创建
|
|
assert!(agent_path.exists());
|
|
// AGENT.md 应该包含模板内容(在注释中)
|
|
let agent_content = fs::read_to_string(&agent_path).unwrap();
|
|
assert!(agent_content.contains("# 自定义 Agent 配置"));
|
|
// 因为模板全是注释,去除后为空,所以不会注入到最终提示词
|
|
assert!(!prompt.contains("自定义 Agent 配置"));
|
|
// 但系统默认提示词始终存在
|
|
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
|
|
|
|
// 测试添加实际内容后会注入
|
|
fs::write(&agent_path, "# 我的自定义配置\n\n我是 Rust 开发者\n").unwrap();
|
|
let prompt_with_custom = load_prompt_from_sources(&[
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
PromptSource::UserCustom {
|
|
path: agent_path.clone(),
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
PromptSource::AutoGenerated(memory_path.clone()),
|
|
])
|
|
.unwrap()
|
|
.unwrap();
|
|
// 现在应该包含自定义配置
|
|
assert!(prompt_with_custom.contains("# 我的自定义配置"));
|
|
assert!(prompt_with_custom.contains("我是 Rust 开发者"));
|
|
assert!(prompt_with_custom.contains(DEFAULT_AGENT_PROMPT));
|
|
|
|
// 测试清空 AGENT.md 后不会注入
|
|
fs::write(&agent_path, "<!-- 全部注释 -->\n\n \n").unwrap();
|
|
let prompt_empty = load_prompt_from_sources(&[
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
PromptSource::UserCustom {
|
|
path: agent_path.clone(),
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
PromptSource::AutoGenerated(memory_path.clone()),
|
|
])
|
|
.unwrap()
|
|
.unwrap();
|
|
// 清空后(只有注释)不会包含自定义配置
|
|
assert!(!prompt_empty.contains("我的自定义配置"));
|
|
assert!(prompt_empty.contains(DEFAULT_AGENT_PROMPT));
|
|
}
|
|
|
|
#[test]
|
|
fn test_load_prompt_from_sources_ignores_missing_optional_source() {
|
|
let temp = tempdir().unwrap();
|
|
let agent_path = temp.path().join("AGENT.md");
|
|
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
|
|
|
|
// 创建用户自定义内容
|
|
fs::write(&agent_path, "# Agent\n静态规则\n").unwrap();
|
|
// MEMORY_SUMMARY.md 不存在
|
|
|
|
let prompt = load_prompt_from_sources(&[
|
|
PromptSource::Builtin(DEFAULT_AGENT_PROMPT),
|
|
PromptSource::UserCustom {
|
|
path: agent_path.clone(),
|
|
template: AGENT_MD_TEMPLATE,
|
|
},
|
|
PromptSource::AutoGenerated(memory_path.clone()),
|
|
])
|
|
.unwrap()
|
|
.unwrap();
|
|
|
|
// 包含系统默认和用户自定义,不包含记忆摘要
|
|
assert!(prompt.contains(DEFAULT_AGENT_PROMPT));
|
|
assert!(prompt.contains("# Agent\n静态规则"));
|
|
assert!(!prompt.contains("MEMORY_SUMMARY"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_comments_and_whitespace_removes_html_comments() {
|
|
let content = "<!-- 注释 -->\n实际内容\n<!-- 另一个注释 -->";
|
|
let result = strip_comments_and_whitespace(content);
|
|
assert!(!result.contains("<!--"));
|
|
assert!(result.contains("实际内容"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_comments_and_whitespace_removes_empty_lines() {
|
|
let content = "\n\n实际内容\n\n \n\n另一行\n\n";
|
|
let result = strip_comments_and_whitespace(content);
|
|
assert_eq!(result, "实际内容\n另一行");
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_comments_and_whitespace_returns_empty_for_only_comments() {
|
|
let content = "<!-- 注释1 -->\n<!-- 注释2 -->\n\n \n";
|
|
let result = strip_comments_and_whitespace(content);
|
|
assert!(result.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_upsert_managed_agent_memory_summary_writes_trimmed_markdown() {
|
|
let temp = tempdir().unwrap();
|
|
let memory_path = temp.path().join("MEMORY_SUMMARY.md");
|
|
|
|
persist_memory_summary(&memory_path, "\n## 用户记忆摘要\n- 偏好简洁\n\n").unwrap();
|
|
|
|
assert_eq!(fs::read_to_string(&memory_path).unwrap(), "## 用户记忆摘要\n- 偏好简洁\n");
|
|
}
|
|
}
|