PicoBot/src/gateway/memory_maintenance.rs

654 lines
22 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::agent::{AgentError, AgentRuntimeConfig};
use crate::config::LLMProviderConfig;
use crate::providers::{ChatCompletionRequest, Message, create_provider};
use crate::storage::{MemoryRecord, SessionStore};
use super::prompt::upsert_managed_agent_memory_summary;
const MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step1_system_prompt.md");
const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step2_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MemoryMaintenanceCategory {
UserFacts,
Preferences,
BehaviorPatterns,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String,
pub(crate) namespace: String,
pub(crate) key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenancePlan {
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceMerge {
pub(crate) source_ids: Vec<String>,
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceConflict {
pub(crate) source_ids: Vec<String>,
pub(crate) note: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryOrganizationOutput {
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
pub(crate) low_value_ids: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemorySummaryInput {
pub(crate) scope_key: String,
pub(crate) organized_memories: Vec<OrganizedMemory>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct OrganizedMemory {
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) scope_key: String,
pub(crate) output: MemoryOrganizationOutput,
pub(crate) managed_markdown: String,
}
pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
}
impl MemoryMaintenanceService {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
Self {
store,
provider_config,
}
}
pub(crate) fn build_plan_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
let memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
if memories.is_empty() {
return Ok(None);
}
Ok(Some(build_memory_maintenance_plan(&memories)))
}
pub(crate) async fn organize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
self.organize_plan(scope_key, &plan).await.map(Some)
}
async fn organize_plan(
&self,
scope_key: &str,
plan: &MemoryMaintenancePlan,
) -> Result<MemoryOrganizationOutput, AgentError> {
let runtime_config = AgentRuntimeConfig::from(self.provider_config.clone());
let provider = create_provider(runtime_config.provider).map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!({
"scope_key": scope_key,
"candidates": plan,
}))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1200),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory organization model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory organization model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory organization model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let raw_content = strip_json_code_fence(&response.content);
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
let output: MemoryOrganizationOutput =
serde_json::from_str(json_candidate).map_err(|err| {
tracing::error!(
scope_key = %scope_key,
error = %err,
raw_len = raw_content.len(),
raw_preview = %preview_text(raw_content, 400),
json_candidate_len = json_candidate.len(),
json_candidate_preview = %preview_text(json_candidate, 400),
"Memory maintenance JSON decode failed"
);
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
})?;
Ok(output)
}
pub(crate) async fn run_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
// 步骤1整理记忆不生成摘要
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
// 步骤2从数据库重新读取剩余的记忆
let remaining_memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| {
AgentError::Other(format!("list remaining memories error: {}", err))
})?;
// 步骤2生成摘要
let managed_markdown = if remaining_memories.is_empty() {
String::new()
} else {
self.generate_summary(scope_key, &remaining_memories).await?
};
Ok(Some(MemoryMaintenanceScopeResult {
scope_key: scope_key.to_string(),
output: organize_output,
managed_markdown,
}))
}
async fn generate_summary(
&self,
scope_key: &str,
remaining_memories: &[MemoryRecord],
) -> Result<String, AgentError> {
let runtime_config = AgentRuntimeConfig::from(self.provider_config.clone());
let provider = create_provider(runtime_config.provider).map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let input = MemorySummaryInput {
scope_key: scope_key.to_string(),
organized_memories: remaining_memories
.iter()
.map(|m| OrganizedMemory {
namespace: m.namespace.clone(),
memory_key: m.memory_key.clone(),
content: m.content.clone(),
})
.collect(),
};
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!(input))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1000),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory summary model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory summary model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory summary model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let content = response.content.trim();
// 确保响应包含标签,如无则自动添加
let tagged_content = if content.contains("<!-- memory_summary_start -->") {
content.to_string()
} else {
format!(
"<!-- memory_summary_start -->\n{}\n<!-- memory_summary_end -->",
content
)
};
Ok(tagged_content)
}
pub(crate) async fn run_for_all_scopes(
&self,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
let scope_keys = self.store.list_memory_scope_keys("user").map_err(|err| {
AgentError::Other(format!("list memory scope keys error: {}", err))
})?;
let mut results = Vec::new();
for scope_key in scope_keys {
let result = match self.run_for_scope(&scope_key).await {
Ok(Some(result)) => result,
Ok(None) => continue,
Err(error) if is_recoverable_maintenance_scope_error(&error) => {
tracing::warn!(
scope_key = %scope_key,
error = %error,
"Memory maintenance skipped scope after recoverable model failure"
);
continue;
}
Err(error) => return Err(error),
};
results.push(result);
}
let combined_markdown = combine_managed_memory_markdown(
&results
.iter()
.map(|result| result.managed_markdown.clone())
.collect::<Vec<_>>(),
);
if !combined_markdown.is_empty() {
upsert_managed_agent_memory_summary(&combined_markdown)?;
}
Ok(results)
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
}
let dedupe_key = format!(
"{}\u{1f}{}\u{1f}{}",
memory.namespace.trim().to_ascii_lowercase(),
memory.memory_key.trim().to_ascii_lowercase(),
normalized_content
);
if !seen.insert(dedupe_key) {
continue;
}
let candidate = MemoryMaintenanceCandidate {
id: memory.id.clone(),
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
};
match memory_maintenance_category(&memory.namespace) {
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
}
}
plan
}
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
match namespace.trim().to_ascii_lowercase().as_str() {
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
"patterns" | "behavior" | "habits" | "workflow" => {
MemoryMaintenanceCategory::BehaviorPatterns
}
_ => MemoryMaintenanceCategory::Other,
}
}
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("error sending request for url")
|| normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool {
is_recoverable_maintenance_llm_error(&error.to_string())
}
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
if let Some(rest) = trimmed.strip_prefix("```") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
trimmed
}
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
let mut start = None;
let mut depth = 0usize;
let mut in_string = false;
let mut escaped = false;
for (index, ch) in content.char_indices() {
if in_string {
if escaped {
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
'"' => in_string = true,
'{' => {
if start.is_none() {
start = Some(index);
}
depth += 1;
}
'}' => {
if depth == 0 {
continue;
}
depth -= 1;
if depth == 0 {
let start = start?;
let end = index + ch.len_utf8();
return Some(content[start..end].trim());
}
}
_ => {}
}
}
None
}
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
let normalized_chunks = chunks
.iter()
.map(|chunk| chunk.trim())
.filter(|chunk| !chunk.is_empty())
.collect::<Vec<_>>();
let mut combined = Vec::new();
for (index, chunk) in normalized_chunks.iter().enumerate() {
let chunk_lines = chunk
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
let is_subset_of_other =
normalized_chunks
.iter()
.enumerate()
.any(|(other_index, other)| {
if index == other_index {
return false;
}
let other_lines = other
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
});
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
combined.push((*chunk).to_string());
}
}
combined.join("\n\n")
}
pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
) -> Result<(), AgentError> {
let all_candidates = plan
.user_facts
.iter()
.chain(plan.preferences.iter())
.chain(plan.behavior_patterns.iter())
.chain(plan.others.iter())
.cloned()
.collect::<Vec<_>>();
let candidates_by_id = all_candidates
.iter()
.map(|candidate| (candidate.id.as_str(), candidate))
.collect::<HashMap<_, _>>();
let mut deleted_ids = HashSet::new();
for merge in &output.merges {
if merge.source_ids.is_empty() {
continue;
}
let source_candidates = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
.collect::<Vec<_>>();
if source_candidates.is_empty() {
continue;
}
let existing_target_id = source_candidates
.iter()
.find(|candidate| {
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
})
.map(|candidate| candidate.id.clone());
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: merge.namespace.trim().to_string(),
memory_key: merge.memory_key.trim().to_string(),
content: merge.content.trim().to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
for candidate in source_candidates {
if existing_target_id
.as_ref()
.is_some_and(|target_id| target_id == &candidate.id)
{
continue;
}
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete merged source memory error: {}", err))
})?;
}
}
}
for memory_id in &output.low_value_ids {
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete low value memory error: {}", err))
})?;
}
}
}
Ok(())
}
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
preview.push_str("...");
}
preview.replace('\n', "\\n")
}