feat: 添加内存维护配置,优化记忆整理逻辑和限制

This commit is contained in:
ooodc 2026-05-23 16:05:11 +08:00
parent 44d9171b86
commit b4ef56803f
9 changed files with 352 additions and 19 deletions

View File

@ -75,6 +75,7 @@ impl InitWizard {
channels: HashMap::new(), channels: HashMap::new(),
skills: crate::config::SkillsConfig::default(), skills: crate::config::SkillsConfig::default(),
tools: crate::config::ToolsConfig::default(), tools: crate::config::ToolsConfig::default(),
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
} }
} }
@ -824,6 +825,7 @@ impl InitWizard {
client: existing.client.clone(), client: existing.client.clone(),
skills: existing.skills.clone(), skills: existing.skills.clone(),
tools: existing.tools.clone(), tools: existing.tools.clone(),
memory_maintenance: existing.memory_maintenance.clone(),
} }
} }

View File

@ -27,6 +27,8 @@ pub struct Config {
pub skills: SkillsConfig, pub skills: SkillsConfig,
#[serde(default)] #[serde(default)]
pub tools: ToolsConfig, pub tools: ToolsConfig,
#[serde(default)]
pub memory_maintenance: MemoryMaintenanceConfig,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
@ -108,6 +110,41 @@ pub struct ToolsConfig {
pub task: TaskConfig, pub task: TaskConfig,
} }
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MemoryMaintenanceConfig {
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
#[serde(default = "default_max_merge_ratio")]
pub max_merge_ratio: f32,
/// 最小保留记忆数量,默认 5
#[serde(default = "default_min_memories_to_keep")]
pub min_memories_to_keep: usize,
/// 单次合并最大源记忆数,默认 3
#[serde(default = "default_max_merge_per_group")]
pub max_merge_per_group: usize,
}
fn default_max_merge_ratio() -> f32 {
0.3
}
fn default_min_memories_to_keep() -> usize {
5
}
fn default_max_merge_per_group() -> usize {
3
}
impl Default for MemoryMaintenanceConfig {
fn default() -> Self {
Self {
max_merge_ratio: default_max_merge_ratio(),
min_memories_to_keep: default_min_memories_to_keep(),
max_merge_per_group: default_max_merge_per_group(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TaskConfig { pub struct TaskConfig {
#[serde(default = "default_task_enabled")] #[serde(default = "default_task_enabled")]

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::config::LLMProviderConfig; use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider}; use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
use crate::storage::{MemoryRecord, SessionStore}; use crate::storage::{MemoryRecord, SessionStore};
@ -17,12 +17,16 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step2_system_prompt.md"); include_str!("memory_maintenance_step2_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
const META_NAMESPACE: &str = "_meta";
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate { pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String, pub(crate) id: String,
pub(crate) namespace: String, pub(crate) namespace: String,
pub(crate) key: String, pub(crate) key: String,
pub(crate) content: String, pub(crate) content: String,
pub(crate) updated_at: i64, // 记忆更新时间Unix timestamp
} }
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
@ -73,13 +77,19 @@ pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) struct MemoryMaintenanceService { pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>, store: Arc<SessionStore>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
maintenance_config: MemoryMaintenanceConfig,
} }
impl MemoryMaintenanceService { impl MemoryMaintenanceService {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self { pub(crate) fn new(
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
maintenance_config: MemoryMaintenanceConfig,
) -> Self {
Self { Self {
store, store,
provider_config, provider_config,
maintenance_config,
} }
} }
@ -107,6 +117,12 @@ impl MemoryMaintenanceService {
&self, &self,
scope_key: &str, scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> { ) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
// 新增:检查是否有新记忆需要整理
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
return Ok(None);
}
let memories = self let memories = self
.store .store
.list_memories_for_scope("user", scope_key) .list_memories_for_scope("user", scope_key)
@ -255,7 +271,15 @@ impl MemoryMaintenanceService {
let organize_output = self.organize_plan(scope_key, &plan).await?; let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete // 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?; apply_memory_maintenance_output(
self.store.as_ref(),
scope_key,
&plan,
&organize_output,
self.maintenance_config.max_merge_ratio,
self.maintenance_config.min_memories_to_keep,
self.maintenance_config.max_merge_per_group,
)?;
// 步骤2从数据库重新读取剩余的记忆 // 步骤2从数据库重新读取剩余的记忆
let remaining_memories = self let remaining_memories = self
@ -470,17 +494,94 @@ impl MemoryMaintenanceService {
let organize_output = self.organize_plan(scope_key, &plan).await?; let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete // 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?; apply_memory_maintenance_output(
self.store.as_ref(),
scope_key,
&plan,
&organize_output,
self.maintenance_config.max_merge_ratio,
self.maintenance_config.min_memories_to_keep,
self.maintenance_config.max_merge_per_group,
)?;
Ok(Some(organize_output)) Ok(Some(organize_output))
} }
} }
/// 获取上次整理时间
fn get_last_maintenance_time(
store: &SessionStore,
scope_key: &str,
) -> Result<Option<i64>, crate::storage::StorageError> {
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
}
/// 记录本次整理时间
fn set_last_maintenance_time(
store: &SessionStore,
scope_key: &str,
time: i64,
) -> Result<(), crate::storage::StorageError> {
store.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: META_NAMESPACE.to_string(),
memory_key: LAST_MAINTENANCE_KEY.to_string(),
content: time.to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})?;
Ok(())
}
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace
fn has_new_memories_since_last_maintenance(
store: &SessionStore,
scope_key: &str,
) -> Result<bool, AgentError> {
let memories = store
.list_memories_for_scope("user", scope_key)
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
// 过滤掉 _meta namespace 的记忆
let user_memories: Vec<_> = memories
.iter()
.filter(|m| m.namespace != META_NAMESPACE)
.collect();
if user_memories.is_empty() {
return Ok(false); // 没有记忆,跳过
}
// 获取上次整理时间
let last_time = get_last_maintenance_time(store, scope_key)
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
match last_time {
None => Ok(true), // 从未整理过,需要整理
Some(last) => {
// 检查是否有记忆在上次整理后更新
let has_new = user_memories.iter().any(|m| m.updated_at > last);
Ok(has_new)
}
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan { pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default(); let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new(); let mut seen = HashSet::new();
for memory in memories { for memory in memories {
// 过滤掉 _meta namespace 的记忆
if memory.namespace == META_NAMESPACE {
continue;
}
let normalized_content = memory.content.trim(); let normalized_content = memory.content.trim();
if normalized_content.is_empty() { if normalized_content.is_empty() {
continue; continue;
@ -501,6 +602,7 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
namespace: memory.namespace.clone(), namespace: memory.namespace.clone(),
key: memory.memory_key.clone(), key: memory.memory_key.clone(),
content: normalized_content.to_string(), content: normalized_content.to_string(),
updated_at: memory.updated_at,
}; };
plan.candidates.push(candidate); plan.candidates.push(candidate);
@ -580,12 +682,115 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
None None
} }
/// 验证记忆整理输出是否符合限制
pub(crate) fn validate_memory_maintenance_output(
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
max_merge_ratio: f32,
min_memories_to_keep: usize,
max_merge_per_group: usize,
) -> Result<(), String> {
let total = plan.candidates.len();
if total == 0 {
return Ok(()); // 没有候选,无需验证
}
// 验证 1: 单次合并数量限制
for merge in &output.merges {
if merge.source_ids.len() > max_merge_per_group {
return Err(format!(
"合并组过大: {} 条源记忆超过上限 {}",
merge.source_ids.len(),
max_merge_per_group
));
}
}
// 验证 2: 跨 namespace 合并检测(完全禁止)
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
.candidates
.iter()
.map(|c| (c.id.as_str(), c))
.collect();
for merge in &output.merges {
let source_namespaces: HashSet<&str> = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
.collect();
// 检查是否跨越多个 namespace
if source_namespaces.len() > 1 {
return Err(format!(
"跨 namespace 合并被禁止: 源来自 {}",
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
));
}
// 检查目标 namespace 是否与源一致
if let Some(src_ns) = source_namespaces.iter().next() {
if *src_ns != merge.namespace {
return Err(format!(
"跨 namespace 合并被禁止: {} → {}",
src_ns, merge.namespace
));
}
}
}
// 验证 3: 总体合并比例
let merged_ids: HashSet<&str> = output
.merges
.iter()
.flat_map(|m| m.source_ids.iter())
.map(|s| s.as_str())
.collect();
let deleted_ids: HashSet<&str> = output
.low_value_ids
.iter()
.map(|s| s.as_str())
.collect();
let affected = merged_ids.len() + deleted_ids.len();
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
if affected > max_allowed {
return Err(format!(
"合并比例超限: {} / {} > {:.0}%",
affected,
total,
max_merge_ratio * 100.0
));
}
// 验证 4: 最小保留数
let remaining = total - affected + output.merges.len();
if remaining < min_memories_to_keep {
return Err(format!(
"保留数不足: {} < {}",
remaining, min_memories_to_keep
));
}
Ok(())
}
pub(crate) fn apply_memory_maintenance_output( pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore, store: &SessionStore,
scope_key: &str, scope_key: &str,
plan: &MemoryMaintenancePlan, plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput, output: &MemoryOrganizationOutput,
max_merge_ratio: f32,
min_memories_to_keep: usize,
max_merge_per_group: usize,
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
// 新增: 验证合并输出
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
.map_err(|e| AgentError::Other(e))?;
let all_candidates = plan.candidates.clone(); let all_candidates = plan.candidates.clone();
let candidates_by_id = all_candidates let candidates_by_id = all_candidates
@ -661,6 +866,11 @@ pub(crate) fn apply_memory_maintenance_output(
} }
} }
// 新增:记录整理完成时间
let now = chrono::Utc::now().timestamp();
set_last_maintenance_time(store, scope_key, now)
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
Ok(()) Ok(())
} }

