feat: add context_window_tokens to model configuration and update related logic
- Introduced context_window_tokens in ModelConfig and LLMProviderConfig structs. - Updated context window estimation logic in ContextCompressor to use context_window_tokens. - Modified tests to accommodate new context_window_tokens field. - Refactored memory maintenance logic into a new memory_maintenance.rs file for better organization. - Ensured backward compatibility by providing default values where necessary. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
b2c8d76820
commit
fa3354db9c
@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
|
|||||||
1. 系统先对当前活动历史做一个近似 token 估算。
|
1. 系统先对当前活动历史做一个近似 token 估算。
|
||||||
估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。
|
估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。
|
||||||
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
|
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
|
||||||
这里的上下文窗口来自 agent 对应模型配置里的 token_limit。
|
这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens;未配置时按 128000 估算。
|
||||||
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
|
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
|
||||||
当前默认会完整保留最近 3 个 user turn。
|
当前默认会完整保留最近 3 个 user turn。
|
||||||
4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。
|
4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。
|
||||||
|
|||||||
@ -60,6 +60,7 @@ pub struct ContextCompressor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ContextCompressor {
|
impl ContextCompressor {
|
||||||
|
#[cfg(test)]
|
||||||
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
|
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
|
||||||
const SUMMARY_RATIO: f64 = 0.1;
|
const SUMMARY_RATIO: f64 = 0.1;
|
||||||
const CHARS_PER_TOKEN: f64 = 2.5;
|
const CHARS_PER_TOKEN: f64 = 2.5;
|
||||||
|
|||||||
@ -159,6 +159,8 @@ pub struct ModelConfig {
|
|||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub context_window_tokens: Option<u32>,
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub extra: HashMap<String, serde_json::Value>,
|
pub extra: HashMap<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
@ -526,6 +528,7 @@ pub struct LLMProviderConfig {
|
|||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
pub context_window_tokens: Option<u32>,
|
||||||
pub model_extra: HashMap<String, serde_json::Value>,
|
pub model_extra: HashMap<String, serde_json::Value>,
|
||||||
pub max_tool_iterations: usize,
|
pub max_tool_iterations: usize,
|
||||||
pub tool_result_max_chars: usize,
|
pub tool_result_max_chars: usize,
|
||||||
@ -534,7 +537,7 @@ pub struct LLMProviderConfig {
|
|||||||
|
|
||||||
impl LLMProviderConfig {
|
impl LLMProviderConfig {
|
||||||
pub fn context_window_tokens(&self) -> usize {
|
pub fn context_window_tokens(&self) -> usize {
|
||||||
self.max_tokens
|
self.context_window_tokens
|
||||||
.map(|value| value as usize)
|
.map(|value| value as usize)
|
||||||
.unwrap_or(128_000)
|
.unwrap_or(128_000)
|
||||||
}
|
}
|
||||||
@ -614,6 +617,7 @@ impl Config {
|
|||||||
model_id: model.model_id.clone(),
|
model_id: model.model_id.clone(),
|
||||||
temperature: model.temperature,
|
temperature: model.temperature,
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
|
context_window_tokens: model.context_window_tokens,
|
||||||
model_extra: model.extra.clone(),
|
model_extra: model.extra.clone(),
|
||||||
max_tool_iterations: agent.max_tool_iterations,
|
max_tool_iterations: agent.max_tool_iterations,
|
||||||
tool_result_max_chars: agent.tool_result_max_chars,
|
tool_result_max_chars: agent.tool_result_max_chars,
|
||||||
@ -1056,7 +1060,44 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_provider_config_summary_budget_scales_with_model_max_tokens() {
|
fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus",
|
||||||
|
"context_window_tokens": 4096
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider_config.context_window_tokens(), 4096);
|
||||||
|
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_config_max_tokens_does_not_change_context_window() {
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
std::fs::write(
|
std::fs::write(
|
||||||
file.path(),
|
file.path(),
|
||||||
@ -1088,8 +1129,9 @@ mod tests {
|
|||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
let provider_config = config.get_provider_config("default").unwrap();
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
assert_eq!(provider_config.context_window_tokens(), 4096);
|
assert_eq!(provider_config.max_tokens, Some(4096));
|
||||||
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
assert_eq!(provider_config.context_window_tokens(), 128_000);
|
||||||
|
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -181,6 +181,7 @@ mod tests {
|
|||||||
model_id: model_id.to_string(),
|
model_id: model_id.to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(32),
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 1,
|
max_tool_iterations: 1,
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
|
|||||||
517
src/gateway/memory_maintenance.rs
Normal file
517
src/gateway/memory_maintenance.rs
Normal file
@ -0,0 +1,517 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::agent::AgentError;
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
||||||
|
use crate::storage::{MemoryRecord, SessionStore};
|
||||||
|
|
||||||
|
use super::prompt::upsert_managed_agent_memory_summary;
|
||||||
|
|
||||||
|
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||||
|
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum MemoryMaintenanceCategory {
|
||||||
|
UserFacts,
|
||||||
|
Preferences,
|
||||||
|
BehaviorPatterns,
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceCandidate {
|
||||||
|
pub(crate) id: String,
|
||||||
|
pub(crate) namespace: String,
|
||||||
|
pub(crate) key: String,
|
||||||
|
pub(crate) content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenancePlan {
|
||||||
|
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceMerge {
|
||||||
|
pub(crate) source_ids: Vec<String>,
|
||||||
|
pub(crate) namespace: String,
|
||||||
|
pub(crate) memory_key: String,
|
||||||
|
pub(crate) content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceConflict {
|
||||||
|
pub(crate) source_ids: Vec<String>,
|
||||||
|
pub(crate) note: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceModelOutput {
|
||||||
|
pub(crate) user_facts: Vec<String>,
|
||||||
|
pub(crate) preferences: Vec<String>,
|
||||||
|
pub(crate) behavior_patterns: Vec<String>,
|
||||||
|
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||||||
|
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||||||
|
pub(crate) low_value_ids: Vec<String>,
|
||||||
|
pub(crate) managed_markdown: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceScopeResult {
|
||||||
|
pub(crate) scope_key: String,
|
||||||
|
pub(crate) output: MemoryMaintenanceModelOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct MemoryMaintenanceService {
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryMaintenanceService {
|
||||||
|
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
provider_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_plan_for_scope(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||||
|
let memories = self
|
||||||
|
.store
|
||||||
|
.list_memories_for_scope("user", scope_key)
|
||||||
|
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||||||
|
|
||||||
|
if memories.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn summarize_for_scope(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||||
|
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
self.summarize_plan(scope_key, &plan).await.map(Some)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn summarize_plan(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
plan: &MemoryMaintenancePlan,
|
||||||
|
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
|
||||||
|
let provider = create_provider(self.provider_config.clone()).map_err(|err| {
|
||||||
|
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
|
||||||
|
Message::user(
|
||||||
|
serde_json::to_string_pretty(&serde_json::json!({
|
||||||
|
"scope_key": scope_key,
|
||||||
|
"candidates": plan,
|
||||||
|
}))
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(1200),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut last_error = None;
|
||||||
|
let mut response = None;
|
||||||
|
|
||||||
|
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.map(Some)
|
||||||
|
.chain(std::iter::once(None))
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
match provider.chat(request.clone()).await {
|
||||||
|
Ok(success) => {
|
||||||
|
response = Some(success);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let error_text = err.to_string();
|
||||||
|
let should_retry =
|
||||||
|
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||||||
|
last_error = Some(error_text.clone());
|
||||||
|
|
||||||
|
if should_retry {
|
||||||
|
tracing::warn!(
|
||||||
|
scope_key = %scope_key,
|
||||||
|
attempt = attempt + 1,
|
||||||
|
retry_in_ms = delay_ms.unwrap_or_default(),
|
||||||
|
error = %error_text,
|
||||||
|
"Memory maintenance model request failed, retrying"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||||||
|
.await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(AgentError::Other(format!(
|
||||||
|
"memory maintenance model error: {}",
|
||||||
|
error_text
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = response.ok_or_else(|| {
|
||||||
|
AgentError::Other(format!(
|
||||||
|
"memory maintenance model error: {}",
|
||||||
|
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let raw_content = strip_json_code_fence(&response.content);
|
||||||
|
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||||||
|
|
||||||
|
let output: MemoryMaintenanceModelOutput =
|
||||||
|
serde_json::from_str(json_candidate).map_err(|err| {
|
||||||
|
tracing::error!(
|
||||||
|
scope_key = %scope_key,
|
||||||
|
error = %err,
|
||||||
|
raw_len = raw_content.len(),
|
||||||
|
raw_preview = %preview_text(raw_content, 400),
|
||||||
|
json_candidate_len = json_candidate.len(),
|
||||||
|
json_candidate_preview = %preview_text(json_candidate, 400),
|
||||||
|
"Memory maintenance JSON decode failed"
|
||||||
|
);
|
||||||
|
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn run_for_scope(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||||
|
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = self.summarize_plan(scope_key, &plan).await?;
|
||||||
|
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
|
||||||
|
|
||||||
|
Ok(Some(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn run_for_all_scopes(
|
||||||
|
&self,
|
||||||
|
updated_since: Option<i64>,
|
||||||
|
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||||
|
let scope_keys = if let Some(cutoff) = updated_since {
|
||||||
|
self.store
|
||||||
|
.list_memory_scope_keys_updated_since("user", cutoff)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!(
|
||||||
|
"list memory scope keys updated since error: {}",
|
||||||
|
err
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
} else {
|
||||||
|
self.store.list_memory_scope_keys("user").map_err(|err| {
|
||||||
|
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
for scope_key in scope_keys {
|
||||||
|
let Some(output) = self.run_for_scope(&scope_key).await? else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
results.push(MemoryMaintenanceScopeResult { scope_key, output });
|
||||||
|
}
|
||||||
|
|
||||||
|
let combined_markdown = combine_managed_memory_markdown(
|
||||||
|
&results
|
||||||
|
.iter()
|
||||||
|
.map(|result| result.output.managed_markdown.clone())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if !combined_markdown.is_empty() {
|
||||||
|
upsert_managed_agent_memory_summary(&combined_markdown)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||||
|
let mut plan = MemoryMaintenancePlan::default();
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
|
for memory in memories {
|
||||||
|
let normalized_content = memory.content.trim();
|
||||||
|
if normalized_content.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let dedupe_key = format!(
|
||||||
|
"{}\u{1f}{}\u{1f}{}",
|
||||||
|
memory.namespace.trim().to_ascii_lowercase(),
|
||||||
|
memory.memory_key.trim().to_ascii_lowercase(),
|
||||||
|
normalized_content
|
||||||
|
);
|
||||||
|
if !seen.insert(dedupe_key) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let candidate = MemoryMaintenanceCandidate {
|
||||||
|
id: memory.id.clone(),
|
||||||
|
namespace: memory.namespace.clone(),
|
||||||
|
key: memory.memory_key.clone(),
|
||||||
|
content: normalized_content.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match memory_maintenance_category(&memory.namespace) {
|
||||||
|
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
||||||
|
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
||||||
|
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
||||||
|
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
plan
|
||||||
|
}
|
||||||
|
|
||||||
|
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
||||||
|
match namespace.trim().to_ascii_lowercase().as_str() {
|
||||||
|
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
||||||
|
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
||||||
|
"patterns" | "behavior" | "habits" | "workflow" => {
|
||||||
|
MemoryMaintenanceCategory::BehaviorPatterns
|
||||||
|
}
|
||||||
|
_ => MemoryMaintenanceCategory::Other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||||
|
let normalized = error.to_ascii_lowercase();
|
||||||
|
normalized.contains("error sending request for url")
|
||||||
|
|| normalized.contains("504")
|
||||||
|
|| normalized.contains("gateway timeout")
|
||||||
|
|| normalized.contains("stream timeout")
|
||||||
|
|| normalized.contains("timed out")
|
||||||
|
|| normalized.contains("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
|
||||||
|
let trimmed = content.trim();
|
||||||
|
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||||
|
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||||
|
}
|
||||||
|
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||||
|
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||||
|
}
|
||||||
|
trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||||||
|
let mut start = None;
|
||||||
|
let mut depth = 0usize;
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut escaped = false;
|
||||||
|
|
||||||
|
for (index, ch) in content.char_indices() {
|
||||||
|
if in_string {
|
||||||
|
if escaped {
|
||||||
|
escaped = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match ch {
|
||||||
|
'\\' => escaped = true,
|
||||||
|
'"' => in_string = false,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match ch {
|
||||||
|
'"' => in_string = true,
|
||||||
|
'{' => {
|
||||||
|
if start.is_none() {
|
||||||
|
start = Some(index);
|
||||||
|
}
|
||||||
|
depth += 1;
|
||||||
|
}
|
||||||
|
'}' => {
|
||||||
|
if depth == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
let start = start?;
|
||||||
|
let end = index + ch.len_utf8();
|
||||||
|
return Some(content[start..end].trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
||||||
|
let normalized_chunks = chunks
|
||||||
|
.iter()
|
||||||
|
.map(|chunk| chunk.trim())
|
||||||
|
.filter(|chunk| !chunk.is_empty())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut combined = Vec::new();
|
||||||
|
for (index, chunk) in normalized_chunks.iter().enumerate() {
|
||||||
|
let chunk_lines = chunk
|
||||||
|
.lines()
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|line| !line.is_empty())
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
let is_subset_of_other =
|
||||||
|
normalized_chunks
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.any(|(other_index, other)| {
|
||||||
|
if index == other_index {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let other_lines = other
|
||||||
|
.lines()
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|line| !line.is_empty())
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
|
||||||
|
});
|
||||||
|
|
||||||
|
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
|
||||||
|
combined.push((*chunk).to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
combined.join("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn apply_memory_maintenance_output(
|
||||||
|
store: &SessionStore,
|
||||||
|
scope_key: &str,
|
||||||
|
plan: &MemoryMaintenancePlan,
|
||||||
|
output: &MemoryMaintenanceModelOutput,
|
||||||
|
) -> Result<(), AgentError> {
|
||||||
|
let all_candidates = plan
|
||||||
|
.user_facts
|
||||||
|
.iter()
|
||||||
|
.chain(plan.preferences.iter())
|
||||||
|
.chain(plan.behavior_patterns.iter())
|
||||||
|
.chain(plan.others.iter())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let candidates_by_id = all_candidates
|
||||||
|
.iter()
|
||||||
|
.map(|candidate| (candidate.id.as_str(), candidate))
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
|
let mut deleted_ids = HashSet::new();
|
||||||
|
|
||||||
|
for merge in &output.merges {
|
||||||
|
if merge.source_ids.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let source_candidates = merge
|
||||||
|
.source_ids
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
if source_candidates.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let existing_target_id = source_candidates
|
||||||
|
.iter()
|
||||||
|
.find(|candidate| {
|
||||||
|
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||||||
|
})
|
||||||
|
.map(|candidate| candidate.id.clone());
|
||||||
|
|
||||||
|
store
|
||||||
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
|
scope_kind: "user".to_string(),
|
||||||
|
scope_key: scope_key.to_string(),
|
||||||
|
namespace: merge.namespace.trim().to_string(),
|
||||||
|
memory_key: merge.memory_key.trim().to_string(),
|
||||||
|
content: merge.content.trim().to_string(),
|
||||||
|
source_type: "memory_maintenance".to_string(),
|
||||||
|
source_session_id: None,
|
||||||
|
source_message_id: None,
|
||||||
|
source_message_seq: None,
|
||||||
|
source_channel_name: None,
|
||||||
|
source_chat_id: None,
|
||||||
|
})
|
||||||
|
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||||||
|
|
||||||
|
for candidate in source_candidates {
|
||||||
|
if existing_target_id
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|target_id| target_id == &candidate.id)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if deleted_ids.insert(candidate.id.clone()) {
|
||||||
|
store
|
||||||
|
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for memory_id in &output.low_value_ids {
|
||||||
|
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||||||
|
if deleted_ids.insert(candidate.id.clone()) {
|
||||||
|
store
|
||||||
|
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("delete low value memory error: {}", err))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||||
|
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||||
|
if content.chars().count() > max_chars {
|
||||||
|
preview.push_str("...");
|
||||||
|
}
|
||||||
|
preview.replace('\n', "\\n")
|
||||||
|
}
|
||||||
@ -1,5 +1,6 @@
|
|||||||
pub mod execution;
|
pub mod execution;
|
||||||
pub mod http;
|
pub mod http;
|
||||||
|
pub mod memory_maintenance;
|
||||||
pub mod processor;
|
pub mod processor;
|
||||||
pub mod prompt;
|
pub mod prompt;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
|||||||
@ -5,7 +5,6 @@ use crate::bus::{
|
|||||||
};
|
};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
@ -14,7 +13,6 @@ use crate::tools::{
|
|||||||
TimeTool, ToolContext, ToolRegistry, WebFetchTool,
|
TimeTool, ToolContext, ToolRegistry, WebFetchTool,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@ -25,183 +23,16 @@ use super::execution::{
|
|||||||
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
|
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
|
||||||
select_provider_config, should_display_message_to_user,
|
select_provider_config, should_display_message_to_user,
|
||||||
};
|
};
|
||||||
use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary};
|
#[cfg(test)]
|
||||||
|
use super::memory_maintenance::{
|
||||||
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
|
||||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error,
|
||||||
|
strip_json_code_fence,
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
};
|
||||||
enum MemoryMaintenanceCategory {
|
use super::memory_maintenance::{
|
||||||
UserFacts,
|
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
||||||
Preferences,
|
};
|
||||||
BehaviorPatterns,
|
use super::prompt::load_agent_prompt;
|
||||||
Other,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub(crate) struct MemoryMaintenanceCandidate {
|
|
||||||
id: String,
|
|
||||||
namespace: String,
|
|
||||||
key: String,
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub(crate) struct MemoryMaintenancePlan {
|
|
||||||
user_facts: Vec<MemoryMaintenanceCandidate>,
|
|
||||||
preferences: Vec<MemoryMaintenanceCandidate>,
|
|
||||||
behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
|
||||||
others: Vec<MemoryMaintenanceCandidate>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub(crate) struct MemoryMaintenanceMerge {
|
|
||||||
pub(crate) source_ids: Vec<String>,
|
|
||||||
pub(crate) namespace: String,
|
|
||||||
pub(crate) memory_key: String,
|
|
||||||
pub(crate) content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub(crate) struct MemoryMaintenanceConflict {
|
|
||||||
pub(crate) source_ids: Vec<String>,
|
|
||||||
pub(crate) note: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub(crate) struct MemoryMaintenanceModelOutput {
|
|
||||||
pub(crate) user_facts: Vec<String>,
|
|
||||||
pub(crate) preferences: Vec<String>,
|
|
||||||
pub(crate) behavior_patterns: Vec<String>,
|
|
||||||
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
|
||||||
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
|
||||||
pub(crate) low_value_ids: Vec<String>,
|
|
||||||
pub(crate) managed_markdown: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub(crate) struct MemoryMaintenanceScopeResult {
|
|
||||||
pub(crate) scope_key: String,
|
|
||||||
pub(crate) output: MemoryMaintenanceModelOutput,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_memory_maintenance_plan(
|
|
||||||
memories: &[crate::storage::MemoryRecord],
|
|
||||||
) -> MemoryMaintenancePlan {
|
|
||||||
let mut plan = MemoryMaintenancePlan::default();
|
|
||||||
let mut seen = HashSet::new();
|
|
||||||
|
|
||||||
for memory in memories {
|
|
||||||
let normalized_content = memory.content.trim();
|
|
||||||
if normalized_content.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let dedupe_key = format!(
|
|
||||||
"{}\u{1f}{}\u{1f}{}",
|
|
||||||
memory.namespace.trim().to_ascii_lowercase(),
|
|
||||||
memory.memory_key.trim().to_ascii_lowercase(),
|
|
||||||
normalized_content
|
|
||||||
);
|
|
||||||
if !seen.insert(dedupe_key) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let candidate = MemoryMaintenanceCandidate {
|
|
||||||
id: memory.id.clone(),
|
|
||||||
namespace: memory.namespace.clone(),
|
|
||||||
key: memory.memory_key.clone(),
|
|
||||||
content: normalized_content.to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
match memory_maintenance_category(&memory.namespace) {
|
|
||||||
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
|
||||||
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
|
||||||
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
|
||||||
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
plan
|
|
||||||
}
|
|
||||||
|
|
||||||
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
|
||||||
match namespace.trim().to_ascii_lowercase().as_str() {
|
|
||||||
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
|
||||||
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
|
||||||
"patterns" | "behavior" | "habits" | "workflow" => {
|
|
||||||
MemoryMaintenanceCategory::BehaviorPatterns
|
|
||||||
}
|
|
||||||
_ => MemoryMaintenanceCategory::Other,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
|
||||||
let normalized = error.to_ascii_lowercase();
|
|
||||||
normalized.contains("error sending request for url")
|
|
||||||
|| normalized.contains("504")
|
|
||||||
|| normalized.contains("gateway timeout")
|
|
||||||
|| normalized.contains("stream timeout")
|
|
||||||
|| normalized.contains("timed out")
|
|
||||||
|| normalized.contains("timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn strip_json_code_fence(content: &str) -> &str {
|
|
||||||
let trimmed = content.trim();
|
|
||||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
|
||||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
|
||||||
}
|
|
||||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
|
||||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
|
||||||
}
|
|
||||||
trimmed
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_json_object(content: &str) -> Option<&str> {
|
|
||||||
let mut start = None;
|
|
||||||
let mut depth = 0usize;
|
|
||||||
let mut in_string = false;
|
|
||||||
let mut escaped = false;
|
|
||||||
|
|
||||||
for (index, ch) in content.char_indices() {
|
|
||||||
if in_string {
|
|
||||||
if escaped {
|
|
||||||
escaped = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
match ch {
|
|
||||||
'\\' => escaped = true,
|
|
||||||
'"' => in_string = false,
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match ch {
|
|
||||||
'"' => in_string = true,
|
|
||||||
'{' => {
|
|
||||||
if start.is_none() {
|
|
||||||
start = Some(index);
|
|
||||||
}
|
|
||||||
depth += 1;
|
|
||||||
}
|
|
||||||
'}' => {
|
|
||||||
if depth == 0 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
depth -= 1;
|
|
||||||
if depth == 0 {
|
|
||||||
let start = start?;
|
|
||||||
let end = index + ch.len_utf8();
|
|
||||||
return Some(content[start..end].trim());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||||
@ -225,138 +56,6 @@ fn enrich_user_content_with_media_refs(
|
|||||||
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
|
||||||
let normalized_chunks = chunks
|
|
||||||
.iter()
|
|
||||||
.map(|chunk| chunk.trim())
|
|
||||||
.filter(|chunk| !chunk.is_empty())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let mut combined = Vec::new();
|
|
||||||
for (index, chunk) in normalized_chunks.iter().enumerate() {
|
|
||||||
let chunk_lines = chunk
|
|
||||||
.lines()
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|line| !line.is_empty())
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
let is_subset_of_other =
|
|
||||||
normalized_chunks
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.any(|(other_index, other)| {
|
|
||||||
if index == other_index {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let other_lines = other
|
|
||||||
.lines()
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|line| !line.is_empty())
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
|
|
||||||
});
|
|
||||||
|
|
||||||
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
|
|
||||||
combined.push((*chunk).to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
combined.join("\n\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_memory_maintenance_output(
|
|
||||||
store: &SessionStore,
|
|
||||||
scope_key: &str,
|
|
||||||
plan: &MemoryMaintenancePlan,
|
|
||||||
output: &MemoryMaintenanceModelOutput,
|
|
||||||
) -> Result<(), AgentError> {
|
|
||||||
let all_candidates = plan
|
|
||||||
.user_facts
|
|
||||||
.iter()
|
|
||||||
.chain(plan.preferences.iter())
|
|
||||||
.chain(plan.behavior_patterns.iter())
|
|
||||||
.chain(plan.others.iter())
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let candidates_by_id = all_candidates
|
|
||||||
.iter()
|
|
||||||
.map(|candidate| (candidate.id.as_str(), candidate))
|
|
||||||
.collect::<HashMap<_, _>>();
|
|
||||||
|
|
||||||
let mut deleted_ids = HashSet::new();
|
|
||||||
|
|
||||||
for merge in &output.merges {
|
|
||||||
if merge.source_ids.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let source_candidates = merge
|
|
||||||
.source_ids
|
|
||||||
.iter()
|
|
||||||
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
if source_candidates.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let existing_target_id = source_candidates
|
|
||||||
.iter()
|
|
||||||
.find(|candidate| {
|
|
||||||
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
|
||||||
})
|
|
||||||
.map(|candidate| candidate.id.clone());
|
|
||||||
|
|
||||||
store
|
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
|
||||||
scope_kind: "user".to_string(),
|
|
||||||
scope_key: scope_key.to_string(),
|
|
||||||
namespace: merge.namespace.trim().to_string(),
|
|
||||||
memory_key: merge.memory_key.trim().to_string(),
|
|
||||||
content: merge.content.trim().to_string(),
|
|
||||||
source_type: "memory_maintenance".to_string(),
|
|
||||||
source_session_id: None,
|
|
||||||
source_message_id: None,
|
|
||||||
source_message_seq: None,
|
|
||||||
source_channel_name: None,
|
|
||||||
source_chat_id: None,
|
|
||||||
})
|
|
||||||
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
|
||||||
|
|
||||||
for candidate in source_candidates {
|
|
||||||
if existing_target_id
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|target_id| target_id == &candidate.id)
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if deleted_ids.insert(candidate.id.clone()) {
|
|
||||||
store
|
|
||||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
|
||||||
.map_err(|err| {
|
|
||||||
AgentError::Other(format!("delete merged source memory error: {}", err))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for memory_id in &output.low_value_ids {
|
|
||||||
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
|
||||||
if deleted_ids.insert(candidate.id.clone()) {
|
|
||||||
store
|
|
||||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
|
||||||
.map_err(|err| {
|
|
||||||
AgentError::Other(format!("delete low value memory error: {}", err))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
@ -609,6 +308,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
||||||
self.latest_user_message(chat_id)
|
self.latest_user_message(chat_id)
|
||||||
.map(|message| message.id.as_str())
|
.map(|message| message.id.as_str())
|
||||||
@ -619,6 +319,7 @@ impl Session {
|
|||||||
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
||||||
self.latest_user_message_id(chat_id)
|
self.latest_user_message_id(chat_id)
|
||||||
.map(|current_id| current_id == message_id)
|
.map(|current_id| current_id == message_id)
|
||||||
@ -1014,194 +715,30 @@ impl SessionManager {
|
|||||||
self.skills.clone()
|
self.skills.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn build_memory_maintenance_plan_for_scope(
|
|
||||||
&self,
|
|
||||||
scope_key: &str,
|
|
||||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
|
||||||
let memories = self
|
|
||||||
.store
|
|
||||||
.list_memories_for_scope("user", scope_key)
|
|
||||||
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
|
||||||
|
|
||||||
if memories.is_empty() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(build_memory_maintenance_plan(&memories)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn upsert_managed_agent_memory_summary(
|
|
||||||
&self,
|
|
||||||
markdown_body: &str,
|
|
||||||
) -> Result<(), AgentError> {
|
|
||||||
upsert_managed_agent_memory_summary(markdown_body)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(not(test), allow(dead_code))]
|
#[cfg_attr(not(test), allow(dead_code))]
|
||||||
pub(crate) async fn summarize_memory_maintenance_for_scope(
|
pub(crate) async fn summarize_memory_maintenance_for_scope(
|
||||||
&self,
|
&self,
|
||||||
scope_key: &str,
|
scope_key: &str,
|
||||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||||
let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else {
|
self.memory_maintenance_service()?
|
||||||
return Ok(None);
|
.summarize_for_scope(scope_key)
|
||||||
};
|
|
||||||
|
|
||||||
self.summarize_memory_maintenance_plan(scope_key, &plan)
|
|
||||||
.await
|
.await
|
||||||
.map(Some)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn summarize_memory_maintenance_plan(
|
|
||||||
&self,
|
|
||||||
scope_key: &str,
|
|
||||||
plan: &MemoryMaintenancePlan,
|
|
||||||
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
|
|
||||||
let provider_config = self.provider_config_for_agent(None)?;
|
|
||||||
let provider = create_provider(provider_config).map_err(|err| {
|
|
||||||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: vec![
|
|
||||||
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
|
|
||||||
Message::user(
|
|
||||||
serde_json::to_string_pretty(&serde_json::json!({
|
|
||||||
"scope_key": scope_key,
|
|
||||||
"candidates": plan,
|
|
||||||
}))
|
|
||||||
.unwrap_or_else(|_| "{}".to_string()),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(1200),
|
|
||||||
tools: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut last_error = None;
|
|
||||||
let mut response = None;
|
|
||||||
|
|
||||||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
|
||||||
.iter()
|
|
||||||
.copied()
|
|
||||||
.map(Some)
|
|
||||||
.chain(std::iter::once(None))
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
match provider.chat(request.clone()).await {
|
|
||||||
Ok(success) => {
|
|
||||||
response = Some(success);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
let error_text = err.to_string();
|
|
||||||
let should_retry =
|
|
||||||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
|
||||||
last_error = Some(error_text.clone());
|
|
||||||
|
|
||||||
if should_retry {
|
|
||||||
tracing::warn!(
|
|
||||||
scope_key = %scope_key,
|
|
||||||
attempt = attempt + 1,
|
|
||||||
retry_in_ms = delay_ms.unwrap_or_default(),
|
|
||||||
error = %error_text,
|
|
||||||
"Memory maintenance model request failed, retrying"
|
|
||||||
);
|
|
||||||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
|
||||||
.await;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Err(AgentError::Other(format!(
|
|
||||||
"memory maintenance model error: {}",
|
|
||||||
error_text
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = response.ok_or_else(|| {
|
|
||||||
AgentError::Other(format!(
|
|
||||||
"memory maintenance model error: {}",
|
|
||||||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let raw_content = strip_json_code_fence(&response.content);
|
|
||||||
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
|
||||||
|
|
||||||
let output: MemoryMaintenanceModelOutput =
|
|
||||||
serde_json::from_str(json_candidate).map_err(|err| {
|
|
||||||
tracing::error!(
|
|
||||||
scope_key = %scope_key,
|
|
||||||
error = %err,
|
|
||||||
raw_len = raw_content.len(),
|
|
||||||
raw_preview = %preview_text(raw_content, 400),
|
|
||||||
json_candidate_len = json_candidate.len(),
|
|
||||||
json_candidate_preview = %preview_text(json_candidate, 400),
|
|
||||||
"Memory maintenance JSON decode failed"
|
|
||||||
);
|
|
||||||
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn run_memory_maintenance_for_scope(
|
|
||||||
&self,
|
|
||||||
scope_key: &str,
|
|
||||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
|
||||||
let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = self
|
|
||||||
.summarize_memory_maintenance_plan(scope_key, &plan)
|
|
||||||
.await?;
|
|
||||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
|
|
||||||
|
|
||||||
Ok(Some(output))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||||||
&self,
|
&self,
|
||||||
updated_since: Option<i64>,
|
updated_since: Option<i64>,
|
||||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||||
let scope_keys = if let Some(cutoff) = updated_since {
|
self.memory_maintenance_service()?
|
||||||
self.store
|
.run_for_all_scopes(updated_since)
|
||||||
.list_memory_scope_keys_updated_since("user", cutoff)
|
.await
|
||||||
.map_err(|err| {
|
}
|
||||||
AgentError::Other(format!(
|
|
||||||
"list memory scope keys updated since error: {}",
|
|
||||||
err
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
} else {
|
|
||||||
self.store.list_memory_scope_keys("user").map_err(|err| {
|
|
||||||
AgentError::Other(format!("list memory scope keys error: {}", err))
|
|
||||||
})?
|
|
||||||
};
|
|
||||||
let mut results = Vec::new();
|
|
||||||
|
|
||||||
for scope_key in scope_keys {
|
fn memory_maintenance_service(&self) -> Result<MemoryMaintenanceService, AgentError> {
|
||||||
let Some(output) = self.run_memory_maintenance_for_scope(&scope_key).await? else {
|
Ok(MemoryMaintenanceService::new(
|
||||||
continue;
|
self.store.clone(),
|
||||||
};
|
self.provider_config_for_agent(None)?,
|
||||||
|
))
|
||||||
results.push(MemoryMaintenanceScopeResult { scope_key, output });
|
|
||||||
}
|
|
||||||
|
|
||||||
let combined_markdown = combine_managed_memory_markdown(
|
|
||||||
&results
|
|
||||||
.iter()
|
|
||||||
.map(|result| result.output.managed_markdown.clone())
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if !combined_markdown.is_empty() {
|
|
||||||
self.upsert_managed_agent_memory_summary(&combined_markdown)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(results)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn provider_config_for_agent(
|
pub fn provider_config_for_agent(
|
||||||
@ -1572,6 +1109,7 @@ mod tests {
|
|||||||
model_id: "test-model".to_string(),
|
model_id: "test-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(32),
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 1,
|
max_tool_iterations: 1,
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
@ -1833,6 +1371,7 @@ mod tests {
|
|||||||
model_id: "timeout-model".to_string(),
|
model_id: "timeout-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(32),
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 1,
|
max_tool_iterations: 1,
|
||||||
llm_timeout_secs: 30,
|
llm_timeout_secs: 30,
|
||||||
@ -1872,6 +1411,7 @@ mod tests {
|
|||||||
model_id: "default-model".to_string(),
|
model_id: "default-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(32),
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 1,
|
max_tool_iterations: 1,
|
||||||
llm_timeout_secs: 30,
|
llm_timeout_secs: 30,
|
||||||
@ -1943,6 +1483,7 @@ mod tests {
|
|||||||
model_id: "default-model".to_string(),
|
model_id: "default-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(32),
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 1,
|
max_tool_iterations: 1,
|
||||||
llm_timeout_secs: 30,
|
llm_timeout_secs: 30,
|
||||||
@ -2020,6 +1561,7 @@ mod tests {
|
|||||||
model_id: "maintenance-model".to_string(),
|
model_id: "maintenance-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(256),
|
max_tokens: Some(256),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::from([(
|
model_extra: HashMap::from([(
|
||||||
"mock_response_content".to_string(),
|
"mock_response_content".to_string(),
|
||||||
json!(mock_response_content),
|
json!(mock_response_content),
|
||||||
@ -2120,6 +1662,7 @@ mod tests {
|
|||||||
model_id: "maintenance-model".to_string(),
|
model_id: "maintenance-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(256),
|
max_tokens: Some(256),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::from([(
|
model_extra: HashMap::from([(
|
||||||
"mock_response_content".to_string(),
|
"mock_response_content".to_string(),
|
||||||
json!(mock_response_content),
|
json!(mock_response_content),
|
||||||
@ -2182,6 +1725,7 @@ mod tests {
|
|||||||
model_id: "maintenance-model".to_string(),
|
model_id: "maintenance-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(256),
|
max_tokens: Some(256),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::from([(
|
model_extra: HashMap::from([(
|
||||||
"mock_response_content".to_string(),
|
"mock_response_content".to_string(),
|
||||||
json!(mock_response_content),
|
json!(mock_response_content),
|
||||||
@ -2241,6 +1785,7 @@ mod tests {
|
|||||||
model_id: "maintenance-model".to_string(),
|
model_id: "maintenance-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(256),
|
max_tokens: Some(256),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 1,
|
max_tool_iterations: 1,
|
||||||
llm_timeout_secs: 30,
|
llm_timeout_secs: 30,
|
||||||
|
|||||||
@ -980,6 +980,7 @@ mod tests {
|
|||||||
model_id: "test-model".to_string(),
|
model_id: "test-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 4,
|
max_tool_iterations: 4,
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ fn load_config() -> Option<LLMProviderConfig> {
|
|||||||
model_id: openai_model,
|
model_id: openai_model,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
max_tool_iterations: 20,
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
|||||||
model_id: openai_model,
|
model_id: openai_model,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
max_tool_iterations: 20,
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user