659 lines
22 KiB
Rust
659 lines
22 KiB
Rust
use std::collections::{HashMap, HashSet};
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
use serde::{Deserialize, Serialize};
|
||
|
||
use crate::agent::AgentError;
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||
use crate::storage::{MemoryRecord, SessionStore};
|
||
|
||
use super::prompt::upsert_managed_agent_memory_summary;
|
||
|
||
const MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT: &str =
|
||
include_str!("memory_maintenance_step1_system_prompt.md");
|
||
const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
||
include_str!("memory_maintenance_step2_system_prompt.md");
|
||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceCandidate {
|
||
pub(crate) id: String,
|
||
pub(crate) namespace: String,
|
||
pub(crate) key: String,
|
||
pub(crate) content: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenancePlan {
|
||
pub(crate) candidates: Vec<MemoryMaintenanceCandidate>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceMerge {
|
||
pub(crate) source_ids: Vec<String>,
|
||
pub(crate) namespace: String,
|
||
pub(crate) memory_key: String,
|
||
pub(crate) content: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceConflict {
|
||
pub(crate) source_ids: Vec<String>,
|
||
pub(crate) note: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemoryOrganizationOutput {
|
||
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||
pub(crate) low_value_ids: Vec<String>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct MemorySummaryInput {
|
||
pub(crate) organized_memories: Vec<OrganizedMemory>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
pub(crate) struct OrganizedMemory {
|
||
pub(crate) namespace: String,
|
||
pub(crate) memory_key: String,
|
||
pub(crate) content: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
pub(crate) struct MemoryMaintenanceScopeResult {
|
||
pub(crate) scope_key: String,
|
||
pub(crate) output: MemoryOrganizationOutput,
|
||
pub(crate) managed_markdown: String,
|
||
}
|
||
|
||
pub(crate) struct MemoryMaintenanceService {
|
||
store: Arc<SessionStore>,
|
||
provider_config: LLMProviderConfig,
|
||
}
|
||
|
||
impl MemoryMaintenanceService {
|
||
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||
Self {
|
||
store,
|
||
provider_config,
|
||
}
|
||
}
|
||
|
||
/// 创建记忆整理专用的 provider,使用 memory_maintenance_timeout_secs 作为超时时间
|
||
fn create_maintenance_provider(
|
||
&self,
|
||
) -> Result<Box<dyn crate::providers::LLMProvider>, crate::providers::ProviderError> {
|
||
let config = &self.provider_config;
|
||
let runtime_config = ProviderRuntimeConfig {
|
||
provider_type: config.provider_type.clone(),
|
||
name: config.name.clone(),
|
||
base_url: config.base_url.clone(),
|
||
api_key: config.api_key.clone(),
|
||
extra_headers: config.extra_headers.clone(),
|
||
llm_timeout_secs: config.memory_maintenance_timeout_secs,
|
||
model_id: config.model_id.clone(),
|
||
temperature: config.temperature,
|
||
max_tokens: config.max_tokens,
|
||
model_extra: config.model_extra.clone(),
|
||
};
|
||
create_provider(runtime_config)
|
||
}
|
||
|
||
pub(crate) fn build_plan_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||
let memories = self
|
||
.store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||
|
||
if memories.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||
}
|
||
|
||
pub(crate) async fn organize_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
|
||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||
return Ok(None);
|
||
};
|
||
|
||
self.organize_plan(scope_key, &plan).await.map(Some)
|
||
}
|
||
|
||
async fn organize_plan(
|
||
&self,
|
||
scope_key: &str,
|
||
plan: &MemoryMaintenancePlan,
|
||
) -> Result<MemoryOrganizationOutput, AgentError> {
|
||
let provider = self.create_maintenance_provider().map_err(|err| {
|
||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||
})?;
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![
|
||
Message::system(MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT),
|
||
Message::user(
|
||
serde_json::to_string_pretty(&serde_json::json!({
|
||
"scope_key": scope_key,
|
||
"candidates": plan,
|
||
}))
|
||
.unwrap_or_else(|_| "{}".to_string()),
|
||
),
|
||
],
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(1200),
|
||
tools: None,
|
||
};
|
||
|
||
let mut last_error = None;
|
||
let mut response = None;
|
||
|
||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||
.iter()
|
||
.copied()
|
||
.map(Some)
|
||
.chain(std::iter::once(None))
|
||
.enumerate()
|
||
{
|
||
match provider.chat(request.clone()).await {
|
||
Ok(success) => {
|
||
response = Some(success);
|
||
break;
|
||
}
|
||
Err(err) => {
|
||
let error_text = err.to_string();
|
||
let should_retry =
|
||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||
last_error = Some(error_text.clone());
|
||
|
||
if should_retry {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
attempt = attempt + 1,
|
||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||
error = %error_text,
|
||
"Memory organization model request failed, retrying"
|
||
);
|
||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||
.await;
|
||
continue;
|
||
}
|
||
|
||
return Err(AgentError::Other(format!(
|
||
"memory organization model error: {}",
|
||
error_text
|
||
)));
|
||
}
|
||
}
|
||
}
|
||
|
||
let response = response.ok_or_else(|| {
|
||
AgentError::Other(format!(
|
||
"memory organization model error: {}",
|
||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||
))
|
||
})?;
|
||
|
||
let raw_content = strip_json_code_fence(&response.content);
|
||
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||
|
||
let output: MemoryOrganizationOutput =
|
||
serde_json::from_str(json_candidate).map_err(|err| {
|
||
tracing::error!(
|
||
scope_key = %scope_key,
|
||
error = %err,
|
||
raw_len = raw_content.len(),
|
||
raw_preview = %preview_text(raw_content, 400),
|
||
json_candidate_len = json_candidate.len(),
|
||
json_candidate_preview = %preview_text(json_candidate, 400),
|
||
"Memory maintenance JSON decode failed"
|
||
);
|
||
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
||
})?;
|
||
|
||
Ok(output)
|
||
}
|
||
|
||
#[cfg_attr(not(test), allow(dead_code))]
|
||
pub(crate) async fn run_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
|
||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||
return Ok(None);
|
||
};
|
||
|
||
// 步骤1:整理记忆(不生成摘要)
|
||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||
|
||
// 应用整理结果(merge和delete)
|
||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||
|
||
// 步骤2:从数据库重新读取剩余的记忆
|
||
let remaining_memories = self
|
||
.store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("list remaining memories error: {}", err))
|
||
})?;
|
||
|
||
// 步骤2:生成摘要
|
||
let managed_markdown = if remaining_memories.is_empty() {
|
||
String::new()
|
||
} else {
|
||
self.generate_summary(scope_key, &remaining_memories).await?
|
||
};
|
||
|
||
Ok(Some(MemoryMaintenanceScopeResult {
|
||
scope_key: scope_key.to_string(),
|
||
output: organize_output,
|
||
managed_markdown,
|
||
}))
|
||
}
|
||
|
||
async fn generate_summary(
|
||
&self,
|
||
scope_key: &str,
|
||
remaining_memories: &[MemoryRecord],
|
||
) -> Result<String, AgentError> {
|
||
let provider = self.create_maintenance_provider().map_err(|err| {
|
||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||
})?;
|
||
|
||
let input = MemorySummaryInput {
|
||
organized_memories: remaining_memories
|
||
.iter()
|
||
.map(|m| OrganizedMemory {
|
||
namespace: m.namespace.clone(),
|
||
memory_key: m.memory_key.clone(),
|
||
content: m.content.clone(),
|
||
})
|
||
.collect(),
|
||
};
|
||
|
||
let request = ChatCompletionRequest {
|
||
messages: vec![
|
||
Message::system(MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT),
|
||
Message::user(
|
||
serde_json::to_string_pretty(&serde_json::json!(input))
|
||
.unwrap_or_else(|_| "{}".to_string()),
|
||
),
|
||
],
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(1000),
|
||
tools: None,
|
||
};
|
||
|
||
let mut last_error = None;
|
||
let mut response = None;
|
||
|
||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||
.iter()
|
||
.copied()
|
||
.map(Some)
|
||
.chain(std::iter::once(None))
|
||
.enumerate()
|
||
{
|
||
match provider.chat(request.clone()).await {
|
||
Ok(success) => {
|
||
response = Some(success);
|
||
break;
|
||
}
|
||
Err(err) => {
|
||
let error_text = err.to_string();
|
||
let should_retry =
|
||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||
last_error = Some(error_text.clone());
|
||
|
||
if should_retry {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
attempt = attempt + 1,
|
||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||
error = %error_text,
|
||
"Memory summary model request failed, retrying"
|
||
);
|
||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||
.await;
|
||
continue;
|
||
}
|
||
|
||
return Err(AgentError::Other(format!(
|
||
"memory summary model error: {}",
|
||
error_text
|
||
)));
|
||
}
|
||
}
|
||
}
|
||
|
||
let response = response.ok_or_else(|| {
|
||
AgentError::Other(format!(
|
||
"memory summary model error: {}",
|
||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||
))
|
||
})?;
|
||
|
||
let content = response.content.trim();
|
||
// 确保响应包含标签,如无则自动添加
|
||
let tagged_content = if content.contains("<!-- memory_summary_start -->") {
|
||
content.to_string()
|
||
} else {
|
||
format!(
|
||
"<!-- memory_summary_start -->\n{}\n<!-- memory_summary_end -->",
|
||
content
|
||
)
|
||
};
|
||
|
||
Ok(tagged_content)
|
||
}
|
||
|
||
pub(crate) async fn run_for_all_scopes(
|
||
&self,
|
||
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
|
||
let scope_keys = self.store.list_memory_scope_keys("user").map_err(|err| {
|
||
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||
})?;
|
||
|
||
if scope_keys.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
// 步骤1:逐个 scope 整理记忆(merge/delete),但不生成摘要
|
||
let mut all_outputs = Vec::new();
|
||
for scope_key in &scope_keys {
|
||
match self.run_organize_for_scope(scope_key).await {
|
||
Ok(Some(output)) => all_outputs.push((scope_key.clone(), output)),
|
||
Ok(None) => continue,
|
||
Err(error) if is_recoverable_maintenance_scope_error(&error) => {
|
||
tracing::warn!(
|
||
scope_key = %scope_key,
|
||
error = %error,
|
||
"Memory maintenance skipped scope after recoverable model failure"
|
||
);
|
||
continue;
|
||
}
|
||
Err(error) => return Err(error),
|
||
}
|
||
}
|
||
|
||
if all_outputs.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
// 步骤2:收集所有 scope 整理后的剩余记忆
|
||
let mut all_remaining_memories = Vec::new();
|
||
for (scope_key, _) in &all_outputs {
|
||
let memories = self
|
||
.store
|
||
.list_memories_for_scope("user", scope_key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!(
|
||
"list remaining memories for scope {} error: {}",
|
||
scope_key, err
|
||
))
|
||
})?;
|
||
all_remaining_memories.extend(memories);
|
||
}
|
||
|
||
// 步骤3:统一生成一个摘要
|
||
let managed_markdown = if all_remaining_memories.is_empty() {
|
||
String::new()
|
||
} else {
|
||
self.generate_summary("all", &all_remaining_memories).await?
|
||
};
|
||
|
||
if !managed_markdown.is_empty() {
|
||
upsert_managed_agent_memory_summary(&managed_markdown)?;
|
||
}
|
||
|
||
// 合并所有输出用于返回
|
||
let combined_output = MemoryOrganizationOutput {
|
||
merges: all_outputs
|
||
.iter()
|
||
.flat_map(|(_, o)| o.merges.clone())
|
||
.collect(),
|
||
conflicts: all_outputs
|
||
.iter()
|
||
.flat_map(|(_, o)| o.conflicts.clone())
|
||
.collect(),
|
||
low_value_ids: all_outputs
|
||
.iter()
|
||
.flat_map(|(_, o)| o.low_value_ids.clone())
|
||
.collect(),
|
||
};
|
||
|
||
Ok(Some(MemoryMaintenanceScopeResult {
|
||
scope_key: "all".to_string(),
|
||
output: combined_output,
|
||
managed_markdown,
|
||
}))
|
||
}
|
||
|
||
/// 仅执行整理步骤(organize + apply),不生成摘要
|
||
async fn run_organize_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
|
||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||
return Ok(None);
|
||
};
|
||
|
||
// 步骤1:整理记忆
|
||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||
|
||
// 应用整理结果(merge和delete)
|
||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||
|
||
Ok(Some(organize_output))
|
||
}
|
||
}
|
||
|
||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||
let mut plan = MemoryMaintenancePlan::default();
|
||
let mut seen = HashSet::new();
|
||
|
||
for memory in memories {
|
||
let normalized_content = memory.content.trim();
|
||
if normalized_content.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let dedupe_key = format!(
|
||
"{}\u{1f}{}\u{1f}{}",
|
||
memory.namespace.trim().to_ascii_lowercase(),
|
||
memory.memory_key.trim().to_ascii_lowercase(),
|
||
normalized_content
|
||
);
|
||
if !seen.insert(dedupe_key) {
|
||
continue;
|
||
}
|
||
|
||
let candidate = MemoryMaintenanceCandidate {
|
||
id: memory.id.clone(),
|
||
namespace: memory.namespace.clone(),
|
||
key: memory.memory_key.clone(),
|
||
content: normalized_content.to_string(),
|
||
};
|
||
|
||
plan.candidates.push(candidate);
|
||
}
|
||
|
||
plan
|
||
}
|
||
|
||
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||
let normalized = error.to_ascii_lowercase();
|
||
normalized.contains("error sending request for url")
|
||
|| normalized.contains("504")
|
||
|| normalized.contains("gateway timeout")
|
||
|| normalized.contains("stream timeout")
|
||
|| normalized.contains("timed out")
|
||
|| normalized.contains("timeout")
|
||
}
|
||
|
||
fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool {
|
||
is_recoverable_maintenance_llm_error(&error.to_string())
|
||
}
|
||
|
||
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
|
||
let trimmed = content.trim();
|
||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||
}
|
||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||
}
|
||
trimmed
|
||
}
|
||
|
||
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||
let mut start = None;
|
||
let mut depth = 0usize;
|
||
let mut in_string = false;
|
||
let mut escaped = false;
|
||
|
||
for (index, ch) in content.char_indices() {
|
||
if in_string {
|
||
if escaped {
|
||
escaped = false;
|
||
continue;
|
||
}
|
||
match ch {
|
||
'\\' => escaped = true,
|
||
'"' => in_string = false,
|
||
_ => {}
|
||
}
|
||
continue;
|
||
}
|
||
|
||
match ch {
|
||
'"' => in_string = true,
|
||
'{' => {
|
||
if start.is_none() {
|
||
start = Some(index);
|
||
}
|
||
depth += 1;
|
||
}
|
||
'}' => {
|
||
if depth == 0 {
|
||
continue;
|
||
}
|
||
depth -= 1;
|
||
if depth == 0 {
|
||
let start = start?;
|
||
let end = index + ch.len_utf8();
|
||
return Some(content[start..end].trim());
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
pub(crate) fn apply_memory_maintenance_output(
|
||
store: &SessionStore,
|
||
scope_key: &str,
|
||
plan: &MemoryMaintenancePlan,
|
||
output: &MemoryOrganizationOutput,
|
||
) -> Result<(), AgentError> {
|
||
let all_candidates = plan.candidates.clone();
|
||
|
||
let candidates_by_id = all_candidates
|
||
.iter()
|
||
.map(|candidate| (candidate.id.as_str(), candidate))
|
||
.collect::<HashMap<_, _>>();
|
||
|
||
let mut deleted_ids = HashSet::new();
|
||
|
||
for merge in &output.merges {
|
||
if merge.source_ids.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let source_candidates = merge
|
||
.source_ids
|
||
.iter()
|
||
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||
.collect::<Vec<_>>();
|
||
if source_candidates.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let existing_target_id = source_candidates
|
||
.iter()
|
||
.find(|candidate| {
|
||
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||
})
|
||
.map(|candidate| candidate.id.clone());
|
||
|
||
store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: merge.namespace.trim().to_string(),
|
||
memory_key: merge.memory_key.trim().to_string(),
|
||
content: merge.content.trim().to_string(),
|
||
source_type: "memory_maintenance".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||
|
||
for candidate in source_candidates {
|
||
if existing_target_id
|
||
.as_ref()
|
||
.is_some_and(|target_id| target_id == &candidate.id)
|
||
{
|
||
continue;
|
||
}
|
||
if deleted_ids.insert(candidate.id.clone()) {
|
||
store
|
||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||
})?;
|
||
}
|
||
}
|
||
}
|
||
|
||
for memory_id in &output.low_value_ids {
|
||
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||
if deleted_ids.insert(candidate.id.clone()) {
|
||
store
|
||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("delete low value memory error: {}", err))
|
||
})?;
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||
if content.chars().count() > max_chars {
|
||
preview.push_str("...");
|
||
}
|
||
preview.replace('\n', "\\n")
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {}
|