View File

@ -40,6 +40,7 @@ impl MemoryMaintenanceCoordinator {
Ok(MemoryMaintenanceService::new( Ok(MemoryMaintenanceService::new(
self.store.clone(), self.store.clone(),
self.provider_configs.default_provider_config(), self.provider_configs.default_provider_config(),
self.provider_configs.default_maintenance_config(),
)) ))
} }
} }

View File

@ -27,13 +27,30 @@
- note: 冲突说明 - note: 冲突说明
- low_value_ids需要删除的低价值候选记忆 ID 数组 - low_value_ids需要删除的低价值候选记忆 ID 数组
组织原则(由你自主决定) 组织原则:
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类 - 根据记忆的语义内容自然分组
- 相似的、互补的记忆可以合并 - **每次合并最多只能合并 2-3 条源记忆**
- 过期、重复、过细的记忆可以标记为低价值 - **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
- 过期、重复、过细的记忆可以标记为低值
- namespace 和 memory_key 的命名应当简洁、有意义 - namespace 和 memory_key 的命名应当简洁、有意义
- 可以自由创建新的 namespace 来组织相关记忆 - **保守原则:宁可保留稍多,不可过度合并**
- **必须保留足够数量的记忆,确保信息多样性**
时间权重原则(关键):
- 每个候选记忆包含 `updated_at` 时间戳Unix timestamp
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
- 如果新旧记忆内容完全相同,保留新的,删除旧的
- 时间戳数值越大表示越新(离当前时间越近)
合并限制(硬性约束,由系统强制检查):
- 单次合并最多来自 3 条源记忆
- 整理后保留的记忆数不得少于 5 条
- 单次整理最多影响 30% 的记忆
- 不同 namespace 的记忆不允许互相合并
额外约束: 额外约束:

