diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index ffbedec..edddb38 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1128,6 +1128,7 @@ mod tests { api_key: "test-key".to_string(), extra_headers: std::collections::HashMap::new(), llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), diff --git a/src/config/mod.rs b/src/config/mod.rs index f08a129..09dbaf7 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -248,6 +248,8 @@ pub struct ProviderConfig { pub extra_headers: HashMap, #[serde(default = "default_llm_timeout_secs")] pub llm_timeout_secs: u64, + #[serde(default = "default_memory_maintenance_timeout_secs")] + pub memory_maintenance_timeout_secs: u64, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -291,6 +293,10 @@ fn default_llm_timeout_secs() -> u64 { 120 } +fn default_memory_maintenance_timeout_secs() -> u64 { + 600 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GatewayConfig { #[serde(default = "default_gateway_host")] @@ -623,6 +629,7 @@ pub struct LLMProviderConfig { pub api_key: String, pub extra_headers: HashMap, pub llm_timeout_secs: u64, + pub memory_maintenance_timeout_secs: u64, pub model_id: String, pub temperature: Option, pub max_tokens: Option, @@ -712,6 +719,7 @@ impl Config { api_key: provider.api_key.clone(), extra_headers: provider.extra_headers.clone(), llm_timeout_secs: provider.llm_timeout_secs, + memory_maintenance_timeout_secs: provider.memory_maintenance_timeout_secs, model_id: model.model_id.clone(), temperature: model.temperature, max_tokens: model.max_tokens, diff --git a/src/gateway/command.rs b/src/gateway/command.rs index 9fad8e9..7100b7b 100644 --- a/src/gateway/command.rs +++ b/src/gateway/command.rs @@ -50,6 +50,7 @@ mod tests { api_key: "test-key".to_string(), extra_headers: HashMap::new(), llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), diff --git a/src/gateway/memory_maintenance.rs b/src/gateway/memory_maintenance.rs index a333f90..5f0c8af 100644 --- a/src/gateway/memory_maintenance.rs +++ b/src/gateway/memory_maintenance.rs @@ -4,9 +4,9 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; -use crate::agent::{AgentError, AgentRuntimeConfig}; +use crate::agent::AgentError; use crate::config::LLMProviderConfig; -use crate::providers::{ChatCompletionRequest, Message, create_provider}; +use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider}; use crate::storage::{MemoryRecord, SessionStore}; use super::prompt::upsert_managed_agent_memory_summary; @@ -17,14 +17,6 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_step2_system_prompt.md"); const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum MemoryMaintenanceCategory { - UserFacts, - Preferences, - BehaviorPatterns, - Other, -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenanceCandidate { pub(crate) id: String, @@ -35,10 +27,7 @@ pub(crate) struct MemoryMaintenanceCandidate { #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct MemoryMaintenancePlan { - pub(crate) user_facts: Vec, - pub(crate) preferences: Vec, - pub(crate) behavior_patterns: Vec, - pub(crate) others: Vec, + pub(crate) candidates: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -94,6 +83,26 @@ impl MemoryMaintenanceService { } } + /// 创建记忆整理专用的 provider,使用 memory_maintenance_timeout_secs 作为超时时间 + fn create_maintenance_provider( + &self, + ) -> Result, crate::providers::ProviderError> { + let config = &self.provider_config; + let runtime_config = ProviderRuntimeConfig { + provider_type: config.provider_type.clone(), + name: config.name.clone(), + base_url: config.base_url.clone(), + api_key: config.api_key.clone(), + extra_headers: config.extra_headers.clone(), + llm_timeout_secs: config.memory_maintenance_timeout_secs, + model_id: config.model_id.clone(), + temperature: config.temperature, + max_tokens: config.max_tokens, + model_extra: config.model_extra.clone(), + }; + create_provider(runtime_config) + } + pub(crate) fn build_plan_for_scope( &self, scope_key: &str, @@ -126,8 +135,7 @@ impl MemoryMaintenanceService { scope_key: &str, plan: &MemoryMaintenancePlan, ) -> Result { - let runtime_config = AgentRuntimeConfig::from(self.provider_config.clone()); - let provider = create_provider(runtime_config.provider).map_err(|err| { + let provider = self.create_maintenance_provider().map_err(|err| { AgentError::Other(format!("create maintenance provider error: {}", err)) })?; @@ -258,8 +266,7 @@ impl MemoryMaintenanceService { scope_key: &str, remaining_memories: &[MemoryRecord], ) -> Result { - let runtime_config = AgentRuntimeConfig::from(self.provider_config.clone()); - let provider = create_provider(runtime_config.provider).map_err(|err| { + let provider = self.create_maintenance_provider().map_err(|err| { AgentError::Other(format!("create maintenance provider error: {}", err)) })?; @@ -478,28 +485,12 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory content: normalized_content.to_string(), }; - match memory_maintenance_category(&memory.namespace) { - MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate), - MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate), - MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate), - MemoryMaintenanceCategory::Other => plan.others.push(candidate), - } + plan.candidates.push(candidate); } plan } -fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory { - match namespace.trim().to_ascii_lowercase().as_str() { - "profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts, - "preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences, - "patterns" | "behavior" | "habits" | "workflow" => { - MemoryMaintenanceCategory::BehaviorPatterns - } - _ => MemoryMaintenanceCategory::Other, - } -} - pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool { let normalized = error.to_ascii_lowercase(); normalized.contains("error sending request for url") @@ -577,14 +568,7 @@ pub(crate) fn apply_memory_maintenance_output( plan: &MemoryMaintenancePlan, output: &MemoryOrganizationOutput, ) -> Result<(), AgentError> { - let all_candidates = plan - .user_facts - .iter() - .chain(plan.preferences.iter()) - .chain(plan.behavior_patterns.iter()) - .chain(plan.others.iter()) - .cloned() - .collect::>(); + let all_candidates = plan.candidates.clone(); let candidates_by_id = all_candidates .iter() @@ -671,6 +655,4 @@ fn preview_text(content: &str, max_chars: usize) -> String { } #[cfg(test)] -mod tests { - use super::*; -} +mod tests {} diff --git a/src/gateway/memory_maintenance_step1_system_prompt.md b/src/gateway/memory_maintenance_step1_system_prompt.md index fb4277c..e0d28ef 100644 --- a/src/gateway/memory_maintenance_step1_system_prompt.md +++ b/src/gateway/memory_maintenance_step1_system_prompt.md @@ -19,14 +19,22 @@ - merges:对象数组。每个对象必须包含 source_ids、namespace、memory_key、content。 - source_ids: 字符串数组,要合并的源记忆ID列表 - - namespace: 目标命名空间 - - memory_key: 目标记忆键 + - namespace: 目标命名空间(可以自由决定,不限于固定分类) + - memory_key: 目标记忆键(可以自由决定) - content: 合并后的内容 - conflicts:对象数组。每个对象必须包含 source_ids、note。 - source_ids: 冲突的记忆ID列表 - note: 冲突说明 - low_value_ids:需要删除的低价值候选记忆 ID 数组 +组织原则(由你自主决定): + +- 根据记忆的语义内容自然分组,不必拘泥于预定义分类 +- 相似的、互补的记忆可以合并 +- 过期、重复、过细的记忆可以标记为低价值 +- namespace 和 memory_key 的命名应当简洁、有意义 +- 可以自由创建新的 namespace 来组织相关记忆 + 额外约束: - 只能引用输入里出现过的候选 id。 diff --git a/src/gateway/provider_config_service.rs b/src/gateway/provider_config_service.rs index 757a913..b2ff563 100644 --- a/src/gateway/provider_config_service.rs +++ b/src/gateway/provider_config_service.rs @@ -51,6 +51,7 @@ mod tests { api_key: "test-key".to_string(), extra_headers: HashMap::new(), llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, model_id: model_id.to_string(), temperature: Some(0.0), max_tokens: Some(32), diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 725d688..5a3f6c9 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -520,6 +520,7 @@ mod tests { api_key: "test-key".to_string(), extra_headers: HashMap::new(), llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), @@ -786,6 +787,7 @@ mod tests { model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -827,6 +829,7 @@ mod tests { model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -900,6 +903,7 @@ mod tests { model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -982,6 +986,7 @@ mod tests { )]), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1065,6 +1070,7 @@ mod tests { model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 1, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1113,7 +1119,7 @@ mod tests { assert!(error.contains("provider=maintenance-provider")); assert!(error.contains("model=maintenance-model")); assert!(error.contains("url=https://example.invalid/v1/chat/completions")); - assert!(error.contains("timeout_secs=1")); + assert!(error.contains("timeout_secs=600")); assert!(error.contains("error sending request for url")); } @@ -1147,6 +1153,7 @@ mod tests { )]), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1211,6 +1218,7 @@ mod tests { )]), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1284,6 +1292,7 @@ mod tests { )]), max_tool_iterations: 1, llm_timeout_secs: 30, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1344,6 +1353,7 @@ mod tests { model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 1, + memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1749,12 +1759,12 @@ mod tests { ]; let plan = build_memory_maintenance_plan(&memories); - assert_eq!(plan.user_facts.len(), 1); - assert_eq!(plan.preferences.len(), 1); - assert_eq!(plan.behavior_patterns.len(), 1); - assert!(plan.others.is_empty()); - assert_eq!(plan.user_facts[0].content, "用户在做AI产品"); - assert_eq!(plan.preferences[0].content, "偏好简洁表达"); - assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码"); + // 去重后应该有3条(第1、2条重复) + assert_eq!(plan.candidates.len(), 3); + // 验证内容包含所有唯一的记忆 + let contents: Vec = plan.candidates.iter().map(|c| c.content.clone()).collect(); + assert!(contents.contains(&"用户在做AI产品".to_string())); + assert!(contents.contains(&"偏好简洁表达".to_string())); + assert!(contents.contains(&"习惯先问方案再要代码".to_string())); } } diff --git a/tests/test_integration.rs b/tests/test_integration.rs index f27358f..b428dba 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -35,6 +35,7 @@ fn load_config() -> Option { api_key: openai_api_key, extra_headers: HashMap::new(), llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, model_id: openai_model, temperature: Some(0.0), max_tokens: Some(100), diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index bfe7b9c..653d1c6 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -37,6 +37,7 @@ fn load_openai_config() -> Option { api_key: openai_api_key, extra_headers: HashMap::new(), llm_timeout_secs: 120, + memory_maintenance_timeout_secs: 600, model_id: openai_model, temperature: Some(0.0), max_tokens: Some(100),