feat: 添加内存维护配置,优化记忆整理逻辑和限制

This commit is contained in:
ooodc 2026-05-23 16:05:11 +08:00
parent 44d9171b86
commit b4ef56803f
9 changed files with 352 additions and 19 deletions

View File

@ -75,6 +75,7 @@ impl InitWizard {
channels: HashMap::new(),
skills: crate::config::SkillsConfig::default(),
tools: crate::config::ToolsConfig::default(),
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
}
}
@ -824,6 +825,7 @@ impl InitWizard {
client: existing.client.clone(),
skills: existing.skills.clone(),
tools: existing.tools.clone(),
memory_maintenance: existing.memory_maintenance.clone(),
}
}

View File

@ -27,6 +27,8 @@ pub struct Config {
pub skills: SkillsConfig,
#[serde(default)]
pub tools: ToolsConfig,
#[serde(default)]
pub memory_maintenance: MemoryMaintenanceConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -108,6 +110,41 @@ pub struct ToolsConfig {
pub task: TaskConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MemoryMaintenanceConfig {
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
#[serde(default = "default_max_merge_ratio")]
pub max_merge_ratio: f32,
/// 最小保留记忆数量,默认 5
#[serde(default = "default_min_memories_to_keep")]
pub min_memories_to_keep: usize,
/// 单次合并最大源记忆数,默认 3
#[serde(default = "default_max_merge_per_group")]
pub max_merge_per_group: usize,
}
fn default_max_merge_ratio() -> f32 {
0.3
}
fn default_min_memories_to_keep() -> usize {
5
}
fn default_max_merge_per_group() -> usize {
3
}
impl Default for MemoryMaintenanceConfig {
fn default() -> Self {
Self {
max_merge_ratio: default_max_merge_ratio(),
min_memories_to_keep: default_min_memories_to_keep(),
max_merge_per_group: default_max_merge_per_group(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TaskConfig {
#[serde(default = "default_task_enabled")]

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
use crate::storage::{MemoryRecord, SessionStore};
@ -17,12 +17,16 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step2_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
const META_NAMESPACE: &str = "_meta";
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String,
pub(crate) namespace: String,
pub(crate) key: String,
pub(crate) content: String,
pub(crate) updated_at: i64, // 记忆更新时间Unix timestamp
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
@ -73,13 +77,19 @@ pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
maintenance_config: MemoryMaintenanceConfig,
}
impl MemoryMaintenanceService {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
pub(crate) fn new(
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
maintenance_config: MemoryMaintenanceConfig,
) -> Self {
Self {
store,
provider_config,
maintenance_config,
}
}
@ -107,6 +117,12 @@ impl MemoryMaintenanceService {
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
// 新增:检查是否有新记忆需要整理
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
return Ok(None);
}
let memories = self
.store
.list_memories_for_scope("user", scope_key)
@ -255,7 +271,15 @@ impl MemoryMaintenanceService {
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
apply_memory_maintenance_output(
self.store.as_ref(),
scope_key,
&plan,
&organize_output,
self.maintenance_config.max_merge_ratio,
self.maintenance_config.min_memories_to_keep,
self.maintenance_config.max_merge_per_group,
)?;
// 步骤2从数据库重新读取剩余的记忆
let remaining_memories = self
@ -470,17 +494,94 @@ impl MemoryMaintenanceService {
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
apply_memory_maintenance_output(
self.store.as_ref(),
scope_key,
&plan,
&organize_output,
self.maintenance_config.max_merge_ratio,
self.maintenance_config.min_memories_to_keep,
self.maintenance_config.max_merge_per_group,
)?;
Ok(Some(organize_output))
}
}
/// 获取上次整理时间
fn get_last_maintenance_time(
store: &SessionStore,
scope_key: &str,
) -> Result<Option<i64>, crate::storage::StorageError> {
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
}
/// 记录本次整理时间
fn set_last_maintenance_time(
store: &SessionStore,
scope_key: &str,
time: i64,
) -> Result<(), crate::storage::StorageError> {
store.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: META_NAMESPACE.to_string(),
memory_key: LAST_MAINTENANCE_KEY.to_string(),
content: time.to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})?;
Ok(())
}
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace
fn has_new_memories_since_last_maintenance(
store: &SessionStore,
scope_key: &str,
) -> Result<bool, AgentError> {
let memories = store
.list_memories_for_scope("user", scope_key)
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
// 过滤掉 _meta namespace 的记忆
let user_memories: Vec<_> = memories
.iter()
.filter(|m| m.namespace != META_NAMESPACE)
.collect();
if user_memories.is_empty() {
return Ok(false); // 没有记忆,跳过
}
// 获取上次整理时间
let last_time = get_last_maintenance_time(store, scope_key)
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
match last_time {
None => Ok(true), // 从未整理过,需要整理
Some(last) => {
// 检查是否有记忆在上次整理后更新
let has_new = user_memories.iter().any(|m| m.updated_at > last);
Ok(has_new)
}
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
// 过滤掉 _meta namespace 的记忆
if memory.namespace == META_NAMESPACE {
continue;
}
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
@ -501,6 +602,7 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
updated_at: memory.updated_at,
};
plan.candidates.push(candidate);
@ -580,12 +682,115 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
None
}
/// 验证记忆整理输出是否符合限制
pub(crate) fn validate_memory_maintenance_output(
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
max_merge_ratio: f32,
min_memories_to_keep: usize,
max_merge_per_group: usize,
) -> Result<(), String> {
let total = plan.candidates.len();
if total == 0 {
return Ok(()); // 没有候选,无需验证
}
// 验证 1: 单次合并数量限制
for merge in &output.merges {
if merge.source_ids.len() > max_merge_per_group {
return Err(format!(
"合并组过大: {} 条源记忆超过上限 {}",
merge.source_ids.len(),
max_merge_per_group
));
}
}
// 验证 2: 跨 namespace 合并检测(完全禁止)
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
.candidates
.iter()
.map(|c| (c.id.as_str(), c))
.collect();
for merge in &output.merges {
let source_namespaces: HashSet<&str> = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
.collect();
// 检查是否跨越多个 namespace
if source_namespaces.len() > 1 {
return Err(format!(
"跨 namespace 合并被禁止: 源来自 {}",
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
));
}
// 检查目标 namespace 是否与源一致
if let Some(src_ns) = source_namespaces.iter().next() {
if *src_ns != merge.namespace {
return Err(format!(
"跨 namespace 合并被禁止: {} → {}",
src_ns, merge.namespace
));
}
}
}
// 验证 3: 总体合并比例
let merged_ids: HashSet<&str> = output
.merges
.iter()
.flat_map(|m| m.source_ids.iter())
.map(|s| s.as_str())
.collect();
let deleted_ids: HashSet<&str> = output
.low_value_ids
.iter()
.map(|s| s.as_str())
.collect();
let affected = merged_ids.len() + deleted_ids.len();
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
if affected > max_allowed {
return Err(format!(
"合并比例超限: {} / {} > {:.0}%",
affected,
total,
max_merge_ratio * 100.0
));
}
// 验证 4: 最小保留数
let remaining = total - affected + output.merges.len();
if remaining < min_memories_to_keep {
return Err(format!(
"保留数不足: {} < {}",
remaining, min_memories_to_keep
));
}
Ok(())
}
pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
max_merge_ratio: f32,
min_memories_to_keep: usize,
max_merge_per_group: usize,
) -> Result<(), AgentError> {
// 新增: 验证合并输出
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
.map_err(|e| AgentError::Other(e))?;
let all_candidates = plan.candidates.clone();
let candidates_by_id = all_candidates
@ -661,6 +866,11 @@ pub(crate) fn apply_memory_maintenance_output(
}
}
// 新增:记录整理完成时间
let now = chrono::Utc::now().timestamp();
set_last_maintenance_time(store, scope_key, now)
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
Ok(())
}

View File

@ -40,6 +40,7 @@ impl MemoryMaintenanceCoordinator {
Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_configs.default_provider_config(),
self.provider_configs.default_maintenance_config(),
))
}
}