View File

@ -84,6 +84,7 @@ impl GatewayState {
Arc::new(BusSessionMessageSender::new(bus.clone())), Arc::new(BusSessionMessageSender::new(bus.clone())),
std::collections::HashSet::new(), std::collections::HashSet::new(),
config.tools.task.clone(), config.tools.task.clone(),
config.memory_maintenance.clone(),
chat_history_ttl_hours, chat_history_ttl_hours,
session_ttl_hours, session_ttl_hours,
)?; )?;

View File

@ -2,22 +2,25 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::config::LLMProviderConfig; use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct ProviderConfigService { pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig, default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>, provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
maintenance_config: MemoryMaintenanceConfig,
} }
impl ProviderConfigService { impl ProviderConfigService {
pub(crate) fn new( pub(crate) fn new(
default_provider_config: LLMProviderConfig, default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>, provider_configs: HashMap<String, LLMProviderConfig>,
maintenance_config: MemoryMaintenanceConfig,
) -> Self { ) -> Self {
Self { Self {
default_provider_config, default_provider_config,
provider_configs: Arc::new(provider_configs), provider_configs: Arc::new(provider_configs),
maintenance_config,
} }
} }
@ -37,6 +40,10 @@ impl ProviderConfigService {
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig { pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone() self.default_provider_config.clone()
} }
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
self.maintenance_config.clone()
}
} }
#[cfg(test)] #[cfg(test)]
@ -72,6 +79,7 @@ mod tests {
"planner".to_string(), "planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"), test_provider_config_named("planner-provider", "planner-model"),
)]), )]),
MemoryMaintenanceConfig::default(),
); );
let selected = service.select(Some("planner")).unwrap(); let selected = service.select(Some("planner")).unwrap();
@ -82,7 +90,11 @@ mod tests {
#[test] #[test]
fn test_select_falls_back_to_default() { fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model"); let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(default_provider, HashMap::new()); let service = ProviderConfigService::new(
default_provider,
HashMap::new(),
MemoryMaintenanceConfig::default(),
);
let selected = service.select(Some("default")).unwrap(); let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider"); assert_eq!(selected.name, "default-provider");

View File

@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, TaskConfig}; use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
use crate::gateway::tool_registry_factory::ToolRegistryFactory; use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{ use crate::storage::{
@ -34,6 +34,7 @@ pub(crate) fn build_session_manager(
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
disabled_tools: HashSet<String>, disabled_tools: HashSet<String>,
task_config: TaskConfig, task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>, session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> { ) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
@ -47,6 +48,7 @@ pub(crate) fn build_session_manager(
Arc::new(NoopSessionMessageSender), Arc::new(NoopSessionMessageSender),
disabled_tools, disabled_tools,
task_config, task_config,
maintenance_config,
chat_history_ttl_hours, chat_history_ttl_hours,
session_ttl_hours, session_ttl_hours,
) )
@ -62,6 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
session_message_sender: Arc<dyn SessionMessageSender>, session_message_sender: Arc<dyn SessionMessageSender>,
disabled_tools: HashSet<String>, disabled_tools: HashSet<String>,
task_config: TaskConfig, task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>, session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> { ) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
@ -70,7 +73,11 @@ pub(crate) fn build_session_manager_with_sender(
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
); );
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>(); let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs); let provider_configs = ProviderConfigService::new(
provider_config.clone(),
provider_configs,
maintenance_config,
);
if let Err(err) = if let Err(err) =
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())

