feat: add context_window_tokens to model configuration and update related logic

- Introduced context_window_tokens in ModelConfig and LLMProviderConfig structs.
- Updated context window estimation logic in ContextCompressor to use context_window_tokens.
- Modified tests to accommodate new context_window_tokens field.
- Refactored memory maintenance logic into a new memory_maintenance.rs file for better organization.
- Ensured backward compatibility by providing default values where necessary.

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-28 11:29:06 +08:00
parent b2c8d76820
commit fa3354db9c
10 changed files with 601 additions and 491 deletions

View File

@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
1. 系统先对当前活动历史做一个近似 token 估算。 1. 系统先对当前活动历史做一个近似 token 估算。
估算规则不是调用 tokenizer而是按“约每 4 个字符约等于 1 token并再乘以 1.2 安全系数”计算。 估算规则不是调用 tokenizer而是按“约每 4 个字符约等于 1 token并再乘以 1.2 安全系数”计算。
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。 2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
这里的上下文窗口来自 agent 对应模型配置里的 token_limit。 这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens未配置时按 128000 估算
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。 3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
当前默认会完整保留最近 3 个 user turn。 当前默认会完整保留最近 3 个 user turn。
4. 一旦满足条件,压缩器会先按 user 消息切分 turn再确定“旧历史”和“最近保留段”的分界点。 4. 一旦满足条件,压缩器会先按 user 消息切分 turn再确定“旧历史”和“最近保留段”的分界点。

View File

