902 lines
30 KiB
Rust
902 lines
30 KiB
Rust
use std::collections::{HashMap, HashSet};
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
use serde::{Deserialize, Serialize};
|
||
|
||
use crate::agent::AgentError;
|
||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||
use crate::storage::{MemoryRecord, SessionStore};
|
||
|
||
use super::prompt::upsert_managed_agent_memory_summary;
|
||
|
||
const MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT: &str =
|
||
include_str!("memory_maintenance_step1_system_prompt.md");
|
||
const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
||
include_str!("memory_maintenance_step2_system_prompt.md");
|
||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||
|
||
const META_NAMESPACE: &str = "_meta";
|
||
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceCandidate {
|
||
pub(crate) id: String,
|
||
pub(crate) namespace: String,
|
||
pub(crate) key: String,
|
||
pub(crate) content: String,
|
||
pub(crate) updated_at: i64, // 记忆更新时间(Unix timestamp)
|
||
}
|
||
|
||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenancePlan {
|
||
pub(crate) candidates: Vec<MemoryMaintenanceCandidate>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceMerge {
|
||
pub(crate) source_ids: Vec<String>,
|
||
pub(crate) namespace: String,
|
||
pub(crate) memory_key: String,
|
||
pub(crate) content: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceConflict {
|
||
pub(crate) source_ids: Vec<String>,
|
||
pub(crate) note: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryOrganizationOutput {
|
||
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||
pub(crate) low_value_ids: Vec<String>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemorySummaryInput {
|
||
pub(crate) organized_memories: Vec<OrganizedMemory>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct OrganizedMemory {
|
||
pub(crate) namespace: String,
|
||
pub(crate) memory_key: String,
|
||
pub(crate) content: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceScopeResult {
|
||
pub(crate) scope_key: String,
|
||
pub(crate) output: MemoryOrganizationOutput,
|
||
pub(crate) managed_markdown: String,
|
||
}
|
||
|
||
pub(crate) struct MemoryMaintenanceService {
|
||
store: Arc<SessionStore>,
|
||
provider_config: LLMProviderConfig,
|
||
maintenance_config: MemoryMaintenanceConfig,
|
||
}
|
||
|
||
impl MemoryMaintenanceService {
|
||
pub(crate) fn new(
|
||
store: Arc<SessionStore>,
|
||
provider_config: LLMProviderConfig,
|
||
maintenance_config: MemoryMaintenanceConfig,
|
||
) -> Self {
|
||
Self {
|
||
store,
|
||
provider_config,
|
||
maintenance_config,
|
||
}
|
||
}
|
||
|
||
/// 创建记忆整理专用的 provider,使用 memory_maintenance_timeout_secs 作为超时时间
|
||
fn create_maintenance_provider(
|
||
&self,
|
||
) -> Result<Box<dyn crate::providers::LLMProvider>, crate::providers::ProviderError> {
|
||
let config = &self.provider_config;
|
||
let runtime_config = ProviderRuntimeConfig {
|
||
provider_type: config.provider_type.clone(),
|
||
name: config.name.clone(),
|
||
base_url: config.base_url.clone(),
|
||
api_key: config.api_key.clone(),
|
||
extra_headers: config.extra_headers.clone(),
|
||
llm_timeout_secs: config.memory_maintenance_timeout_secs,
|
||
model_id: config.model_id.clone(),
|
||
temperature: config.temperature,
|
||
max_tokens: config.max_tokens,
|
||
model_extra: config.model_extra.clone(),
|
||
};
|
||
create_provider(runtime_config)
|
||
}
|
||
|
||
pub(crate) fn build_plan_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||
// 新增:检查是否有新记忆需要整理
|
||
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
|
||
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
|
||
return Ok(None);
|
||
}
|
||
|
||
let memories = self
|
||
.store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||
|
||
if memories.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
// 记忆数量不足最小保留数时,无需整理,直接跳过
|
||
// 避免浪费 LLM token 并触发无意义的 "保留数不足" 错误
|
||
if memories.len() < self.maintenance_config.min_memories_to_keep {
|
||
tracing::info!(
|
||
scope_key = %scope_key,
|
||
count = memories.len(),
|
||
min_required = self.maintenance_config.min_memories_to_keep,
|
||
"Skipping scope: not enough memories to organize"
|
||
);
|
||
return Ok(None);
|
||
}
|
||
|
||
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||
}
|
||
|
||
pub(crate) async fn organize_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
|
||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||
return Ok(None);
|
||
};
|
||
|
||
self.organize_plan(scope_key, &plan).await.map(Some)
|
||
}
|
||
|
||
async fn organize_plan(
|
||
&self,
|
||
scope_key: &str,
|
||
plan: &MemoryMaintenancePlan,
|
||
) -> Result<MemoryOrganizationOutput, AgentError> {
|
||
let provider = self.create_maintenance_provider().map_err(|err| {
|
||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||
})?;
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![
|
||
Message::system(MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT),
|
||
Message::user(
|
||
serde_json::to_string_pretty(&serde_json::json!({
|
||
"scope_key": scope_key,
|
||
"candidates": plan,
|
||
}))
|
||
.unwrap_or_else(|_| "{}".to_string()),
|
||
),
|
||
],
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(4000),
|
||
tools: None,
|
||
};
|
||
|
||
let mut last_error = None;
|
||
|
||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||
.iter()
|
||
.copied()
|
||
.map(Some)
|
||
.chain(std::iter::once(None))
|
||
.enumerate()
|
||
{
|
||
let response = match provider.chat(request.clone()).await {
|
||
Ok(success) => success,
|
||
Err(err) => {
|
||
let error_text = err.to_string();
|
||
let should_retry =
|
||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||
last_error = Some(error_text.clone());
|
||
|
||
if should_retry {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
attempt = attempt + 1,
|
||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||
error = %error_text,
|
||
"Memory organization model request failed, retrying"
|
||
);
|
||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||
.await;
|
||
continue;
|
||
}
|
||
|
||
return Err(AgentError::Other(format!(
|
||
"memory organization model error: {}",
|
||
error_text
|
||
)));
|
||
}
|
||
};
|
||
|
||
let raw_content = strip_json_code_fence(&response.content);
|
||
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||
|
||
match serde_json::from_str::<MemoryOrganizationOutput>(json_candidate) {
|
||
Ok(parsed) => return Ok(parsed),
|
||
Err(err) => {
|
||
let error_msg = err.to_string();
|
||
let is_truncated = error_msg.contains("EOF while parsing")
|
||
|| error_msg.contains("expected");
|
||
|
||
let should_retry = delay_ms.is_some() && is_truncated;
|
||
last_error = Some(error_msg.clone());
|
||
|
||
if should_retry {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
attempt = attempt + 1,
|
||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||
error = %error_msg,
|
||
raw_len = raw_content.len(),
|
||
"Memory organization JSON parse failed (possibly truncated), retrying"
|
||
);
|
||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||
.await;
|
||
continue;
|
||
}
|
||
|
||
tracing::error!(
|
||
scope_key = %scope_key,
|
||
error = %err,
|
||
raw_len = raw_content.len(),
|
||
raw_preview = %preview_text(raw_content, 400),
|
||
json_candidate_len = json_candidate.len(),
|
||
json_candidate_preview = %preview_text(json_candidate, 400),
|
||
"Memory maintenance JSON decode failed"
|
||
);
|
||
return Err(AgentError::Other(format!(
|
||
"memory maintenance JSON decode error: {}",
|
||
err
|
||
)));
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(AgentError::Other(format!(
|
||
"memory organization failed after retries: {}",
|
||
last_error.unwrap_or_else(|| "unknown error".to_string())
|
||
)))
|
||
}
|
||
|
||
#[cfg_attr(not(test), allow(dead_code))]
|
||
pub(crate) async fn run_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
|
||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||
return Ok(None);
|
||
};
|
||
|
||
// 步骤1:整理记忆(不生成摘要)
|
||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||
|
||
// 应用整理结果(merge和delete)
|
||
apply_memory_maintenance_output(
|
||
self.store.as_ref(),
|
||
scope_key,
|
||
&plan,
|
||
&organize_output,
|
||
self.maintenance_config.max_merge_ratio,
|
||
self.maintenance_config.min_memories_to_keep,
|
||
self.maintenance_config.max_merge_per_group,
|
||
)?;
|
||
|
||
// 步骤2:从数据库重新读取剩余的记忆
|
||
let remaining_memories = self
|
||
.store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("list remaining memories error: {}", err))
|
||
})?;
|
||
|
||
// 步骤2:生成摘要
|
||
let managed_markdown = if remaining_memories.is_empty() {
|
||
String::new()
|
||
} else {
|
||
self.generate_summary(scope_key, &remaining_memories).await?
|
||
};
|
||
|
||
Ok(Some(MemoryMaintenanceScopeResult {
|
||
scope_key: scope_key.to_string(),
|
||
output: organize_output,
|
||
managed_markdown,
|
||
}))
|
||
}
|
||
|
||
async fn generate_summary(
|
||
&self,
|
||
scope_key: &str,
|
||
remaining_memories: &[MemoryRecord],
|
||
) -> Result<String, AgentError> {
|
||
let provider = self.create_maintenance_provider().map_err(|err| {
|
||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||
})?;
|
||
|
||
let input = MemorySummaryInput {
|
||
organized_memories: remaining_memories
|
||
.iter()
|
||
.map(|m| OrganizedMemory {
|
||
namespace: m.namespace.clone(),
|
||
memory_key: m.memory_key.clone(),
|
||
content: m.content.clone(),
|
||
})
|
||
.collect(),
|
||
};
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![
|
||
Message::system(MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT),
|
||
Message::user(
|
||
serde_json::to_string_pretty(&serde_json::json!(input))
|
||
.unwrap_or_else(|_| "{}".to_string()),
|
||
),
|
||
],
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(1000),
|
||
tools: None,
|
||
};
|
||
|
||
let mut last_error = None;
|
||
let mut response = None;
|
||
|
||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||
.iter()
|
||
.copied()
|
||
.map(Some)
|
||
.chain(std::iter::once(None))
|
||
.enumerate()
|
||
{
|
||
match provider.chat(request.clone()).await {
|
||
Ok(success) => {
|
||
response = Some(success);
|
||
break;
|
||
}
|
||
Err(err) => {
|
||
let error_text = err.to_string();
|
||
let should_retry =
|
||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||
last_error = Some(error_text.clone());
|
||
|
||
if should_retry {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
attempt = attempt + 1,
|
||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||
error = %error_text,
|
||
"Memory summary model request failed, retrying"
|
||
);
|
||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||
.await;
|
||
continue;
|
||
}
|
||
|
||
return Err(AgentError::Other(format!(
|
||
"memory summary model error: {}",
|
||
error_text
|
||
)));
|
||
}
|
||
}
|
||
}
|
||
|
||
let response = response.ok_or_else(|| {
|
||
AgentError::Other(format!(
|
||
"memory summary model error: {}",
|
||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||
))
|
||
})?;
|
||
|
||
let content = response.content.trim();
|
||
// 确保响应包含标签,如无则自动添加
|
||
let tagged_content = if content.contains("<!-- memory_summary_start -->") {
|
||
content.to_string()
|
||
} else {
|
||
format!(
|
||
"<!-- memory_summary_start -->\n{}\n<!-- memory_summary_end -->",
|
||
content
|
||
)
|
||
};
|
||
|
||
Ok(tagged_content)
|
||
}
|
||
|
||
pub(crate) async fn run_for_all_scopes(
|
||
&self,
|
||
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
|
||
let scope_keys = self.store.list_memory_scope_keys("user").map_err(|err| {
|
||
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||
})?;
|
||
|
||
if scope_keys.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
// 步骤1:逐个 scope 整理记忆(merge/delete),但不生成摘要
|
||
let mut all_outputs = Vec::new();
|
||
for scope_key in &scope_keys {
|
||
match self.run_organize_for_scope(scope_key).await {
|
||
Ok(Some(output)) => all_outputs.push((scope_key.clone(), output)),
|
||
Ok(None) => continue,
|
||
Err(error) if is_recoverable_maintenance_scope_error(&error) => {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
error = %error,
|
||
"Memory maintenance skipped scope after recoverable model failure"
|
||
);
|
||
continue;
|
||
}
|
||
Err(error) => return Err(error),
|
||
}
|
||
}
|
||
|
||
if all_outputs.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
// 步骤2:收集所有 scope 整理后的剩余记忆
|
||
let mut all_remaining_memories = Vec::new();
|
||
for (scope_key, _) in &all_outputs {
|
||
let memories = self
|
||
.store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!(
|
||
"list remaining memories for scope {} error: {}",
|
||
scope_key, err
|
||
))
|
||
})?;
|
||
all_remaining_memories.extend(memories);
|
||
}
|
||
|
||
// 步骤3:统一生成一个摘要
|
||
let managed_markdown = if all_remaining_memories.is_empty() {
|
||
String::new()
|
||
} else {
|
||
self.generate_summary("all", &all_remaining_memories).await?
|
||
};
|
||
|
||
if !managed_markdown.is_empty() {
|
||
upsert_managed_agent_memory_summary(&managed_markdown)?;
|
||
}
|
||
|
||
// 合并所有输出用于返回
|
||
let combined_output = MemoryOrganizationOutput {
|
||
merges: all_outputs
|
||
.iter()
|
||
.flat_map(|(_, o)| o.merges.clone())
|
||
.collect(),
|
||
conflicts: all_outputs
|
||
.iter()
|
||
.flat_map(|(_, o)| o.conflicts.clone())
|
||
.collect(),
|
||
low_value_ids: all_outputs
|
||
.iter()
|
||
.flat_map(|(_, o)| o.low_value_ids.clone())
|
||
.collect(),
|
||
};
|
||
|
||
Ok(Some(MemoryMaintenanceScopeResult {
|
||
scope_key: "all".to_string(),
|
||
output: combined_output,
|
||
managed_markdown,
|
||
}))
|
||
}
|
||
|
||
/// 仅执行整理步骤(organize + apply),不生成摘要
|
||
async fn run_organize_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
|
||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||
return Ok(None);
|
||
};
|
||
|
||
// 步骤1:整理记忆
|
||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||
|
||
// 应用整理结果(merge和delete)
|
||
apply_memory_maintenance_output(
|
||
self.store.as_ref(),
|
||
scope_key,
|
||
&plan,
|
||
&organize_output,
|
||
self.maintenance_config.max_merge_ratio,
|
||
self.maintenance_config.min_memories_to_keep,
|
||
self.maintenance_config.max_merge_per_group,
|
||
)?;
|
||
|
||
Ok(Some(organize_output))
|
||
}
|
||
}
|
||
|
||
/// 获取上次整理时间
|
||
fn get_last_maintenance_time(
|
||
store: &SessionStore,
|
||
scope_key: &str,
|
||
) -> Result<Option<i64>, crate::storage::StorageError> {
|
||
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
|
||
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
|
||
}
|
||
|
||
/// 记录本次整理时间
|
||
fn set_last_maintenance_time(
|
||
store: &SessionStore,
|
||
scope_key: &str,
|
||
time: i64,
|
||
) -> Result<(), crate::storage::StorageError> {
|
||
store.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: META_NAMESPACE.to_string(),
|
||
memory_key: LAST_MAINTENANCE_KEY.to_string(),
|
||
content: time.to_string(),
|
||
source_type: "memory_maintenance".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace)
|
||
fn has_new_memories_since_last_maintenance(
|
||
store: &SessionStore,
|
||
scope_key: &str,
|
||
) -> Result<bool, AgentError> {
|
||
let memories = store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
|
||
|
||
// 过滤掉 _meta namespace 的记忆
|
||
let user_memories: Vec<_> = memories
|
||
.iter()
|
||
.filter(|m| m.namespace != META_NAMESPACE)
|
||
.collect();
|
||
|
||
if user_memories.is_empty() {
|
||
return Ok(false); // 没有记忆,跳过
|
||
}
|
||
|
||
// 获取上次整理时间
|
||
let last_time = get_last_maintenance_time(store, scope_key)
|
||
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
|
||
|
||
match last_time {
|
||
None => Ok(true), // 从未整理过,需要整理
|
||
Some(last) => {
|
||
// 检查是否有记忆在上次整理后更新
|
||
let has_new = user_memories.iter().any(|m| m.updated_at > last);
|
||
Ok(has_new)
|
||
}
|
||
}
|
||
}
|
||
|
||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||
let mut plan = MemoryMaintenancePlan::default();
|
||
let mut seen = HashSet::new();
|
||
|
||
for memory in memories {
|
||
// 过滤掉 _meta namespace 的记忆
|
||
if memory.namespace == META_NAMESPACE {
|
||
continue;
|
||
}
|
||
|
||
let normalized_content = memory.content.trim();
|
||
if normalized_content.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let dedupe_key = format!(
|
||
"{}\u{1f}{}\u{1f}{}",
|
||
memory.namespace.trim().to_ascii_lowercase(),
|
||
memory.memory_key.trim().to_ascii_lowercase(),
|
||
normalized_content
|
||
);
|
||
if !seen.insert(dedupe_key) {
|
||
continue;
|
||
}
|
||
|
||
let candidate = MemoryMaintenanceCandidate {
|
||
id: memory.id.clone(),
|
||
namespace: memory.namespace.clone(),
|
||
key: memory.memory_key.clone(),
|
||
content: normalized_content.to_string(),
|
||
updated_at: memory.updated_at,
|
||
};
|
||
|
||
plan.candidates.push(candidate);
|
||
}
|
||
|
||
plan
|
||
}
|
||
|
||
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||
let normalized = error.to_ascii_lowercase();
|
||
normalized.contains("error sending request for url")
|
||
|| normalized.contains("504")
|
||
|| normalized.contains("gateway timeout")
|
||
|| normalized.contains("stream timeout")
|
||
|| normalized.contains("timed out")
|
||
|| normalized.contains("timeout")
|
||
// 验证拒绝 — 记忆太少,跳过本次 scope 即可,不应视为作业失败
|
||
|| error.contains("保留数不足")
|
||
|| error.contains("合并比例超限")
|
||
}
|
||
|
||
fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool {
|
||
is_recoverable_maintenance_llm_error(&error.to_string())
|
||
}
|
||
|
||
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
|
||
let trimmed = content.trim();
|
||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||
}
|
||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||
}
|
||
trimmed
|
||
}
|
||
|
||
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||
let mut start = None;
|
||
let mut depth = 0usize;
|
||
let mut in_string = false;
|
||
let mut escaped = false;
|
||
|
||
for (index, ch) in content.char_indices() {
|
||
if in_string {
|
||
if escaped {
|
||
escaped = false;
|
||
continue;
|
||
}
|
||
match ch {
|
||
'\\' => escaped = true,
|
||
'"' => in_string = false,
|
||
_ => {}
|
||
}
|
||
continue;
|
||
}
|
||
|
||
match ch {
|
||
'"' => in_string = true,
|
||
'{' => {
|
||
if start.is_none() {
|
||
start = Some(index);
|
||
}
|
||
depth += 1;
|
||
}
|
||
'}' => {
|
||
if depth == 0 {
|
||
continue;
|
||
}
|
||
depth -= 1;
|
||
if depth == 0 {
|
||
let start = start?;
|
||
let end = index + ch.len_utf8();
|
||
return Some(content[start..end].trim());
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
/// 验证记忆整理输出是否符合限制
|
||
pub(crate) fn validate_memory_maintenance_output(
|
||
plan: &MemoryMaintenancePlan,
|
||
output: &MemoryOrganizationOutput,
|
||
max_merge_ratio: f32,
|
||
min_memories_to_keep: usize,
|
||
max_merge_per_group: usize,
|
||
) -> Result<(), String> {
|
||
let total = plan.candidates.len();
|
||
|
||
if total == 0 {
|
||
return Ok(()); // 没有候选,无需验证
|
||
}
|
||
|
||
// 验证 1: 单次合并数量限制
|
||
for merge in &output.merges {
|
||
if merge.source_ids.len() > max_merge_per_group {
|
||
return Err(format!(
|
||
"合并组过大: {} 条源记忆超过上限 {}",
|
||
merge.source_ids.len(),
|
||
max_merge_per_group
|
||
));
|
||
}
|
||
}
|
||
|
||
// 验证 2: 跨 namespace 合并检测(完全禁止)
|
||
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
|
||
.candidates
|
||
.iter()
|
||
.map(|c| (c.id.as_str(), c))
|
||
.collect();
|
||
|
||
for merge in &output.merges {
|
||
let source_namespaces: HashSet<&str> = merge
|
||
.source_ids
|
||
.iter()
|
||
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
|
||
.collect();
|
||
|
||
// 检查是否跨越多个 namespace
|
||
if source_namespaces.len() > 1 {
|
||
return Err(format!(
|
||
"跨 namespace 合并被禁止: 源来自 {}",
|
||
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
|
||
));
|
||
}
|
||
|
||
// 检查目标 namespace 是否与源一致
|
||
if let Some(src_ns) = source_namespaces.iter().next() {
|
||
if *src_ns != merge.namespace {
|
||
return Err(format!(
|
||
"跨 namespace 合并被禁止: {} → {}",
|
||
src_ns, merge.namespace
|
||
));
|
||
}
|
||
}
|
||
}
|
||
|
||
// 验证 3: 总体合并比例
|
||
let merged_ids: HashSet<&str> = output
|
||
.merges
|
||
.iter()
|
||
.flat_map(|m| m.source_ids.iter())
|
||
.map(|s| s.as_str())
|
||
.collect();
|
||
|
||
let deleted_ids: HashSet<&str> = output
|
||
.low_value_ids
|
||
.iter()
|
||
.map(|s| s.as_str())
|
||
.collect();
|
||
|
||
let affected = merged_ids.len() + deleted_ids.len();
|
||
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
|
||
|
||
if affected > max_allowed {
|
||
return Err(format!(
|
||
"合并比例超限: {} / {} > {:.0}%",
|
||
affected,
|
||
total,
|
||
max_merge_ratio * 100.0
|
||
));
|
||
}
|
||
|
||
// 验证 4: 最小保留数
|
||
let remaining = total - affected + output.merges.len();
|
||
if remaining < min_memories_to_keep {
|
||
return Err(format!(
|
||
"保留数不足: {} < {}",
|
||
remaining, min_memories_to_keep
|
||
));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub(crate) fn apply_memory_maintenance_output(
|
||
store: &SessionStore,
|
||
scope_key: &str,
|
||
plan: &MemoryMaintenancePlan,
|
||
output: &MemoryOrganizationOutput,
|
||
max_merge_ratio: f32,
|
||
min_memories_to_keep: usize,
|
||
max_merge_per_group: usize,
|
||
) -> Result<(), AgentError> {
|
||
// 新增: 验证合并输出
|
||
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
|
||
.map_err(|e| AgentError::Other(e))?;
|
||
|
||
let all_candidates = plan.candidates.clone();
|
||
|
||
let candidates_by_id = all_candidates
|
||
.iter()
|
||
.map(|candidate| (candidate.id.as_str(), candidate))
|
||
.collect::<HashMap<_, _>>();
|
||
|
||
let mut deleted_ids = HashSet::new();
|
||
|
||
for merge in &output.merges {
|
||
if merge.source_ids.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let source_candidates = merge
|
||
.source_ids
|
||
.iter()
|
||
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||
.collect::<Vec<_>>();
|
||
if source_candidates.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let existing_target_id = source_candidates
|
||
.iter()
|
||
.find(|candidate| {
|
||
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||
})
|
||
.map(|candidate| candidate.id.clone());
|
||
|
||
store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: merge.namespace.trim().to_string(),
|
||
memory_key: merge.memory_key.trim().to_string(),
|
||
content: merge.content.trim().to_string(),
|
||
source_type: "memory_maintenance".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||
|
||
for candidate in source_candidates {
|
||
if existing_target_id
|
||
.as_ref()
|
||
.is_some_and(|target_id| target_id == &candidate.id)
|
||
{
|
||
continue;
|
||
}
|
||
if deleted_ids.insert(candidate.id.clone()) {
|
||
store
|
||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||
})?;
|
||
}
|
||
}
|
||
}
|
||
|
||
for memory_id in &output.low_value_ids {
|
||
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||
if deleted_ids.insert(candidate.id.clone()) {
|
||
store
|
||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("delete low value memory error: {}", err))
|
||
})?;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 新增:记录整理完成时间
|
||
let now = chrono::Utc::now().timestamp();
|
||
set_last_maintenance_time(store, scope_key, now)
|
||
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||
if content.chars().count() > max_chars {
|
||
preview.push_str("...");
|
||
}
|
||
preview.replace('\n', "\\n")
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {}
|