View File

@ -501,6 +501,7 @@ impl SessionManager {
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
disabled_tools: std::collections::HashSet<String>, disabled_tools: std::collections::HashSet<String>,
task_config: crate::config::TaskConfig, task_config: crate::config::TaskConfig,
maintenance_config: crate::config::MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>, chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>, session_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
@ -513,6 +514,7 @@ impl SessionManager {
skills, skills,
disabled_tools, disabled_tools,
task_config, task_config,
maintenance_config,
chat_history_ttl_hours, chat_history_ttl_hours,
session_ttl_hours, session_ttl_hours,
) )
@ -973,6 +975,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1025,6 +1028,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1093,6 +1097,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1178,6 +1183,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1264,6 +1270,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1349,6 +1356,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1416,6 +1424,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1492,6 +1501,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1555,6 +1565,7 @@ mod tests {
Arc::new(SkillRuntime::default()), Arc::new(SkillRuntime::default()),
HashSet::new(), HashSet::new(),
crate::config::TaskConfig::default(), crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4), Some(4),
Some(24), Some(24),
) )
@ -1593,6 +1604,8 @@ mod tests {
let store = SessionStore::in_memory().unwrap(); let store = SessionStore::in_memory().unwrap();
let scope_key = "feishu:user-1"; let scope_key = "feishu:user-1";
// 创建足够的记忆7条让合并操作满足保护限制
// 合并后需要保留至少 5 条min_memories_to_keep
let work = store let work = store
.put_memory(&crate::storage::MemoryUpsert { .put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(), scope_kind: "user".to_string(),
@ -1639,9 +1652,30 @@ mod tests {
}) })
.unwrap(); .unwrap();
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
for i in 0..4 {
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "profile".to_string(),
memory_key: format!("extra_{}", i),
content: format!("额外记忆 {}", i),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
}
let plan = build_memory_maintenance_plan( let plan = build_memory_maintenance_plan(
&store.list_memories_for_scope("user", scope_key).unwrap(), &store.list_memories_for_scope("user", scope_key).unwrap(),
); );
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
let output = MemoryOrganizationOutput { let output = MemoryOrganizationOutput {
merges: vec![MemoryMaintenanceMerge { merges: vec![MemoryMaintenanceMerge {
source_ids: vec![work.id.clone(), role.id.clone()], source_ids: vec![work.id.clone(), role.id.clone()],
@ -1653,13 +1687,25 @@ mod tests {
low_value_ids: vec![noise.id.clone()], low_value_ids: vec![noise.id.clone()],
}; };
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap(); // 使用默认配置进行验证
apply_memory_maintenance_output(
&store,
scope_key,
&plan,
&output,
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
)
.unwrap();
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap(); let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
assert_eq!(all_memories.len(), 1); // 过滤掉 _meta 记录
assert_eq!(all_memories[0].namespace, "profile"); let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
assert_eq!(all_memories[0].memory_key, "work"); // 合并 2 条为 1 条,删除 1 条7 - 2 + 1 = 6 条(加上 _meta 记录)
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现"); assert_eq!(user_memories.len(), 6);
// 验证合并后的记忆存在
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
} }
#[test] #[test]