use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; use serde::{Deserialize, Serialize}; use crate::agent::AgentError; use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig}; use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider}; use crate::storage::{MemoryRecord, SessionStore}; use super::prompt::upsert_managed_agent_memory_summary; const MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_step1_system_prompt.md"); const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_step2_system_prompt.md"); const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; const META_NAMESPACE: &str = "_meta"; const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at"; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceCandidate { pub(crate) id: String, pub(crate) namespace: String, pub(crate) key: String, pub(crate) content: String, pub(crate) updated_at: i64, // 记忆更新时间(Unix timestamp) } #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenancePlan { pub(crate) candidates: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceMerge { pub(crate) source_ids: Vec, pub(crate) namespace: String, pub(crate) memory_key: String, pub(crate) content: String, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceConflict { pub(crate) source_ids: Vec, pub(crate) note: String, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryOrganizationOutput { pub(crate) merges: Vec, pub(crate) conflicts: Vec, pub(crate) low_value_ids: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemorySummaryInput { pub(crate) organized_memories: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct OrganizedMemory { pub(crate) namespace: String, pub(crate) memory_key: String, pub(crate) content: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceScopeResult { pub(crate) scope_key: String, pub(crate) output: MemoryOrganizationOutput, pub(crate) managed_markdown: String, } pub(crate) struct MemoryMaintenanceService { store: Arc, provider_config: LLMProviderConfig, maintenance_config: MemoryMaintenanceConfig, } impl MemoryMaintenanceService { pub(crate) fn new( store: Arc, provider_config: LLMProviderConfig, maintenance_config: MemoryMaintenanceConfig, ) -> Self { Self { store, provider_config, maintenance_config, } } /// 创建记忆整理专用的 provider,使用 memory_maintenance_timeout_secs 作为超时时间 fn create_maintenance_provider( &self, ) -> Result, crate::providers::ProviderError> { let config = &self.provider_config; let runtime_config = ProviderRuntimeConfig { provider_type: config.provider_type.clone(), name: config.name.clone(), base_url: config.base_url.clone(), api_key: config.api_key.clone(), extra_headers: config.extra_headers.clone(), llm_timeout_secs: config.memory_maintenance_timeout_secs, model_id: config.model_id.clone(), temperature: config.temperature, max_tokens: config.max_tokens, model_extra: config.model_extra.clone(), }; create_provider(runtime_config) } pub(crate) fn build_plan_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { // 新增:检查是否有新记忆需要整理 if !has_new_memories_since_last_maintenance(&self.store, scope_key)? { tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping"); return Ok(None); } let memories = self .store .list_memories_for_scope("user", scope_key) .map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?; if memories.is_empty() { return Ok(None); } // 记忆数量不足最小保留数时,无需整理,直接跳过 // 避免浪费 LLM token 并触发无意义的 "保留数不足" 错误 if memories.len() < self.maintenance_config.min_memories_to_keep { tracing::info!( scope_key = %scope_key, count = memories.len(), min_required = self.maintenance_config.min_memories_to_keep, "Skipping scope: not enough memories to organize" ); return Ok(None); } Ok(Some(build_memory_maintenance_plan(&memories))) } pub(crate) async fn organize_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let Some(plan) = self.build_plan_for_scope(scope_key)? else { return Ok(None); }; self.organize_plan(scope_key, &plan).await.map(Some) } async fn organize_plan( &self, scope_key: &str, plan: &MemoryMaintenancePlan, ) -> Result { let provider = self.create_maintenance_provider().map_err(|err| { AgentError::Other(format!("create maintenance provider error: {}", err)) })?; let request = ChatCompletionRequest { messages: vec![ Message::system(MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT), Message::user( serde_json::to_string_pretty(&serde_json::json!({ "scope_key": scope_key, "candidates": plan, })) .unwrap_or_else(|_| "{}".to_string()), ), ], temperature: Some(0.0), max_tokens: Some(4000), tools: None, }; let mut last_error = None; for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS .iter() .copied() .map(Some) .chain(std::iter::once(None)) .enumerate() { let response = match provider.chat(request.clone()).await { Ok(success) => success, Err(err) => { let error_text = err.to_string(); let should_retry = delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); last_error = Some(error_text.clone()); if should_retry { tracing::warn!( scope_key = %scope_key, attempt = attempt + 1, retry_in_ms = delay_ms.unwrap_or_default(), error = %error_text, "Memory organization model request failed, retrying" ); tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) .await; continue; } return Err(AgentError::Other(format!( "memory organization model error: {}", error_text ))); } }; let raw_content = strip_json_code_fence(&response.content); let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content); match serde_json::from_str::(json_candidate) { Ok(parsed) => return Ok(parsed), Err(err) => { let error_msg = err.to_string(); let is_truncated = error_msg.contains("EOF while parsing") || error_msg.contains("expected"); let should_retry = delay_ms.is_some() && is_truncated; last_error = Some(error_msg.clone()); if should_retry { tracing::warn!( scope_key = %scope_key, attempt = attempt + 1, retry_in_ms = delay_ms.unwrap_or_default(), error = %error_msg, raw_len = raw_content.len(), "Memory organization JSON parse failed (possibly truncated), retrying" ); tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) .await; continue; } tracing::error!( scope_key = %scope_key, error = %err, raw_len = raw_content.len(), raw_preview = %preview_text(raw_content, 400), json_candidate_len = json_candidate.len(), json_candidate_preview = %preview_text(json_candidate, 400), "Memory maintenance JSON decode failed" ); return Err(AgentError::Other(format!( "memory maintenance JSON decode error: {}", err ))); } } } Err(AgentError::Other(format!( "memory organization failed after retries: {}", last_error.unwrap_or_else(|| "unknown error".to_string()) ))) } #[cfg_attr(not(test), allow(dead_code))] pub(crate) async fn run_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let Some(plan) = self.build_plan_for_scope(scope_key)? else { return Ok(None); }; // 步骤1:整理记忆(不生成摘要) let organize_output = self.organize_plan(scope_key, &plan).await?; // 应用整理结果(merge和delete) apply_memory_maintenance_output( self.store.as_ref(), scope_key, &plan, &organize_output, self.maintenance_config.max_merge_ratio, self.maintenance_config.min_memories_to_keep, self.maintenance_config.max_merge_per_group, )?; // 步骤2:从数据库重新读取剩余的记忆 let remaining_memories = self .store .list_memories_for_scope("user", scope_key) .map_err(|err| { AgentError::Other(format!("list remaining memories error: {}", err)) })?; // 步骤2:生成摘要 let managed_markdown = if remaining_memories.is_empty() { String::new() } else { self.generate_summary(scope_key, &remaining_memories).await? }; Ok(Some(MemoryMaintenanceScopeResult { scope_key: scope_key.to_string(), output: organize_output, managed_markdown, })) } async fn generate_summary( &self, scope_key: &str, remaining_memories: &[MemoryRecord], ) -> Result { let provider = self.create_maintenance_provider().map_err(|err| { AgentError::Other(format!("create maintenance provider error: {}", err)) })?; let input = MemorySummaryInput { organized_memories: remaining_memories .iter() .map(|m| OrganizedMemory { namespace: m.namespace.clone(), memory_key: m.memory_key.clone(), content: m.content.clone(), }) .collect(), }; let request = ChatCompletionRequest { messages: vec![ Message::system(MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT), Message::user( serde_json::to_string_pretty(&serde_json::json!(input)) .unwrap_or_else(|_| "{}".to_string()), ), ], temperature: Some(0.0), max_tokens: Some(1000), tools: None, }; let mut last_error = None; let mut response = None; for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS .iter() .copied() .map(Some) .chain(std::iter::once(None)) .enumerate() { match provider.chat(request.clone()).await { Ok(success) => { response = Some(success); break; } Err(err) => { let error_text = err.to_string(); let should_retry = delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); last_error = Some(error_text.clone()); if should_retry { tracing::warn!( scope_key = %scope_key, attempt = attempt + 1, retry_in_ms = delay_ms.unwrap_or_default(), error = %error_text, "Memory summary model request failed, retrying" ); tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) .await; continue; } return Err(AgentError::Other(format!( "memory summary model error: {}", error_text ))); } } } let response = response.ok_or_else(|| { AgentError::Other(format!( "memory summary model error: {}", last_error.unwrap_or_else(|| "unknown provider error".to_string()) )) })?; let content = response.content.trim(); // 确保响应包含标签,如无则自动添加 let tagged_content = if content.contains("") { content.to_string() } else { format!( "\n{}\n", content ) }; Ok(tagged_content) } pub(crate) async fn run_for_all_scopes( &self, ) -> Result, AgentError> { let scope_keys = self.store.list_memory_scope_keys("user").map_err(|err| { AgentError::Other(format!("list memory scope keys error: {}", err)) })?; if scope_keys.is_empty() { return Ok(None); } // 步骤1:逐个 scope 整理记忆(merge/delete),但不生成摘要 let mut all_outputs = Vec::new(); for scope_key in &scope_keys { match self.run_organize_for_scope(scope_key).await { Ok(Some(output)) => all_outputs.push((scope_key.clone(), output)), Ok(None) => continue, Err(error) if is_recoverable_maintenance_scope_error(&error) => { tracing::warn!( scope_key = %scope_key, error = %error, "Memory maintenance skipped scope after recoverable model failure" ); continue; } Err(error) => return Err(error), } } if all_outputs.is_empty() { return Ok(None); } // 步骤2:收集所有 scope 整理后的剩余记忆 let mut all_remaining_memories = Vec::new(); for (scope_key, _) in &all_outputs { let memories = self .store .list_memories_for_scope("user", scope_key) .map_err(|err| { AgentError::Other(format!( "list remaining memories for scope {} error: {}", scope_key, err )) })?; all_remaining_memories.extend(memories); } // 步骤3:统一生成一个摘要 let managed_markdown = if all_remaining_memories.is_empty() { String::new() } else { self.generate_summary("all", &all_remaining_memories).await? }; if !managed_markdown.is_empty() { upsert_managed_agent_memory_summary(&managed_markdown)?; } // 合并所有输出用于返回 let combined_output = MemoryOrganizationOutput { merges: all_outputs .iter() .flat_map(|(_, o)| o.merges.clone()) .collect(), conflicts: all_outputs .iter() .flat_map(|(_, o)| o.conflicts.clone()) .collect(), low_value_ids: all_outputs .iter() .flat_map(|(_, o)| o.low_value_ids.clone()) .collect(), }; Ok(Some(MemoryMaintenanceScopeResult { scope_key: "all".to_string(), output: combined_output, managed_markdown, })) } /// 仅执行整理步骤(organize + apply),不生成摘要 async fn run_organize_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let Some(plan) = self.build_plan_for_scope(scope_key)? else { return Ok(None); }; // 步骤1:整理记忆 let organize_output = self.organize_plan(scope_key, &plan).await?; // 应用整理结果(merge和delete) apply_memory_maintenance_output( self.store.as_ref(), scope_key, &plan, &organize_output, self.maintenance_config.max_merge_ratio, self.maintenance_config.min_memories_to_keep, self.maintenance_config.max_merge_per_group, )?; Ok(Some(organize_output)) } } /// 获取上次整理时间 fn get_last_maintenance_time( store: &SessionStore, scope_key: &str, ) -> Result, crate::storage::StorageError> { let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?; Ok(meta.and_then(|m| m.content.parse::().ok())) } /// 记录本次整理时间 fn set_last_maintenance_time( store: &SessionStore, scope_key: &str, time: i64, ) -> Result<(), crate::storage::StorageError> { store.put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: META_NAMESPACE.to_string(), memory_key: LAST_MAINTENANCE_KEY.to_string(), content: time.to_string(), source_type: "memory_maintenance".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, })?; Ok(()) } /// 检查是否有需要整理的新记忆(过滤掉 _meta namespace) fn has_new_memories_since_last_maintenance( store: &SessionStore, scope_key: &str, ) -> Result { let memories = store .list_memories_for_scope("user", scope_key) .map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?; // 过滤掉 _meta namespace 的记忆 let user_memories: Vec<_> = memories .iter() .filter(|m| m.namespace != META_NAMESPACE) .collect(); if user_memories.is_empty() { return Ok(false); // 没有记忆,跳过 } // 获取上次整理时间 let last_time = get_last_maintenance_time(store, scope_key) .map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?; match last_time { None => Ok(true), // 从未整理过,需要整理 Some(last) => { // 检查是否有记忆在上次整理后更新 let has_new = user_memories.iter().any(|m| m.updated_at > last); Ok(has_new) } } } pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan { let mut plan = MemoryMaintenancePlan::default(); let mut seen = HashSet::new(); for memory in memories { // 过滤掉 _meta namespace 的记忆 if memory.namespace == META_NAMESPACE { continue; } let normalized_content = memory.content.trim(); if normalized_content.is_empty() { continue; } let dedupe_key = format!( "{}\u{1f}{}\u{1f}{}", memory.namespace.trim().to_ascii_lowercase(), memory.memory_key.trim().to_ascii_lowercase(), normalized_content ); if !seen.insert(dedupe_key) { continue; } let candidate = MemoryMaintenanceCandidate { id: memory.id.clone(), namespace: memory.namespace.clone(), key: memory.memory_key.clone(), content: normalized_content.to_string(), updated_at: memory.updated_at, }; plan.candidates.push(candidate); } plan } pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool { let normalized = error.to_ascii_lowercase(); normalized.contains("error sending request for url") || normalized.contains("504") || normalized.contains("gateway timeout") || normalized.contains("stream timeout") || normalized.contains("timed out") || normalized.contains("timeout") // 验证拒绝 — 记忆太少,跳过本次 scope 即可,不应视为作业失败 || error.contains("保留数不足") || error.contains("合并比例超限") } fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool { is_recoverable_maintenance_llm_error(&error.to_string()) } pub(crate) fn strip_json_code_fence(content: &str) -> &str { let trimmed = content.trim(); if let Some(rest) = trimmed.strip_prefix("```json") { return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); } if let Some(rest) = trimmed.strip_prefix("```") { return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); } trimmed } pub(crate) fn extract_json_object(content: &str) -> Option<&str> { let mut start = None; let mut depth = 0usize; let mut in_string = false; let mut escaped = false; for (index, ch) in content.char_indices() { if in_string { if escaped { escaped = false; continue; } match ch { '\\' => escaped = true, '"' => in_string = false, _ => {} } continue; } match ch { '"' => in_string = true, '{' => { if start.is_none() { start = Some(index); } depth += 1; } '}' => { if depth == 0 { continue; } depth -= 1; if depth == 0 { let start = start?; let end = index + ch.len_utf8(); return Some(content[start..end].trim()); } } _ => {} } } None } /// 验证记忆整理输出是否符合限制 pub(crate) fn validate_memory_maintenance_output( plan: &MemoryMaintenancePlan, output: &MemoryOrganizationOutput, max_merge_ratio: f32, min_memories_to_keep: usize, max_merge_per_group: usize, ) -> Result<(), String> { let total = plan.candidates.len(); if total == 0 { return Ok(()); // 没有候选,无需验证 } // 验证 1: 单次合并数量限制 for merge in &output.merges { if merge.source_ids.len() > max_merge_per_group { return Err(format!( "合并组过大: {} 条源记忆超过上限 {}", merge.source_ids.len(), max_merge_per_group )); } } // 验证 2: 跨 namespace 合并检测(完全禁止) let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan .candidates .iter() .map(|c| (c.id.as_str(), c)) .collect(); for merge in &output.merges { let source_namespaces: HashSet<&str> = merge .source_ids .iter() .filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str())) .collect(); // 检查是否跨越多个 namespace if source_namespaces.len() > 1 { return Err(format!( "跨 namespace 合并被禁止: 源来自 {}", source_namespaces.iter().cloned().collect::>().join(", ") )); } // 检查目标 namespace 是否与源一致 if let Some(src_ns) = source_namespaces.iter().next() { if *src_ns != merge.namespace { return Err(format!( "跨 namespace 合并被禁止: {} → {}", src_ns, merge.namespace )); } } } // 验证 3: 总体合并比例 let merged_ids: HashSet<&str> = output .merges .iter() .flat_map(|m| m.source_ids.iter()) .map(|s| s.as_str()) .collect(); let deleted_ids: HashSet<&str> = output .low_value_ids .iter() .map(|s| s.as_str()) .collect(); let affected = merged_ids.len() + deleted_ids.len(); let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize; if affected > max_allowed { return Err(format!( "合并比例超限: {} / {} > {:.0}%", affected, total, max_merge_ratio * 100.0 )); } // 验证 4: 最小保留数 let remaining = total - affected + output.merges.len(); if remaining < min_memories_to_keep { return Err(format!( "保留数不足: {} < {}", remaining, min_memories_to_keep )); } Ok(()) } pub(crate) fn apply_memory_maintenance_output( store: &SessionStore, scope_key: &str, plan: &MemoryMaintenancePlan, output: &MemoryOrganizationOutput, max_merge_ratio: f32, min_memories_to_keep: usize, max_merge_per_group: usize, ) -> Result<(), AgentError> { // 新增: 验证合并输出 validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group) .map_err(|e| AgentError::Other(e))?; let all_candidates = plan.candidates.clone(); let candidates_by_id = all_candidates .iter() .map(|candidate| (candidate.id.as_str(), candidate)) .collect::>(); let mut deleted_ids = HashSet::new(); for merge in &output.merges { if merge.source_ids.is_empty() { continue; } let source_candidates = merge .source_ids .iter() .filter_map(|id| candidates_by_id.get(id.as_str()).copied()) .collect::>(); if source_candidates.is_empty() { continue; } let existing_target_id = source_candidates .iter() .find(|candidate| { candidate.namespace == merge.namespace && candidate.key == merge.memory_key }) .map(|candidate| candidate.id.clone()); store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: merge.namespace.trim().to_string(), memory_key: merge.memory_key.trim().to_string(), content: merge.content.trim().to_string(), source_type: "memory_maintenance".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?; for candidate in source_candidates { if existing_target_id .as_ref() .is_some_and(|target_id| target_id == &candidate.id) { continue; } if deleted_ids.insert(candidate.id.clone()) { store .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) .map_err(|err| { AgentError::Other(format!("delete merged source memory error: {}", err)) })?; } } } for memory_id in &output.low_value_ids { if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) { if deleted_ids.insert(candidate.id.clone()) { store .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) .map_err(|err| { AgentError::Other(format!("delete low value memory error: {}", err)) })?; } } } // 新增:记录整理完成时间 let now = chrono::Utc::now().timestamp(); set_last_maintenance_time(store, scope_key, now) .map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?; Ok(()) } fn preview_text(content: &str, max_chars: usize) -> String { let mut preview = content.chars().take(max_chars).collect::(); if content.chars().count() > max_chars { preview.push_str("..."); } preview.replace('\n', "\\n") } #[cfg(test)] mod tests {}