From 9cda2ab8d5a752d2fa9dafdf791c15f380c12554 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Wed, 22 Apr 2026 11:37:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(memory=5Fmanage):=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E7=AE=A1=E7=90=86=E5=B7=A5=E5=85=B7=E6=8F=8F?= =?UTF-8?q?=E8=BF=B0=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=90=9C=E7=B4=A2=E5=85=B3?= =?UTF-8?q?=E9=94=AE=E8=AF=8D=E7=9A=84=E5=8F=8C=E8=AF=AD=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=92=8C=E5=86=85=E5=AE=B9=E5=8C=B9=E9=85=8D=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + src/agent/agent_loop.rs | 335 ++++++++++++++++++++++++++++++++++++- src/gateway/session.rs | 312 +++++++++++++++++++++++++++++++++- src/storage/mod.rs | 150 +++++++++++++++++ src/tools/memory_manage.rs | 4 +- 5 files changed, 794 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 38601bf..fd25c41 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ AGENTS.md CLAUDE.md Cargo.lock .playwright-cli/ +.venv \ No newline at end of file diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index eac4bb9..3f6ead5 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -9,6 +9,7 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess use crate::skills::SkillRuntime; use crate::storage::SessionStore; use crate::tools::{ToolContext, ToolRegistry}; +use serde::Deserialize; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; @@ -20,6 +21,12 @@ use std::time::Instant; const MAX_TOOL_RESULT_CHARS: usize = 16_000; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; +const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions。若需要写入,优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。"; +const MEMORY_EXTRACTION_SYSTEM_PROMPT: &str = "你负责从一小段对话中提取值得长期记忆的信息。只返回 JSON 数组,不要输出解释或 Markdown。数组元素格式为 {\"namespace\": string, \"key\": string, \"content\": string}。只有在内容属于长期偏好、稳定事实、用户纠正、持续任务上下文或明确决策时才输出;否则返回 []。namespace 只能使用 preferences、profile、tasks、decisions。key 用简短稳定的 snake_case 英文标识。不要保存一次性工具结果、临时列表、敏感信息或猜测。最多输出 2 条。"; +const MEMORY_EXTRACTION_REASONING_EFFORT: &str = "none"; +const MEMORY_EXTRACTION_MAX_TOKENS: u32 = 192; +const MEMORY_EXTRACTION_CONTEXT_MESSAGES: usize = 4; +const MEMORY_EXTRACTION_MAX_CANDIDATES: usize = 2; /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { @@ -221,6 +228,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message { /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { + provider_config: LLMProviderConfig, provider: Box, tools: Arc, skills: Arc, @@ -238,6 +246,19 @@ pub struct AgentProcessResult { pub emitted_messages: Vec, } +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +struct MemoryCandidate { + namespace: String, + key: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct MemoryCandidateEnvelope { + #[serde(default)] + memories: Vec, +} + #[async_trait] pub trait EmittedMessageHandler: Send + Sync + 'static { async fn handle(&self, message: ChatMessage); @@ -246,10 +267,11 @@ pub trait EmittedMessageHandler: Send + Sync + 'static { impl AgentLoop { pub fn new(provider_config: LLMProviderConfig) -> Result { let max_iterations = provider_config.max_tool_iterations; - let provider = create_provider(provider_config) + let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { + provider_config, provider, tools: Arc::new(ToolRegistry::new()), skills: Arc::new(SkillRuntime::default()), @@ -264,10 +286,11 @@ impl AgentLoop { pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { let max_iterations = provider_config.max_tool_iterations; - let provider = create_provider(provider_config) + let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { + provider_config, provider, tools, skills: Arc::new(SkillRuntime::default()), @@ -286,10 +309,11 @@ impl AgentLoop { skills: Arc, ) -> Result { let max_iterations = provider_config.max_tool_iterations; - let provider = create_provider(provider_config) + let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { + provider_config, provider, tools, skills, @@ -342,6 +366,7 @@ impl AgentLoop { // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); + let mut memory_write_occurred = false; for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] @@ -352,6 +377,7 @@ impl AgentLoop { if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } + messages_for_llm.push(Message::system(MEMORY_AUTOSAVE_SYSTEM_PROMPT)); messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); // Build request @@ -386,6 +412,9 @@ impl AgentLoop { // If no tool calls, this is the final response if response.tool_calls.is_empty() { let assistant_message = ChatMessage::assistant(response.content); + if !memory_write_occurred { + self.maybe_extract_and_store_memories(&messages, &assistant_message).await?; + } emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, @@ -407,6 +436,14 @@ impl AgentLoop { // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; + if response + .tool_calls + .iter() + .zip(tool_results.iter()) + .any(|(tool_call, result)| did_successfully_write_memory(tool_call, result)) + { + memory_write_occurred = true; + } for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { // Log function call with name and arguments @@ -470,6 +507,7 @@ impl AgentLoop { if let Some(skill_prompt) = self.skills.system_index_prompt() { messages_for_llm.push(Message::system(skill_prompt)); } + messages_for_llm.push(Message::system(MEMORY_AUTOSAVE_SYSTEM_PROMPT)); messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); let request = ChatCompletionRequest { @@ -482,6 +520,9 @@ impl AgentLoop { match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = ChatMessage::assistant(response.content); + if !memory_write_occurred { + self.maybe_extract_and_store_memories(&messages, &assistant_message).await?; + } emitted_messages.push(assistant_message.clone()); Ok(AgentProcessResult { final_response: assistant_message, @@ -698,6 +739,227 @@ impl AgentLoop { } } +fn did_successfully_write_memory(tool_call: &ToolCall, result: &ToolExecutionOutcome) -> bool { + if !result.success || tool_call.name != "memory_manage" { + return false; + } + + matches!( + tool_call.arguments.get("action").and_then(|value| value.as_str()), + Some("put") | Some("update") + ) +} + +fn lightweight_provider_config(provider_config: &LLMProviderConfig) -> LLMProviderConfig { + let mut config = provider_config.clone(); + if config.provider_type == "openai" { + config.model_extra.insert( + "reasoning_effort".to_string(), + serde_json::Value::String(MEMORY_EXTRACTION_REASONING_EFFORT.to_string()), + ); + } + config +} + +fn build_memory_extraction_context(messages: &[ChatMessage], final_response: &ChatMessage) -> String { + let recent_messages: Vec<&ChatMessage> = messages + .iter() + .rev() + .filter(|message| matches!(message.role.as_str(), "user" | "assistant")) + .take(MEMORY_EXTRACTION_CONTEXT_MESSAGES) + .collect::>() + .into_iter() + .rev() + .collect(); + + let mut lines = vec!["请判断这轮对话中是否有值得长期记忆的信息:".to_string()]; + for message in recent_messages { + let content = message.content.trim(); + if content.is_empty() { + continue; + } + + lines.push(format!( + "{}: {}", + message.role, + truncate_text_for_memory_extraction(content) + )); + } + + let final_content = final_response.content.trim(); + if !final_content.is_empty() { + lines.push(format!( + "assistant_final: {}", + truncate_text_for_memory_extraction(final_content) + )); + } + + lines.join("\n") +} + +fn truncate_text_for_memory_extraction(text: &str) -> String { + if text.chars().count() <= 240 { + return text.to_string(); + } + + let prefix: String = text.chars().take(240).collect(); + format!("{}...", prefix) +} + +fn parse_memory_candidates(raw: &str) -> Vec { + let trimmed = raw.trim(); + let json_payload = trimmed + .strip_prefix("```json") + .and_then(|value| value.strip_suffix("```")) + .map(str::trim) + .or_else(|| { + trimmed + .strip_prefix("```") + .and_then(|value| value.strip_suffix("```")) + .map(str::trim) + }) + .unwrap_or(trimmed); + + if let Ok(candidates) = serde_json::from_str::>(json_payload) { + return normalize_memory_candidates(candidates); + } + + if let Ok(envelope) = serde_json::from_str::(json_payload) { + return normalize_memory_candidates(envelope.memories); + } + + Vec::new() +} + +fn normalize_memory_candidates(candidates: Vec) -> Vec { + let mut normalized = Vec::new(); + + for candidate in candidates { + let Some(namespace) = canonical_memory_namespace(&candidate.namespace) else { + continue; + }; + let key = canonical_memory_key(&candidate.key); + let content = candidate.content.trim().to_string(); + if key.is_empty() || content.len() < 8 { + continue; + } + + normalized.push(MemoryCandidate { + namespace, + key, + content, + }); + + if normalized.len() >= MEMORY_EXTRACTION_MAX_CANDIDATES { + break; + } + } + + normalized +} + +fn canonical_memory_namespace(namespace: &str) -> Option { + match namespace.trim().to_lowercase().as_str() { + "preferences" | "preference" => Some("preferences".to_string()), + "profile" | "profiles" => Some("profile".to_string()), + "tasks" | "task" => Some("tasks".to_string()), + "decisions" | "decision" => Some("decisions".to_string()), + _ => None, + } +} + +fn canonical_memory_key(key: &str) -> String { + let mut normalized = String::new(); + let mut prev_is_sep = false; + + for ch in key.trim().chars() { + let mapped = if ch.is_ascii_alphanumeric() { + prev_is_sep = false; + Some(ch.to_ascii_lowercase()) + } else if matches!(ch, '_' | '-' | '/' | ' ') { + if prev_is_sep { + None + } else { + prev_is_sep = true; + Some('_') + } + } else { + None + }; + + if let Some(ch) = mapped { + normalized.push(ch); + } + } + + normalized.trim_matches('_').to_string() +} + +impl AgentLoop { + async fn maybe_extract_and_store_memories( + &self, + messages: &[ChatMessage], + final_response: &ChatMessage, + ) -> Result { + let Some(tool) = self.tools.get("memory_manage") else { + return Ok(0); + }; + if self.tool_context.channel_name.is_none() || self.tool_context.sender_id.is_none() { + return Ok(0); + } + + let context = build_memory_extraction_context(messages, final_response); + let provider = create_provider(lightweight_provider_config(&self.provider_config)) + .map_err(|err| AgentError::ProviderCreation(err.to_string()))?; + let request = ChatCompletionRequest { + messages: vec![ + Message::system(MEMORY_EXTRACTION_SYSTEM_PROMPT), + Message::user(context), + ], + temperature: Some(0.0), + max_tokens: Some(MEMORY_EXTRACTION_MAX_TOKENS), + tools: None, + }; + + let response = provider + .chat(request) + .await + .map_err(|err| AgentError::LlmError(err.to_string()))?; + let candidates = parse_memory_candidates(&response.content); + if candidates.is_empty() { + tracing::debug!("No auto-save memory candidates extracted from final response"); + return Ok(0); + } + + tracing::info!(candidate_count = candidates.len(), candidates = ?candidates, "Auto-save memory candidates extracted"); + + let mut saved = 0; + for candidate in candidates { + let result = tool + .execute_with_context( + &self.tool_context, + serde_json::json!({ + "action": "put", + "namespace": candidate.namespace, + "key": candidate.key, + "content": candidate.content, + }), + ) + .await + .map_err(|err| AgentError::Other(format!("auto-save memory tool error: {}", err)))?; + + if result.success { + saved += 1; + } else { + tracing::warn!(error = result.error.as_deref().unwrap_or("unknown"), "Auto-save memory write failed"); + } + } + + tracing::info!(saved_count = saved, "Auto-save memory extraction completed"); + Ok(saved) + } +} + #[cfg(test)] mod tests { use super::*; @@ -773,6 +1035,73 @@ mod tests { assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); } + + #[test] + fn test_did_successfully_write_memory_only_accepts_successful_put_or_update() { + let tool_call = ToolCall { + id: "call_1".to_string(), + name: "memory_manage".to_string(), + arguments: serde_json::json!({ "action": "put" }), + }; + + assert!(did_successfully_write_memory( + &tool_call, + &ToolExecutionOutcome::success("ok".to_string()) + )); + + let failed = ToolExecutionOutcome::failure("err".to_string(), Some("boom".to_string())); + assert!(!did_successfully_write_memory(&tool_call, &failed)); + + let search_call = ToolCall { + id: "call_2".to_string(), + name: "memory_manage".to_string(), + arguments: serde_json::json!({ "action": "search" }), + }; + assert!(!did_successfully_write_memory( + &search_call, + &ToolExecutionOutcome::success("ok".to_string()) + )); + } + + #[test] + fn test_parse_memory_candidates_normalizes_and_limits_results() { + let raw = r#"```json + {"memories": [ + {"namespace": "preference", "key": "Email Folder Preference", "content": "用户提到邮件时默认查看代收邮箱而不是收件箱。"}, + {"namespace": "decision", "key": "mailbox strategy", "content": "后续默认先查看代收邮箱。"}, + {"namespace": "tasks", "key": "short", "content": "太短"} + ]} + ```"#; + + let candidates = parse_memory_candidates(raw); + assert_eq!(candidates.len(), 2); + assert_eq!(candidates[0].namespace, "preferences"); + assert_eq!(candidates[0].key, "email_folder_preference"); + assert_eq!(candidates[1].namespace, "decisions"); + assert_eq!(candidates[1].key, "mailbox_strategy"); + } + + #[test] + fn test_build_memory_extraction_context_uses_recent_user_and_assistant_messages() { + let messages = vec![ + ChatMessage::system("system"), + ChatMessage::tool("call-1", "calculator", "2"), + ChatMessage::user("first user"), + ChatMessage::assistant("first assistant"), + ChatMessage::user("second user"), + ChatMessage::assistant("second assistant"), + ChatMessage::user("third user"), + ]; + let final_response = ChatMessage::assistant("final assistant"); + + let context = build_memory_extraction_context(&messages, &final_response); + assert!(context.contains("user: first user") == false); + assert!(context.contains("user: second user")); + assert!(context.contains("assistant: second assistant")); + assert!(context.contains("user: third user")); + assert!(context.contains("assistant_final: final assistant")); + assert!(!context.contains("tool:")); + } } #[derive(Debug)] diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 4008e20..756dd2c 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -8,9 +8,10 @@ use uuid::Uuid; use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler}; +use crate::providers::{create_provider, ChatCompletionRequest, Message}; use crate::protocol::WsOutbound; use crate::skills::SkillRuntime; -use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; +use crate::storage::{MemoryRecord, SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry, @@ -18,6 +19,10 @@ use crate::tools::{ }; const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot,一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n"; +const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文:优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语,优先使用最容易命中记忆的核心检索词,不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。"; +const RELATED_MEMORY_SYSTEM_PROMPT_PREFIX: &str = "找到相关的记忆。你必须优先参考这些记忆,并在后续推理中把它们当作当前会话的补充上下文;若与用户本轮明确要求冲突,以用户本轮要求为准。"; +const MEMORY_KEYWORD_REASONING_EFFORT: &str = "none"; +const MEMORY_KEYWORD_MAX_CHARS: usize = 32; /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 @@ -622,7 +627,42 @@ impl SessionManager { session_guard.append_persisted_message(chat_id, user_message)?; // 获取完整历史 - let history = session_guard.get_or_create_history(chat_id).clone(); + let mut history = session_guard.get_or_create_history(chat_id).clone(); + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + history_len = history.len(), + "Starting synchronous related memory search" + ); + if let Some(memory_prompt) = build_related_memory_prompt( + session_guard.provider_config().clone(), + self.store.clone(), + channel_name.to_string(), + sender_id.to_string(), + chat_id.to_string(), + history.clone(), + ) + .await? + { + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + prompt_len = memory_prompt.len(), + "Injecting related memory system prompt before agent processing" + ); + let memory_message = ChatMessage::system(memory_prompt); + session_guard.append_persisted_message(chat_id, memory_message.clone())?; + history.push(memory_message); + } else { + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + "No related memory prompt generated before agent processing" + ); + } // 压缩历史(如果需要) let history = session_guard.compressor @@ -678,6 +718,233 @@ impl SessionManager { } } +async fn build_related_memory_prompt( + provider_config: LLMProviderConfig, + store: Arc, + channel_name: String, + sender_id: String, + chat_id: String, + history: Vec, +) -> Result, AgentError> { + let keywords = generate_memory_search_keywords(provider_config, &history).await?; + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + keyword_count = keywords.len(), + keywords = ?keywords, + "Generated memory search keywords" + ); + if keywords.is_empty() { + return Ok(None); + } + + let memories = search_related_memories( + store, + &channel_name, + &sender_id, + &chat_id, + &keywords, + ) + .await?; + + if memories.is_empty() { + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + keyword_count = keywords.len(), + "Related memory search returned no matches" + ); + return Ok(None); + } + + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + keyword_count = keywords.len(), + memory_count = memories.len(), + "Related memory search produced matches" + ); + + Ok(Some(format_related_memory_system_prompt(&keywords, &memories))) +} + +async fn generate_memory_search_keywords( + mut provider_config: LLMProviderConfig, + history: &[ChatMessage], +) -> Result, AgentError> { + if provider_config.provider_type == "openai" { + provider_config.model_extra.insert( + "reasoning_effort".to_string(), + serde_json::Value::String(MEMORY_KEYWORD_REASONING_EFFORT.to_string()), + ); + } + + let provider = create_provider(provider_config) + .map_err(|err| AgentError::ProviderCreation(err.to_string()))?; + + let request = ChatCompletionRequest { + messages: vec![ + Message::system(MEMORY_KEYWORD_SYSTEM_PROMPT), + Message::user(build_memory_keyword_context(history)), + ], + temperature: Some(0.0), + max_tokens: Some(128), + tools: None, + }; + + let response = provider + .chat(request) + .await + .map_err(|err| AgentError::LlmError(err.to_string()))?; + + Ok(parse_memory_keywords(&response.content)) +} + +fn build_memory_keyword_context(history: &[ChatMessage]) -> String { + let recent_messages: Vec<&ChatMessage> = history + .iter() + .rev() + .filter(|message| message.role != "system") + .take(8) + .collect::>() + .into_iter() + .rev() + .collect(); + + let mut lines = vec!["请基于以下最近会话生成长期记忆搜索关键词:".to_string()]; + for message in recent_messages { + let content = message.content.trim(); + if content.is_empty() { + continue; + } + + let compact = if content.chars().count() > 240 { + let prefix: String = content.chars().take(240).collect(); + format!("{}...", prefix) + } else { + content.to_string() + }; + lines.push(format!("{}: {}", message.role, compact)); + } + + lines.join("\n") +} + +fn parse_memory_keywords(raw: &str) -> Vec { + if let Ok(keywords) = serde_json::from_str::>(raw) { + return normalize_keywords(keywords); + } + + normalize_keywords( + raw.split(|ch| matches!(ch, '\n' | ',' | ',' | ';' | ';')) + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToOwned::to_owned) + .collect(), + ) +} + +fn normalize_keywords(keywords: Vec) -> Vec { + let mut seen = HashSet::new(); + let mut normalized = Vec::new(); + + for keyword in keywords { + let candidate = keyword + .trim() + .trim_matches('"') + .trim_matches('[') + .trim_matches(']') + .trim() + .to_string(); + let candidate = compact_memory_keyword(&candidate); + if candidate.is_empty() { + continue; + } + + let key = candidate.to_lowercase(); + if seen.insert(key) { + normalized.push(candidate); + } + + if normalized.len() >= 6 { + break; + } + } + + normalized +} + +fn compact_memory_keyword(candidate: &str) -> String { + let compact = candidate + .split_whitespace() + .next() + .unwrap_or(candidate) + .trim_matches(|ch: char| matches!(ch, '"' | '\'' | '[' | ']' | '。' | ',' | ',' | ';' | ';' | ':' | ':')) + .trim(); + + if compact.is_empty() { + return String::new(); + } + + compact.chars().take(MEMORY_KEYWORD_MAX_CHARS).collect() +} + +async fn search_related_memories( + store: Arc, + channel_name: &str, + sender_id: &str, + chat_id: &str, + keywords: &[String], +) -> Result, AgentError> { + tracing::debug!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + keyword_count = keywords.len(), + keywords = ?keywords, + "Searching related memories with a single batched FTS query" + ); + + let scope_key = format!("{}:{}", channel_name, sender_id); + let merged = store + .search_memories_any("user", &scope_key, keywords, None, 10) + .map_err(|err| AgentError::Other(format!("batched memory search error: {}", err)))?; + + tracing::info!( + channel = %channel_name, + chat_id = %chat_id, + sender_id = %sender_id, + keyword_count = keywords.len(), + deduped_memory_count = merged.len(), + "Finished related memory search aggregation" + ); + + Ok(merged) +} + +fn format_related_memory_system_prompt(keywords: &[String], memories: &[MemoryRecord]) -> String { + let mut lines = vec![ + RELATED_MEMORY_SYSTEM_PROMPT_PREFIX.to_string(), + format!("检索关键词: {}", keywords.join(" / ")), + "相关记忆: ".to_string(), + ]; + + for (index, memory) in memories.iter().take(8).enumerate() { + lines.push(format!( + "{}. [{} / {}] {}", + index + 1, + memory.namespace, + memory.memory_key, + memory.content.replace('\n', " ").trim() + )); + } + + lines.join("\n") +} + #[cfg(test)] mod tests { use super::*; @@ -893,4 +1160,43 @@ mod tests { assert_eq!(history.len(), 1); assert_eq!(history[0].role, "system"); } + + #[test] + fn test_parse_memory_keywords_handles_json_and_dedup() { + let keywords = parse_memory_keywords("[\"Rust\", \"偏好\", \"rust\", \"自动化\"]"); + assert_eq!(keywords, vec!["Rust", "偏好", "自动化"]); + + let fallback = parse_memory_keywords("Rust, 偏好\n自动化"); + assert_eq!(fallback, vec!["Rust", "偏好", "自动化"]); + + let compacted = parse_memory_keywords("[\"用户 身份 信息 长描述\", \"email_folder_preference details\"]"); + assert_eq!(compacted, vec!["用户", "email_folder_preference"]); + } + + #[test] + fn test_format_related_memory_system_prompt_includes_prefix_keywords_and_memory_lines() { + let prompt = format_related_memory_system_prompt( + &["Rust 偏好".to_string(), "审批项目".to_string()], + &[MemoryRecord { + id: "memory-1".to_string(), + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "profile".to_string(), + memory_key: "language".to_string(), + content: "用户偏好 Rust 和自动化工具".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-1".to_string()), + source_message_id: Some("msg-1".to_string()), + source_message_seq: Some(1), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-1".to_string()), + created_at: 1, + updated_at: 1, + }], + ); + + assert!(prompt.contains("找到相关的记忆")); + assert!(prompt.contains("Rust 偏好 / 审批项目")); + assert!(prompt.contains("[profile / language] 用户偏好 Rust 和自动化工具")); + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index dab0816..459148e 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -774,6 +774,67 @@ impl SessionStore { Ok(memories) } + pub fn search_memories_any( + &self, + scope_kind: &str, + scope_key: &str, + queries: &[String], + namespace: Option<&str>, + limit: usize, + ) -> Result, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let limit = limit.max(1) as i64; + let query = quote_fts_or_query(queries); + if query.is_empty() { + return Ok(Vec::new()); + } + + let mut memories = Vec::new(); + + if let Some(namespace) = namespace { + let mut stmt = conn.prepare( + " + SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content, + m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq, + m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at + FROM memories_fts f + JOIN memories m ON m.rowid = f.rowid + WHERE memories_fts MATCH ?1 + AND m.scope_kind = ?2 + AND m.scope_key = ?3 + AND m.namespace = ?4 + ORDER BY bm25(memories_fts), m.updated_at DESC + LIMIT ?5 + ", + )?; + let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?; + for row in rows { + memories.push(row?); + } + } else { + let mut stmt = conn.prepare( + " + SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content, + m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq, + m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at + FROM memories_fts f + JOIN memories m ON m.rowid = f.rowid + WHERE memories_fts MATCH ?1 + AND m.scope_kind = ?2 + AND m.scope_key = ?3 + ORDER BY bm25(memories_fts), m.updated_at DESC + LIMIT ?4 + ", + )?; + let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?; + for row in rows { + memories.push(row?); + } + } + + Ok(memories) + } + pub fn load_messages(&self, session_id: &str) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let cutoff_seq = active_reset_cutoff(&conn, session_id)?; @@ -991,6 +1052,16 @@ fn quote_fts_query(query: &str) -> String { format!("\"{}\"", query.replace('"', "\"\"")) } +fn quote_fts_or_query(queries: &[String]) -> String { + queries + .iter() + .map(|query| query.trim()) + .filter(|query| !query.is_empty()) + .map(quote_fts_query) + .collect::>() + .join(" OR ") +} + #[cfg(test)] mod tests { use super::*; @@ -1352,4 +1423,83 @@ mod tests { .unwrap(); assert!(hits_after_delete.is_empty()); } + + #[test] + fn test_memory_search_matches_memory_key_field() { + let store = SessionStore::in_memory().unwrap(); + + store + .put_memory(&MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "preferences".to_string(), + memory_key: "email_folder_preference".to_string(), + content: "用户提到邮件时默认查看代收邮箱。".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-8".to_string()), + source_message_id: Some("msg-8".to_string()), + source_message_seq: Some(8), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-8".to_string()), + }) + .unwrap(); + + let hits = store + .search_memories("user", "feishu:user-1", "email_folder_preference", None, 10) + .unwrap(); + + assert_eq!(hits.len(), 1); + assert_eq!(hits[0].memory_key, "email_folder_preference"); + } + + #[test] + fn test_search_memories_any_matches_multiple_keywords_once() { + let store = SessionStore::in_memory().unwrap(); + + store + .put_memory(&MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "preferences".to_string(), + memory_key: "editor".to_string(), + content: "Prefers rust-analyzer and cargo test output".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-2".to_string()), + source_message_id: Some("msg-2".to_string()), + source_message_seq: Some(3), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-2".to_string()), + }) + .unwrap(); + + store + .put_memory(&MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "tasks".to_string(), + memory_key: "quality".to_string(), + content: "Tracks clippy warnings before release".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-3".to_string()), + source_message_id: Some("msg-3".to_string()), + source_message_seq: Some(4), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-3".to_string()), + }) + .unwrap(); + + let hits = store + .search_memories_any( + "user", + "feishu:user-1", + &["rust-analyzer".to_string(), "clippy".to_string()], + None, + 10, + ) + .unwrap(); + + assert_eq!(hits.len(), 2); + assert!(hits.iter().any(|memory| memory.memory_key == "editor")); + assert!(hits.iter().any(|memory| memory.memory_key == "quality")); + } } \ No newline at end of file diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs index 94eceba..a436061 100644 --- a/src/tools/memory_manage.rs +++ b/src/tools/memory_manage.rs @@ -23,7 +23,7 @@ impl Tool for MemoryManageTool { } fn description(&self) -> &str { - "Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information by keyword. Memories are scoped to the current channel and sender, and record the originating session/message when available." + "Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information. Search matches namespace, memory_key, and content. When searching, prefer bilingual queries that include both Chinese and English aliases, and include likely snake_case key terms when known. Memories are scoped to the current channel and sender, and record the originating session/message when available." } fn parameters_schema(&self) -> serde_json::Value { @@ -41,7 +41,7 @@ impl Tool for MemoryManageTool { }, "query": { "type": "string", - "description": "Keyword query for full-text memory search, such as a preference, fact, name, topic, or prior decision" + "description": "Keyword query for full-text memory search across namespace, memory_key, and content. Prefer concise bilingual keywords when possible, for example Chinese plus English aliases and likely snake_case key terms." }, "key": { "type": "string",