@ -60,6 +60,7 @@ pub struct ContextCompressor {
} }
impl ContextCompressor { impl ContextCompressor {
#[cfg(test)]
fn summary_char_budget_for_context_window(context_window: usize) -> usize { fn summary_char_budget_for_context_window(context_window: usize) -> usize {
const SUMMARY_RATIO: f64 = 0.1; const SUMMARY_RATIO: f64 = 0.1;
const CHARS_PER_TOKEN: f64 = 2.5; const CHARS_PER_TOKEN: f64 = 2.5;

View File

@ -159,6 +159,8 @@ pub struct ModelConfig {
pub temperature: Option<f32>, pub temperature: Option<f32>,
#[serde(default)] #[serde(default)]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
#[serde(default)]
pub context_window_tokens: Option<u32>,
#[serde(flatten)] #[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>, pub extra: HashMap<String, serde_json::Value>,
} }
@ -526,6 +528,7 @@ pub struct LLMProviderConfig {
pub model_id: String, pub model_id: String,
pub temperature: Option<f32>, pub temperature: Option<f32>,
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
pub context_window_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>, pub model_extra: HashMap<String, serde_json::Value>,
pub max_tool_iterations: usize, pub max_tool_iterations: usize,
pub tool_result_max_chars: usize, pub tool_result_max_chars: usize,
@ -534,7 +537,7 @@ pub struct LLMProviderConfig {
impl LLMProviderConfig { impl LLMProviderConfig {
pub fn context_window_tokens(&self) -> usize { pub fn context_window_tokens(&self) -> usize {
self.max_tokens self.context_window_tokens
.map(|value| value as usize) .map(|value| value as usize)
.unwrap_or(128_000) .unwrap_or(128_000)
} }
@ -614,6 +617,7 @@ impl Config {
model_id: model.model_id.clone(), model_id: model.model_id.clone(),
temperature: model.temperature, temperature: model.temperature,
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
context_window_tokens: model.context_window_tokens,
model_extra: model.extra.clone(), model_extra: model.extra.clone(),
max_tool_iterations: agent.max_tool_iterations, max_tool_iterations: agent.max_tool_iterations,
tool_result_max_chars: agent.tool_result_max_chars, tool_result_max_chars: agent.tool_result_max_chars,
@ -1056,7 +1060,44 @@ mod tests {
} }
#[test] #[test]
fn test_provider_config_summary_budget_scales_with_model_max_tokens() { fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"context_window_tokens": 4096
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.context_window_tokens(), 4096);
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
}
#[test]
fn test_provider_config_max_tokens_does_not_change_context_window() {
let file = tempfile::NamedTempFile::new().unwrap(); let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write( std::fs::write(
file.path(), file.path(),
@ -1088,8 +1129,9 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap(); let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap(); let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.context_window_tokens(), 4096); assert_eq!(provider_config.max_tokens, Some(4096));
assert_eq!(provider_config.context_summary_char_budget(), 1_500); assert_eq!(provider_config.context_window_tokens(), 128_000);
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
} }
#[test] #[test]

View File

@ -181,6 +181,7 @@ mod tests {
model_id: model_id.to_string(), model_id: model_id.to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(32), max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 1, max_tool_iterations: 1,
tool_result_max_chars: 20_000, tool_result_max_chars: 20_000,

View File

@ -0,0 +1,517 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::providers::{ChatCompletionRequest, Message, create_provider};
use crate::storage::{MemoryRecord, SessionStore};
use super::prompt::upsert_managed_agent_memory_summary;
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MemoryMaintenanceCategory {
UserFacts,
Preferences,
BehaviorPatterns,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String,
pub(crate) namespace: String,
pub(crate) key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenancePlan {
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceMerge {
pub(crate) source_ids: Vec<String>,
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceConflict {
pub(crate) source_ids: Vec<String>,
pub(crate) note: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceModelOutput {
pub(crate) user_facts: Vec<String>,
pub(crate) preferences: Vec<String>,
pub(crate) behavior_patterns: Vec<String>,
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
pub(crate) low_value_ids: Vec<String>,
pub(crate) managed_markdown: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) scope_key: String,
pub(crate) output: MemoryMaintenanceModelOutput,
}
pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
}
impl MemoryMaintenanceService {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
Self {
store,
provider_config,
}
}
pub(crate) fn build_plan_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
let memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
if memories.is_empty() {
return Ok(None);
}
Ok(Some(build_memory_maintenance_plan(&memories)))
}
pub(crate) async fn summarize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
self.summarize_plan(scope_key, &plan).await.map(Some)
}
async fn summarize_plan(
&self,
scope_key: &str,
plan: &MemoryMaintenancePlan,
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
let provider = create_provider(self.provider_config.clone()).map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!({
"scope_key": scope_key,
"candidates": plan,
}))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1200),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory maintenance model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory maintenance model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory maintenance model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let raw_content = strip_json_code_fence(&response.content);
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
let output: MemoryMaintenanceModelOutput =
serde_json::from_str(json_candidate).map_err(|err| {
tracing::error!(
scope_key = %scope_key,
error = %err,
raw_len = raw_content.len(),
raw_preview = %preview_text(raw_content, 400),
json_candidate_len = json_candidate.len(),
json_candidate_preview = %preview_text(json_candidate, 400),
"Memory maintenance JSON decode failed"
);
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
})?;
Ok(output)
}
pub(crate) async fn run_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
let output = self.summarize_plan(scope_key, &plan).await?;
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
Ok(Some(output))
}
pub(crate) async fn run_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
let scope_keys = if let Some(cutoff) = updated_since {
self.store
.list_memory_scope_keys_updated_since("user", cutoff)
.map_err(|err| {
AgentError::Other(format!(
"list memory scope keys updated since error: {}",
err
))
})?
} else {
self.store.list_memory_scope_keys("user").map_err(|err| {
AgentError::Other(format!("list memory scope keys error: {}", err))
})?
};
let mut results = Vec::new();
for scope_key in scope_keys {
let Some(output) = self.run_for_scope(&scope_key).await? else {
continue;
};
results.push(MemoryMaintenanceScopeResult { scope_key, output });
}
let combined_markdown = combine_managed_memory_markdown(
&results
.iter()
.map(|result| result.output.managed_markdown.clone())
.collect::<Vec<_>>(),
);
if !combined_markdown.is_empty() {
upsert_managed_agent_memory_summary(&combined_markdown)?;
}
Ok(results)
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
}
let dedupe_key = format!(
"{}\u{1f}{}\u{1f}{}",
memory.namespace.trim().to_ascii_lowercase(),
memory.memory_key.trim().to_ascii_lowercase(),
normalized_content
);
if !seen.insert(dedupe_key) {
continue;
}
let candidate = MemoryMaintenanceCandidate {
id: memory.id.clone(),
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
};
match memory_maintenance_category(&memory.namespace) {
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
}
}
plan
}
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
match namespace.trim().to_ascii_lowercase().as_str() {
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
"patterns" | "behavior" | "habits" | "workflow" => {
MemoryMaintenanceCategory::BehaviorPatterns
}
_ => MemoryMaintenanceCategory::Other,
}
}
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("error sending request for url")
|| normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
if let Some(rest) = trimmed.strip_prefix("```") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
trimmed
}
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
let mut start = None;
let mut depth = 0usize;
let mut in_string = false;
let mut escaped = false;
for (index, ch) in content.char_indices() {
if in_string {
if escaped {
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
'"' => in_string = true,
'{' => {
if start.is_none() {
start = Some(index);
}
depth += 1;
}
'}' => {
if depth == 0 {
continue;
}
depth -= 1;
if depth == 0 {
let start = start?;
let end = index + ch.len_utf8();
return Some(content[start..end].trim());
}
}
_ => {}
}
}
None
}
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
let normalized_chunks = chunks
.iter()
.map(|chunk| chunk.trim())
.filter(|chunk| !chunk.is_empty())
.collect::<Vec<_>>();
let mut combined = Vec::new();
for (index, chunk) in normalized_chunks.iter().enumerate() {
let chunk_lines = chunk
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
let is_subset_of_other =
normalized_chunks
.iter()
.enumerate()
.any(|(other_index, other)| {
if index == other_index {
return false;
}
let other_lines = other
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
});
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
combined.push((*chunk).to_string());
}
}
combined.join("\n\n")
}
pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryMaintenanceModelOutput,
) -> Result<(), AgentError> {
let all_candidates = plan
.user_facts
.iter()
.chain(plan.preferences.iter())
.chain(plan.behavior_patterns.iter())
.chain(plan.others.iter())
.cloned()
.collect::<Vec<_>>();
let candidates_by_id = all_candidates
.iter()
.map(|candidate| (candidate.id.as_str(), candidate))
.collect::<HashMap<_, _>>();
let mut deleted_ids = HashSet::new();
for merge in &output.merges {
if merge.source_ids.is_empty() {
continue;
}
let source_candidates = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
.collect::<Vec<_>>();
if source_candidates.is_empty() {
continue;
}
let existing_target_id = source_candidates
.iter()
.find(|candidate| {
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
})
.map(|candidate| candidate.id.clone());
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: merge.namespace.trim().to_string(),
memory_key: merge.memory_key.trim().to_string(),
content: merge.content.trim().to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
for candidate in source_candidates {
if existing_target_id
.as_ref()
.is_some_and(|target_id| target_id == &candidate.id)
{
continue;
}
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete merged source memory error: {}", err))
})?;
}
}
}
for memory_id in &output.low_value_ids {
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete low value memory error: {}", err))
})?;
}
}
}
Ok(())
}
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
preview.push_str("...");
}
preview.replace('\n', "\\n")
}

