From c602a0695d3009221ca157341d9d8836052bd8f1 Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Thu, 7 May 2026 23:32:59 +0800 Subject: [PATCH] feat: add memory system with FTS5 search and context compression integration --- src/agent/agent_loop.rs | 2 +- src/agent/context_compressor.rs | 42 +++++++ src/agent/system_prompt.rs | 68 ++++++++++- src/config/mod.rs | 36 ++++++ src/gateway/mod.rs | 18 ++- src/lib.rs | 1 + src/memory/mod.rs | 199 ++++++++++++++++++++++++++++++++ src/memory/types.rs | 90 +++++++++++++++ src/session/session.rs | 66 +++++++++-- src/storage/memory.rs | 183 +++++++++++++++++++++++++++++ src/storage/mod.rs | 68 +++++++++-- src/storage/session.rs | 1 + src/tools/chat_manager.rs | 2 + src/tools/memory_forget.rs | 60 ++++++++++ src/tools/memory_recall.rs | 118 +++++++++++++++++++ src/tools/memory_store.rs | 90 +++++++++++++++ src/tools/mod.rs | 20 +++- 17 files changed, 1044 insertions(+), 20 deletions(-) create mode 100644 src/memory/mod.rs create mode 100644 src/memory/types.rs create mode 100644 src/storage/memory.rs create mode 100644 src/tools/memory_forget.rs create mode 100644 src/tools/memory_recall.rs create mode 100644 src/tools/memory_store.rs diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 42a0d2c..9dc6d3a 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -341,7 +341,7 @@ impl AgentLoop { // Build and inject system prompt if not present let has_system = messages.first().map_or(false, |m| m.role == "system"); if !has_system { - let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None); + let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None); #[cfg(debug_assertions)] tracing::debug!("System prompt injected:\n{}", system_prompt); messages.insert(0, ChatMessage::system(system_prompt)); diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index f27d9dc..4417720 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use crate::bus::ChatMessage; +use crate::memory::MemoryManager; use crate::providers::{ChatCompletionRequest, LLMProvider, Message}; use crate::agent::AgentError; @@ -49,6 +50,11 @@ pub struct ContextCompressor { threshold_ratio: f64, /// Shared LLM provider for summarization provider: Arc, + /// Memory manager handle (optional). When set, compressed + /// context summaries are persisted as timeline memory entries. + memory: Option>, + /// Current session ID for timeline memory writes. + session_id: Option, } impl ContextCompressor { @@ -59,6 +65,8 @@ impl ContextCompressor { context_window, threshold_ratio: 0.5, provider, + memory: None, + session_id: None, } } @@ -73,9 +81,22 @@ impl ContextCompressor { context_window, threshold_ratio: 0.5, provider, + memory: None, + session_id: None, } } + /// Attach a memory manager to persist compressed summaries. + pub fn with_memory(mut self, memory: Arc) -> Self { + self.memory = Some(memory); + self + } + + /// Set the current session ID for timeline writes. + pub fn set_session_id(&mut self, id: Option) { + self.session_id = id; + } + /// Get the compression threshold in tokens. fn threshold(&self) -> usize { (self.context_window as f64 * self.threshold_ratio) as usize @@ -218,6 +239,27 @@ impl ContextCompressor { let between = &history[between_start..between_end]; let summary = self.summarize_segment(between).await?; + // Persist compressed summary as timeline memory entry + if let Some(ref mm) = self.memory { + let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); + let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", + ts, between.len(), summary); + let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); + let mm_clone = mm.clone(); + let sid = self.session_id.clone(); + tokio::spawn(async move { + if let Err(e) = mm_clone.store( + &key, + &timeline_content, + crate::memory::MemoryCategory::Timeline, + sid.as_deref(), + Some(0.3), + ).await { + tracing::warn!(error = %e, "Failed to store compressed context as timeline"); + } + }); + } + // Add summary as a special user message new_messages.push(ChatMessage::user(format!( "[Context Summary]\n\n{}", diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index de63f93..78c2a51 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -19,6 +19,8 @@ pub struct PromptContext<'a> { pub model_name: &'a str, pub tools: &'a ToolRegistry, pub session_id: Option<&'a str>, + /// Pre-fetched memory context string to inject. + pub memory_context: Option<&'a str>, } /// Trait for system prompt sections. @@ -43,6 +45,7 @@ impl SystemPromptBuilder { Box::new(SafetySection), Box::new(WorkspaceSection), Box::new(UserProfileSection), + Box::new(MemorySection), Box::new(DateTimeSection), Box::new(RuntimeSection), Box::new(CrossChannelSection), @@ -284,6 +287,24 @@ impl PromptSection for RuntimeSection { } } +/// Injects relevant knowledge memories into the system prompt. +pub struct MemorySection; + +impl PromptSection for MemorySection { + fn name(&self) -> &str { + "memory" + } + + fn build(&self, ctx: &PromptContext<'_>) -> String { + match ctx.memory_context { + Some(context) if !context.is_empty() => { + format!("## 记忆上下文\n\n{}", context) + } + _ => String::new(), + } + } +} + // === Helper Functions === /// Get user config directory (~/.picobot/). @@ -321,12 +342,19 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option) -> String { +pub fn build_system_prompt( + workspace_dir: &Path, + model_name: &str, + tools: &ToolRegistry, + session_id: Option<&str>, + memory_context: Option<&str>, +) -> String { let ctx = PromptContext { workspace_dir, model_name, tools, session_id, + memory_context, }; SystemPromptBuilder::with_defaults().build(&ctx) } @@ -346,6 +374,7 @@ mod tests { model_name: "test-model", tools: &tools, session_id: None, + memory_context: None, }; let prompt = SystemPromptBuilder::with_defaults().build(&ctx); @@ -375,9 +404,44 @@ mod tests { let temp_dir = std::env::temp_dir(); let tools = ToolRegistry::new(); - let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None); + let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None, None); assert!(!prompt.is_empty()); assert!(prompt.contains("test-model")); } + + #[test] + fn test_memory_section_with_context() { + let temp_dir = std::env::temp_dir(); + let tools = ToolRegistry::new(); + + let ctx = PromptContext { + workspace_dir: &temp_dir, + model_name: "test", + tools: &tools, + session_id: None, + memory_context: Some("- user_pref: Prefers Rust"), + }; + + let prompt = SystemPromptBuilder::with_defaults().build(&ctx); + assert!(prompt.contains("## 记忆上下文")); + assert!(prompt.contains("Prefers Rust")); + } + + #[test] + fn test_memory_section_without_context() { + let temp_dir = std::env::temp_dir(); + let tools = ToolRegistry::new(); + + let ctx = PromptContext { + workspace_dir: &temp_dir, + model_name: "test", + tools: &tools, + session_id: None, + memory_context: None, + }; + + let prompt = SystemPromptBuilder::with_defaults().build(&ctx); + assert!(!prompt.contains("## 记忆上下文")); + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 9d965cf..88dfc23 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -49,6 +49,8 @@ pub struct Config { pub client: ClientConfig, #[serde(default)] pub channels: HashMap, + #[serde(default)] + pub memory: MemoryConfig, #[serde(default = "default_workspace_dir")] pub workspace_dir: String, } @@ -218,6 +220,40 @@ impl Default for ClientConfig { } } +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct MemoryConfig { + /// Master switch for the memory system. + #[serde(default)] + pub enabled: bool, + /// Provider name for consolidation LLM calls (key in `providers`). + #[serde(default)] + pub consolidation_provider: Option, + /// Model name for consolidation LLM calls (key in `models`). + #[serde(default)] + pub consolidation_model: Option, + /// Max knowledge entries injected into system prompt per turn. + #[serde(default = "default_recall_limit")] + pub recall_limit: usize, + /// Number of turns without consolidation before forcing one. + #[serde(default = "default_consolidation_turn_threshold")] + pub consolidation_turn_threshold: usize, + /// Idle minutes before triggering consolidation (for async channels). + #[serde(default = "default_idle_consolidation_minutes")] + pub idle_consolidation_minutes: u64, + /// Days before timeline entries are auto-cleaned. + #[serde(default = "default_timeline_retention_days")] + pub timeline_retention_days: u64, + /// Consecutive consolidation failures before degrading to raw archive. + #[serde(default = "default_max_failures_before_degrade")] + pub max_failures_before_degrade: usize, +} + +fn default_recall_limit() -> usize { 5 } +fn default_consolidation_turn_threshold() -> usize { 3 } +fn default_idle_consolidation_minutes() -> u64 { 10 } +fn default_timeline_retention_days() -> u64 { 90 } +fn default_max_failures_before_degrade() -> usize { 3 } + #[derive(Debug, Clone)] pub struct LLMProviderConfig { pub provider_type: String, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e4ea877..66ef39e 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,6 +10,7 @@ use crate::channels::{ChannelManager, CliChatChannel}; use crate::channels::base::{Channel, ChannelError}; use crate::config::{Config, expand_path, ensure_workspace_dir}; use crate::logging; +use crate::memory::MemoryManager; use crate::session::SessionManager; use crate::scheduler::Scheduler; @@ -55,11 +56,26 @@ impl GatewayState { ); tracing::info!("Session storage: {}", db_path.display()); + // Initialize MemoryManager if memory system is enabled + let memory_manager = if config.memory.enabled { + let mm = Arc::new(MemoryManager::new(storage.clone())); + tracing::info!("Memory system enabled"); + Some(mm) + } else { + None + }; + // Create MessageBus first (shared by SessionManager and ChannelManager) let bus = MessageBus::new(100); // Create SessionManager with bus injection - let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone(), bus.clone())?; + let session_manager = SessionManager::new( + session_ttl_hours, + provider_config.clone(), + storage.clone(), + bus.clone(), + memory_manager, + )?; let session_manager = Arc::new(session_manager); // Start background cleanup task (default 60 minutes) diff --git a/src/lib.rs b/src/lib.rs index 218460a..21c069a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod client; pub mod protocol; pub mod channels; pub mod logging; +pub mod memory; pub mod observability; pub mod scheduler; pub mod skills; diff --git a/src/memory/mod.rs b/src/memory/mod.rs new file mode 100644 index 0000000..651e6f9 --- /dev/null +++ b/src/memory/mod.rs @@ -0,0 +1,199 @@ +pub mod types; + +use std::sync::Arc; +use uuid::Uuid; + +use crate::storage::Storage; +pub use types::{ConsolidationFact, ConsolidationResult, MemoryCategory, MemoryEntry}; + +/// MemoryManager provides high-level memory operations. +/// Wraps the Storage SQLite layer with semantic methods. +#[derive(Clone)] +pub struct MemoryManager { + storage: Arc, +} + +impl MemoryManager { + pub fn new(storage: Arc) -> Self { + Self { storage } + } + + /// Store or update a memory entry. Generates timestamp and UUID. + pub async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + importance: Option, + ) -> Result<(), crate::storage::StorageError> { + let now = chrono::Utc::now().to_rfc3339(); + let entry = MemoryEntry { + id: Uuid::new_v4().to_string(), + key: key.to_string(), + content: content.to_string(), + category, + importance: importance.unwrap_or(0.5), + session_id: session_id.map(|s| s.to_string()), + created_at: now.clone(), + updated_at: now, + }; + self.storage.upsert_memory(&entry).await + } + + /// Search memories by keyword query. Returns entries sorted by relevance. + pub async fn recall( + &self, + query: &str, + limit: usize, + category: Option, + ) -> Result, crate::storage::StorageError> { + self.storage + .search_memories(query, category.as_ref(), limit) + .await + } + + /// Search memories by time range (Unix milliseconds). + pub async fn recall_by_time( + &self, + since: i64, + until: i64, + limit: usize, + category: Option, + ) -> Result, crate::storage::StorageError> { + self.storage + .search_memories_by_time(since, until, category.as_ref(), limit) + .await + } + + /// Delete a memory entry by key. + pub async fn forget(&self, key: &str) -> Result<(), crate::storage::StorageError> { + self.storage.delete_memory(key).await + } + + /// Check if the memory system has any entries (for testing/health check). + pub async fn is_empty(&self) -> Result { + self.recall("*", 1, None).await.map(|r| r.is_empty()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use tempfile::tempdir; + + async fn setup_memory_manager() -> (Arc, tempfile::TempDir) { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test.db"); + let storage = Arc::new(Storage::new(&db_path).await.unwrap()); + let mm = Arc::new(MemoryManager::new(storage)); + (mm, dir) + } + + #[tokio::test] + async fn test_store_and_recall() { + let (mm, _dir) = setup_memory_manager().await; + + mm.store( + "test_key", + "This is a test memory", + MemoryCategory::Knowledge, + None, + Some(0.8), + ) + .await + .unwrap(); + + let results = mm.recall("test memory", 10, None).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "test_key"); + assert_eq!(results[0].content, "This is a test memory"); + assert!((results[0].importance - 0.8).abs() < 0.01); + } + + #[tokio::test] + async fn test_upsert_overwrites() { + let (mm, _dir) = setup_memory_manager().await; + + mm.store( + "dup_key", + "original", + MemoryCategory::Knowledge, + None, + None, + ) + .await + .unwrap(); + mm.store( + "dup_key", + "updated", + MemoryCategory::Knowledge, + None, + Some(0.9), + ) + .await + .unwrap(); + + let results = mm.recall("updated", 10, None).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].content, "updated"); + } + + #[tokio::test] + async fn test_forget() { + let (mm, _dir) = setup_memory_manager().await; + + mm.store( + "to_delete", + "will be deleted", + MemoryCategory::Knowledge, + None, + None, + ) + .await + .unwrap(); + mm.forget("to_delete").await.unwrap(); + + let results = mm.recall("deleted", 10, None).await.unwrap(); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn test_category_filter() { + let (mm, _dir) = setup_memory_manager().await; + + mm.store( + "knowledge_1", + "fact content", + MemoryCategory::Knowledge, + None, + None, + ) + .await + .unwrap(); + mm.store( + "timeline_1", + "summary content", + MemoryCategory::Timeline, + None, + None, + ) + .await + .unwrap(); + + let know_results = mm + .recall("content", 10, Some(MemoryCategory::Knowledge)) + .await + .unwrap(); + assert_eq!(know_results.len(), 1); + assert_eq!(know_results[0].key, "knowledge_1"); + + let time_results = mm + .recall("content", 10, Some(MemoryCategory::Timeline)) + .await + .unwrap(); + assert_eq!(time_results.len(), 1); + assert_eq!(time_results[0].key, "timeline_1"); + } +} diff --git a/src/memory/types.rs b/src/memory/types.rs new file mode 100644 index 0000000..3987d21 --- /dev/null +++ b/src/memory/types.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; + +/// Memory categories for the memory system. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MemoryCategory { + /// Long-term facts, preferences, patterns, insights (merged fact+insight). + Knowledge, + /// Conversation summaries produced by context compression. + Timeline, +} + +impl MemoryCategory { + pub fn as_str(&self) -> &str { + match self { + Self::Knowledge => "knowledge", + Self::Timeline => "timeline", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "knowledge" => Some(Self::Knowledge), + "timeline" => Some(Self::Timeline), + _ => None, + } + } +} + +/// A single memory entry stored in SQLite. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryEntry { + pub id: String, + /// Semantic identifier (e.g. "user_prefers_rust"). + pub key: String, + /// The memory content. + pub content: String, + /// Category: knowledge or timeline. + pub category: MemoryCategory, + /// Importance score 0.0–1.0 (default 0.5). + pub importance: f64, + /// Associated session ID (optional). + pub session_id: Option, + /// RFC 3339 creation timestamp. + pub created_at: String, + /// RFC 3339 last update timestamp. + pub updated_at: String, +} + +/// Result from an LLM consolidation call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConsolidationResult { + /// New or updated knowledge entries extracted from conversation. + pub facts: Vec, + /// Summary entry for timeline (formatted as "[YYYY-MM-DD HH:MM] text..."). + pub timeline: Option, + /// Keys of existing memories that should be invalidated. + pub invalidations: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConsolidationFact { + pub key: String, + pub content: String, + pub importance: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memory_category_as_str() { + assert_eq!(MemoryCategory::Knowledge.as_str(), "knowledge"); + assert_eq!(MemoryCategory::Timeline.as_str(), "timeline"); + } + + #[test] + fn test_memory_category_from_str() { + assert_eq!( + MemoryCategory::from_str("knowledge"), + Some(MemoryCategory::Knowledge) + ); + assert_eq!( + MemoryCategory::from_str("timeline"), + Some(MemoryCategory::Timeline) + ); + assert_eq!(MemoryCategory::from_str("invalid"), None); + } +} diff --git a/src/session/session.rs b/src/session/session.rs index fa88f32..af77300 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -55,6 +55,10 @@ pub struct Session { storage: Option>, routing_info: String, + /// Timestamp (Unix ms) of the last consolidation. + /// Messages before this time have been compressed into memory. + pub last_consolidated_at: Option, + memory_manager: Option>, } impl Session { @@ -65,6 +69,7 @@ impl Session { storage: Option>, routing_info: String, title: String, + memory_manager: Option>, ) -> Result { let mut provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; @@ -78,6 +83,12 @@ impl Session { ..Default::default() }; + let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config); + if let Some(ref mm) = memory_manager { + compressor = compressor.with_memory(mm.clone()); + compressor.set_session_id(Some(id.to_string())); + } + let now = chrono::Utc::now().timestamp_millis(); Ok(Self { @@ -92,9 +103,11 @@ impl Session { provider_config: provider_config.clone(), provider: provider.clone(), tools, - compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config), + compressor, storage, routing_info, + last_consolidated_at: None, + memory_manager, }) } @@ -104,6 +117,7 @@ impl Session { provider_config: LLMProviderConfig, tools: Arc, storage: StdArc, + memory_manager: Option>, ) -> Result { let session_meta = storage.get_session(&id.to_string()).await .map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?; @@ -121,6 +135,12 @@ impl Session { ..Default::default() }; + let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config); + if let Some(ref mm) = memory_manager { + compressor = compressor.with_memory(mm.clone()); + compressor.set_session_id(Some(id.to_string())); + } + // Convert MessageMeta to ChatMessage // Clear tool_call_id/tool_name — they're not valid across API sessions let chat_messages: Vec = messages.into_iter().map(|m| { @@ -152,9 +172,11 @@ impl Session { provider_config: provider_config.clone(), provider: provider.clone(), tools, - compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config), + compressor, storage: Some(storage), routing_info: session_meta.routing_info.unwrap_or_default(), + last_consolidated_at: session_meta.last_consolidated_at, + memory_manager, }) } @@ -282,6 +304,7 @@ impl Session { Some(self.routing_info.clone()) }, deleted_at: None, + last_consolidated_at: self.last_consolidated_at, }; storage.upsert_session(&meta).await?; } @@ -372,13 +395,14 @@ impl Session { Ok(self.create_agent()?.with_notify(notify_tx)) } - /// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills) - pub fn build_system_prompt(&self, skills_prompt: &str) -> String { + /// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills + memory) + pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String { let base_prompt = build_system_prompt( &self.provider_config.workspace_dir, &self.provider_config.model_id, &self.tools, Some(&self.id.to_string()), + memory_context, ); if skills_prompt.trim().is_empty() { @@ -560,6 +584,7 @@ pub struct SessionManager { storage: Arc, bus: Arc, current_source_session: Arc>>, + memory_manager: Option>, } struct SessionManagerInner { @@ -647,12 +672,13 @@ impl SessionManager { provider_config: LLMProviderConfig, storage: Arc, bus: Arc, + memory_manager: Option>, ) -> Result { let skills_loader = SkillsLoader::new(); skills_loader.load_skills(); let skills_loader = Arc::new(skills_loader); - let tools = Arc::new(create_default_tools(skills_loader.clone())); + let tools = Arc::new(create_default_tools(skills_loader.clone(), memory_manager.clone())); Ok(Self { inner: Arc::new(Mutex::new(SessionManagerInner { @@ -667,6 +693,7 @@ impl SessionManager { storage, bus, current_source_session: Arc::new(Mutex::new(None)), + memory_manager, }) } @@ -800,7 +827,7 @@ impl SessionManager { // Build the same system prompt that would be injected to the model let skills_prompt = self.skills_loader.build_skills_prompt(); - let system_prompt = session_guard.build_system_prompt(&skills_prompt); + let system_prompt = session_guard.build_system_prompt(&skills_prompt, None); let filepath = session_guard.dump_to_file(&system_prompt) .map_err(|e| AgentError::Other(format!("Failed to save dump: {}", e)))?; @@ -879,6 +906,7 @@ impl SessionManager { message_count: 0, routing_info: if routing_info.is_empty() { None } else { Some(routing_info.clone()) }, deleted_at: None, + last_consolidated_at: None, }; self.storage.upsert_session(&meta).await .map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?; @@ -890,6 +918,7 @@ impl SessionManager { Some(self.storage.clone()), routing_info, title.clone(), + self.memory_manager.clone(), ).await?; let arc = Arc::new(Mutex::new(session)); @@ -921,6 +950,7 @@ impl SessionManager { self.provider_config.clone(), self.tools.clone(), self.storage.clone(), + self.memory_manager.clone(), ).await?; let arc = Arc::new(Mutex::new(session)); @@ -944,6 +974,7 @@ impl SessionManager { Some(self.storage.clone()), String::new(), format!("新对话"), + self.memory_manager.clone(), ).await?; let arc = Arc::new(Mutex::new(session)); @@ -1220,9 +1251,28 @@ impl SessionManager { // Build skills prompt let skills_prompt = self.skills_loader.build_skills_prompt(); + // Fetch memory context + let memory_context = if let Some(ref mm) = self.memory_manager { + match mm.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await { + Ok(entries) if !entries.is_empty() => { + Some(entries.iter() + .map(|e| format!("- {}: {}", e.key, e.content)) + .collect::>() + .join("\n")) + } + Err(e) => { + tracing::warn!(error = %e, "Failed to fetch memory context"); + None + } + _ => None, + } + } else { + None + }; + // Build combined system prompt and inject at position 0 // This ensures AgentLoop.process() sees a system message and doesn't inject its own - let system_prompt = session_guard.build_system_prompt(&skills_prompt); + let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref()); history.insert(0, ChatMessage::system(system_prompt)); let history = session_guard.compressor @@ -1324,7 +1374,7 @@ impl SessionManager { let mut history = session_guard.get_history().to_vec(); let skills_prompt = self.skills_loader.build_skills_prompt(); - let system_prompt = session_guard.build_system_prompt(&skills_prompt); + let system_prompt = session_guard.build_system_prompt(&skills_prompt, None); let cron_context = format!( "\n\n## 定时任务执行\n\n\ 你正在执行定时任务「{}」({})。\n\ diff --git a/src/storage/memory.rs b/src/storage/memory.rs new file mode 100644 index 0000000..0ac41b5 --- /dev/null +++ b/src/storage/memory.rs @@ -0,0 +1,183 @@ +use sqlx::Row; + +use crate::memory::{MemoryCategory, MemoryEntry}; + +use super::StorageError; + +impl super::Storage { + /// Store or update a memory entry (upsert by key). + pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> { + let category_str = entry.category.as_str(); + sqlx::query( + r#" + INSERT INTO memories (id, key, content, category, importance, session_id, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(key) DO UPDATE SET + content = excluded.content, + category = excluded.category, + importance = excluded.importance, + session_id = excluded.session_id, + updated_at = excluded.updated_at + "#, + ) + .bind(&entry.id) + .bind(&entry.key) + .bind(&entry.content) + .bind(category_str) + .bind(entry.importance) + .bind(&entry.session_id) + .bind(&entry.created_at) + .bind(&entry.updated_at) + .execute(self.pool()) + .await?; + Ok(()) + } + + /// Delete a memory entry by key. + pub async fn delete_memory(&self, key: &str) -> Result<(), StorageError> { + sqlx::query("DELETE FROM memories WHERE key = ?") + .bind(key) + .execute(self.pool()) + .await?; + Ok(()) + } + + /// Search memories by keyword using FTS5. + /// Falls back to LIKE query if FTS5 returns no results. + pub async fn search_memories( + &self, + query: &str, + category: Option<&MemoryCategory>, + limit: usize, + ) -> Result, StorageError> { + // Build FTS5 query: wrap each word in quotes and join with OR + let fts_query = query + .split_whitespace() + .map(|w| format!("\"{}\"", w.replace('"', ""))) + .collect::>() + .join(" OR "); + + let category_filter = category.map(|c| c.as_str()); + + // Try FTS5 first + let rows = sqlx::query( + r#" + SELECT m.id, m.key, m.content, m.category, m.importance, + m.session_id, m.created_at, m.updated_at + FROM memory_fts f + JOIN memories m ON f.rowid = m.rowid + WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?) + ORDER BY rank + LIMIT ? + "#, + ) + .bind(&fts_query) + .bind(category_filter) + .bind(category_filter) + .bind(limit as i64) + .fetch_all(self.pool()) + .await?; + + let mut entries = parse_memory_rows(&rows)?; + + // Fallback to LIKE if FTS5 returned nothing + if entries.is_empty() { + let like_pattern = format!("%{}%", query.replace('%', "").replace('_', "")); + let rows = sqlx::query( + r#" + SELECT id, key, content, category, importance, + session_id, created_at, updated_at + FROM memories + WHERE (key LIKE ? OR content LIKE ?) + AND (? IS NULL OR category = ?) + ORDER BY importance DESC, updated_at DESC + LIMIT ? + "#, + ) + .bind(&like_pattern) + .bind(&like_pattern) + .bind(category_filter) + .bind(category_filter) + .bind(limit as i64) + .fetch_all(self.pool()) + .await?; + + entries = parse_memory_rows(&rows)?; + } + + Ok(entries) + } + + /// Retrieve memories within a time range. + pub async fn search_memories_by_time( + &self, + since: i64, + until: i64, + category: Option<&MemoryCategory>, + limit: usize, + ) -> Result, StorageError> { + let category_filter = category.map(|c| c.as_str()); + let since_dt = chrono::DateTime::from_timestamp_millis(since) + .unwrap_or_default() + .to_rfc3339(); + let until_dt = chrono::DateTime::from_timestamp_millis(until) + .unwrap_or_default() + .to_rfc3339(); + + let rows = sqlx::query( + r#" + SELECT id, key, content, category, importance, + session_id, created_at, updated_at + FROM memories + WHERE created_at >= ? AND created_at <= ? + AND (? IS NULL OR category = ?) + ORDER BY created_at DESC + LIMIT ? + "#, + ) + .bind(&since_dt) + .bind(&until_dt) + .bind(category_filter) + .bind(category_filter) + .bind(limit as i64) + .fetch_all(self.pool()) + .await?; + + parse_memory_rows(&rows) + } + + /// Delete old timeline entries beyond retention period. + pub async fn cleanup_old_timelines(&self, retention_days: u64) -> Result { + let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64); + let cutoff_str = cutoff.to_rfc3339(); + + let result = sqlx::query( + "DELETE FROM memories WHERE category = 'timeline' AND created_at < ?", + ) + .bind(&cutoff_str) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected()) + } +} + +fn parse_memory_rows( + rows: &[sqlx::sqlite::SqliteRow], +) -> Result, StorageError> { + rows.iter() + .map(|row| { + Ok(MemoryEntry { + id: row.try_get("id")?, + key: row.try_get("key")?, + content: row.try_get("content")?, + category: MemoryCategory::from_str(&row.try_get::("category")?) + .unwrap_or(MemoryCategory::Knowledge), + importance: row.try_get::("importance")?, + session_id: row.try_get::, _>("session_id")?, + created_at: row.try_get("created_at")?, + updated_at: row.try_get("updated_at")?, + }) + }) + .collect() +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index e581bf8..69f98f2 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,4 +1,5 @@ pub mod error; +pub mod memory; pub mod message; pub mod scheduler; pub mod session; @@ -40,6 +41,7 @@ impl Storage { message_count INTEGER DEFAULT 0, routing_info TEXT, deleted_at INTEGER, + last_consolidated_at INTEGER, UNIQUE(channel, chat_id, dialog_id) ) "#, @@ -94,6 +96,47 @@ impl Storage { .await .ok(); + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + content TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'knowledge', + importance REAL NOT NULL DEFAULT 0.5, + session_id TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + "#, + ) + .execute(&self.pool) + .await?; + + // FTS5 virtual table for full-text search on memories + sqlx::query( + r#" + CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5( + key, + content, + content=memories, + content_rowid=rowid + ) + "#, + ) + .execute(&self.pool) + .await?; + + // Migration: add last_consolidated_at column if not exists + sqlx::query( + r#" + ALTER TABLE sessions ADD COLUMN last_consolidated_at INTEGER + "#, + ) + .execute(&self.pool) + .await + .ok(); + sqlx::query( r#" CREATE TABLE IF NOT EXISTS llm_calls ( @@ -229,14 +272,15 @@ impl Storage { pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> { sqlx::query( r#" - INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET title = excluded.title, last_active_at = excluded.last_active_at, message_count = excluded.message_count, routing_info = excluded.routing_info, - deleted_at = excluded.deleted_at + deleted_at = excluded.deleted_at, + last_consolidated_at = excluded.last_consolidated_at "#, ) .bind(&meta.id) @@ -249,6 +293,7 @@ impl Storage { .bind(meta.message_count) .bind(&meta.routing_info) .bind(meta.deleted_at) + .bind(meta.last_consolidated_at) .execute(self.pool()) .await?; @@ -258,7 +303,7 @@ impl Storage { pub async fn get_session(&self, id: &str) -> Result { let row = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at FROM sessions WHERE id = ? AND deleted_at IS NULL "#, ) @@ -278,6 +323,7 @@ impl Storage { message_count: row.get("message_count"), routing_info: row.get("routing_info"), deleted_at: row.get("deleted_at"), + last_consolidated_at: row.get("last_consolidated_at"), }) } @@ -289,7 +335,7 @@ impl Storage { ) -> Result, StorageError> { let rows = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at FROM sessions WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL ORDER BY last_active_at DESC @@ -315,6 +361,7 @@ impl Storage { message_count: row.get("message_count"), routing_info: row.get("routing_info"), deleted_at: row.get("deleted_at"), + last_consolidated_at: row.get("last_consolidated_at"), }) .collect()) } @@ -362,7 +409,7 @@ impl Storage { let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis; let row = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at FROM sessions WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ? ORDER BY last_active_at DESC @@ -387,6 +434,7 @@ impl Storage { message_count: row.get("message_count"), routing_info: row.get("routing_info"), deleted_at: row.get("deleted_at"), + last_consolidated_at: row.get("last_consolidated_at"), })), None => Ok(None), } @@ -471,7 +519,7 @@ impl Storage { ) -> Result, StorageError> { let rows = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at FROM sessions WHERE deleted_at IS NULL ORDER BY last_active_at DESC @@ -495,6 +543,7 @@ impl Storage { message_count: row.get("message_count"), routing_info: row.get("routing_info"), deleted_at: row.get("deleted_at"), + last_consolidated_at: row.get("last_consolidated_at"), }) .collect()) } @@ -599,6 +648,7 @@ mod tests { message_count: 0, routing_info: Some(r#"{"type":"cli"}"#.to_string()), deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&meta).await.unwrap(); @@ -633,6 +683,7 @@ mod tests { message_count: i, routing_info: None, deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&meta).await.unwrap(); } @@ -658,6 +709,7 @@ mod tests { message_count: 0, routing_info: None, deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&meta).await.unwrap(); @@ -683,6 +735,7 @@ mod tests { message_count: 0, routing_info: None, deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&session_meta).await.unwrap(); @@ -723,6 +776,7 @@ mod tests { message_count: 0, routing_info: None, deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&meta).await.unwrap(); diff --git a/src/storage/session.rs b/src/storage/session.rs index bb3eb15..72ae6d2 100644 --- a/src/storage/session.rs +++ b/src/storage/session.rs @@ -12,4 +12,5 @@ pub struct SessionMeta { pub message_count: i64, pub routing_info: Option, pub deleted_at: Option, + pub last_consolidated_at: Option, } diff --git a/src/tools/chat_manager.rs b/src/tools/chat_manager.rs index 4a5bf28..2c18604 100644 --- a/src/tools/chat_manager.rs +++ b/src/tools/chat_manager.rs @@ -263,6 +263,7 @@ mod tests { message_count: i * 5, routing_info: None, deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&meta).await.unwrap(); } @@ -296,6 +297,7 @@ mod tests { message_count: 3, routing_info: None, deleted_at: None, + last_consolidated_at: None, }; storage.upsert_session(&meta).await.unwrap(); diff --git a/src/tools/memory_forget.rs b/src/tools/memory_forget.rs new file mode 100644 index 0000000..0233587 --- /dev/null +++ b/src/tools/memory_forget.rs @@ -0,0 +1,60 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +use crate::memory::MemoryManager; + +pub struct MemoryForgetTool { + memory: Arc, +} + +impl MemoryForgetTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for MemoryForgetTool { + fn name(&self) -> &str { + "memory_forget" + } + + fn description(&self) -> &str { + "Delete a memory entry by its key. Use this when information is outdated, \ + incorrect, or the user asks to forget something." + } + + fn read_only(&self) -> bool { + false + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "The key of the memory entry to delete." + } + }, + "required": ["key"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?; + + self.memory.forget(key).await?; + + Ok(ToolResult { + success: true, + output: format!("Memory deleted: {}", key), + error: None, + }) + } +} diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs new file mode 100644 index 0000000..771dee1 --- /dev/null +++ b/src/tools/memory_recall.rs @@ -0,0 +1,118 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +use crate::memory::{MemoryCategory, MemoryManager}; + +pub struct MemoryRecallTool { + memory: Arc, +} + +impl MemoryRecallTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for MemoryRecallTool { + fn name(&self) -> &str { + "memory_recall" + } + + fn description(&self) -> &str { + "Search and retrieve entries from long-term memory. \ + Use this to recall previously stored facts, preferences, or conversation history. \ + Supports keyword search and optional time-range filtering." + } + + fn read_only(&self) -> bool { + true + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query — keywords to match against memory keys and content." + }, + "category": { + "type": "string", + "enum": ["knowledge", "timeline"], + "description": "Filter by memory category. Omit to search all categories." + }, + "since": { + "type": "integer", + "description": "Start of time range (Unix milliseconds)." + }, + "until": { + "type": "integer", + "description": "End of time range (Unix milliseconds)." + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 10)." + } + }, + "required": ["query"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let query = args + .get("query") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?; + + let category = args + .get("category") + .and_then(|v| v.as_str()) + .and_then(MemoryCategory::from_str); + + let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; + + let entries = if args.get("since").is_some() || args.get("until").is_some() { + let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0); + let until = args + .get("until") + .and_then(|v| v.as_i64()) + .unwrap_or(chrono::Utc::now().timestamp_millis()); + self.memory + .recall_by_time(since, until, limit, category) + .await? + } else { + self.memory.recall(query, limit, category).await? + }; + + if entries.is_empty() { + return Ok(ToolResult { + success: true, + output: "No matching memories found.".to_string(), + error: None, + }); + } + + let formatted = entries + .iter() + .map(|e| { + format!( + "- {} [{}] [importance: {:.1}]: {}", + e.key, + e.category.as_str(), + e.importance, + e.content + ) + }) + .collect::>() + .join("\n"); + + Ok(ToolResult { + success: true, + output: format!("Found {} memories:\n{}", entries.len(), formatted), + error: None, + }) + } +} diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs new file mode 100644 index 0000000..22aa5ee --- /dev/null +++ b/src/tools/memory_store.rs @@ -0,0 +1,90 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +use crate::memory::{MemoryCategory, MemoryManager}; + +pub struct MemoryStoreTool { + memory: Arc, +} + +impl MemoryStoreTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for MemoryStoreTool { + fn name(&self) -> &str { + "memory_store" + } + + fn description(&self) -> &str { + "Store a fact, preference, or insight into long-term memory. \ + Use this when the user shares important information you should remember. \ + Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \ + and the full content to remember." + } + + fn read_only(&self) -> bool { + false + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key." + }, + "content": { + "type": "string", + "description": "The full content of the memory entry." + }, + "category": { + "type": "string", + "enum": ["knowledge", "timeline"], + "description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries." + }, + "importance": { + "type": "number", + "description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info." + } + }, + "required": ["key", "content"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?; + + let content = args + .get("content") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?; + + let category = args + .get("category") + .and_then(|v| v.as_str()) + .and_then(MemoryCategory::from_str) + .unwrap_or(MemoryCategory::Knowledge); + + let importance = args.get("importance").and_then(|v| v.as_f64()); + + self.memory + .store(key, content, category, None, importance) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Memory stored: {}", key), + error: None, + }) + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index b08adf9..bd1b66c 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -7,6 +7,9 @@ pub mod file_read; pub mod file_write; pub mod get_skill; pub mod http_request; +pub mod memory_forget; +pub mod memory_recall; +pub mod memory_store; pub mod registry; pub mod schema; pub mod send_message; @@ -21,6 +24,9 @@ pub use file_read::FileReadTool; pub use file_write::FileWriteTool; pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; +pub use memory_forget::MemoryForgetTool; +pub use memory_recall::MemoryRecallTool; +pub use memory_store::MemoryStoreTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use send_message::SendMessageTool; @@ -28,12 +34,16 @@ pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use web_fetch::WebFetchTool; use std::sync::Arc; +use crate::memory::MemoryManager; use crate::skills::SkillsLoader; /// Create the base tool registry (without send_message). /// `send_message` tool is registered later via `SessionManager::register_outbound_tool()` /// once the available channel names are known. -pub fn create_default_tools(skills_loader: Arc) -> ToolRegistry { +pub fn create_default_tools( + skills_loader: Arc, + memory: Option>, +) -> ToolRegistry { let registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); registry.register(FileReadTool::new()); @@ -48,5 +58,13 @@ pub fn create_default_tools(skills_loader: Arc) -> ToolRegistry { )); registry.register(WebFetchTool::new(50_000, 30)); registry.register(GetSkillTool::new(skills_loader)); + + // Register memory tools if memory system is available + if let Some(mm) = memory { + registry.register(MemoryStoreTool::new(mm.clone())); + registry.register(MemoryRecallTool::new(mm.clone())); + registry.register(MemoryForgetTool::new(mm.clone())); + } + registry }