feat: 添加内存维护配置,优化记忆整理逻辑和限制
This commit is contained in:
parent
44d9171b86
commit
b4ef56803f
@ -75,6 +75,7 @@ impl InitWizard {
|
||||
channels: HashMap::new(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
tools: crate::config::ToolsConfig::default(),
|
||||
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -824,6 +825,7 @@ impl InitWizard {
|
||||
client: existing.client.clone(),
|
||||
skills: existing.skills.clone(),
|
||||
tools: existing.tools.clone(),
|
||||
memory_maintenance: existing.memory_maintenance.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,6 +27,8 @@ pub struct Config {
|
||||
pub skills: SkillsConfig,
|
||||
#[serde(default)]
|
||||
pub tools: ToolsConfig,
|
||||
#[serde(default)]
|
||||
pub memory_maintenance: MemoryMaintenanceConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -108,6 +110,41 @@ pub struct ToolsConfig {
|
||||
pub task: TaskConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MemoryMaintenanceConfig {
|
||||
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
|
||||
#[serde(default = "default_max_merge_ratio")]
|
||||
pub max_merge_ratio: f32,
|
||||
/// 最小保留记忆数量,默认 5
|
||||
#[serde(default = "default_min_memories_to_keep")]
|
||||
pub min_memories_to_keep: usize,
|
||||
/// 单次合并最大源记忆数,默认 3
|
||||
#[serde(default = "default_max_merge_per_group")]
|
||||
pub max_merge_per_group: usize,
|
||||
}
|
||||
|
||||
fn default_max_merge_ratio() -> f32 {
|
||||
0.3
|
||||
}
|
||||
|
||||
fn default_min_memories_to_keep() -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_max_merge_per_group() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
impl Default for MemoryMaintenanceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_merge_ratio: default_max_merge_ratio(),
|
||||
min_memories_to_keep: default_min_memories_to_keep(),
|
||||
max_merge_per_group: default_max_merge_per_group(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TaskConfig {
|
||||
#[serde(default = "default_task_enabled")]
|
||||
|
||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||
use crate::storage::{MemoryRecord, SessionStore};
|
||||
|
||||
@ -17,12 +17,16 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
||||
include_str!("memory_maintenance_step2_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
|
||||
const META_NAMESPACE: &str = "_meta";
|
||||
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceCandidate {
|
||||
pub(crate) id: String,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) key: String,
|
||||
pub(crate) content: String,
|
||||
pub(crate) updated_at: i64, // 记忆更新时间(Unix timestamp)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
@ -73,13 +77,19 @@ pub(crate) struct MemoryMaintenanceScopeResult {
|
||||
pub(crate) struct MemoryMaintenanceService {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
}
|
||||
|
||||
impl MemoryMaintenanceService {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
pub(crate) fn new(
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
maintenance_config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,6 +117,12 @@ impl MemoryMaintenanceService {
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||
// 新增:检查是否有新记忆需要整理
|
||||
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
|
||||
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
@ -255,7 +271,15 @@ impl MemoryMaintenanceService {
|
||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||
|
||||
// 应用整理结果(merge和delete)
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||||
apply_memory_maintenance_output(
|
||||
self.store.as_ref(),
|
||||
scope_key,
|
||||
&plan,
|
||||
&organize_output,
|
||||
self.maintenance_config.max_merge_ratio,
|
||||
self.maintenance_config.min_memories_to_keep,
|
||||
self.maintenance_config.max_merge_per_group,
|
||||
)?;
|
||||
|
||||
// 步骤2:从数据库重新读取剩余的记忆
|
||||
let remaining_memories = self
|
||||
@ -470,17 +494,94 @@ impl MemoryMaintenanceService {
|
||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||
|
||||
// 应用整理结果(merge和delete)
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||||
apply_memory_maintenance_output(
|
||||
self.store.as_ref(),
|
||||
scope_key,
|
||||
&plan,
|
||||
&organize_output,
|
||||
self.maintenance_config.max_merge_ratio,
|
||||
self.maintenance_config.min_memories_to_keep,
|
||||
self.maintenance_config.max_merge_per_group,
|
||||
)?;
|
||||
|
||||
Ok(Some(organize_output))
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取上次整理时间
|
||||
fn get_last_maintenance_time(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<i64>, crate::storage::StorageError> {
|
||||
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
|
||||
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
|
||||
}
|
||||
|
||||
/// 记录本次整理时间
|
||||
fn set_last_maintenance_time(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
time: i64,
|
||||
) -> Result<(), crate::storage::StorageError> {
|
||||
store.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: META_NAMESPACE.to_string(),
|
||||
memory_key: LAST_MAINTENANCE_KEY.to_string(),
|
||||
content: time.to_string(),
|
||||
source_type: "memory_maintenance".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace)
|
||||
fn has_new_memories_since_last_maintenance(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
) -> Result<bool, AgentError> {
|
||||
let memories = store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
|
||||
|
||||
// 过滤掉 _meta namespace 的记忆
|
||||
let user_memories: Vec<_> = memories
|
||||
.iter()
|
||||
.filter(|m| m.namespace != META_NAMESPACE)
|
||||
.collect();
|
||||
|
||||
if user_memories.is_empty() {
|
||||
return Ok(false); // 没有记忆,跳过
|
||||
}
|
||||
|
||||
// 获取上次整理时间
|
||||
let last_time = get_last_maintenance_time(store, scope_key)
|
||||
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
|
||||
|
||||
match last_time {
|
||||
None => Ok(true), // 从未整理过,需要整理
|
||||
Some(last) => {
|
||||
// 检查是否有记忆在上次整理后更新
|
||||
let has_new = user_memories.iter().any(|m| m.updated_at > last);
|
||||
Ok(has_new)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||
let mut plan = MemoryMaintenancePlan::default();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for memory in memories {
|
||||
// 过滤掉 _meta namespace 的记忆
|
||||
if memory.namespace == META_NAMESPACE {
|
||||
continue;
|
||||
}
|
||||
|
||||
let normalized_content = memory.content.trim();
|
||||
if normalized_content.is_empty() {
|
||||
continue;
|
||||
@ -501,6 +602,7 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
|
||||
namespace: memory.namespace.clone(),
|
||||
key: memory.memory_key.clone(),
|
||||
content: normalized_content.to_string(),
|
||||
updated_at: memory.updated_at,
|
||||
};
|
||||
|
||||
plan.candidates.push(candidate);
|
||||
@ -580,12 +682,115 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// 验证记忆整理输出是否符合限制
|
||||
pub(crate) fn validate_memory_maintenance_output(
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryOrganizationOutput,
|
||||
max_merge_ratio: f32,
|
||||
min_memories_to_keep: usize,
|
||||
max_merge_per_group: usize,
|
||||
) -> Result<(), String> {
|
||||
let total = plan.candidates.len();
|
||||
|
||||
if total == 0 {
|
||||
return Ok(()); // 没有候选,无需验证
|
||||
}
|
||||
|
||||
// 验证 1: 单次合并数量限制
|
||||
for merge in &output.merges {
|
||||
if merge.source_ids.len() > max_merge_per_group {
|
||||
return Err(format!(
|
||||
"合并组过大: {} 条源记忆超过上限 {}",
|
||||
merge.source_ids.len(),
|
||||
max_merge_per_group
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 2: 跨 namespace 合并检测(完全禁止)
|
||||
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
|
||||
.candidates
|
||||
.iter()
|
||||
.map(|c| (c.id.as_str(), c))
|
||||
.collect();
|
||||
|
||||
for merge in &output.merges {
|
||||
let source_namespaces: HashSet<&str> = merge
|
||||
.source_ids
|
||||
.iter()
|
||||
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
|
||||
.collect();
|
||||
|
||||
// 检查是否跨越多个 namespace
|
||||
if source_namespaces.len() > 1 {
|
||||
return Err(format!(
|
||||
"跨 namespace 合并被禁止: 源来自 {}",
|
||||
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
// 检查目标 namespace 是否与源一致
|
||||
if let Some(src_ns) = source_namespaces.iter().next() {
|
||||
if *src_ns != merge.namespace {
|
||||
return Err(format!(
|
||||
"跨 namespace 合并被禁止: {} → {}",
|
||||
src_ns, merge.namespace
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 3: 总体合并比例
|
||||
let merged_ids: HashSet<&str> = output
|
||||
.merges
|
||||
.iter()
|
||||
.flat_map(|m| m.source_ids.iter())
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
let deleted_ids: HashSet<&str> = output
|
||||
.low_value_ids
|
||||
.iter()
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
let affected = merged_ids.len() + deleted_ids.len();
|
||||
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
|
||||
|
||||
if affected > max_allowed {
|
||||
return Err(format!(
|
||||
"合并比例超限: {} / {} > {:.0}%",
|
||||
affected,
|
||||
total,
|
||||
max_merge_ratio * 100.0
|
||||
));
|
||||
}
|
||||
|
||||
// 验证 4: 最小保留数
|
||||
let remaining = total - affected + output.merges.len();
|
||||
if remaining < min_memories_to_keep {
|
||||
return Err(format!(
|
||||
"保留数不足: {} < {}",
|
||||
remaining, min_memories_to_keep
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn apply_memory_maintenance_output(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryOrganizationOutput,
|
||||
max_merge_ratio: f32,
|
||||
min_memories_to_keep: usize,
|
||||
max_merge_per_group: usize,
|
||||
) -> Result<(), AgentError> {
|
||||
// 新增: 验证合并输出
|
||||
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
|
||||
.map_err(|e| AgentError::Other(e))?;
|
||||
|
||||
let all_candidates = plan.candidates.clone();
|
||||
|
||||
let candidates_by_id = all_candidates
|
||||
@ -661,6 +866,11 @@ pub(crate) fn apply_memory_maintenance_output(
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:记录整理完成时间
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
set_last_maintenance_time(store, scope_key, now)
|
||||
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ impl MemoryMaintenanceCoordinator {
|
||||
Ok(MemoryMaintenanceService::new(
|
||||
self.store.clone(),
|
||||
self.provider_configs.default_provider_config(),
|
||||
self.provider_configs.default_maintenance_config(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,13 +27,30 @@
|
||||
- note: 冲突说明
|
||||
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
||||
|
||||
组织原则(由你自主决定):
|
||||
组织原则:
|
||||
|
||||
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
|
||||
- 相似的、互补的记忆可以合并
|
||||
- 过期、重复、过细的记忆可以标记为低价值
|
||||
- 根据记忆的语义内容自然分组
|
||||
- **每次合并最多只能合并 2-3 条源记忆**
|
||||
- **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
|
||||
- 过期、重复、过细的记忆可以标记为低值
|
||||
- namespace 和 memory_key 的命名应当简洁、有意义
|
||||
- 可以自由创建新的 namespace 来组织相关记忆
|
||||
- **保守原则:宁可保留稍多,不可过度合并**
|
||||
- **必须保留足够数量的记忆,确保信息多样性**
|
||||
|
||||
时间权重原则(关键):
|
||||
|
||||
- 每个候选记忆包含 `updated_at` 时间戳(Unix timestamp,秒)
|
||||
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
|
||||
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
|
||||
- 如果新旧记忆内容完全相同,保留新的,删除旧的
|
||||
- 时间戳数值越大表示越新(离当前时间越近)
|
||||
|
||||
合并限制(硬性约束,由系统强制检查):
|
||||
|
||||
- 单次合并最多来自 3 条源记忆
|
||||
- 整理后保留的记忆数不得少于 5 条
|
||||
- 单次整理最多影响 30% 的记忆
|
||||
- 不同 namespace 的记忆不允许互相合并
|
||||
|
||||
额外约束:
|
||||
|
||||
|
||||
@ -84,6 +84,7 @@ impl GatewayState {
|
||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||
std::collections::HashSet::new(),
|
||||
config.tools.task.clone(),
|
||||
config.memory_maintenance.clone(),
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)?;
|
||||
|
||||
@ -2,22 +2,25 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ProviderConfigService {
|
||||
default_provider_config: LLMProviderConfig,
|
||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
}
|
||||
|
||||
impl ProviderConfigService {
|
||||
pub(crate) fn new(
|
||||
default_provider_config: LLMProviderConfig,
|
||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
default_provider_config,
|
||||
provider_configs: Arc::new(provider_configs),
|
||||
maintenance_config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +40,10 @@ impl ProviderConfigService {
|
||||
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
||||
self.default_provider_config.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
|
||||
self.maintenance_config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -72,6 +79,7 @@ mod tests {
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]),
|
||||
MemoryMaintenanceConfig::default(),
|
||||
);
|
||||
|
||||
let selected = service.select(Some("planner")).unwrap();
|
||||
@ -82,7 +90,11 @@ mod tests {
|
||||
#[test]
|
||||
fn test_select_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let service = ProviderConfigService::new(default_provider, HashMap::new());
|
||||
let service = ProviderConfigService::new(
|
||||
default_provider,
|
||||
HashMap::new(),
|
||||
MemoryMaintenanceConfig::default(),
|
||||
);
|
||||
|
||||
let selected = service.select(Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
|
||||
@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::{LLMProviderConfig, TaskConfig};
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{
|
||||
@ -34,6 +34,7 @@ pub(crate) fn build_session_manager(
|
||||
skills: Arc<SkillRuntime>,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||
@ -47,6 +48,7 @@ pub(crate) fn build_session_manager(
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
disabled_tools,
|
||||
task_config,
|
||||
maintenance_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)
|
||||
@ -62,6 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||
@ -70,7 +73,11 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||
);
|
||||
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs);
|
||||
let provider_configs = ProviderConfigService::new(
|
||||
provider_config.clone(),
|
||||
provider_configs,
|
||||
maintenance_config,
|
||||
);
|
||||
|
||||
if let Err(err) =
|
||||
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
||||
|
||||
@ -501,6 +501,7 @@ impl SessionManager {
|
||||
skills: Arc<SkillRuntime>,
|
||||
disabled_tools: std::collections::HashSet<String>,
|
||||
task_config: crate::config::TaskConfig,
|
||||
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<Self, AgentError> {
|
||||
@ -513,6 +514,7 @@ impl SessionManager {
|
||||
skills,
|
||||
disabled_tools,
|
||||
task_config,
|
||||
maintenance_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)
|
||||
@ -973,6 +975,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1025,6 +1028,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1093,6 +1097,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1178,6 +1183,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1264,6 +1270,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1349,6 +1356,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1416,6 +1424,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1492,6 +1501,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1555,6 +1565,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1593,6 +1604,8 @@ mod tests {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let scope_key = "feishu:user-1";
|
||||
|
||||
// 创建足够的记忆(7条),让合并操作满足保护限制
|
||||
// 合并后需要保留至少 5 条(min_memories_to_keep)
|
||||
let work = store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
@ -1639,9 +1652,30 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
|
||||
for i in 0..4 {
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: "profile".to_string(),
|
||||
memory_key: format!("extra_{}", i),
|
||||
content: format!("额外记忆 {}", i),
|
||||
source_type: "message".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let plan = build_memory_maintenance_plan(
|
||||
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
||||
);
|
||||
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
|
||||
|
||||
let output = MemoryOrganizationOutput {
|
||||
merges: vec![MemoryMaintenanceMerge {
|
||||
source_ids: vec![work.id.clone(), role.id.clone()],
|
||||
@ -1653,13 +1687,25 @@ mod tests {
|
||||
low_value_ids: vec![noise.id.clone()],
|
||||
};
|
||||
|
||||
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
|
||||
// 使用默认配置进行验证
|
||||
apply_memory_maintenance_output(
|
||||
&store,
|
||||
scope_key,
|
||||
&plan,
|
||||
&output,
|
||||
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
|
||||
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
|
||||
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
||||
assert_eq!(all_memories.len(), 1);
|
||||
assert_eq!(all_memories[0].namespace, "profile");
|
||||
assert_eq!(all_memories[0].memory_key, "work");
|
||||
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
|
||||
// 过滤掉 _meta 记录
|
||||
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
|
||||
// 合并 2 条为 1 条,删除 1 条,7 - 2 + 1 = 6 条(加上 _meta 记录)
|
||||
assert_eq!(user_memories.len(), 6);
|
||||
// 验证合并后的记忆存在
|
||||
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user