feat: 添加内存维护配置,优化记忆整理逻辑和限制
This commit is contained in:
parent
44d9171b86
commit
b4ef56803f
@ -75,6 +75,7 @@ impl InitWizard {
|
|||||||
channels: HashMap::new(),
|
channels: HashMap::new(),
|
||||||
skills: crate::config::SkillsConfig::default(),
|
skills: crate::config::SkillsConfig::default(),
|
||||||
tools: crate::config::ToolsConfig::default(),
|
tools: crate::config::ToolsConfig::default(),
|
||||||
|
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -824,6 +825,7 @@ impl InitWizard {
|
|||||||
client: existing.client.clone(),
|
client: existing.client.clone(),
|
||||||
skills: existing.skills.clone(),
|
skills: existing.skills.clone(),
|
||||||
tools: existing.tools.clone(),
|
tools: existing.tools.clone(),
|
||||||
|
memory_maintenance: existing.memory_maintenance.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,8 @@ pub struct Config {
|
|||||||
pub skills: SkillsConfig,
|
pub skills: SkillsConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub tools: ToolsConfig,
|
pub tools: ToolsConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub memory_maintenance: MemoryMaintenanceConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -108,6 +110,41 @@ pub struct ToolsConfig {
|
|||||||
pub task: TaskConfig,
|
pub task: TaskConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MemoryMaintenanceConfig {
|
||||||
|
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
|
||||||
|
#[serde(default = "default_max_merge_ratio")]
|
||||||
|
pub max_merge_ratio: f32,
|
||||||
|
/// 最小保留记忆数量,默认 5
|
||||||
|
#[serde(default = "default_min_memories_to_keep")]
|
||||||
|
pub min_memories_to_keep: usize,
|
||||||
|
/// 单次合并最大源记忆数,默认 3
|
||||||
|
#[serde(default = "default_max_merge_per_group")]
|
||||||
|
pub max_merge_per_group: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_merge_ratio() -> f32 {
|
||||||
|
0.3
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_min_memories_to_keep() -> usize {
|
||||||
|
5
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_merge_per_group() -> usize {
|
||||||
|
3
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MemoryMaintenanceConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_merge_ratio: default_max_merge_ratio(),
|
||||||
|
min_memories_to_keep: default_min_memories_to_keep(),
|
||||||
|
max_merge_per_group: default_max_merge_per_group(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct TaskConfig {
|
pub struct TaskConfig {
|
||||||
#[serde(default = "default_task_enabled")]
|
#[serde(default = "default_task_enabled")]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||||||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||||
use crate::storage::{MemoryRecord, SessionStore};
|
use crate::storage::{MemoryRecord, SessionStore};
|
||||||
|
|
||||||
@ -17,12 +17,16 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
|||||||
include_str!("memory_maintenance_step2_system_prompt.md");
|
include_str!("memory_maintenance_step2_system_prompt.md");
|
||||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||||
|
|
||||||
|
const META_NAMESPACE: &str = "_meta";
|
||||||
|
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
pub(crate) struct MemoryMaintenanceCandidate {
|
pub(crate) struct MemoryMaintenanceCandidate {
|
||||||
pub(crate) id: String,
|
pub(crate) id: String,
|
||||||
pub(crate) namespace: String,
|
pub(crate) namespace: String,
|
||||||
pub(crate) key: String,
|
pub(crate) key: String,
|
||||||
pub(crate) content: String,
|
pub(crate) content: String,
|
||||||
|
pub(crate) updated_at: i64, // 记忆更新时间(Unix timestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
@ -73,13 +77,19 @@ pub(crate) struct MemoryMaintenanceScopeResult {
|
|||||||
pub(crate) struct MemoryMaintenanceService {
|
pub(crate) struct MemoryMaintenanceService {
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MemoryMaintenanceService {
|
impl MemoryMaintenanceService {
|
||||||
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
pub(crate) fn new(
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
store,
|
store,
|
||||||
provider_config,
|
provider_config,
|
||||||
|
maintenance_config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,6 +117,12 @@ impl MemoryMaintenanceService {
|
|||||||
&self,
|
&self,
|
||||||
scope_key: &str,
|
scope_key: &str,
|
||||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||||
|
// 新增:检查是否有新记忆需要整理
|
||||||
|
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
|
||||||
|
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
let memories = self
|
let memories = self
|
||||||
.store
|
.store
|
||||||
.list_memories_for_scope("user", scope_key)
|
.list_memories_for_scope("user", scope_key)
|
||||||
@ -255,7 +271,15 @@ impl MemoryMaintenanceService {
|
|||||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||||
|
|
||||||
// 应用整理结果(merge和delete)
|
// 应用整理结果(merge和delete)
|
||||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
apply_memory_maintenance_output(
|
||||||
|
self.store.as_ref(),
|
||||||
|
scope_key,
|
||||||
|
&plan,
|
||||||
|
&organize_output,
|
||||||
|
self.maintenance_config.max_merge_ratio,
|
||||||
|
self.maintenance_config.min_memories_to_keep,
|
||||||
|
self.maintenance_config.max_merge_per_group,
|
||||||
|
)?;
|
||||||
|
|
||||||
// 步骤2:从数据库重新读取剩余的记忆
|
// 步骤2:从数据库重新读取剩余的记忆
|
||||||
let remaining_memories = self
|
let remaining_memories = self
|
||||||
@ -470,17 +494,94 @@ impl MemoryMaintenanceService {
|
|||||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||||
|
|
||||||
// 应用整理结果(merge和delete)
|
// 应用整理结果(merge和delete)
|
||||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
apply_memory_maintenance_output(
|
||||||
|
self.store.as_ref(),
|
||||||
|
scope_key,
|
||||||
|
&plan,
|
||||||
|
&organize_output,
|
||||||
|
self.maintenance_config.max_merge_ratio,
|
||||||
|
self.maintenance_config.min_memories_to_keep,
|
||||||
|
self.maintenance_config.max_merge_per_group,
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(Some(organize_output))
|
Ok(Some(organize_output))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 获取上次整理时间
|
||||||
|
fn get_last_maintenance_time(
|
||||||
|
store: &SessionStore,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<i64>, crate::storage::StorageError> {
|
||||||
|
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
|
||||||
|
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记录本次整理时间
|
||||||
|
fn set_last_maintenance_time(
|
||||||
|
store: &SessionStore,
|
||||||
|
scope_key: &str,
|
||||||
|
time: i64,
|
||||||
|
) -> Result<(), crate::storage::StorageError> {
|
||||||
|
store.put_memory(&crate::storage::MemoryUpsert {
|
||||||
|
scope_kind: "user".to_string(),
|
||||||
|
scope_key: scope_key.to_string(),
|
||||||
|
namespace: META_NAMESPACE.to_string(),
|
||||||
|
memory_key: LAST_MAINTENANCE_KEY.to_string(),
|
||||||
|
content: time.to_string(),
|
||||||
|
source_type: "memory_maintenance".to_string(),
|
||||||
|
source_session_id: None,
|
||||||
|
source_message_id: None,
|
||||||
|
source_message_seq: None,
|
||||||
|
source_channel_name: None,
|
||||||
|
source_chat_id: None,
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace)
|
||||||
|
fn has_new_memories_since_last_maintenance(
|
||||||
|
store: &SessionStore,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<bool, AgentError> {
|
||||||
|
let memories = store
|
||||||
|
.list_memories_for_scope("user", scope_key)
|
||||||
|
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
|
||||||
|
|
||||||
|
// 过滤掉 _meta namespace 的记忆
|
||||||
|
let user_memories: Vec<_> = memories
|
||||||
|
.iter()
|
||||||
|
.filter(|m| m.namespace != META_NAMESPACE)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if user_memories.is_empty() {
|
||||||
|
return Ok(false); // 没有记忆,跳过
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取上次整理时间
|
||||||
|
let last_time = get_last_maintenance_time(store, scope_key)
|
||||||
|
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
|
||||||
|
|
||||||
|
match last_time {
|
||||||
|
None => Ok(true), // 从未整理过,需要整理
|
||||||
|
Some(last) => {
|
||||||
|
// 检查是否有记忆在上次整理后更新
|
||||||
|
let has_new = user_memories.iter().any(|m| m.updated_at > last);
|
||||||
|
Ok(has_new)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||||
let mut plan = MemoryMaintenancePlan::default();
|
let mut plan = MemoryMaintenancePlan::default();
|
||||||
let mut seen = HashSet::new();
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
for memory in memories {
|
for memory in memories {
|
||||||
|
// 过滤掉 _meta namespace 的记忆
|
||||||
|
if memory.namespace == META_NAMESPACE {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let normalized_content = memory.content.trim();
|
let normalized_content = memory.content.trim();
|
||||||
if normalized_content.is_empty() {
|
if normalized_content.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
@ -501,6 +602,7 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
|
|||||||
namespace: memory.namespace.clone(),
|
namespace: memory.namespace.clone(),
|
||||||
key: memory.memory_key.clone(),
|
key: memory.memory_key.clone(),
|
||||||
content: normalized_content.to_string(),
|
content: normalized_content.to_string(),
|
||||||
|
updated_at: memory.updated_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
plan.candidates.push(candidate);
|
plan.candidates.push(candidate);
|
||||||
@ -580,12 +682,115 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 验证记忆整理输出是否符合限制
|
||||||
|
pub(crate) fn validate_memory_maintenance_output(
|
||||||
|
plan: &MemoryMaintenancePlan,
|
||||||
|
output: &MemoryOrganizationOutput,
|
||||||
|
max_merge_ratio: f32,
|
||||||
|
min_memories_to_keep: usize,
|
||||||
|
max_merge_per_group: usize,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let total = plan.candidates.len();
|
||||||
|
|
||||||
|
if total == 0 {
|
||||||
|
return Ok(()); // 没有候选,无需验证
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 1: 单次合并数量限制
|
||||||
|
for merge in &output.merges {
|
||||||
|
if merge.source_ids.len() > max_merge_per_group {
|
||||||
|
return Err(format!(
|
||||||
|
"合并组过大: {} 条源记忆超过上限 {}",
|
||||||
|
merge.source_ids.len(),
|
||||||
|
max_merge_per_group
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 2: 跨 namespace 合并检测(完全禁止)
|
||||||
|
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
|
||||||
|
.candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| (c.id.as_str(), c))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for merge in &output.merges {
|
||||||
|
let source_namespaces: HashSet<&str> = merge
|
||||||
|
.source_ids
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 检查是否跨越多个 namespace
|
||||||
|
if source_namespaces.len() > 1 {
|
||||||
|
return Err(format!(
|
||||||
|
"跨 namespace 合并被禁止: 源来自 {}",
|
||||||
|
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查目标 namespace 是否与源一致
|
||||||
|
if let Some(src_ns) = source_namespaces.iter().next() {
|
||||||
|
if *src_ns != merge.namespace {
|
||||||
|
return Err(format!(
|
||||||
|
"跨 namespace 合并被禁止: {} → {}",
|
||||||
|
src_ns, merge.namespace
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 3: 总体合并比例
|
||||||
|
let merged_ids: HashSet<&str> = output
|
||||||
|
.merges
|
||||||
|
.iter()
|
||||||
|
.flat_map(|m| m.source_ids.iter())
|
||||||
|
.map(|s| s.as_str())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let deleted_ids: HashSet<&str> = output
|
||||||
|
.low_value_ids
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.as_str())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let affected = merged_ids.len() + deleted_ids.len();
|
||||||
|
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
|
||||||
|
|
||||||
|
if affected > max_allowed {
|
||||||
|
return Err(format!(
|
||||||
|
"合并比例超限: {} / {} > {:.0}%",
|
||||||
|
affected,
|
||||||
|
total,
|
||||||
|
max_merge_ratio * 100.0
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 4: 最小保留数
|
||||||
|
let remaining = total - affected + output.merges.len();
|
||||||
|
if remaining < min_memories_to_keep {
|
||||||
|
return Err(format!(
|
||||||
|
"保留数不足: {} < {}",
|
||||||
|
remaining, min_memories_to_keep
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn apply_memory_maintenance_output(
|
pub(crate) fn apply_memory_maintenance_output(
|
||||||
store: &SessionStore,
|
store: &SessionStore,
|
||||||
scope_key: &str,
|
scope_key: &str,
|
||||||
plan: &MemoryMaintenancePlan,
|
plan: &MemoryMaintenancePlan,
|
||||||
output: &MemoryOrganizationOutput,
|
output: &MemoryOrganizationOutput,
|
||||||
|
max_merge_ratio: f32,
|
||||||
|
min_memories_to_keep: usize,
|
||||||
|
max_merge_per_group: usize,
|
||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
|
// 新增: 验证合并输出
|
||||||
|
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
|
||||||
|
.map_err(|e| AgentError::Other(e))?;
|
||||||
|
|
||||||
let all_candidates = plan.candidates.clone();
|
let all_candidates = plan.candidates.clone();
|
||||||
|
|
||||||
let candidates_by_id = all_candidates
|
let candidates_by_id = all_candidates
|
||||||
@ -661,6 +866,11 @@ pub(crate) fn apply_memory_maintenance_output(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 新增:记录整理完成时间
|
||||||
|
let now = chrono::Utc::now().timestamp();
|
||||||
|
set_last_maintenance_time(store, scope_key, now)
|
||||||
|
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,6 +40,7 @@ impl MemoryMaintenanceCoordinator {
|
|||||||
Ok(MemoryMaintenanceService::new(
|
Ok(MemoryMaintenanceService::new(
|
||||||
self.store.clone(),
|
self.store.clone(),
|
||||||
self.provider_configs.default_provider_config(),
|
self.provider_configs.default_provider_config(),
|
||||||
|
self.provider_configs.default_maintenance_config(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,15 +27,32 @@
|
|||||||
- note: 冲突说明
|
- note: 冲突说明
|
||||||
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
||||||
|
|
||||||
组织原则(由你自主决定):
|
组织原则:
|
||||||
|
|
||||||
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
|
- 根据记忆的语义内容自然分组
|
||||||
- 相似的、互补的记忆可以合并
|
- **每次合并最多只能合并 2-3 条源记忆**
|
||||||
- 过期、重复、过细的记忆可以标记为低价值
|
- **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
|
||||||
|
- 过期、重复、过细的记忆可以标记为低值
|
||||||
- namespace 和 memory_key 的命名应当简洁、有意义
|
- namespace 和 memory_key 的命名应当简洁、有意义
|
||||||
- 可以自由创建新的 namespace 来组织相关记忆
|
- **保守原则:宁可保留稍多,不可过度合并**
|
||||||
|
- **必须保留足够数量的记忆,确保信息多样性**
|
||||||
|
|
||||||
|
时间权重原则(关键):
|
||||||
|
|
||||||
|
- 每个候选记忆包含 `updated_at` 时间戳(Unix timestamp,秒)
|
||||||
|
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
|
||||||
|
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
|
||||||
|
- 如果新旧记忆内容完全相同,保留新的,删除旧的
|
||||||
|
- 时间戳数值越大表示越新(离当前时间越近)
|
||||||
|
|
||||||
|
合并限制(硬性约束,由系统强制检查):
|
||||||
|
|
||||||
|
- 单次合并最多来自 3 条源记忆
|
||||||
|
- 整理后保留的记忆数不得少于 5 条
|
||||||
|
- 单次整理最多影响 30% 的记忆
|
||||||
|
- 不同 namespace 的记忆不允许互相合并
|
||||||
|
|
||||||
额外约束:
|
额外约束:
|
||||||
|
|
||||||
- 只能引用输入里出现过的候选 id。
|
- 只能引用输入里出现过的候选 id。
|
||||||
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
|
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
|
||||||
@ -84,6 +84,7 @@ impl GatewayState {
|
|||||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||||
std::collections::HashSet::new(),
|
std::collections::HashSet::new(),
|
||||||
config.tools.task.clone(),
|
config.tools.task.clone(),
|
||||||
|
config.memory_maintenance.clone(),
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)?;
|
)?;
|
||||||
|
|||||||
@ -2,22 +2,25 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct ProviderConfigService {
|
pub(crate) struct ProviderConfigService {
|
||||||
default_provider_config: LLMProviderConfig,
|
default_provider_config: LLMProviderConfig,
|
||||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||||
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProviderConfigService {
|
impl ProviderConfigService {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
default_provider_config: LLMProviderConfig,
|
default_provider_config: LLMProviderConfig,
|
||||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
default_provider_config,
|
default_provider_config,
|
||||||
provider_configs: Arc::new(provider_configs),
|
provider_configs: Arc::new(provider_configs),
|
||||||
|
maintenance_config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +40,10 @@ impl ProviderConfigService {
|
|||||||
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
||||||
self.default_provider_config.clone()
|
self.default_provider_config.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
|
||||||
|
self.maintenance_config.clone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -72,6 +79,7 @@ mod tests {
|
|||||||
"planner".to_string(),
|
"planner".to_string(),
|
||||||
test_provider_config_named("planner-provider", "planner-model"),
|
test_provider_config_named("planner-provider", "planner-model"),
|
||||||
)]),
|
)]),
|
||||||
|
MemoryMaintenanceConfig::default(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let selected = service.select(Some("planner")).unwrap();
|
let selected = service.select(Some("planner")).unwrap();
|
||||||
@ -82,7 +90,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_select_falls_back_to_default() {
|
fn test_select_falls_back_to_default() {
|
||||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
let service = ProviderConfigService::new(default_provider, HashMap::new());
|
let service = ProviderConfigService::new(
|
||||||
|
default_provider,
|
||||||
|
HashMap::new(),
|
||||||
|
MemoryMaintenanceConfig::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let selected = service.select(Some("default")).unwrap();
|
let selected = service.select(Some("default")).unwrap();
|
||||||
assert_eq!(selected.name, "default-provider");
|
assert_eq!(selected.name, "default-provider");
|
||||||
|
|||||||
@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::{LLMProviderConfig, TaskConfig};
|
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
||||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{
|
use crate::storage::{
|
||||||
@ -34,6 +34,7 @@ pub(crate) fn build_session_manager(
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||||
@ -47,6 +48,7 @@ pub(crate) fn build_session_manager(
|
|||||||
Arc::new(NoopSessionMessageSender),
|
Arc::new(NoopSessionMessageSender),
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
task_config,
|
||||||
|
maintenance_config,
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)
|
)
|
||||||
@ -62,6 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||||
@ -70,7 +73,11 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||||
);
|
);
|
||||||
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||||
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs);
|
let provider_configs = ProviderConfigService::new(
|
||||||
|
provider_config.clone(),
|
||||||
|
provider_configs,
|
||||||
|
maintenance_config,
|
||||||
|
);
|
||||||
|
|
||||||
if let Err(err) =
|
if let Err(err) =
|
||||||
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
||||||
|
|||||||
@ -501,6 +501,7 @@ impl SessionManager {
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: std::collections::HashSet<String>,
|
disabled_tools: std::collections::HashSet<String>,
|
||||||
task_config: crate::config::TaskConfig,
|
task_config: crate::config::TaskConfig,
|
||||||
|
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
@ -513,6 +514,7 @@ impl SessionManager {
|
|||||||
skills,
|
skills,
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
task_config,
|
||||||
|
maintenance_config,
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)
|
)
|
||||||
@ -973,6 +975,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1025,6 +1028,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1093,6 +1097,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1178,6 +1183,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1264,6 +1270,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1349,6 +1356,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1416,6 +1424,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1492,6 +1501,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1555,6 +1565,7 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
|
crate::config::MemoryMaintenanceConfig::default(),
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1593,6 +1604,8 @@ mod tests {
|
|||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
let scope_key = "feishu:user-1";
|
let scope_key = "feishu:user-1";
|
||||||
|
|
||||||
|
// 创建足够的记忆(7条),让合并操作满足保护限制
|
||||||
|
// 合并后需要保留至少 5 条(min_memories_to_keep)
|
||||||
let work = store
|
let work = store
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
@ -1639,9 +1652,30 @@ mod tests {
|
|||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
|
||||||
|
for i in 0..4 {
|
||||||
|
store
|
||||||
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
|
scope_kind: "user".to_string(),
|
||||||
|
scope_key: scope_key.to_string(),
|
||||||
|
namespace: "profile".to_string(),
|
||||||
|
memory_key: format!("extra_{}", i),
|
||||||
|
content: format!("额外记忆 {}", i),
|
||||||
|
source_type: "message".to_string(),
|
||||||
|
source_session_id: None,
|
||||||
|
source_message_id: None,
|
||||||
|
source_message_seq: None,
|
||||||
|
source_channel_name: None,
|
||||||
|
source_chat_id: None,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
let plan = build_memory_maintenance_plan(
|
let plan = build_memory_maintenance_plan(
|
||||||
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
||||||
);
|
);
|
||||||
|
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
|
||||||
|
|
||||||
let output = MemoryOrganizationOutput {
|
let output = MemoryOrganizationOutput {
|
||||||
merges: vec![MemoryMaintenanceMerge {
|
merges: vec![MemoryMaintenanceMerge {
|
||||||
source_ids: vec![work.id.clone(), role.id.clone()],
|
source_ids: vec![work.id.clone(), role.id.clone()],
|
||||||
@ -1653,13 +1687,25 @@ mod tests {
|
|||||||
low_value_ids: vec![noise.id.clone()],
|
low_value_ids: vec![noise.id.clone()],
|
||||||
};
|
};
|
||||||
|
|
||||||
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
|
// 使用默认配置进行验证
|
||||||
|
apply_memory_maintenance_output(
|
||||||
|
&store,
|
||||||
|
scope_key,
|
||||||
|
&plan,
|
||||||
|
&output,
|
||||||
|
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
|
||||||
|
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
|
||||||
|
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
||||||
assert_eq!(all_memories.len(), 1);
|
// 过滤掉 _meta 记录
|
||||||
assert_eq!(all_memories[0].namespace, "profile");
|
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
|
||||||
assert_eq!(all_memories[0].memory_key, "work");
|
// 合并 2 条为 1 条,删除 1 条,7 - 2 + 1 = 6 条(加上 _meta 记录)
|
||||||
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
|
assert_eq!(user_memories.len(), 6);
|
||||||
|
// 验证合并后的记忆存在
|
||||||
|
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user