View File

@ -1,5 +1,6 @@
pub mod execution; pub mod execution;
pub mod http; pub mod http;
pub mod memory_maintenance;
pub mod processor; pub mod processor;
pub mod prompt; pub mod prompt;
pub mod session; pub mod session;

View File

@ -5,7 +5,6 @@ use crate::bus::{
}; };
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::providers::{ChatCompletionRequest, Message, create_provider};
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{ use crate::tools::{
@ -14,7 +13,6 @@ use crate::tools::{
TimeTool, ToolContext, ToolRegistry, WebFetchTool, TimeTool, ToolContext, ToolRegistry, WebFetchTool,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -25,183 +23,16 @@ use super::execution::{
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt, AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
select_provider_config, should_display_message_to_user, select_provider_config, should_display_message_to_user,
}; };
use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary}; #[cfg(test)]
use super::memory_maintenance::{
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md"); MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error,
strip_json_code_fence,
#[derive(Debug, Clone, Copy, PartialEq, Eq)] };
enum MemoryMaintenanceCategory { use super::memory_maintenance::{
UserFacts, MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
Preferences, };
BehaviorPatterns, use super::prompt::load_agent_prompt;
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
id: String,
namespace: String,
key: String,
content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenancePlan {
user_facts: Vec<MemoryMaintenanceCandidate>,
preferences: Vec<MemoryMaintenanceCandidate>,
behavior_patterns: Vec<MemoryMaintenanceCandidate>,
others: Vec<MemoryMaintenanceCandidate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceMerge {
pub(crate) source_ids: Vec<String>,
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceConflict {
pub(crate) source_ids: Vec<String>,
pub(crate) note: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceModelOutput {
pub(crate) user_facts: Vec<String>,
pub(crate) preferences: Vec<String>,
pub(crate) behavior_patterns: Vec<String>,
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
pub(crate) low_value_ids: Vec<String>,
pub(crate) managed_markdown: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) scope_key: String,
pub(crate) output: MemoryMaintenanceModelOutput,
}
fn build_memory_maintenance_plan(
memories: &[crate::storage::MemoryRecord],
) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
}
let dedupe_key = format!(
"{}\u{1f}{}\u{1f}{}",
memory.namespace.trim().to_ascii_lowercase(),
memory.memory_key.trim().to_ascii_lowercase(),
normalized_content
);
if !seen.insert(dedupe_key) {
continue;
}
let candidate = MemoryMaintenanceCandidate {
id: memory.id.clone(),
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
};
match memory_maintenance_category(&memory.namespace) {
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
}
}
plan
}
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
match namespace.trim().to_ascii_lowercase().as_str() {
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
"patterns" | "behavior" | "habits" | "workflow" => {
MemoryMaintenanceCategory::BehaviorPatterns
}
_ => MemoryMaintenanceCategory::Other,
}
}
fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("error sending request for url")
|| normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
if let Some(rest) = trimmed.strip_prefix("```") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
trimmed
}
fn extract_json_object(content: &str) -> Option<&str> {
let mut start = None;
let mut depth = 0usize;
let mut in_string = false;
let mut escaped = false;
for (index, ch) in content.char_indices() {
if in_string {
if escaped {
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
'"' => in_string = true,
'{' => {
if start.is_none() {
start = Some(index);
}
depth += 1;
}
'}' => {
if depth == 0 {
continue;
}
depth -= 1;
if depth == 0 {
let start = start?;
let end = index + ch.len_utf8();
return Some(content[start..end].trim());
}
}
_ => {}
}
}
None
}
fn preview_text(content: &str, max_chars: usize) -> String { fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>(); let mut preview = content.chars().take(max_chars).collect::<String>();
@ -225,138 +56,6 @@ fn enrich_user_content_with_media_refs(
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}")) Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
} }
fn combine_managed_memory_markdown(chunks: &[String]) -> String {
let normalized_chunks = chunks
.iter()
.map(|chunk| chunk.trim())
.filter(|chunk| !chunk.is_empty())
.collect::<Vec<_>>();
let mut combined = Vec::new();
for (index, chunk) in normalized_chunks.iter().enumerate() {
let chunk_lines = chunk
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
let is_subset_of_other =
normalized_chunks
.iter()
.enumerate()
.any(|(other_index, other)| {
if index == other_index {
return false;
}
let other_lines = other
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
});
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
combined.push((*chunk).to_string());
}
}
combined.join("\n\n")
}
fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryMaintenanceModelOutput,
) -> Result<(), AgentError> {
let all_candidates = plan
.user_facts
.iter()
.chain(plan.preferences.iter())
.chain(plan.behavior_patterns.iter())
.chain(plan.others.iter())
.cloned()
.collect::<Vec<_>>();
let candidates_by_id = all_candidates
.iter()
.map(|candidate| (candidate.id.as_str(), candidate))
.collect::<HashMap<_, _>>();
let mut deleted_ids = HashSet::new();
for merge in &output.merges {
if merge.source_ids.is_empty() {
continue;
}
let source_candidates = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
.collect::<Vec<_>>();
if source_candidates.is_empty() {
continue;
}
let existing_target_id = source_candidates
.iter()
.find(|candidate| {
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
})
.map(|candidate| candidate.id.clone());
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: merge.namespace.trim().to_string(),
memory_key: merge.memory_key.trim().to_string(),
content: merge.content.trim().to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
for candidate in source_candidates {
if existing_target_id
.as_ref()
.is_some_and(|target_id| target_id == &candidate.id)
{
continue;
}
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete merged source memory error: {}", err))
})?;
}
}
}
for memory_id in &output.low_value_ids {
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete low value memory error: {}", err))
})?;
}
}
}
Ok(())
}
/// Session 按 channel 隔离,每个 channel 一个 Session /// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理 /// History 按 chat_id 隔离,由 Session 统一管理
pub struct Session { pub struct Session {
@ -609,6 +308,7 @@ impl Session {
} }
} }
#[cfg(test)]
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> { fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
self.latest_user_message(chat_id) self.latest_user_message(chat_id)
.map(|message| message.id.as_str()) .map(|message| message.id.as_str())
@ -619,6 +319,7 @@ impl Session {
.and_then(|history| history.iter().rev().find(|message| message.role == "user")) .and_then(|history| history.iter().rev().find(|message| message.role == "user"))
} }
#[cfg(test)]
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool { fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
self.latest_user_message_id(chat_id) self.latest_user_message_id(chat_id)
.map(|current_id| current_id == message_id) .map(|current_id| current_id == message_id)
@ -1014,194 +715,30 @@ impl SessionManager {
self.skills.clone() self.skills.clone()
} }
pub(crate) fn build_memory_maintenance_plan_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
let memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
if memories.is_empty() {
return Ok(None);
}
Ok(Some(build_memory_maintenance_plan(&memories)))
}
pub(crate) fn upsert_managed_agent_memory_summary(
&self,
markdown_body: &str,
) -> Result<(), AgentError> {
upsert_managed_agent_memory_summary(markdown_body)
}
#[cfg_attr(not(test), allow(dead_code))] #[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn summarize_memory_maintenance_for_scope( pub(crate) async fn summarize_memory_maintenance_for_scope(
&self, &self,
scope_key: &str, scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> { ) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else { self.memory_maintenance_service()?
return Ok(None); .summarize_for_scope(scope_key)
};
self.summarize_memory_maintenance_plan(scope_key, &plan)
.await .await
.map(Some)
}
async fn summarize_memory_maintenance_plan(
&self,
scope_key: &str,
plan: &MemoryMaintenancePlan,
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
let provider_config = self.provider_config_for_agent(None)?;
let provider = create_provider(provider_config).map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!({
"scope_key": scope_key,
"candidates": plan,
}))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1200),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory maintenance model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory maintenance model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory maintenance model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let raw_content = strip_json_code_fence(&response.content);
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
let output: MemoryMaintenanceModelOutput =
serde_json::from_str(json_candidate).map_err(|err| {
tracing::error!(
scope_key = %scope_key,
error = %err,
raw_len = raw_content.len(),
raw_preview = %preview_text(raw_content, 400),
json_candidate_len = json_candidate.len(),
json_candidate_preview = %preview_text(json_candidate, 400),
"Memory maintenance JSON decode failed"
);
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
})?;
Ok(output)
}
pub(crate) async fn run_memory_maintenance_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else {
return Ok(None);
};
let output = self
.summarize_memory_maintenance_plan(scope_key, &plan)
.await?;
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
Ok(Some(output))
} }
pub(crate) async fn run_memory_maintenance_for_all_scopes( pub(crate) async fn run_memory_maintenance_for_all_scopes(
&self, &self,
updated_since: Option<i64>, updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> { ) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
let scope_keys = if let Some(cutoff) = updated_since { self.memory_maintenance_service()?
self.store .run_for_all_scopes(updated_since)
.list_memory_scope_keys_updated_since("user", cutoff) .await
.map_err(|err| { }
AgentError::Other(format!(
"list memory scope keys updated since error: {}", fn memory_maintenance_service(&self) -> Result<MemoryMaintenanceService, AgentError> {
err Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_config_for_agent(None)?,
)) ))
})?
} else {
self.store.list_memory_scope_keys("user").map_err(|err| {
AgentError::Other(format!("list memory scope keys error: {}", err))
})?
};
let mut results = Vec::new();
for scope_key in scope_keys {
let Some(output) = self.run_memory_maintenance_for_scope(&scope_key).await? else {
continue;
};
results.push(MemoryMaintenanceScopeResult { scope_key, output });
}
let combined_markdown = combine_managed_memory_markdown(
&results
.iter()
.map(|result| result.output.managed_markdown.clone())
.collect::<Vec<_>>(),
);
if !combined_markdown.is_empty() {
self.upsert_managed_agent_memory_summary(&combined_markdown)?;
}
Ok(results)
} }
pub fn provider_config_for_agent( pub fn provider_config_for_agent(
@ -1572,6 +1109,7 @@ mod tests {
model_id: "test-model".to_string(), model_id: "test-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(32), max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 1, max_tool_iterations: 1,
tool_result_max_chars: 20_000, tool_result_max_chars: 20_000,
@ -1833,6 +1371,7 @@ mod tests {
model_id: "timeout-model".to_string(), model_id: "timeout-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(32), max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 1, max_tool_iterations: 1,
llm_timeout_secs: 30, llm_timeout_secs: 30,
@ -1872,6 +1411,7 @@ mod tests {
model_id: "default-model".to_string(), model_id: "default-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(32), max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 1, max_tool_iterations: 1,
llm_timeout_secs: 30, llm_timeout_secs: 30,
@ -1943,6 +1483,7 @@ mod tests {
model_id: "default-model".to_string(), model_id: "default-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(32), max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 1, max_tool_iterations: 1,
llm_timeout_secs: 30, llm_timeout_secs: 30,
@ -2020,6 +1561,7 @@ mod tests {
model_id: "maintenance-model".to_string(), model_id: "maintenance-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(256), max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([( model_extra: HashMap::from([(
"mock_response_content".to_string(), "mock_response_content".to_string(),
json!(mock_response_content), json!(mock_response_content),
@ -2120,6 +1662,7 @@ mod tests {
model_id: "maintenance-model".to_string(), model_id: "maintenance-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(256), max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([( model_extra: HashMap::from([(
"mock_response_content".to_string(), "mock_response_content".to_string(),
json!(mock_response_content), json!(mock_response_content),
@ -2182,6 +1725,7 @@ mod tests {
model_id: "maintenance-model".to_string(), model_id: "maintenance-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(256), max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([( model_extra: HashMap::from([(
"mock_response_content".to_string(), "mock_response_content".to_string(),
json!(mock_response_content), json!(mock_response_content),
@ -2241,6 +1785,7 @@ mod tests {
model_id: "maintenance-model".to_string(), model_id: "maintenance-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(256), max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 1, max_tool_iterations: 1,
llm_timeout_secs: 30, llm_timeout_secs: 30,

View File

@ -980,6 +980,7 @@ mod tests {
model_id: "test-model".to_string(), model_id: "test-model".to_string(),
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: None, max_tokens: None,
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 4, max_tool_iterations: 4,
tool_result_max_chars: 20_000, tool_result_max_chars: 20_000,

View File

@ -23,6 +23,7 @@ fn load_config() -> Option<LLMProviderConfig> {
model_id: openai_model, model_id: openai_model,
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(100), max_tokens: Some(100),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 20, max_tool_iterations: 20,
tool_result_max_chars: 20_000, tool_result_max_chars: 20_000,

View File

@ -23,6 +23,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
model_id: openai_model, model_id: openai_model,
temperature: Some(0.0), temperature: Some(0.0),
max_tokens: Some(100), max_tokens: Some(100),
context_window_tokens: None,
model_extra: HashMap::new(), model_extra: HashMap::new(),
max_tool_iterations: 20, max_tool_iterations: 20,
tool_result_max_chars: 20_000, tool_result_max_chars: 20_000,