View File

@ -27,15 +27,32 @@
- note: 冲突说明
- low_value_ids需要删除的低价值候选记忆 ID 数组
组织原则(由你自主决定)
组织原则:
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
- 相似的、互补的记忆可以合并
- 过期、重复、过细的记忆可以标记为低价值
- 根据记忆的语义内容自然分组
- **每次合并最多只能合并 2-3 条源记忆**
- **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
- 过期、重复、过细的记忆可以标记为低值
- namespace 和 memory_key 的命名应当简洁、有意义
- 可以自由创建新的 namespace 来组织相关记忆
- **保守原则:宁可保留稍多,不可过度合并**
- **必须保留足够数量的记忆,确保信息多样性**
时间权重原则(关键):
- 每个候选记忆包含 `updated_at` 时间戳Unix timestamp
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
- 如果新旧记忆内容完全相同,保留新的,删除旧的
- 时间戳数值越大表示越新(离当前时间越近)
合并限制(硬性约束,由系统强制检查):
- 单次合并最多来自 3 条源记忆
- 整理后保留的记忆数不得少于 5 条
- 单次整理最多影响 30% 的记忆
- 不同 namespace 的记忆不允许互相合并
额外约束:
- 只能引用输入里出现过的候选 id。
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。

View File

@ -84,6 +84,7 @@ impl GatewayState {
Arc::new(BusSessionMessageSender::new(bus.clone())),
std::collections::HashSet::new(),
config.tools.task.clone(),
config.memory_maintenance.clone(),
chat_history_ttl_hours,
session_ttl_hours,
)?;

