PicoBot/src/gateway/memory_maintenance.rs

659 lines
22 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
use crate::storage::{MemoryRecord, SessionStore};
use super::prompt::upsert_managed_agent_memory_summary;
const MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step1_system_prompt.md");
const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step2_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String,
pub(crate) namespace: String,
pub(crate) key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenancePlan {
pub(crate) candidates: Vec<MemoryMaintenanceCandidate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceMerge {
pub(crate) source_ids: Vec<String>,
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceConflict {
pub(crate) source_ids: Vec<String>,
pub(crate) note: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryOrganizationOutput {
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
pub(crate) low_value_ids: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemorySummaryInput {
pub(crate) organized_memories: Vec<OrganizedMemory>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct OrganizedMemory {
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) scope_key: String,
pub(crate) output: MemoryOrganizationOutput,
pub(crate) managed_markdown: String,
}
pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
}
impl MemoryMaintenanceService {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
Self {
store,
provider_config,
}
}
/// 创建记忆整理专用的 provider使用 memory_maintenance_timeout_secs 作为超时时间
fn create_maintenance_provider(
&self,
) -> Result<Box<dyn crate::providers::LLMProvider>, crate::providers::ProviderError> {
let config = &self.provider_config;
let runtime_config = ProviderRuntimeConfig {
provider_type: config.provider_type.clone(),
name: config.name.clone(),
base_url: config.base_url.clone(),
api_key: config.api_key.clone(),
extra_headers: config.extra_headers.clone(),
llm_timeout_secs: config.memory_maintenance_timeout_secs,
model_id: config.model_id.clone(),
temperature: config.temperature,
max_tokens: config.max_tokens,
model_extra: config.model_extra.clone(),
};
create_provider(runtime_config)
}
pub(crate) fn build_plan_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
let memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
if memories.is_empty() {
return Ok(None);
}
Ok(Some(build_memory_maintenance_plan(&memories)))
}
pub(crate) async fn organize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
self.organize_plan(scope_key, &plan).await.map(Some)
}
async fn organize_plan(
&self,
scope_key: &str,
plan: &MemoryMaintenancePlan,
) -> Result<MemoryOrganizationOutput, AgentError> {
let provider = self.create_maintenance_provider().map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_STEP1_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!({
"scope_key": scope_key,
"candidates": plan,
}))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1200),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory organization model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory organization model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory organization model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let raw_content = strip_json_code_fence(&response.content);
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
let output: MemoryOrganizationOutput =
serde_json::from_str(json_candidate).map_err(|err| {
tracing::error!(
scope_key = %scope_key,
error = %err,
raw_len = raw_content.len(),
raw_preview = %preview_text(raw_content, 400),
json_candidate_len = json_candidate.len(),
json_candidate_preview = %preview_text(json_candidate, 400),
"Memory maintenance JSON decode failed"
);
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
})?;
Ok(output)
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn run_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
// 步骤1整理记忆不生成摘要
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
// 步骤2从数据库重新读取剩余的记忆
let remaining_memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| {
AgentError::Other(format!("list remaining memories error: {}", err))
})?;
// 步骤2生成摘要
let managed_markdown = if remaining_memories.is_empty() {
String::new()
} else {
self.generate_summary(scope_key, &remaining_memories).await?
};
Ok(Some(MemoryMaintenanceScopeResult {
scope_key: scope_key.to_string(),
output: organize_output,
managed_markdown,
}))
}
async fn generate_summary(
&self,
scope_key: &str,
remaining_memories: &[MemoryRecord],
) -> Result<String, AgentError> {
let provider = self.create_maintenance_provider().map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let input = MemorySummaryInput {
organized_memories: remaining_memories
.iter()
.map(|m| OrganizedMemory {
namespace: m.namespace.clone(),
memory_key: m.memory_key.clone(),
content: m.content.clone(),
})
.collect(),
};
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!(input))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1000),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory summary model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory summary model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory summary model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let content = response.content.trim();
// 确保响应包含标签,如无则自动添加
let tagged_content = if content.contains("<!-- memory_summary_start -->") {
content.to_string()
} else {
format!(
"<!-- memory_summary_start -->\n{}\n<!-- memory_summary_end -->",
content
)
};
Ok(tagged_content)
}
pub(crate) async fn run_for_all_scopes(
&self,
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
let scope_keys = self.store.list_memory_scope_keys("user").map_err(|err| {
AgentError::Other(format!("list memory scope keys error: {}", err))
})?;
if scope_keys.is_empty() {
return Ok(None);
}
// 步骤1逐个 scope 整理记忆merge/delete但不生成摘要
let mut all_outputs = Vec::new();
for scope_key in &scope_keys {
match self.run_organize_for_scope(scope_key).await {
Ok(Some(output)) => all_outputs.push((scope_key.clone(), output)),
Ok(None) => continue,
Err(error) if is_recoverable_maintenance_scope_error(&error) => {
tracing::warn!(
scope_key = %scope_key,
error = %error,
"Memory maintenance skipped scope after recoverable model failure"
);
continue;
}
Err(error) => return Err(error),
}
}
if all_outputs.is_empty() {
return Ok(None);
}
// 步骤2收集所有 scope 整理后的剩余记忆
let mut all_remaining_memories = Vec::new();
for (scope_key, _) in &all_outputs {
let memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| {
AgentError::Other(format!(
"list remaining memories for scope {} error: {}",
scope_key, err
))
})?;
all_remaining_memories.extend(memories);
}
// 步骤3统一生成一个摘要
let managed_markdown = if all_remaining_memories.is_empty() {
String::new()
} else {
self.generate_summary("all", &all_remaining_memories).await?
};
if !managed_markdown.is_empty() {
upsert_managed_agent_memory_summary(&managed_markdown)?;
}
// 合并所有输出用于返回
let combined_output = MemoryOrganizationOutput {
merges: all_outputs
.iter()
.flat_map(|(_, o)| o.merges.clone())
.collect(),
conflicts: all_outputs
.iter()
.flat_map(|(_, o)| o.conflicts.clone())
.collect(),
low_value_ids: all_outputs
.iter()
.flat_map(|(_, o)| o.low_value_ids.clone())
.collect(),
};
Ok(Some(MemoryMaintenanceScopeResult {
scope_key: "all".to_string(),
output: combined_output,
managed_markdown,
}))
}
/// 仅执行整理步骤organize + apply不生成摘要
async fn run_organize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
// 步骤1整理记忆
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
Ok(Some(organize_output))
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
}
let dedupe_key = format!(
"{}\u{1f}{}\u{1f}{}",
memory.namespace.trim().to_ascii_lowercase(),
memory.memory_key.trim().to_ascii_lowercase(),
normalized_content
);
if !seen.insert(dedupe_key) {
continue;
}
let candidate = MemoryMaintenanceCandidate {
id: memory.id.clone(),
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
};
plan.candidates.push(candidate);
}
plan
}
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("error sending request for url")
|| normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool {
is_recoverable_maintenance_llm_error(&error.to_string())
}
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
if let Some(rest) = trimmed.strip_prefix("```") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
trimmed
}
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
let mut start = None;
let mut depth = 0usize;
let mut in_string = false;
let mut escaped = false;
for (index, ch) in content.char_indices() {
if in_string {
if escaped {
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
'"' => in_string = true,
'{' => {
if start.is_none() {
start = Some(index);
}
depth += 1;
}
'}' => {
if depth == 0 {
continue;
}
depth -= 1;
if depth == 0 {
let start = start?;
let end = index + ch.len_utf8();
return Some(content[start..end].trim());
}
}
_ => {}
}
}
None
}
pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
) -> Result<(), AgentError> {
let all_candidates = plan.candidates.clone();
let candidates_by_id = all_candidates
.iter()
.map(|candidate| (candidate.id.as_str(), candidate))
.collect::<HashMap<_, _>>();
let mut deleted_ids = HashSet::new();
for merge in &output.merges {
if merge.source_ids.is_empty() {
continue;
}
let source_candidates = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
.collect::<Vec<_>>();
if source_candidates.is_empty() {
continue;
}
let existing_target_id = source_candidates
.iter()
.find(|candidate| {
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
})
.map(|candidate| candidate.id.clone());
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: merge.namespace.trim().to_string(),
memory_key: merge.memory_key.trim().to_string(),
content: merge.content.trim().to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
for candidate in source_candidates {
if existing_target_id
.as_ref()
.is_some_and(|target_id| target_id == &candidate.id)
{
continue;
}
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete merged source memory error: {}", err))
})?;
}
}
}
for memory_id in &output.low_value_ids {
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete low value memory error: {}", err))
})?;
}
}
}
Ok(())
}
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
preview.push_str("...");
}
preview.replace('\n', "\\n")
}
#[cfg(test)]
mod tests {}