diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 9699050..042b843 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -52,18 +52,21 @@ impl MemoryManager { } /// Search memories by keyword query. Returns entries sorted by relevance. + /// When `session_id` is provided, results are filtered to that session. pub async fn recall( &self, query: &str, limit: usize, category: Option, + session_id: Option<&str>, ) -> Result, crate::storage::StorageError> { self.storage - .search_memories(query, category.as_ref(), limit) + .search_memories(query, category.as_ref(), session_id, limit) .await } /// Search memories by time range (Unix milliseconds). + /// When `session_id` is provided, results are filtered to that session. pub async fn recall_by_time( &self, since: i64, @@ -71,9 +74,10 @@ impl MemoryManager { query: Option<&str>, limit: usize, category: Option, + session_id: Option<&str>, ) -> Result, crate::storage::StorageError> { self.storage - .search_memories_by_time(since, until, query, category.as_ref(), limit) + .search_memories_by_time(since, until, query, category.as_ref(), session_id, limit) .await } @@ -84,7 +88,7 @@ impl MemoryManager { /// Check if the memory system has any entries (for testing/health check). pub async fn is_empty(&self) -> Result { - self.recall("*", 1, None).await.map(|r| r.is_empty()) + self.recall("*", 1, None, None).await.map(|r| r.is_empty()) } } @@ -116,7 +120,7 @@ mod tests { .await .unwrap(); - let results = mm.recall("test memory", 10, None).await.unwrap(); + let results = mm.recall("test memory", 10, None, None).await.unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].key, "test_key"); assert_eq!(results[0].content, "This is a test memory"); @@ -146,7 +150,7 @@ mod tests { .await .unwrap(); - let results = mm.recall("updated", 10, None).await.unwrap(); + let results = mm.recall("updated", 10, None, None).await.unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].content, "updated"); } @@ -166,7 +170,7 @@ mod tests { .unwrap(); mm.forget("to_delete").await.unwrap(); - let results = mm.recall("deleted", 10, None).await.unwrap(); + let results = mm.recall("deleted", 10, None, None).await.unwrap(); assert!(results.is_empty()); } @@ -194,17 +198,60 @@ mod tests { .unwrap(); let know_results = mm - .recall("content", 10, Some(MemoryCategory::Knowledge)) + .recall("content", 10, Some(MemoryCategory::Knowledge), None) .await .unwrap(); assert_eq!(know_results.len(), 1); assert_eq!(know_results[0].key, "knowledge_1"); let time_results = mm - .recall("content", 10, Some(MemoryCategory::Timeline)) + .recall("content", 10, Some(MemoryCategory::Timeline), None) .await .unwrap(); assert_eq!(time_results.len(), 1); assert_eq!(time_results[0].key, "timeline_1"); } + + #[tokio::test] + async fn test_session_id_filter() { + let (mm, _dir) = setup_memory_manager().await; + + // Store a timeline entry for session A + mm.store( + "tl_a", + "summary from session A", + MemoryCategory::Timeline, + Some("chan:chat:dialog_a"), + Some(0.5), + ) + .await + .unwrap(); + + // Store a timeline entry for session B + mm.store( + "tl_b", + "summary from session B", + MemoryCategory::Timeline, + Some("chan:chat:dialog_b"), + Some(0.5), + ) + .await + .unwrap(); + + // Recall without session_id — should get both + let all = mm + .recall("summary", 10, Some(MemoryCategory::Timeline), None) + .await + .unwrap(); + assert_eq!(all.len(), 2); + + // Recall scoped to session A — should get only tl_a + let scoped = mm + .recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a")) + .await + .unwrap(); + assert_eq!(scoped.len(), 1); + assert_eq!(scoped[0].key, "tl_a"); + assert_eq!(scoped[0].session_id.as_deref(), Some("chan:chat:dialog_a")); + } } diff --git a/src/session/session.rs b/src/session/session.rs index 3c8e0db..e4ff787 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -1322,7 +1322,7 @@ impl SessionManager { let skills_prompt = self.skills_loader.build_skills_prompt(); // Fetch memory context - let memory_context = match self.memory_manager.recall(content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await { + let memory_context = match self.memory_manager.recall(content, 5, Some(crate::memory::MemoryCategory::Knowledge), None).await { Ok(entries) if !entries.is_empty() => { Some(entries.iter() .map(|e| format!("- {}: {}", e.key, e.content)) diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 046ec30..49ea476 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -56,6 +56,7 @@ impl super::Storage { &self, query: &str, category: Option<&MemoryCategory>, + session_id: Option<&str>, limit: usize, ) -> Result, StorageError> { // Build FTS5 query: segment with jieba, wrap each term in quotes, join with OR @@ -76,7 +77,7 @@ impl super::Storage { m.session_id, m.created_at, m.updated_at FROM memory_fts f JOIN memories m ON f.rowid = m.rowid - WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?) + WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?) AND (? IS NULL OR m.session_id = ?) ORDER BY rank LIMIT ? "#, @@ -84,6 +85,8 @@ impl super::Storage { .bind(&fts_query) .bind(category_filter) .bind(category_filter) + .bind(session_id) + .bind(session_id) .bind(limit as i64) .fetch_all(self.pool()) .await?; @@ -113,6 +116,7 @@ impl super::Storage { FROM memories WHERE ({}) AND (? IS NULL OR category = ?) + AND (? IS NULL OR session_id = ?) ORDER BY importance DESC, updated_at DESC LIMIT ? "#, @@ -127,6 +131,8 @@ impl super::Storage { query_builder = query_builder .bind(category_filter) .bind(category_filter) + .bind(session_id) + .bind(session_id) .bind(limit as i64); let rows = query_builder.fetch_all(self.pool()).await?; @@ -144,6 +150,7 @@ impl super::Storage { until: i64, query: Option<&str>, category: Option<&MemoryCategory>, + session_id: Option<&str>, limit: usize, ) -> Result, StorageError> { let category_filter = category.map(|c| c.as_str()); @@ -180,6 +187,7 @@ impl super::Storage { WHERE ({}) AND created_at >= ? AND created_at <= ? AND (? IS NULL OR category = ?) + AND (? IS NULL OR session_id = ?) ORDER BY created_at DESC LIMIT ? "#, @@ -196,6 +204,8 @@ impl super::Storage { .bind(&until_dt) .bind(category_filter) .bind(category_filter) + .bind(session_id) + .bind(session_id) .bind(limit as i64); query_builder.fetch_all(self.pool()).await? @@ -207,6 +217,7 @@ impl super::Storage { FROM memories WHERE created_at >= ? AND created_at <= ? AND (? IS NULL OR category = ?) + AND (? IS NULL OR session_id = ?) ORDER BY created_at DESC LIMIT ? "#, @@ -215,6 +226,8 @@ impl super::Storage { .bind(&until_dt) .bind(category_filter) .bind(category_filter) + .bind(session_id) + .bind(session_id) .bind(limit as i64) .fetch_all(self.pool()) .await? diff --git a/src/tools/memory.rs b/src/tools/memory.rs index c34ad22..0966c77 100644 --- a/src/tools/memory.rs +++ b/src/tools/memory.rs @@ -24,7 +24,7 @@ impl Tool for MemoryStoreTool { } fn description(&self) -> &str { - "Store a fact, preference, or insight into long-term memory. \ + "Store a fact, preference, or insight into long-term knowledge memory. \ Use this when the user shares important information you should remember. \ Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \ and the full content to remember." @@ -46,11 +46,6 @@ impl Tool for MemoryStoreTool { "type": "string", "description": "The full content of the memory entry." }, - "category": { - "type": "string", - "enum": ["knowledge", "timeline"], - "description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries." - }, "importance": { "type": "number", "description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info." @@ -71,16 +66,10 @@ impl Tool for MemoryStoreTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?; - let category = args - .get("category") - .and_then(|v| v.as_str()) - .and_then(MemoryCategory::from_str) - .unwrap_or(MemoryCategory::Knowledge); - let importance = args.get("importance").and_then(|v| v.as_f64()); self.memory - .store(key, content, category, None, importance) + .store(key, content, MemoryCategory::Knowledge, None, importance) .await?; Ok(ToolResult { @@ -110,8 +99,8 @@ impl Tool for MemoryRecallTool { } fn description(&self) -> &str { - "Search and retrieve entries from long-term memory using keyword matching. \ - Use this to recall previously stored facts, preferences, or conversation history. \ + "Search and retrieve entries from long-term knowledge memory using keyword matching. \ + Use this to recall previously stored facts, preferences, or insights. \ IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence). \ Use multiple synonymous or related terms to increase recall. \ Example: instead of 'what is the user location', use 'user location address city residence'. \ @@ -130,11 +119,6 @@ impl Tool for MemoryRecallTool { "type": "string", "description": "Space-separated KEYWORDS for memory search (NOT a natural language question). Use multiple related terms for better recall, e.g. 'address city location residence'." }, - "category": { - "type": "string", - "enum": ["knowledge", "timeline"], - "description": "Filter by memory category. Omit to search all categories." - }, "since": { "type": "integer", "description": "Start of time range (Unix milliseconds)." @@ -158,11 +142,6 @@ impl Tool for MemoryRecallTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?; - let category = args - .get("category") - .and_then(|v| v.as_str()) - .and_then(MemoryCategory::from_str); - let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; let entries = if args.get("since").is_some() || args.get("until").is_some() { @@ -172,10 +151,10 @@ impl Tool for MemoryRecallTool { .and_then(|v| v.as_i64()) .unwrap_or(chrono::Utc::now().timestamp_millis()); self.memory - .recall_by_time(since, until, Some(query), limit, category) + .recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None) .await? } else { - self.memory.recall(query, limit, category).await? + self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await? }; if entries.is_empty() { @@ -189,10 +168,12 @@ impl Tool for MemoryRecallTool { let formatted = entries .iter() .map(|e| { + let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); format!( - "- {} [{}] [importance: {:.1}]: {}", + "- {} [{}]{} [importance: {:.1}]: {}", e.key, e.category.as_str(), + session, e.importance, e.content ) @@ -208,6 +189,119 @@ impl Tool for MemoryRecallTool { } } +// ── TimelineRecallTool ──────────────────────────────────────────────── + +pub struct TimelineRecallTool { + memory: Arc, +} + +impl TimelineRecallTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for TimelineRecallTool { + fn name(&self) -> &str { + "timeline_recall" + } + + fn description(&self) -> &str { + "Search and retrieve conversation summaries from timeline memory. \ + Use this to recall what was discussed in past sessions or earlier in the current session. \ + Optionally filter by session_id to scope to a specific conversation. \ + IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence)." + } + + fn read_only(&self) -> bool { + true + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Space-separated KEYWORDS for timeline search (NOT a natural language question). Use multiple related terms for better recall." + }, + "session_id": { + "type": "string", + "description": "Filter to a specific session (format: channel:chat_id:dialog_id). Omit to search across all sessions." + }, + "since": { + "type": "integer", + "description": "Start of time range (Unix milliseconds)." + }, + "until": { + "type": "integer", + "description": "End of time range (Unix milliseconds)." + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 10)." + } + }, + "required": ["query"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let query = args + .get("query") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?; + + let session_id = args.get("session_id").and_then(|v| v.as_str()); + + let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; + + let entries = if args.get("since").is_some() || args.get("until").is_some() { + let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0); + let until = args + .get("until") + .and_then(|v| v.as_i64()) + .unwrap_or(chrono::Utc::now().timestamp_millis()); + self.memory + .recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id) + .await? + } else { + self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await? + }; + + if entries.is_empty() { + return Ok(ToolResult { + success: true, + output: "No matching timeline entries found.".to_string(), + error: None, + }); + } + + let formatted = entries + .iter() + .map(|e| { + let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); + format!( + "- {} [{}]{} [importance: {:.1}]: {}", + e.key, + e.category.as_str(), + session, + e.importance, + e.content + ) + }) + .collect::>() + .join("\n"); + + Ok(ToolResult { + success: true, + output: format!("Found {} timeline entries:\n{}", entries.len(), formatted), + error: None, + }) + } +} + // ── MemoryForgetTool ───────────────────────────────────────────────── pub struct MemoryForgetTool { diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 59dd62c..6f40f83 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -22,7 +22,7 @@ pub use file_read::FileReadTool; pub use file_write::FileWriteTool; pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; -pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool}; +pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool}; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use send_message::SendMessageTool; @@ -57,6 +57,7 @@ pub fn create_default_tools( registry.register(MemoryStoreTool::new(memory.clone())); registry.register(MemoryRecallTool::new(memory.clone())); + registry.register(TimelineRecallTool::new(memory.clone())); registry.register(MemoryForgetTool::new(memory.clone())); registry