From fa3354db9ccb9844bbf155410911a4f9e00e653c Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 28 Apr 2026 11:29:06 +0800 Subject: [PATCH] feat: add context_window_tokens to model configuration and update related logic - Introduced context_window_tokens in ModelConfig and LLMProviderConfig structs. - Updated context window estimation logic in ContextCompressor to use context_window_tokens. - Modified tests to accommodate new context_window_tokens field. - Refactored memory maintenance logic into a new memory_maintenance.rs file for better organization. - Ensured backward compatibility by providing default values where necessary. Co-authored-by: Copilot --- README.md | 2 +- src/agent/context_compressor.rs | 1 + src/config/mod.rs | 50 ++- src/gateway/execution.rs | 1 + src/gateway/memory_maintenance.rs | 517 ++++++++++++++++++++++++++++++ src/gateway/mod.rs | 1 + src/gateway/session.rs | 517 ++---------------------------- src/scheduler/mod.rs | 1 + tests/test_integration.rs | 1 + tests/test_tool_calling.rs | 1 + 10 files changed, 601 insertions(+), 491 deletions(-) create mode 100644 src/gateway/memory_maintenance.rs diff --git a/README.md b/README.md index b0c918b..2fe5e6a 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文 1. 系统先对当前活动历史做一个近似 token 估算。 估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。 2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。 - 这里的上下文窗口来自 agent 对应模型配置里的 token_limit。 + 这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens;未配置时按 128000 估算。 3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。 当前默认会完整保留最近 3 个 user turn。 4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。 diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index a495ed8..4a40596 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -60,6 +60,7 @@ pub struct ContextCompressor { } impl ContextCompressor { + #[cfg(test)] fn summary_char_budget_for_context_window(context_window: usize) -> usize { const SUMMARY_RATIO: f64 = 0.1; const CHARS_PER_TOKEN: f64 = 2.5; diff --git a/src/config/mod.rs b/src/config/mod.rs index d44f3a0..488f2b7 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -159,6 +159,8 @@ pub struct ModelConfig { pub temperature: Option, #[serde(default)] pub max_tokens: Option, + #[serde(default)] + pub context_window_tokens: Option, #[serde(flatten)] pub extra: HashMap, } @@ -526,6 +528,7 @@ pub struct LLMProviderConfig { pub model_id: String, pub temperature: Option, pub max_tokens: Option, + pub context_window_tokens: Option, pub model_extra: HashMap, pub max_tool_iterations: usize, pub tool_result_max_chars: usize, @@ -534,7 +537,7 @@ pub struct LLMProviderConfig { impl LLMProviderConfig { pub fn context_window_tokens(&self) -> usize { - self.max_tokens + self.context_window_tokens .map(|value| value as usize) .unwrap_or(128_000) } @@ -614,6 +617,7 @@ impl Config { model_id: model.model_id.clone(), temperature: model.temperature, max_tokens: model.max_tokens, + context_window_tokens: model.context_window_tokens, model_extra: model.extra.clone(), max_tool_iterations: agent.max_tool_iterations, tool_result_max_chars: agent.tool_result_max_chars, @@ -1056,7 +1060,44 @@ mod tests { } #[test] - fn test_provider_config_summary_budget_scales_with_model_max_tokens() { + fn test_provider_config_summary_budget_scales_with_context_window_tokens() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus", + "context_window_tokens": 4096 + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let provider_config = config.get_provider_config("default").unwrap(); + + assert_eq!(provider_config.context_window_tokens(), 4096); + assert_eq!(provider_config.context_summary_char_budget(), 1_500); + } + + #[test] + fn test_provider_config_max_tokens_does_not_change_context_window() { let file = tempfile::NamedTempFile::new().unwrap(); std::fs::write( file.path(), @@ -1088,8 +1129,9 @@ mod tests { let config = Config::load(file.path().to_str().unwrap()).unwrap(); let provider_config = config.get_provider_config("default").unwrap(); - assert_eq!(provider_config.context_window_tokens(), 4096); - assert_eq!(provider_config.context_summary_char_budget(), 1_500); + assert_eq!(provider_config.max_tokens, Some(4096)); + assert_eq!(provider_config.context_window_tokens(), 128_000); + assert_eq!(provider_config.context_summary_char_budget(), 32_000); } #[test] diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index 3221012..a7c1964 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -181,6 +181,7 @@ mod tests { model_id: model_id.to_string(), temperature: Some(0.0), max_tokens: Some(32), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 20_000, diff --git a/src/gateway/memory_maintenance.rs b/src/gateway/memory_maintenance.rs new file mode 100644 index 0000000..78c9508 --- /dev/null +++ b/src/gateway/memory_maintenance.rs @@ -0,0 +1,517 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; + +use crate::agent::AgentError; +use crate::config::LLMProviderConfig; +use crate::providers::{ChatCompletionRequest, Message, create_provider}; +use crate::storage::{MemoryRecord, SessionStore}; + +use super::prompt::upsert_managed_agent_memory_summary; + +const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md"); +const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MemoryMaintenanceCategory { + UserFacts, + Preferences, + BehaviorPatterns, + Other, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct MemoryMaintenanceCandidate { + pub(crate) id: String, + pub(crate) namespace: String, + pub(crate) key: String, + pub(crate) content: String, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct MemoryMaintenancePlan { + pub(crate) user_facts: Vec, + pub(crate) preferences: Vec, + pub(crate) behavior_patterns: Vec, + pub(crate) others: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct MemoryMaintenanceMerge { + pub(crate) source_ids: Vec, + pub(crate) namespace: String, + pub(crate) memory_key: String, + pub(crate) content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct MemoryMaintenanceConflict { + pub(crate) source_ids: Vec, + pub(crate) note: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct MemoryMaintenanceModelOutput { + pub(crate) user_facts: Vec, + pub(crate) preferences: Vec, + pub(crate) behavior_patterns: Vec, + pub(crate) merges: Vec, + pub(crate) conflicts: Vec, + pub(crate) low_value_ids: Vec, + pub(crate) managed_markdown: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct MemoryMaintenanceScopeResult { + pub(crate) scope_key: String, + pub(crate) output: MemoryMaintenanceModelOutput, +} + +pub(crate) struct MemoryMaintenanceService { + store: Arc, + provider_config: LLMProviderConfig, +} + +impl MemoryMaintenanceService { + pub(crate) fn new(store: Arc, provider_config: LLMProviderConfig) -> Self { + Self { + store, + provider_config, + } + } + + pub(crate) fn build_plan_for_scope( + &self, + scope_key: &str, + ) -> Result, AgentError> { + let memories = self + .store + .list_memories_for_scope("user", scope_key) + .map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?; + + if memories.is_empty() { + return Ok(None); + } + + Ok(Some(build_memory_maintenance_plan(&memories))) + } + + pub(crate) async fn summarize_for_scope( + &self, + scope_key: &str, + ) -> Result, AgentError> { + let Some(plan) = self.build_plan_for_scope(scope_key)? else { + return Ok(None); + }; + + self.summarize_plan(scope_key, &plan).await.map(Some) + } + + async fn summarize_plan( + &self, + scope_key: &str, + plan: &MemoryMaintenancePlan, + ) -> Result { + let provider = create_provider(self.provider_config.clone()).map_err(|err| { + AgentError::Other(format!("create maintenance provider error: {}", err)) + })?; + + let request = ChatCompletionRequest { + messages: vec![ + Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT), + Message::user( + serde_json::to_string_pretty(&serde_json::json!({ + "scope_key": scope_key, + "candidates": plan, + })) + .unwrap_or_else(|_| "{}".to_string()), + ), + ], + temperature: Some(0.0), + max_tokens: Some(1200), + tools: None, + }; + + let mut last_error = None; + let mut response = None; + + for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS + .iter() + .copied() + .map(Some) + .chain(std::iter::once(None)) + .enumerate() + { + match provider.chat(request.clone()).await { + Ok(success) => { + response = Some(success); + break; + } + Err(err) => { + let error_text = err.to_string(); + let should_retry = + delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); + last_error = Some(error_text.clone()); + + if should_retry { + tracing::warn!( + scope_key = %scope_key, + attempt = attempt + 1, + retry_in_ms = delay_ms.unwrap_or_default(), + error = %error_text, + "Memory maintenance model request failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) + .await; + continue; + } + + return Err(AgentError::Other(format!( + "memory maintenance model error: {}", + error_text + ))); + } + } + } + + let response = response.ok_or_else(|| { + AgentError::Other(format!( + "memory maintenance model error: {}", + last_error.unwrap_or_else(|| "unknown provider error".to_string()) + )) + })?; + + let raw_content = strip_json_code_fence(&response.content); + let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content); + + let output: MemoryMaintenanceModelOutput = + serde_json::from_str(json_candidate).map_err(|err| { + tracing::error!( + scope_key = %scope_key, + error = %err, + raw_len = raw_content.len(), + raw_preview = %preview_text(raw_content, 400), + json_candidate_len = json_candidate.len(), + json_candidate_preview = %preview_text(json_candidate, 400), + "Memory maintenance JSON decode failed" + ); + AgentError::Other(format!("memory maintenance JSON decode error: {}", err)) + })?; + + Ok(output) + } + + pub(crate) async fn run_for_scope( + &self, + scope_key: &str, + ) -> Result, AgentError> { + let Some(plan) = self.build_plan_for_scope(scope_key)? else { + return Ok(None); + }; + + let output = self.summarize_plan(scope_key, &plan).await?; + apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?; + + Ok(Some(output)) + } + + pub(crate) async fn run_for_all_scopes( + &self, + updated_since: Option, + ) -> Result, AgentError> { + let scope_keys = if let Some(cutoff) = updated_since { + self.store + .list_memory_scope_keys_updated_since("user", cutoff) + .map_err(|err| { + AgentError::Other(format!( + "list memory scope keys updated since error: {}", + err + )) + })? + } else { + self.store.list_memory_scope_keys("user").map_err(|err| { + AgentError::Other(format!("list memory scope keys error: {}", err)) + })? + }; + let mut results = Vec::new(); + + for scope_key in scope_keys { + let Some(output) = self.run_for_scope(&scope_key).await? else { + continue; + }; + + results.push(MemoryMaintenanceScopeResult { scope_key, output }); + } + + let combined_markdown = combine_managed_memory_markdown( + &results + .iter() + .map(|result| result.output.managed_markdown.clone()) + .collect::>(), + ); + + if !combined_markdown.is_empty() { + upsert_managed_agent_memory_summary(&combined_markdown)?; + } + + Ok(results) + } +} + +pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan { + let mut plan = MemoryMaintenancePlan::default(); + let mut seen = HashSet::new(); + + for memory in memories { + let normalized_content = memory.content.trim(); + if normalized_content.is_empty() { + continue; + } + + let dedupe_key = format!( + "{}\u{1f}{}\u{1f}{}", + memory.namespace.trim().to_ascii_lowercase(), + memory.memory_key.trim().to_ascii_lowercase(), + normalized_content + ); + if !seen.insert(dedupe_key) { + continue; + } + + let candidate = MemoryMaintenanceCandidate { + id: memory.id.clone(), + namespace: memory.namespace.clone(), + key: memory.memory_key.clone(), + content: normalized_content.to_string(), + }; + + match memory_maintenance_category(&memory.namespace) { + MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate), + MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate), + MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate), + MemoryMaintenanceCategory::Other => plan.others.push(candidate), + } + } + + plan +} + +fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory { + match namespace.trim().to_ascii_lowercase().as_str() { + "profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts, + "preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences, + "patterns" | "behavior" | "habits" | "workflow" => { + MemoryMaintenanceCategory::BehaviorPatterns + } + _ => MemoryMaintenanceCategory::Other, + } +} + +pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool { + let normalized = error.to_ascii_lowercase(); + normalized.contains("error sending request for url") + || normalized.contains("504") + || normalized.contains("gateway timeout") + || normalized.contains("stream timeout") + || normalized.contains("timed out") + || normalized.contains("timeout") +} + +pub(crate) fn strip_json_code_fence(content: &str) -> &str { + let trimmed = content.trim(); + if let Some(rest) = trimmed.strip_prefix("```json") { + return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); + } + if let Some(rest) = trimmed.strip_prefix("```") { + return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); + } + trimmed +} + +pub(crate) fn extract_json_object(content: &str) -> Option<&str> { + let mut start = None; + let mut depth = 0usize; + let mut in_string = false; + let mut escaped = false; + + for (index, ch) in content.char_indices() { + if in_string { + if escaped { + escaped = false; + continue; + } + match ch { + '\\' => escaped = true, + '"' => in_string = false, + _ => {} + } + continue; + } + + match ch { + '"' => in_string = true, + '{' => { + if start.is_none() { + start = Some(index); + } + depth += 1; + } + '}' => { + if depth == 0 { + continue; + } + depth -= 1; + if depth == 0 { + let start = start?; + let end = index + ch.len_utf8(); + return Some(content[start..end].trim()); + } + } + _ => {} + } + } + + None +} + +pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String { + let normalized_chunks = chunks + .iter() + .map(|chunk| chunk.trim()) + .filter(|chunk| !chunk.is_empty()) + .collect::>(); + + let mut combined = Vec::new(); + for (index, chunk) in normalized_chunks.iter().enumerate() { + let chunk_lines = chunk + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .collect::>(); + + let is_subset_of_other = + normalized_chunks + .iter() + .enumerate() + .any(|(other_index, other)| { + if index == other_index { + return false; + } + + let other_lines = other + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .collect::>(); + + chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines) + }); + + if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) { + combined.push((*chunk).to_string()); + } + } + + combined.join("\n\n") +} + +pub(crate) fn apply_memory_maintenance_output( + store: &SessionStore, + scope_key: &str, + plan: &MemoryMaintenancePlan, + output: &MemoryMaintenanceModelOutput, +) -> Result<(), AgentError> { + let all_candidates = plan + .user_facts + .iter() + .chain(plan.preferences.iter()) + .chain(plan.behavior_patterns.iter()) + .chain(plan.others.iter()) + .cloned() + .collect::>(); + + let candidates_by_id = all_candidates + .iter() + .map(|candidate| (candidate.id.as_str(), candidate)) + .collect::>(); + + let mut deleted_ids = HashSet::new(); + + for merge in &output.merges { + if merge.source_ids.is_empty() { + continue; + } + + let source_candidates = merge + .source_ids + .iter() + .filter_map(|id| candidates_by_id.get(id.as_str()).copied()) + .collect::>(); + if source_candidates.is_empty() { + continue; + } + + let existing_target_id = source_candidates + .iter() + .find(|candidate| { + candidate.namespace == merge.namespace && candidate.key == merge.memory_key + }) + .map(|candidate| candidate.id.clone()); + + store + .put_memory(&crate::storage::MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: scope_key.to_string(), + namespace: merge.namespace.trim().to_string(), + memory_key: merge.memory_key.trim().to_string(), + content: merge.content.trim().to_string(), + source_type: "memory_maintenance".to_string(), + source_session_id: None, + source_message_id: None, + source_message_seq: None, + source_channel_name: None, + source_chat_id: None, + }) + .map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?; + + for candidate in source_candidates { + if existing_target_id + .as_ref() + .is_some_and(|target_id| target_id == &candidate.id) + { + continue; + } + if deleted_ids.insert(candidate.id.clone()) { + store + .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) + .map_err(|err| { + AgentError::Other(format!("delete merged source memory error: {}", err)) + })?; + } + } + } + + for memory_id in &output.low_value_ids { + if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) { + if deleted_ids.insert(candidate.id.clone()) { + store + .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) + .map_err(|err| { + AgentError::Other(format!("delete low value memory error: {}", err)) + })?; + } + } + } + + Ok(()) +} + +fn preview_text(content: &str, max_chars: usize) -> String { + let mut preview = content.chars().take(max_chars).collect::(); + if content.chars().count() > max_chars { + preview.push_str("..."); + } + preview.replace('\n', "\\n") +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 9a41f17..28a29af 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,5 +1,6 @@ pub mod execution; pub mod http; +pub mod memory_maintenance; pub mod processor; pub mod prompt; pub mod session; diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 6a0bbe8..3f8828a 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -5,7 +5,6 @@ use crate::bus::{ }; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; -use crate::providers::{ChatCompletionRequest, Message, create_provider}; use crate::skills::SkillRuntime; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ @@ -14,7 +13,6 @@ use crate::tools::{ TimeTool, ToolContext, ToolRegistry, WebFetchTool, }; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -25,183 +23,16 @@ use super::execution::{ AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt, select_provider_config, should_display_message_to_user, }; -use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary}; - -const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md"); -const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum MemoryMaintenanceCategory { - UserFacts, - Preferences, - BehaviorPatterns, - Other, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub(crate) struct MemoryMaintenanceCandidate { - id: String, - namespace: String, - key: String, - content: String, -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] -pub(crate) struct MemoryMaintenancePlan { - user_facts: Vec, - preferences: Vec, - behavior_patterns: Vec, - others: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub(crate) struct MemoryMaintenanceMerge { - pub(crate) source_ids: Vec, - pub(crate) namespace: String, - pub(crate) memory_key: String, - pub(crate) content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub(crate) struct MemoryMaintenanceConflict { - pub(crate) source_ids: Vec, - pub(crate) note: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub(crate) struct MemoryMaintenanceModelOutput { - pub(crate) user_facts: Vec, - pub(crate) preferences: Vec, - pub(crate) behavior_patterns: Vec, - pub(crate) merges: Vec, - pub(crate) conflicts: Vec, - pub(crate) low_value_ids: Vec, - pub(crate) managed_markdown: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct MemoryMaintenanceScopeResult { - pub(crate) scope_key: String, - pub(crate) output: MemoryMaintenanceModelOutput, -} - -fn build_memory_maintenance_plan( - memories: &[crate::storage::MemoryRecord], -) -> MemoryMaintenancePlan { - let mut plan = MemoryMaintenancePlan::default(); - let mut seen = HashSet::new(); - - for memory in memories { - let normalized_content = memory.content.trim(); - if normalized_content.is_empty() { - continue; - } - - let dedupe_key = format!( - "{}\u{1f}{}\u{1f}{}", - memory.namespace.trim().to_ascii_lowercase(), - memory.memory_key.trim().to_ascii_lowercase(), - normalized_content - ); - if !seen.insert(dedupe_key) { - continue; - } - - let candidate = MemoryMaintenanceCandidate { - id: memory.id.clone(), - namespace: memory.namespace.clone(), - key: memory.memory_key.clone(), - content: normalized_content.to_string(), - }; - - match memory_maintenance_category(&memory.namespace) { - MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate), - MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate), - MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate), - MemoryMaintenanceCategory::Other => plan.others.push(candidate), - } - } - - plan -} - -fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory { - match namespace.trim().to_ascii_lowercase().as_str() { - "profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts, - "preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences, - "patterns" | "behavior" | "habits" | "workflow" => { - MemoryMaintenanceCategory::BehaviorPatterns - } - _ => MemoryMaintenanceCategory::Other, - } -} - -fn is_recoverable_maintenance_llm_error(error: &str) -> bool { - let normalized = error.to_ascii_lowercase(); - normalized.contains("error sending request for url") - || normalized.contains("504") - || normalized.contains("gateway timeout") - || normalized.contains("stream timeout") - || normalized.contains("timed out") - || normalized.contains("timeout") -} - -fn strip_json_code_fence(content: &str) -> &str { - let trimmed = content.trim(); - if let Some(rest) = trimmed.strip_prefix("```json") { - return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); - } - if let Some(rest) = trimmed.strip_prefix("```") { - return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed); - } - trimmed -} - -fn extract_json_object(content: &str) -> Option<&str> { - let mut start = None; - let mut depth = 0usize; - let mut in_string = false; - let mut escaped = false; - - for (index, ch) in content.char_indices() { - if in_string { - if escaped { - escaped = false; - continue; - } - match ch { - '\\' => escaped = true, - '"' => in_string = false, - _ => {} - } - continue; - } - - match ch { - '"' => in_string = true, - '{' => { - if start.is_none() { - start = Some(index); - } - depth += 1; - } - '}' => { - if depth == 0 { - continue; - } - depth -= 1; - if depth == 0 { - let start = start?; - let end = index + ch.len_utf8(); - return Some(content[start..end].trim()); - } - } - _ => {} - } - } - - None -} +#[cfg(test)] +use super::memory_maintenance::{ + MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan, + combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error, + strip_json_code_fence, +}; +use super::memory_maintenance::{ + MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService, +}; +use super::prompt::load_agent_prompt; fn preview_text(content: &str, max_chars: usize) -> String { let mut preview = content.chars().take(max_chars).collect::(); @@ -225,138 +56,6 @@ fn enrich_user_content_with_media_refs( Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}")) } -fn combine_managed_memory_markdown(chunks: &[String]) -> String { - let normalized_chunks = chunks - .iter() - .map(|chunk| chunk.trim()) - .filter(|chunk| !chunk.is_empty()) - .collect::>(); - - let mut combined = Vec::new(); - for (index, chunk) in normalized_chunks.iter().enumerate() { - let chunk_lines = chunk - .lines() - .map(str::trim) - .filter(|line| !line.is_empty()) - .collect::>(); - - let is_subset_of_other = - normalized_chunks - .iter() - .enumerate() - .any(|(other_index, other)| { - if index == other_index { - return false; - } - - let other_lines = other - .lines() - .map(str::trim) - .filter(|line| !line.is_empty()) - .collect::>(); - - chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines) - }); - - if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) { - combined.push((*chunk).to_string()); - } - } - - combined.join("\n\n") -} - -fn apply_memory_maintenance_output( - store: &SessionStore, - scope_key: &str, - plan: &MemoryMaintenancePlan, - output: &MemoryMaintenanceModelOutput, -) -> Result<(), AgentError> { - let all_candidates = plan - .user_facts - .iter() - .chain(plan.preferences.iter()) - .chain(plan.behavior_patterns.iter()) - .chain(plan.others.iter()) - .cloned() - .collect::>(); - - let candidates_by_id = all_candidates - .iter() - .map(|candidate| (candidate.id.as_str(), candidate)) - .collect::>(); - - let mut deleted_ids = HashSet::new(); - - for merge in &output.merges { - if merge.source_ids.is_empty() { - continue; - } - - let source_candidates = merge - .source_ids - .iter() - .filter_map(|id| candidates_by_id.get(id.as_str()).copied()) - .collect::>(); - if source_candidates.is_empty() { - continue; - } - - let existing_target_id = source_candidates - .iter() - .find(|candidate| { - candidate.namespace == merge.namespace && candidate.key == merge.memory_key - }) - .map(|candidate| candidate.id.clone()); - - store - .put_memory(&crate::storage::MemoryUpsert { - scope_kind: "user".to_string(), - scope_key: scope_key.to_string(), - namespace: merge.namespace.trim().to_string(), - memory_key: merge.memory_key.trim().to_string(), - content: merge.content.trim().to_string(), - source_type: "memory_maintenance".to_string(), - source_session_id: None, - source_message_id: None, - source_message_seq: None, - source_channel_name: None, - source_chat_id: None, - }) - .map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?; - - for candidate in source_candidates { - if existing_target_id - .as_ref() - .is_some_and(|target_id| target_id == &candidate.id) - { - continue; - } - if deleted_ids.insert(candidate.id.clone()) { - store - .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) - .map_err(|err| { - AgentError::Other(format!("delete merged source memory error: {}", err)) - })?; - } - } - } - - for memory_id in &output.low_value_ids { - if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) { - if deleted_ids.insert(candidate.id.clone()) { - store - .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) - .map_err(|err| { - AgentError::Other(format!("delete low value memory error: {}", err)) - })?; - } - } - } - - Ok(()) -} - /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 pub struct Session { @@ -609,6 +308,7 @@ impl Session { } } + #[cfg(test)] fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> { self.latest_user_message(chat_id) .map(|message| message.id.as_str()) @@ -619,6 +319,7 @@ impl Session { .and_then(|history| history.iter().rev().find(|message| message.role == "user")) } + #[cfg(test)] fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool { self.latest_user_message_id(chat_id) .map(|current_id| current_id == message_id) @@ -1014,194 +715,30 @@ impl SessionManager { self.skills.clone() } - pub(crate) fn build_memory_maintenance_plan_for_scope( - &self, - scope_key: &str, - ) -> Result, AgentError> { - let memories = self - .store - .list_memories_for_scope("user", scope_key) - .map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?; - - if memories.is_empty() { - return Ok(None); - } - - Ok(Some(build_memory_maintenance_plan(&memories))) - } - - pub(crate) fn upsert_managed_agent_memory_summary( - &self, - markdown_body: &str, - ) -> Result<(), AgentError> { - upsert_managed_agent_memory_summary(markdown_body) - } - #[cfg_attr(not(test), allow(dead_code))] pub(crate) async fn summarize_memory_maintenance_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { - let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else { - return Ok(None); - }; - - self.summarize_memory_maintenance_plan(scope_key, &plan) + self.memory_maintenance_service()? + .summarize_for_scope(scope_key) .await - .map(Some) - } - - async fn summarize_memory_maintenance_plan( - &self, - scope_key: &str, - plan: &MemoryMaintenancePlan, - ) -> Result { - let provider_config = self.provider_config_for_agent(None)?; - let provider = create_provider(provider_config).map_err(|err| { - AgentError::Other(format!("create maintenance provider error: {}", err)) - })?; - - let request = ChatCompletionRequest { - messages: vec![ - Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT), - Message::user( - serde_json::to_string_pretty(&serde_json::json!({ - "scope_key": scope_key, - "candidates": plan, - })) - .unwrap_or_else(|_| "{}".to_string()), - ), - ], - temperature: Some(0.0), - max_tokens: Some(1200), - tools: None, - }; - - let mut last_error = None; - let mut response = None; - - for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS - .iter() - .copied() - .map(Some) - .chain(std::iter::once(None)) - .enumerate() - { - match provider.chat(request.clone()).await { - Ok(success) => { - response = Some(success); - break; - } - Err(err) => { - let error_text = err.to_string(); - let should_retry = - delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); - last_error = Some(error_text.clone()); - - if should_retry { - tracing::warn!( - scope_key = %scope_key, - attempt = attempt + 1, - retry_in_ms = delay_ms.unwrap_or_default(), - error = %error_text, - "Memory maintenance model request failed, retrying" - ); - tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) - .await; - continue; - } - - return Err(AgentError::Other(format!( - "memory maintenance model error: {}", - error_text - ))); - } - } - } - - let response = response.ok_or_else(|| { - AgentError::Other(format!( - "memory maintenance model error: {}", - last_error.unwrap_or_else(|| "unknown provider error".to_string()) - )) - })?; - - let raw_content = strip_json_code_fence(&response.content); - let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content); - - let output: MemoryMaintenanceModelOutput = - serde_json::from_str(json_candidate).map_err(|err| { - tracing::error!( - scope_key = %scope_key, - error = %err, - raw_len = raw_content.len(), - raw_preview = %preview_text(raw_content, 400), - json_candidate_len = json_candidate.len(), - json_candidate_preview = %preview_text(json_candidate, 400), - "Memory maintenance JSON decode failed" - ); - AgentError::Other(format!("memory maintenance JSON decode error: {}", err)) - })?; - - Ok(output) - } - - pub(crate) async fn run_memory_maintenance_for_scope( - &self, - scope_key: &str, - ) -> Result, AgentError> { - let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else { - return Ok(None); - }; - - let output = self - .summarize_memory_maintenance_plan(scope_key, &plan) - .await?; - apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?; - - Ok(Some(output)) } pub(crate) async fn run_memory_maintenance_for_all_scopes( &self, updated_since: Option, ) -> Result, AgentError> { - let scope_keys = if let Some(cutoff) = updated_since { - self.store - .list_memory_scope_keys_updated_since("user", cutoff) - .map_err(|err| { - AgentError::Other(format!( - "list memory scope keys updated since error: {}", - err - )) - })? - } else { - self.store.list_memory_scope_keys("user").map_err(|err| { - AgentError::Other(format!("list memory scope keys error: {}", err)) - })? - }; - let mut results = Vec::new(); + self.memory_maintenance_service()? + .run_for_all_scopes(updated_since) + .await + } - for scope_key in scope_keys { - let Some(output) = self.run_memory_maintenance_for_scope(&scope_key).await? else { - continue; - }; - - results.push(MemoryMaintenanceScopeResult { scope_key, output }); - } - - let combined_markdown = combine_managed_memory_markdown( - &results - .iter() - .map(|result| result.output.managed_markdown.clone()) - .collect::>(), - ); - - if !combined_markdown.is_empty() { - self.upsert_managed_agent_memory_summary(&combined_markdown)?; - } - - Ok(results) + fn memory_maintenance_service(&self) -> Result { + Ok(MemoryMaintenanceService::new( + self.store.clone(), + self.provider_config_for_agent(None)?, + )) } pub fn provider_config_for_agent( @@ -1572,6 +1109,7 @@ mod tests { model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 20_000, @@ -1833,6 +1371,7 @@ mod tests { model_id: "timeout-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, @@ -1872,6 +1411,7 @@ mod tests { model_id: "default-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, @@ -1943,6 +1483,7 @@ mod tests { model_id: "default-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, @@ -2020,6 +1561,7 @@ mod tests { model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), + context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), @@ -2120,6 +1662,7 @@ mod tests { model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), + context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), @@ -2182,6 +1725,7 @@ mod tests { model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), + context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), @@ -2241,6 +1785,7 @@ mod tests { model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index ff6a1b0..1472704 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -980,6 +980,7 @@ mod tests { model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: None, + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 4, tool_result_max_chars: 20_000, diff --git a/tests/test_integration.rs b/tests/test_integration.rs index 620ace5..8b0892a 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -23,6 +23,7 @@ fn load_config() -> Option { model_id: openai_model, temperature: Some(0.0), max_tokens: Some(100), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 20, tool_result_max_chars: 20_000, diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index c79e3bc..7084cc1 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -23,6 +23,7 @@ fn load_openai_config() -> Option { model_id: openai_model, temperature: Some(0.0), max_tokens: Some(100), + context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 20, tool_result_max_chars: 20_000,