feat: add context_window_tokens to model configuration and update related logic
- Introduced context_window_tokens in ModelConfig and LLMProviderConfig structs. - Updated context window estimation logic in ContextCompressor to use context_window_tokens. - Modified tests to accommodate new context_window_tokens field. - Refactored memory maintenance logic into a new memory_maintenance.rs file for better organization. - Ensured backward compatibility by providing default values where necessary. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
b2c8d76820
commit
fa3354db9c
@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
|
||||
1. 系统先对当前活动历史做一个近似 token 估算。
|
||||
估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。
|
||||
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
|
||||
这里的上下文窗口来自 agent 对应模型配置里的 token_limit。
|
||||
这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens;未配置时按 128000 估算。
|
||||
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
|
||||
当前默认会完整保留最近 3 个 user turn。
|
||||
4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。
|
||||
|
||||
@ -60,6 +60,7 @@ pub struct ContextCompressor {
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
#[cfg(test)]
|
||||
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
|
||||
const SUMMARY_RATIO: f64 = 0.1;
|
||||
const CHARS_PER_TOKEN: f64 = 2.5;
|
||||
|
||||
@ -159,6 +159,8 @@ pub struct ModelConfig {
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub context_window_tokens: Option<u32>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
@ -526,6 +528,7 @@ pub struct LLMProviderConfig {
|
||||
pub model_id: String,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub context_window_tokens: Option<u32>,
|
||||
pub model_extra: HashMap<String, serde_json::Value>,
|
||||
pub max_tool_iterations: usize,
|
||||
pub tool_result_max_chars: usize,
|
||||
@ -534,7 +537,7 @@ pub struct LLMProviderConfig {
|
||||
|
||||
impl LLMProviderConfig {
|
||||
pub fn context_window_tokens(&self) -> usize {
|
||||
self.max_tokens
|
||||
self.context_window_tokens
|
||||
.map(|value| value as usize)
|
||||
.unwrap_or(128_000)
|
||||
}
|
||||
@ -614,6 +617,7 @@ impl Config {
|
||||
model_id: model.model_id.clone(),
|
||||
temperature: model.temperature,
|
||||
max_tokens: model.max_tokens,
|
||||
context_window_tokens: model.context_window_tokens,
|
||||
model_extra: model.extra.clone(),
|
||||
max_tool_iterations: agent.max_tool_iterations,
|
||||
tool_result_max_chars: agent.tool_result_max_chars,
|
||||
@ -1056,7 +1060,44 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_summary_budget_scales_with_model_max_tokens() {
|
||||
fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
"providers": {
|
||||
"aliyun": {
|
||||
"type": "openai",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"api_key": "test-key",
|
||||
"extra_headers": {}
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"qwen-plus": {
|
||||
"model_id": "qwen-plus",
|
||||
"context_window_tokens": 4096
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "aliyun",
|
||||
"model": "qwen-plus"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
let provider_config = config.get_provider_config("default").unwrap();
|
||||
|
||||
assert_eq!(provider_config.context_window_tokens(), 4096);
|
||||
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_max_tokens_does_not_change_context_window() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
@ -1088,8 +1129,9 @@ mod tests {
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
let provider_config = config.get_provider_config("default").unwrap();
|
||||
|
||||
assert_eq!(provider_config.context_window_tokens(), 4096);
|
||||
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
||||
assert_eq!(provider_config.max_tokens, Some(4096));
|
||||
assert_eq!(provider_config.context_window_tokens(), 128_000);
|
||||
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -181,6 +181,7 @@ mod tests {
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
|
||||
517
src/gateway/memory_maintenance.rs
Normal file
517
src/gateway/memory_maintenance.rs
Normal file
@ -0,0 +1,517 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use crate::storage::{MemoryRecord, SessionStore};
|
||||
|
||||
use super::prompt::upsert_managed_agent_memory_summary;
|
||||
|
||||
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum MemoryMaintenanceCategory {
|
||||
UserFacts,
|
||||
Preferences,
|
||||
BehaviorPatterns,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceCandidate {
|
||||
pub(crate) id: String,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) key: String,
|
||||
pub(crate) content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenancePlan {
|
||||
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceMerge {
|
||||
pub(crate) source_ids: Vec<String>,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) memory_key: String,
|
||||
pub(crate) content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceConflict {
|
||||
pub(crate) source_ids: Vec<String>,
|
||||
pub(crate) note: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceModelOutput {
|
||||
pub(crate) user_facts: Vec<String>,
|
||||
pub(crate) preferences: Vec<String>,
|
||||
pub(crate) behavior_patterns: Vec<String>,
|
||||
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||||
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||||
pub(crate) low_value_ids: Vec<String>,
|
||||
pub(crate) managed_markdown: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceScopeResult {
|
||||
pub(crate) scope_key: String,
|
||||
pub(crate) output: MemoryMaintenanceModelOutput,
|
||||
}
|
||||
|
||||
pub(crate) struct MemoryMaintenanceService {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
impl MemoryMaintenanceService {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_plan_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||||
|
||||
if memories.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||||
}
|
||||
|
||||
pub(crate) async fn summarize_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
self.summarize_plan(scope_key, &plan).await.map(Some)
|
||||
}
|
||||
|
||||
async fn summarize_plan(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
|
||||
let provider = create_provider(self.provider_config.clone()).map_err(|err| {
|
||||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||
})?;
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![
|
||||
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
|
||||
Message::user(
|
||||
serde_json::to_string_pretty(&serde_json::json!({
|
||||
"scope_key": scope_key,
|
||||
"candidates": plan,
|
||||
}))
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
],
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(1200),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let mut last_error = None;
|
||||
let mut response = None;
|
||||
|
||||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||||
.iter()
|
||||
.copied()
|
||||
.map(Some)
|
||||
.chain(std::iter::once(None))
|
||||
.enumerate()
|
||||
{
|
||||
match provider.chat(request.clone()).await {
|
||||
Ok(success) => {
|
||||
response = Some(success);
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
let error_text = err.to_string();
|
||||
let should_retry =
|
||||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||||
last_error = Some(error_text.clone());
|
||||
|
||||
if should_retry {
|
||||
tracing::warn!(
|
||||
scope_key = %scope_key,
|
||||
attempt = attempt + 1,
|
||||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||||
error = %error_text,
|
||||
"Memory maintenance model request failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AgentError::Other(format!(
|
||||
"memory maintenance model error: {}",
|
||||
error_text
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = response.ok_or_else(|| {
|
||||
AgentError::Other(format!(
|
||||
"memory maintenance model error: {}",
|
||||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||||
))
|
||||
})?;
|
||||
|
||||
let raw_content = strip_json_code_fence(&response.content);
|
||||
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||||
|
||||
let output: MemoryMaintenanceModelOutput =
|
||||
serde_json::from_str(json_candidate).map_err(|err| {
|
||||
tracing::error!(
|
||||
scope_key = %scope_key,
|
||||
error = %err,
|
||||
raw_len = raw_content.len(),
|
||||
raw_preview = %preview_text(raw_content, 400),
|
||||
json_candidate_len = json_candidate.len(),
|
||||
json_candidate_preview = %preview_text(json_candidate, 400),
|
||||
"Memory maintenance JSON decode failed"
|
||||
);
|
||||
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
||||
})?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) async fn run_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let output = self.summarize_plan(scope_key, &plan).await?;
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
|
||||
|
||||
Ok(Some(output))
|
||||
}
|
||||
|
||||
pub(crate) async fn run_for_all_scopes(
|
||||
&self,
|
||||
updated_since: Option<i64>,
|
||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||
let scope_keys = if let Some(cutoff) = updated_since {
|
||||
self.store
|
||||
.list_memory_scope_keys_updated_since("user", cutoff)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!(
|
||||
"list memory scope keys updated since error: {}",
|
||||
err
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
self.store.list_memory_scope_keys("user").map_err(|err| {
|
||||
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||||
})?
|
||||
};
|
||||
let mut results = Vec::new();
|
||||
|
||||
for scope_key in scope_keys {
|
||||
let Some(output) = self.run_for_scope(&scope_key).await? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
results.push(MemoryMaintenanceScopeResult { scope_key, output });
|
||||
}
|
||||
|
||||
let combined_markdown = combine_managed_memory_markdown(
|
||||
&results
|
||||
.iter()
|
||||
.map(|result| result.output.managed_markdown.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
if !combined_markdown.is_empty() {
|
||||
upsert_managed_agent_memory_summary(&combined_markdown)?;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||
let mut plan = MemoryMaintenancePlan::default();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for memory in memories {
|
||||
let normalized_content = memory.content.trim();
|
||||
if normalized_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let dedupe_key = format!(
|
||||
"{}\u{1f}{}\u{1f}{}",
|
||||
memory.namespace.trim().to_ascii_lowercase(),
|
||||
memory.memory_key.trim().to_ascii_lowercase(),
|
||||
normalized_content
|
||||
);
|
||||
if !seen.insert(dedupe_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate = MemoryMaintenanceCandidate {
|
||||
id: memory.id.clone(),
|
||||
namespace: memory.namespace.clone(),
|
||||
key: memory.memory_key.clone(),
|
||||
content: normalized_content.to_string(),
|
||||
};
|
||||
|
||||
match memory_maintenance_category(&memory.namespace) {
|
||||
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
||||
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
||||
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
||||
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
||||
}
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
|
||||
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
||||
match namespace.trim().to_ascii_lowercase().as_str() {
|
||||
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
||||
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
||||
"patterns" | "behavior" | "habits" | "workflow" => {
|
||||
MemoryMaintenanceCategory::BehaviorPatterns
|
||||
}
|
||||
_ => MemoryMaintenanceCategory::Other,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||
let normalized = error.to_ascii_lowercase();
|
||||
normalized.contains("error sending request for url")
|
||||
|| normalized.contains("504")
|
||||
|| normalized.contains("gateway timeout")
|
||||
|| normalized.contains("stream timeout")
|
||||
|| normalized.contains("timed out")
|
||||
|| normalized.contains("timeout")
|
||||
}
|
||||
|
||||
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
|
||||
let trimmed = content.trim();
|
||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||
}
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||
}
|
||||
trimmed
|
||||
}
|
||||
|
||||
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||||
let mut start = None;
|
||||
let mut depth = 0usize;
|
||||
let mut in_string = false;
|
||||
let mut escaped = false;
|
||||
|
||||
for (index, ch) in content.char_indices() {
|
||||
if in_string {
|
||||
if escaped {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
match ch {
|
||||
'\\' => escaped = true,
|
||||
'"' => in_string = false,
|
||||
_ => {}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'"' => in_string = true,
|
||||
'{' => {
|
||||
if start.is_none() {
|
||||
start = Some(index);
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
'}' => {
|
||||
if depth == 0 {
|
||||
continue;
|
||||
}
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
let start = start?;
|
||||
let end = index + ch.len_utf8();
|
||||
return Some(content[start..end].trim());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
||||
let normalized_chunks = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.trim())
|
||||
.filter(|chunk| !chunk.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = Vec::new();
|
||||
for (index, chunk) in normalized_chunks.iter().enumerate() {
|
||||
let chunk_lines = chunk
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let is_subset_of_other =
|
||||
normalized_chunks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.any(|(other_index, other)| {
|
||||
if index == other_index {
|
||||
return false;
|
||||
}
|
||||
|
||||
let other_lines = other
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
|
||||
});
|
||||
|
||||
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
|
||||
combined.push((*chunk).to_string());
|
||||
}
|
||||
}
|
||||
|
||||
combined.join("\n\n")
|
||||
}
|
||||
|
||||
pub(crate) fn apply_memory_maintenance_output(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryMaintenanceModelOutput,
|
||||
) -> Result<(), AgentError> {
|
||||
let all_candidates = plan
|
||||
.user_facts
|
||||
.iter()
|
||||
.chain(plan.preferences.iter())
|
||||
.chain(plan.behavior_patterns.iter())
|
||||
.chain(plan.others.iter())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let candidates_by_id = all_candidates
|
||||
.iter()
|
||||
.map(|candidate| (candidate.id.as_str(), candidate))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut deleted_ids = HashSet::new();
|
||||
|
||||
for merge in &output.merges {
|
||||
if merge.source_ids.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let source_candidates = merge
|
||||
.source_ids
|
||||
.iter()
|
||||
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||||
.collect::<Vec<_>>();
|
||||
if source_candidates.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let existing_target_id = source_candidates
|
||||
.iter()
|
||||
.find(|candidate| {
|
||||
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||||
})
|
||||
.map(|candidate| candidate.id.clone());
|
||||
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: merge.namespace.trim().to_string(),
|
||||
memory_key: merge.memory_key.trim().to_string(),
|
||||
content: merge.content.trim().to_string(),
|
||||
source_type: "memory_maintenance".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})
|
||||
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||||
|
||||
for candidate in source_candidates {
|
||||
if existing_target_id
|
||||
.as_ref()
|
||||
.is_some_and(|target_id| target_id == &candidate.id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if deleted_ids.insert(candidate.id.clone()) {
|
||||
store
|
||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for memory_id in &output.low_value_ids {
|
||||
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||||
if deleted_ids.insert(candidate.id.clone()) {
|
||||
store
|
||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("delete low value memory error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||
if content.chars().count() > max_chars {
|
||||
preview.push_str("...");
|
||||
}
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
@ -1,5 +1,6 @@
|
||||
pub mod execution;
|
||||
pub mod http;
|
||||
pub mod memory_maintenance;
|
||||
pub mod processor;
|
||||
pub mod prompt;
|
||||
pub mod session;
|
||||
|
||||
@ -5,7 +5,6 @@ use crate::bus::{
|
||||
};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
use crate::tools::{
|
||||
@ -14,7 +13,6 @@ use crate::tools::{
|
||||
TimeTool, ToolContext, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
@ -25,183 +23,16 @@ use super::execution::{
|
||||
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
|
||||
select_provider_config, should_display_message_to_user,
|
||||
};
|
||||
use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary};
|
||||
|
||||
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum MemoryMaintenanceCategory {
|
||||
UserFacts,
|
||||
Preferences,
|
||||
BehaviorPatterns,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceCandidate {
|
||||
id: String,
|
||||
namespace: String,
|
||||
key: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenancePlan {
|
||||
user_facts: Vec<MemoryMaintenanceCandidate>,
|
||||
preferences: Vec<MemoryMaintenanceCandidate>,
|
||||
behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
||||
others: Vec<MemoryMaintenanceCandidate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceMerge {
|
||||
pub(crate) source_ids: Vec<String>,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) memory_key: String,
|
||||
pub(crate) content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceConflict {
|
||||
pub(crate) source_ids: Vec<String>,
|
||||
pub(crate) note: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceModelOutput {
|
||||
pub(crate) user_facts: Vec<String>,
|
||||
pub(crate) preferences: Vec<String>,
|
||||
pub(crate) behavior_patterns: Vec<String>,
|
||||
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||||
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||||
pub(crate) low_value_ids: Vec<String>,
|
||||
pub(crate) managed_markdown: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceScopeResult {
|
||||
pub(crate) scope_key: String,
|
||||
pub(crate) output: MemoryMaintenanceModelOutput,
|
||||
}
|
||||
|
||||
fn build_memory_maintenance_plan(
|
||||
memories: &[crate::storage::MemoryRecord],
|
||||
) -> MemoryMaintenancePlan {
|
||||
let mut plan = MemoryMaintenancePlan::default();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for memory in memories {
|
||||
let normalized_content = memory.content.trim();
|
||||
if normalized_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let dedupe_key = format!(
|
||||
"{}\u{1f}{}\u{1f}{}",
|
||||
memory.namespace.trim().to_ascii_lowercase(),
|
||||
memory.memory_key.trim().to_ascii_lowercase(),
|
||||
normalized_content
|
||||
);
|
||||
if !seen.insert(dedupe_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate = MemoryMaintenanceCandidate {
|
||||
id: memory.id.clone(),
|
||||
namespace: memory.namespace.clone(),
|
||||
key: memory.memory_key.clone(),
|
||||
content: normalized_content.to_string(),
|
||||
#[cfg(test)]
|
||||
use super::memory_maintenance::{
|
||||
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
|
||||
combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error,
|
||||
strip_json_code_fence,
|
||||
};
|
||||
|
||||
match memory_maintenance_category(&memory.namespace) {
|
||||
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
||||
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
||||
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
||||
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
||||
}
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
|
||||
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
||||
match namespace.trim().to_ascii_lowercase().as_str() {
|
||||
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
||||
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
||||
"patterns" | "behavior" | "habits" | "workflow" => {
|
||||
MemoryMaintenanceCategory::BehaviorPatterns
|
||||
}
|
||||
_ => MemoryMaintenanceCategory::Other,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||
let normalized = error.to_ascii_lowercase();
|
||||
normalized.contains("error sending request for url")
|
||||
|| normalized.contains("504")
|
||||
|| normalized.contains("gateway timeout")
|
||||
|| normalized.contains("stream timeout")
|
||||
|| normalized.contains("timed out")
|
||||
|| normalized.contains("timeout")
|
||||
}
|
||||
|
||||
fn strip_json_code_fence(content: &str) -> &str {
|
||||
let trimmed = content.trim();
|
||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||
}
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||
}
|
||||
trimmed
|
||||
}
|
||||
|
||||
fn extract_json_object(content: &str) -> Option<&str> {
|
||||
let mut start = None;
|
||||
let mut depth = 0usize;
|
||||
let mut in_string = false;
|
||||
let mut escaped = false;
|
||||
|
||||
for (index, ch) in content.char_indices() {
|
||||
if in_string {
|
||||
if escaped {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
match ch {
|
||||
'\\' => escaped = true,
|
||||
'"' => in_string = false,
|
||||
_ => {}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'"' => in_string = true,
|
||||
'{' => {
|
||||
if start.is_none() {
|
||||
start = Some(index);
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
'}' => {
|
||||
if depth == 0 {
|
||||
continue;
|
||||
}
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
let start = start?;
|
||||
let end = index + ch.len_utf8();
|
||||
return Some(content[start..end].trim());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
use super::memory_maintenance::{
|
||||
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
||||
};
|
||||
use super::prompt::load_agent_prompt;
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||
@ -225,138 +56,6 @@ fn enrich_user_content_with_media_refs(
|
||||
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
||||
}
|
||||
|
||||
fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
||||
let normalized_chunks = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.trim())
|
||||
.filter(|chunk| !chunk.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = Vec::new();
|
||||
for (index, chunk) in normalized_chunks.iter().enumerate() {
|
||||
let chunk_lines = chunk
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let is_subset_of_other =
|
||||
normalized_chunks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.any(|(other_index, other)| {
|
||||
if index == other_index {
|
||||
return false;
|
||||
}
|
||||
|
||||
let other_lines = other
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
|
||||
});
|
||||
|
||||
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
|
||||
combined.push((*chunk).to_string());
|
||||
}
|
||||
}
|
||||
|
||||
combined.join("\n\n")
|
||||
}
|
||||
|
||||
fn apply_memory_maintenance_output(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryMaintenanceModelOutput,
|
||||
) -> Result<(), AgentError> {
|
||||
let all_candidates = plan
|
||||
.user_facts
|
||||
.iter()
|
||||
.chain(plan.preferences.iter())
|
||||
.chain(plan.behavior_patterns.iter())
|
||||
.chain(plan.others.iter())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let candidates_by_id = all_candidates
|
||||
.iter()
|
||||
.map(|candidate| (candidate.id.as_str(), candidate))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut deleted_ids = HashSet::new();
|
||||
|
||||
for merge in &output.merges {
|
||||
if merge.source_ids.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let source_candidates = merge
|
||||
.source_ids
|
||||
.iter()
|
||||
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||||
.collect::<Vec<_>>();
|
||||
if source_candidates.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let existing_target_id = source_candidates
|
||||
.iter()
|
||||
.find(|candidate| {
|
||||
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||||
})
|
||||
.map(|candidate| candidate.id.clone());
|
||||
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: merge.namespace.trim().to_string(),
|
||||
memory_key: merge.memory_key.trim().to_string(),
|
||||
content: merge.content.trim().to_string(),
|
||||
source_type: "memory_maintenance".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})
|
||||
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||||
|
||||
for candidate in source_candidates {
|
||||
if existing_target_id
|
||||
.as_ref()
|
||||
.is_some_and(|target_id| target_id == &candidate.id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if deleted_ids.insert(candidate.id.clone()) {
|
||||
store
|
||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for memory_id in &output.low_value_ids {
|
||||
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||||
if deleted_ids.insert(candidate.id.clone()) {
|
||||
store
|
||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("delete low value memory error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||
pub struct Session {
|
||||
@ -609,6 +308,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
||||
self.latest_user_message(chat_id)
|
||||
.map(|message| message.id.as_str())
|
||||
@ -619,6 +319,7 @@ impl Session {
|
||||
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
||||
self.latest_user_message_id(chat_id)
|
||||
.map(|current_id| current_id == message_id)
|
||||
@ -1014,194 +715,30 @@ impl SessionManager {
|
||||
self.skills.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn build_memory_maintenance_plan_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||||
|
||||
if memories.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||||
}
|
||||
|
||||
pub(crate) fn upsert_managed_agent_memory_summary(
|
||||
&self,
|
||||
markdown_body: &str,
|
||||
) -> Result<(), AgentError> {
|
||||
upsert_managed_agent_memory_summary(markdown_body)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
pub(crate) async fn summarize_memory_maintenance_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
self.summarize_memory_maintenance_plan(scope_key, &plan)
|
||||
self.memory_maintenance_service()?
|
||||
.summarize_for_scope(scope_key)
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
|
||||
async fn summarize_memory_maintenance_plan(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
|
||||
let provider_config = self.provider_config_for_agent(None)?;
|
||||
let provider = create_provider(provider_config).map_err(|err| {
|
||||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||
})?;
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![
|
||||
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
|
||||
Message::user(
|
||||
serde_json::to_string_pretty(&serde_json::json!({
|
||||
"scope_key": scope_key,
|
||||
"candidates": plan,
|
||||
}))
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
],
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(1200),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let mut last_error = None;
|
||||
let mut response = None;
|
||||
|
||||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||||
.iter()
|
||||
.copied()
|
||||
.map(Some)
|
||||
.chain(std::iter::once(None))
|
||||
.enumerate()
|
||||
{
|
||||
match provider.chat(request.clone()).await {
|
||||
Ok(success) => {
|
||||
response = Some(success);
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
let error_text = err.to_string();
|
||||
let should_retry =
|
||||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||||
last_error = Some(error_text.clone());
|
||||
|
||||
if should_retry {
|
||||
tracing::warn!(
|
||||
scope_key = %scope_key,
|
||||
attempt = attempt + 1,
|
||||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||||
error = %error_text,
|
||||
"Memory maintenance model request failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AgentError::Other(format!(
|
||||
"memory maintenance model error: {}",
|
||||
error_text
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = response.ok_or_else(|| {
|
||||
AgentError::Other(format!(
|
||||
"memory maintenance model error: {}",
|
||||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||||
))
|
||||
})?;
|
||||
|
||||
let raw_content = strip_json_code_fence(&response.content);
|
||||
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||||
|
||||
let output: MemoryMaintenanceModelOutput =
|
||||
serde_json::from_str(json_candidate).map_err(|err| {
|
||||
tracing::error!(
|
||||
scope_key = %scope_key,
|
||||
error = %err,
|
||||
raw_len = raw_content.len(),
|
||||
raw_preview = %preview_text(raw_content, 400),
|
||||
json_candidate_len = json_candidate.len(),
|
||||
json_candidate_preview = %preview_text(json_candidate, 400),
|
||||
"Memory maintenance JSON decode failed"
|
||||
);
|
||||
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
||||
})?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) async fn run_memory_maintenance_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
let Some(plan) = self.build_memory_maintenance_plan_for_scope(scope_key)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let output = self
|
||||
.summarize_memory_maintenance_plan(scope_key, &plan)
|
||||
.await?;
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
|
||||
|
||||
Ok(Some(output))
|
||||
}
|
||||
|
||||
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||||
&self,
|
||||
updated_since: Option<i64>,
|
||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||
let scope_keys = if let Some(cutoff) = updated_since {
|
||||
self.store
|
||||
.list_memory_scope_keys_updated_since("user", cutoff)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!(
|
||||
"list memory scope keys updated since error: {}",
|
||||
err
|
||||
self.memory_maintenance_service()?
|
||||
.run_for_all_scopes(updated_since)
|
||||
.await
|
||||
}
|
||||
|
||||
fn memory_maintenance_service(&self) -> Result<MemoryMaintenanceService, AgentError> {
|
||||
Ok(MemoryMaintenanceService::new(
|
||||
self.store.clone(),
|
||||
self.provider_config_for_agent(None)?,
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
self.store.list_memory_scope_keys("user").map_err(|err| {
|
||||
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||||
})?
|
||||
};
|
||||
let mut results = Vec::new();
|
||||
|
||||
for scope_key in scope_keys {
|
||||
let Some(output) = self.run_memory_maintenance_for_scope(&scope_key).await? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
results.push(MemoryMaintenanceScopeResult { scope_key, output });
|
||||
}
|
||||
|
||||
let combined_markdown = combine_managed_memory_markdown(
|
||||
&results
|
||||
.iter()
|
||||
.map(|result| result.output.managed_markdown.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
if !combined_markdown.is_empty() {
|
||||
self.upsert_managed_agent_memory_summary(&combined_markdown)?;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub fn provider_config_for_agent(
|
||||
@ -1572,6 +1109,7 @@ mod tests {
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
@ -1833,6 +1371,7 @@ mod tests {
|
||||
model_id: "timeout-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
@ -1872,6 +1411,7 @@ mod tests {
|
||||
model_id: "default-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
@ -1943,6 +1483,7 @@ mod tests {
|
||||
model_id: "default-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
@ -2020,6 +1561,7 @@ mod tests {
|
||||
model_id: "maintenance-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(256),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::from([(
|
||||
"mock_response_content".to_string(),
|
||||
json!(mock_response_content),
|
||||
@ -2120,6 +1662,7 @@ mod tests {
|
||||
model_id: "maintenance-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(256),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::from([(
|
||||
"mock_response_content".to_string(),
|
||||
json!(mock_response_content),
|
||||
@ -2182,6 +1725,7 @@ mod tests {
|
||||
model_id: "maintenance-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(256),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::from([(
|
||||
"mock_response_content".to_string(),
|
||||
json!(mock_response_content),
|
||||
@ -2241,6 +1785,7 @@ mod tests {
|
||||
model_id: "maintenance-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(256),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
|
||||
@ -980,6 +980,7 @@ mod tests {
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: None,
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 4,
|
||||
tool_result_max_chars: 20_000,
|
||||
|
||||
@ -23,6 +23,7 @@ fn load_config() -> Option<LLMProviderConfig> {
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 20,
|
||||
tool_result_max_chars: 20_000,
|
||||
|
||||
@ -23,6 +23,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 20,
|
||||
tool_result_max_chars: 20_000,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user