View File

@ -2,22 +2,25 @@ use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
#[derive(Clone)]
pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
maintenance_config: MemoryMaintenanceConfig,
}
impl ProviderConfigService {
pub(crate) fn new(
default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
maintenance_config: MemoryMaintenanceConfig,
) -> Self {
Self {
default_provider_config,
provider_configs: Arc::new(provider_configs),
maintenance_config,
}
}
@ -37,6 +40,10 @@ impl ProviderConfigService {
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone()
}
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
self.maintenance_config.clone()
}
}
#[cfg(test)]
@ -72,6 +79,7 @@ mod tests {
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]),
MemoryMaintenanceConfig::default(),
);
let selected = service.select(Some("planner")).unwrap();
@ -82,7 +90,11 @@ mod tests {
#[test]
fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(default_provider, HashMap::new());
let service = ProviderConfigService::new(
default_provider,
HashMap::new(),
MemoryMaintenanceConfig::default(),
);
let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");

View File

@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, TaskConfig};
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::skills::SkillRuntime;
use crate::storage::{
@ -34,6 +34,7 @@ pub(crate) fn build_session_manager(
skills: Arc<SkillRuntime>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
@ -47,6 +48,7 @@ pub(crate) fn build_session_manager(
Arc::new(NoopSessionMessageSender),
disabled_tools,
task_config,
maintenance_config,
chat_history_ttl_hours,
session_ttl_hours,
)
@ -62,6 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
session_message_sender: Arc<dyn SessionMessageSender>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
@ -70,7 +73,11 @@ pub(crate) fn build_session_manager_with_sender(
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs);
let provider_configs = ProviderConfigService::new(
provider_config.clone(),
provider_configs,
maintenance_config,
);
if let Err(err) =
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())

View File

@ -501,6 +501,7 @@ impl SessionManager {
skills: Arc<SkillRuntime>,
disabled_tools: std::collections::HashSet<String>,
task_config: crate::config::TaskConfig,
maintenance_config: crate::config::MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> {
@ -513,6 +514,7 @@ impl SessionManager {
skills,
disabled_tools,
task_config,
maintenance_config,
chat_history_ttl_hours,
session_ttl_hours,
)
@ -973,6 +975,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1025,6 +1028,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1093,6 +1097,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1178,6 +1183,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1264,6 +1270,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1349,6 +1356,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1416,6 +1424,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1492,6 +1501,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1555,6 +1565,7 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1593,6 +1604,8 @@ mod tests {
let store = SessionStore::in_memory().unwrap();
let scope_key = "feishu:user-1";
// 创建足够的记忆7条让合并操作满足保护限制
// 合并后需要保留至少 5 条min_memories_to_keep
let work = store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
@ -1639,9 +1652,30 @@ mod tests {
})
.unwrap();
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
for i in 0..4 {
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "profile".to_string(),
memory_key: format!("extra_{}", i),
content: format!("额外记忆 {}", i),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
}
let plan = build_memory_maintenance_plan(
&store.list_memories_for_scope("user", scope_key).unwrap(),
);
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
let output = MemoryOrganizationOutput {
merges: vec![MemoryMaintenanceMerge {
source_ids: vec![work.id.clone(), role.id.clone()],
@ -1653,13 +1687,25 @@ mod tests {
low_value_ids: vec![noise.id.clone()],
};
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
// 使用默认配置进行验证
apply_memory_maintenance_output(
&store,
scope_key,
&plan,
&output,
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
)
.unwrap();
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
assert_eq!(all_memories.len(), 1);
assert_eq!(all_memories[0].namespace, "profile");
assert_eq!(all_memories[0].memory_key, "work");
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
// 过滤掉 _meta 记录
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
// 合并 2 条为 1 条,删除 1 条7 - 2 + 1 = 6 条(加上 _meta 记录)
assert_eq!(user_memories.len(), 6);
// 验证合并后的记忆存在
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
}
#[test]