use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; use serde::{Deserialize, Serialize}; use crate::agent::AgentError; use crate::config::LLMProviderConfig; use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider}; use crate::storage::{MemoryRecord, SessionStore}; use super::prompt::upsert_managed_agent_memory_summary; const MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_step1_system_prompt.md"); const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_step2_system_prompt.md"); const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceCandidate { pub(crate) id: String, pub(crate) namespace: String, pub(crate) key: String, pub(crate) content: String, } #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenancePlan { pub(crate) candidates: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceMerge { pub(crate) source_ids: Vec, pub(crate) namespace: String, pub(crate) memory_key: String, pub(crate) content: String, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceConflict { pub(crate) source_ids: Vec, pub(crate) note: String, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryOrganizationOutput { pub(crate) merges: Vec, pub(crate) conflicts: Vec, pub(crate) low_value_ids: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemorySummaryInput { pub(crate) organized_memories: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct OrganizedMemory { pub(crate) namespace: String, pub(crate) memory_key: String, pub(crate) content: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceScopeResult { pub(crate) scope_key: String, pub(crate) output: MemoryOrganizationOutput, pub(crate) managed_markdown: String, } pub(crate) struct MemoryMaintenanceService { store: Arc, provider_config: LLMProviderConfig, } impl MemoryMaintenanceService { pub(crate) fn new(store: Arc, provider_config: LLMProviderConfig) -> Self { Self { store, provider_config, } } /// 创建记忆整理专用的 provider,使用 memory_maintenance_timeout_secs 作为超时时间 fn create_maintenance_provider( &self, ) -> Result, crate::providers::ProviderError> { let config = &self.provider_config; let runtime_config = ProviderRuntimeConfig { provider_type: config.provider_type.clone(), name: config.name.clone(), base_url: config.base_url.clone(), api_key: config.api_key.clone(), extra_headers: config.extra_headers.clone(), llm_timeout_secs: config.memory_maintenance_timeout_secs, model_id: config.model_id.clone(), temperature: config.temperature, max_tokens: config.max_tokens, model_extra: config.model_extra.clone(), }; create_provider(runtime_config) } pub(crate) fn build_plan_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let memories = self .store .list_memories_for_scope("user", scope_key) .map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?; if memories.is_empty() { return Ok(None); } Ok(Some(build_memory_maintenance_plan(&memories))) } pub(crate) async fn organize_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let Some(plan) = self.build_plan_for_scope(scope_key)? else { return Ok(None); }; self.organize_plan(scope_key, &plan).await.map(Some) } async fn organize_plan( &self, scope_key: &str, plan: &MemoryMaintenancePlan, ) -> Result { let provider = self.create_maintenance_provider().map_err(|err| { AgentError::Other(format!("create maintenance provider error: {}", err)) })?; let request = ChatCompletionRequest { messages: vec![ Message::system(MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT), Message::user( serde_json::to_string_pretty(&serde_json::json!({ "scope_key": scope_key, "candidates": plan, })) .unwrap_or_else(|_| "{}".to_string()), ), ], temperature: Some(0.0), max_tokens: Some(1200), tools: None, }; let mut last_error = None; let mut response = None; for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS .iter() .copied() .map(Some) .chain(std::iter::once(None)) .enumerate() { match provider.chat(request.clone()).await { Ok(success) => { response = Some(success); break; } Err(err) => { let error_text = err.to_string(); let should_retry = delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); last_error = Some(error_text.clone()); if should_retry { tracing::warn!( scope_key = %scope_key, attempt = attempt + 1, retry_in_ms = delay_ms.unwrap_or_default(), error = %error_text, "Memory organization model request failed, retrying" ); tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) .await; continue; } return Err(AgentError::Other(format!( "memory organization model error: {}", error_text ))); } } } let response = response.ok_or_else(|| { AgentError::Other(format!( "memory organization model error: {}", last_error.unwrap_or_else(|| "unknown provider error".to_string()) )) })?; let raw_content = strip_json_code_fence(&response.content); let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content); let output: MemoryOrganizationOutput = serde_json::from_str(json_candidate).map_err(|err| { tracing::error!( scope_key = %scope_key, error = %err, raw_len = raw_content.len(), raw_preview = %preview_text(raw_content, 400), json_candidate_len = json_candidate.len(), json_candidate_preview = %preview_text(json_candidate, 400), "Memory maintenance JSON decode failed" ); AgentError::Other(format!("memory maintenance JSON decode error: {}", err)) })?; Ok(output) } #[cfg_attr(not(test), allow(dead_code))] pub(crate) async fn run_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let Some(plan) = self.build_plan_for_scope(scope_key)? else { return Ok(None); }; // 步骤1:整理记忆(不生成摘要) let organize_output = self.organize_plan(scope_key, &plan).await?; // 应用整理结果(merge和delete) apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?; // 步骤2:从数据库重新读取剩余的记忆 let remaining_memories = self .store .list_memories_for_scope("user", scope_key) .map_err(|err| { AgentError::Other(format!("list remaining memories error: {}", err)) })?; // 步骤2:生成摘要 let managed_markdown = if remaining_memories.is_empty() { String::new() } else { self.generate_summary(scope_key, &remaining_memories).await? }; Ok(Some(MemoryMaintenanceScopeResult { scope_key: scope_key.to_string(), output: organize_output, managed_markdown, })) } async fn generate_summary( &self, scope_key: &str, remaining_memories: &[MemoryRecord], ) -> Result { let provider = self.create_maintenance_provider().map_err(|err| { AgentError::Other(format!("create maintenance provider error: {}", err)) })?; let input = MemorySummaryInput { organized_memories: remaining_memories .iter() .map(|m| OrganizedMemory { namespace: m.namespace.clone(), memory_key: m.memory_key.clone(), content: m.content.clone(), }) .collect(), }; let request = ChatCompletionRequest { messages: vec![ Message::system(MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT), Message::user( serde_json::to_string_pretty(&serde_json::json!(input)) .unwrap_or_else(|_| "{}".to_string()), ), ], temperature: Some(0.0), max_tokens: Some(1000), tools: None, }; let mut last_error = None; let mut response = None; for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS .iter() .copied() .map(Some) .chain(std::iter::once(None)) .enumerate() { match provider.chat(request.clone()).await { Ok(success) => { response = Some(success); break; } Err(err) => { let error_text = err.to_string(); let should_retry = delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); last_error = Some(error_text.clone()); if should_retry { tracing::warn!( scope_key = %scope_key, attempt = attempt + 1, retry_in_ms = delay_ms.unwrap_or_default(), error = %error_text, "Memory summary model request failed, retrying" ); tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) .await; continue; } return Err(AgentError::Other(format!( "memory summary model error: {}", error_text ))); } } } let response = response.ok_or_else(|| { AgentError::Other(format!( "memory summary model error: {}", last_error.unwrap_or_else(|| "unknown provider error".to_string()) )) })?; let content = response.content.trim(); // 确保响应包含标签,如无则自动添加 let tagged_content = if content.contains("") { content.to_string() } else { format!( "\n{}\n", content ) }; Ok(tagged_content) } pub(crate) async fn run_for_all_scopes( &self, ) -> Result, AgentError> { let scope_keys = self.store.list_memory_scope_keys("user").map_err(|err| { AgentError::Other(format!("list memory scope keys error: {}", err)) })?; if scope_keys.is_empty() { return Ok(None); } // 步骤1:逐个 scope 整理记忆(merge/delete),但不生成摘要 let mut all_outputs = Vec::new(); for scope_key in &scope_keys { match self.run_organize_for_scope(scope_key).await { Ok(Some(output)) => all_outputs.push((scope_key.clone(), output)), Ok(None) => continue, Err(error) if is_recoverable_maintenance_scope_error(&error) => { tracing::warn!( scope_key = %scope_key, error = %error, "Memory maintenance skipped scope after recoverable model failure" ); continue; } Err(error) => return Err(error), } } if all_outputs.is_empty() { return Ok(None); } // 步骤2:收集所有 scope 整理后的剩余记忆 let mut all_remaining_memories = Vec::new(); for (scope_key, _) in &all_outputs { let memories = self .store .list_memories_for_scope("user", scope_key) .map_err(|err| { AgentError::Other(format!( "list remaining memories for scope {} error: {}", scope_key, err )) })?; all_remaining_memories.extend(memories); } // 步骤3:统一生成一个摘要 let managed_markdown = if all_remaining_memories.is_empty() { String::new() } else { self.generate_summary("all", &all_remaining_memories).await? }; if !managed_markdown.is_empty() { upsert_managed_agent_memory_summary(&managed_markdown)?; } // 合并所有输出用于返回 let combined_output = MemoryOrganizationOutput { merges: all_outputs .iter() .flat_map(|(_, o)| o.merges.clone()) .collect(), conflicts: all_outputs .iter() .flat_map(|(_, o)| o.conflicts.clone()) .collect(), low_value_ids: all_outputs .iter() .flat_map(|(_, o)| o.low_value_ids.clone()) .collect(), }; Ok(Some(MemoryMaintenanceScopeResult { scope_key: "all".to_string(), output: combined_output, managed_markdown, })) } /// 仅执行整理步骤(organize + apply),不生成摘要 async fn run_organize_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { let Some(plan) = self.build_plan_for_scope(scope_key)? else { return Ok(None); }; // 步骤1:整理记忆 let organize_output = self.organize_plan(scope_key, &plan).await?; // 应用整理结果(merge和delete) apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?; Ok(Some(organize_output)) } } pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan { let mut plan = MemoryMaintenancePlan::default(); let mut seen = HashSet::new(); for memory in memories { let normalized_content = memory.content.trim(); if normalized_content.is_empty() { continue; } let dedupe_key = format!( "{}\u{1f}{}\u{1f}{}", memory.namespace.trim().to_ascii_lowercase(), memory.memory_key.trim().to_ascii_lowercase(), normalized_content ); if !seen.insert(dedupe_key) { continue; } let candidate = MemoryMaintenanceCandidate { id: memory.id.clone(), namespace: memory.namespace.clone(), key: memory.memory_key.clone(), content: normalized_content.to_string(), }; plan.candidates.push(candidate); } plan } pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool { let normalized = error.to_ascii_lowercase(); normalized.contains("error sending request for url") || normalized.contains("504") || normalized.contains("gateway timeout") || normalized.contains("stream timeout") || normalized.contains("timed out") || normalized.contains("timeout") } fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool { is_recoverable_maintenance_llm_error(&error.to_string()) } pub(crate) fn strip_json_code_fence(content: &str) -> &str { let trimmed = content.trim(); if let Some(rest) = trimmed.strip_prefix("```json") { return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); } if let Some(rest) = trimmed.strip_prefix("```") { return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); } trimmed } pub(crate) fn extract_json_object(content: &str) -> Option<&str> { let mut start = None; let mut depth = 0usize; let mut in_string = false; let mut escaped = false; for (index, ch) in content.char_indices() { if in_string { if escaped { escaped = false; continue; } match ch { '\\' => escaped = true, '"' => in_string = false, _ => {} } continue; } match ch { '"' => in_string = true, '{' => { if start.is_none() { start = Some(index); } depth += 1; } '}' => { if depth == 0 { continue; } depth -= 1; if depth == 0 { let start = start?; let end = index + ch.len_utf8(); return Some(content[start..end].trim()); } } _ => {} } } None } pub(crate) fn apply_memory_maintenance_output( store: &SessionStore, scope_key: &str, plan: &MemoryMaintenancePlan, output: &MemoryOrganizationOutput, ) -> Result<(), AgentError> { let all_candidates = plan.candidates.clone(); let candidates_by_id = all_candidates .iter() .map(|candidate| (candidate.id.as_str(), candidate)) .collect::>(); let mut deleted_ids = HashSet::new(); for merge in &output.merges { if merge.source_ids.is_empty() { continue; } let source_candidates = merge .source_ids .iter() .filter_map(|id| candidates_by_id.get(id.as_str()).copied()) .collect::>(); if source_candidates.is_empty() { continue; } let existing_target_id = source_candidates .iter() .find(|candidate| { candidate.namespace == merge.namespace && candidate.key == merge.memory_key }) .map(|candidate| candidate.id.clone()); store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: merge.namespace.trim().to_string(), memory_key: merge.memory_key.trim().to_string(), content: merge.content.trim().to_string(), source_type: "memory_maintenance".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?; for candidate in source_candidates { if existing_target_id .as_ref() .is_some_and(|target_id| target_id == &candidate.id) { continue; } if deleted_ids.insert(candidate.id.clone()) { store .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) .map_err(|err| { AgentError::Other(format!("delete merged source memory error: {}", err)) })?; } } } for memory_id in &output.low_value_ids { if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) { if deleted_ids.insert(candidate.id.clone()) { store .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) .map_err(|err| { AgentError::Other(format!("delete low value memory error: {}", err)) })?; } } } Ok(()) } fn preview_text(content: &str, max_chars: usize) -> String { let mut preview = content.chars().take(max_chars).collect::(); if content.chars().count() > max_chars { preview.push_str("..."); } preview.replace('\n', "\\n") } #[cfg(test)] mod tests {}