From 2617558a27bed558ae973014b3854de4a1745f8b Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Fri, 8 May 2026 10:28:34 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=B0=E5=BF=86=E5=B7=A5=E5=85=B7=E5=90=88?= =?UTF-8?q?=E5=B9=B6=E6=96=87=E4=BB=B6=EF=BC=8C=E6=94=B9=E6=88=90=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E7=B3=BB=E7=BB=9F=E6=80=BB=E6=98=AF=E5=BC=80=E5=90=AF?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.example.json | 9 ++ src/agent/context_compressor.rs | 78 +++++----- src/config/mod.rs | 35 ++++- src/gateway/mod.rs | 25 ++- src/memory/mod.rs | 16 +- src/session/session.rs | 50 +++--- src/tools/memory.rs | 262 ++++++++++++++++++++++++++++++++ src/tools/memory_forget.rs | 60 -------- src/tools/memory_recall.rs | 118 -------------- src/tools/memory_store.rs | 90 ----------- src/tools/mod.rs | 19 +-- 11 files changed, 401 insertions(+), 361 deletions(-) create mode 100644 src/tools/memory.rs delete mode 100644 src/tools/memory_forget.rs delete mode 100644 src/tools/memory_recall.rs delete mode 100644 src/tools/memory_store.rs diff --git a/config.example.json b/config.example.json index 43ef441..25c62be 100644 --- a/config.example.json +++ b/config.example.json @@ -63,5 +63,14 @@ "reaction_emoji": "Typing" } }, + "memory": { + "consolidation_provider": null, + "consolidation_model": null, + "recall_limit": 5, + "consolidation_turn_threshold": 10, + "idle_consolidation_minutes": 10, + "timeline_retention_days": 90, + "max_failures_before_degrade": 3 + }, "workspace_dir": "~/.picobot/workspace" } diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index 4417720..66a6703 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -50,22 +50,26 @@ pub struct ContextCompressor { threshold_ratio: f64, /// Shared LLM provider for summarization provider: Arc, - /// Memory manager handle (optional). When set, compressed - /// context summaries are persisted as timeline memory entries. - memory: Option>, + /// Memory manager handle. Compressed context summaries are persisted + /// as timeline memory entries. + memory: Arc, /// Current session ID for timeline memory writes. session_id: Option, } impl ContextCompressor { - /// Create a new compressor with the given provider and context window size. - pub fn new(provider: Arc, context_window: usize) -> Self { + /// Create a new compressor with the given provider, context window size, and memory manager. + pub fn new( + provider: Arc, + context_window: usize, + memory: Arc, + ) -> Self { Self { config: ContextCompressionConfig::default(), context_window, threshold_ratio: 0.5, provider, - memory: None, + memory, session_id: None, } } @@ -75,23 +79,18 @@ impl ContextCompressor { provider: Arc, context_window: usize, config: ContextCompressionConfig, + memory: Arc, ) -> Self { Self { config, context_window, threshold_ratio: 0.5, provider, - memory: None, + memory, session_id: None, } } - /// Attach a memory manager to persist compressed summaries. - pub fn with_memory(mut self, memory: Arc) -> Self { - self.memory = Some(memory); - self - } - /// Set the current session ID for timeline writes. pub fn set_session_id(&mut self, id: Option) { self.session_id = id; @@ -240,25 +239,23 @@ impl ContextCompressor { let summary = self.summarize_segment(between).await?; // Persist compressed summary as timeline memory entry - if let Some(ref mm) = self.memory { - let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); - let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", - ts, between.len(), summary); - let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); - let mm_clone = mm.clone(); - let sid = self.session_id.clone(); - tokio::spawn(async move { - if let Err(e) = mm_clone.store( - &key, - &timeline_content, - crate::memory::MemoryCategory::Timeline, - sid.as_deref(), - Some(0.3), - ).await { - tracing::warn!(error = %e, "Failed to store compressed context as timeline"); - } - }); - } + let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); + let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", + ts, between.len(), summary); + let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); + let mm = self.memory.clone(); + let sid = self.session_id.clone(); + tokio::spawn(async move { + if let Err(e) = mm.store( + &key, + &timeline_content, + crate::memory::MemoryCategory::Timeline, + sid.as_deref(), + Some(0.3), + ).await { + tracing::warn!(error = %e, "Failed to store compressed context as timeline"); + } + }); // Add summary as a special user message new_messages.push(ChatMessage::user(format!( @@ -370,6 +367,7 @@ mod tests { use super::*; use crate::providers::ChatCompletionResponse; use async_trait::async_trait; + use std::sync::OnceLock; /// Mock provider for testing - panics if actually used for LLM calls struct MockProvider; @@ -400,6 +398,18 @@ mod tests { Arc::new(MockProvider) } + fn test_memory_manager() -> Arc { + static MM: OnceLock> = OnceLock::new(); + MM.get_or_init(|| { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id())); + let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); + Arc::new(MemoryManager::new(storage, "test".into(), "test".into())) + }) + }).clone() + } + #[test] fn test_estimate_tokens() { let messages = vec![ @@ -422,7 +432,7 @@ mod tests { tool_result_trim_chars: 50, ..Default::default() }; - let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config); + let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); let mut messages = vec![ ChatMessage::user("Hello"), @@ -436,7 +446,7 @@ mod tests { #[test] fn test_threshold() { - let compressor = ContextCompressor::new(mock_provider(), 128_000); + let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager()); assert_eq!(compressor.threshold(), 64_000); } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 88dfc23..c9bb2ab 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -220,15 +220,14 @@ impl Default for ClientConfig { } } -#[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct MemoryConfig { - /// Master switch for the memory system. - #[serde(default)] - pub enabled: bool, /// Provider name for consolidation LLM calls (key in `providers`). + /// If not set, falls back to the main agent's provider. #[serde(default)] pub consolidation_provider: Option, /// Model name for consolidation LLM calls (key in `models`). + /// If not set, falls back to the main agent's model. #[serde(default)] pub consolidation_model: Option, /// Max knowledge entries injected into system prompt per turn. @@ -248,8 +247,34 @@ pub struct MemoryConfig { pub max_failures_before_degrade: usize, } +impl Default for MemoryConfig { + fn default() -> Self { + Self { + consolidation_provider: None, + consolidation_model: None, + recall_limit: 5, + consolidation_turn_threshold: 10, + idle_consolidation_minutes: 10, + timeline_retention_days: 90, + max_failures_before_degrade: 3, + } + } +} + +impl MemoryConfig { + /// Resolve consolidation provider name, falling back to the main agent's provider. + pub fn resolve_consolidation_provider(&self, default: &str) -> String { + self.consolidation_provider.clone().unwrap_or_else(|| default.to_string()) + } + + /// Resolve consolidation model name, falling back to the main agent's model. + pub fn resolve_consolidation_model(&self, default: &str) -> String { + self.consolidation_model.clone().unwrap_or_else(|| default.to_string()) + } +} + fn default_recall_limit() -> usize { 5 } -fn default_consolidation_turn_threshold() -> usize { 3 } +fn default_consolidation_turn_threshold() -> usize { 10 } fn default_idle_consolidation_minutes() -> u64 { 10 } fn default_timeline_retention_days() -> u64 { 90 } fn default_max_failures_before_degrade() -> usize { 3 } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 66ef39e..b6104f8 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -56,14 +56,23 @@ impl GatewayState { ); tracing::info!("Session storage: {}", db_path.display()); - // Initialize MemoryManager if memory system is enabled - let memory_manager = if config.memory.enabled { - let mm = Arc::new(MemoryManager::new(storage.clone())); - tracing::info!("Memory system enabled"); - Some(mm) - } else { - None - }; + // Resolve consolidation provider/model with fallback to main agent config + let consolidation_provider = config + .memory + .resolve_consolidation_provider(&provider_config.name); + let consolidation_model = config + .memory + .resolve_consolidation_model(&provider_config.model_id); + let memory_manager = Arc::new(MemoryManager::new( + storage.clone(), + consolidation_provider, + consolidation_model, + )); + tracing::info!( + consolidation_provider = %memory_manager.consolidation_provider, + consolidation_model = %memory_manager.consolidation_model, + "Memory system initialized" + ); // Create MessageBus first (shared by SessionManager and ChannelManager) let bus = MessageBus::new(100); diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 651e6f9..e68883e 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -11,11 +11,21 @@ pub use types::{ConsolidationFact, ConsolidationResult, MemoryCategory, MemoryEn #[derive(Clone)] pub struct MemoryManager { storage: Arc, + pub consolidation_provider: String, + pub consolidation_model: String, } impl MemoryManager { - pub fn new(storage: Arc) -> Self { - Self { storage } + pub fn new( + storage: Arc, + consolidation_provider: String, + consolidation_model: String, + ) -> Self { + Self { + storage, + consolidation_provider, + consolidation_model, + } } /// Store or update a memory entry. Generates timestamp and UUID. @@ -87,7 +97,7 @@ mod tests { let dir = tempdir().unwrap(); let db_path = dir.path().join("test.db"); let storage = Arc::new(Storage::new(&db_path).await.unwrap()); - let mm = Arc::new(MemoryManager::new(storage)); + let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into())); (mm, dir) } diff --git a/src/session/session.rs b/src/session/session.rs index af77300..7ca56e6 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -58,7 +58,7 @@ pub struct Session { /// Timestamp (Unix ms) of the last consolidation. /// Messages before this time have been compressed into memory. pub last_consolidated_at: Option, - memory_manager: Option>, + memory_manager: Arc, } impl Session { @@ -69,7 +69,7 @@ impl Session { storage: Option>, routing_info: String, title: String, - memory_manager: Option>, + memory_manager: Arc, ) -> Result { let mut provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; @@ -83,11 +83,8 @@ impl Session { ..Default::default() }; - let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config); - if let Some(ref mm) = memory_manager { - compressor = compressor.with_memory(mm.clone()); - compressor.set_session_id(Some(id.to_string())); - } + let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone()); + compressor.set_session_id(Some(id.to_string())); let now = chrono::Utc::now().timestamp_millis(); @@ -117,7 +114,7 @@ impl Session { provider_config: LLMProviderConfig, tools: Arc, storage: StdArc, - memory_manager: Option>, + memory_manager: Arc, ) -> Result { let session_meta = storage.get_session(&id.to_string()).await .map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?; @@ -135,11 +132,8 @@ impl Session { ..Default::default() }; - let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config); - if let Some(ref mm) = memory_manager { - compressor = compressor.with_memory(mm.clone()); - compressor.set_session_id(Some(id.to_string())); - } + let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone()); + compressor.set_session_id(Some(id.to_string())); // Convert MessageMeta to ChatMessage // Clear tool_call_id/tool_name — they're not valid across API sessions @@ -584,7 +578,7 @@ pub struct SessionManager { storage: Arc, bus: Arc, current_source_session: Arc>>, - memory_manager: Option>, + memory_manager: Arc, } struct SessionManagerInner { @@ -672,7 +666,7 @@ impl SessionManager { provider_config: LLMProviderConfig, storage: Arc, bus: Arc, - memory_manager: Option>, + memory_manager: Arc, ) -> Result { let skills_loader = SkillsLoader::new(); skills_loader.load_skills(); @@ -1252,22 +1246,18 @@ impl SessionManager { let skills_prompt = self.skills_loader.build_skills_prompt(); // Fetch memory context - let memory_context = if let Some(ref mm) = self.memory_manager { - match mm.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await { - Ok(entries) if !entries.is_empty() => { - Some(entries.iter() - .map(|e| format!("- {}: {}", e.key, e.content)) - .collect::>() - .join("\n")) - } - Err(e) => { - tracing::warn!(error = %e, "Failed to fetch memory context"); - None - } - _ => None, + let memory_context = match self.memory_manager.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await { + Ok(entries) if !entries.is_empty() => { + Some(entries.iter() + .map(|e| format!("- {}: {}", e.key, e.content)) + .collect::>() + .join("\n")) } - } else { - None + Err(e) => { + tracing::warn!(error = %e, "Failed to fetch memory context"); + None + } + _ => None, }; // Build combined system prompt and inject at position 0 diff --git a/src/tools/memory.rs b/src/tools/memory.rs new file mode 100644 index 0000000..795c975 --- /dev/null +++ b/src/tools/memory.rs @@ -0,0 +1,262 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +use crate::memory::{MemoryCategory, MemoryManager}; + +// ── MemoryStoreTool ────────────────────────────────────────────────── + +pub struct MemoryStoreTool { + memory: Arc, +} + +impl MemoryStoreTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for MemoryStoreTool { + fn name(&self) -> &str { + "memory_store" + } + + fn description(&self) -> &str { + "Store a fact, preference, or insight into long-term memory. \ + Use this when the user shares important information you should remember. \ + Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \ + and the full content to remember." + } + + fn read_only(&self) -> bool { + false + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key." + }, + "content": { + "type": "string", + "description": "The full content of the memory entry." + }, + "category": { + "type": "string", + "enum": ["knowledge", "timeline"], + "description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries." + }, + "importance": { + "type": "number", + "description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info." + } + }, + "required": ["key", "content"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?; + + let content = args + .get("content") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?; + + let category = args + .get("category") + .and_then(|v| v.as_str()) + .and_then(MemoryCategory::from_str) + .unwrap_or(MemoryCategory::Knowledge); + + let importance = args.get("importance").and_then(|v| v.as_f64()); + + self.memory + .store(key, content, category, None, importance) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Memory stored: {}", key), + error: None, + }) + } +} + +// ── MemoryRecallTool ───────────────────────────────────────────────── + +pub struct MemoryRecallTool { + memory: Arc, +} + +impl MemoryRecallTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for MemoryRecallTool { + fn name(&self) -> &str { + "memory_recall" + } + + fn description(&self) -> &str { + "Search and retrieve entries from long-term memory. \ + Use this to recall previously stored facts, preferences, or conversation history. \ + Supports keyword search and optional time-range filtering." + } + + fn read_only(&self) -> bool { + true + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query — keywords to match against memory keys and content." + }, + "category": { + "type": "string", + "enum": ["knowledge", "timeline"], + "description": "Filter by memory category. Omit to search all categories." + }, + "since": { + "type": "integer", + "description": "Start of time range (Unix milliseconds)." + }, + "until": { + "type": "integer", + "description": "End of time range (Unix milliseconds)." + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 10)." + } + }, + "required": ["query"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let query = args + .get("query") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?; + + let category = args + .get("category") + .and_then(|v| v.as_str()) + .and_then(MemoryCategory::from_str); + + let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; + + let entries = if args.get("since").is_some() || args.get("until").is_some() { + let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0); + let until = args + .get("until") + .and_then(|v| v.as_i64()) + .unwrap_or(chrono::Utc::now().timestamp_millis()); + self.memory + .recall_by_time(since, until, limit, category) + .await? + } else { + self.memory.recall(query, limit, category).await? + }; + + if entries.is_empty() { + return Ok(ToolResult { + success: true, + output: "No matching memories found.".to_string(), + error: None, + }); + } + + let formatted = entries + .iter() + .map(|e| { + format!( + "- {} [{}] [importance: {:.1}]: {}", + e.key, + e.category.as_str(), + e.importance, + e.content + ) + }) + .collect::>() + .join("\n"); + + Ok(ToolResult { + success: true, + output: format!("Found {} memories:\n{}", entries.len(), formatted), + error: None, + }) + } +} + +// ── MemoryForgetTool ───────────────────────────────────────────────── + +pub struct MemoryForgetTool { + memory: Arc, +} + +impl MemoryForgetTool { + pub fn new(memory: Arc) -> Self { + Self { memory } + } +} + +#[async_trait] +impl Tool for MemoryForgetTool { + fn name(&self) -> &str { + "memory_forget" + } + + fn description(&self) -> &str { + "Delete a memory entry by its key. Use this when information is outdated, \ + incorrect, or the user asks to forget something." + } + + fn read_only(&self) -> bool { + false + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "The key of the memory entry to delete." + } + }, + "required": ["key"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?; + + self.memory.forget(key).await?; + + Ok(ToolResult { + success: true, + output: format!("Memory deleted: {}", key), + error: None, + }) + } +} diff --git a/src/tools/memory_forget.rs b/src/tools/memory_forget.rs deleted file mode 100644 index 0233587..0000000 --- a/src/tools/memory_forget.rs +++ /dev/null @@ -1,60 +0,0 @@ -use super::traits::{Tool, ToolResult}; -use async_trait::async_trait; -use serde_json::json; -use std::sync::Arc; - -use crate::memory::MemoryManager; - -pub struct MemoryForgetTool { - memory: Arc, -} - -impl MemoryForgetTool { - pub fn new(memory: Arc) -> Self { - Self { memory } - } -} - -#[async_trait] -impl Tool for MemoryForgetTool { - fn name(&self) -> &str { - "memory_forget" - } - - fn description(&self) -> &str { - "Delete a memory entry by its key. Use this when information is outdated, \ - incorrect, or the user asks to forget something." - } - - fn read_only(&self) -> bool { - false - } - - fn parameters_schema(&self) -> serde_json::Value { - json!({ - "type": "object", - "properties": { - "key": { - "type": "string", - "description": "The key of the memory entry to delete." - } - }, - "required": ["key"] - }) - } - - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { - let key = args - .get("key") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?; - - self.memory.forget(key).await?; - - Ok(ToolResult { - success: true, - output: format!("Memory deleted: {}", key), - error: None, - }) - } -} diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs deleted file mode 100644 index 771dee1..0000000 --- a/src/tools/memory_recall.rs +++ /dev/null @@ -1,118 +0,0 @@ -use super::traits::{Tool, ToolResult}; -use async_trait::async_trait; -use serde_json::json; -use std::sync::Arc; - -use crate::memory::{MemoryCategory, MemoryManager}; - -pub struct MemoryRecallTool { - memory: Arc, -} - -impl MemoryRecallTool { - pub fn new(memory: Arc) -> Self { - Self { memory } - } -} - -#[async_trait] -impl Tool for MemoryRecallTool { - fn name(&self) -> &str { - "memory_recall" - } - - fn description(&self) -> &str { - "Search and retrieve entries from long-term memory. \ - Use this to recall previously stored facts, preferences, or conversation history. \ - Supports keyword search and optional time-range filtering." - } - - fn read_only(&self) -> bool { - true - } - - fn parameters_schema(&self) -> serde_json::Value { - json!({ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query — keywords to match against memory keys and content." - }, - "category": { - "type": "string", - "enum": ["knowledge", "timeline"], - "description": "Filter by memory category. Omit to search all categories." - }, - "since": { - "type": "integer", - "description": "Start of time range (Unix milliseconds)." - }, - "until": { - "type": "integer", - "description": "End of time range (Unix milliseconds)." - }, - "limit": { - "type": "integer", - "description": "Max results to return (default 10)." - } - }, - "required": ["query"] - }) - } - - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { - let query = args - .get("query") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?; - - let category = args - .get("category") - .and_then(|v| v.as_str()) - .and_then(MemoryCategory::from_str); - - let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; - - let entries = if args.get("since").is_some() || args.get("until").is_some() { - let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0); - let until = args - .get("until") - .and_then(|v| v.as_i64()) - .unwrap_or(chrono::Utc::now().timestamp_millis()); - self.memory - .recall_by_time(since, until, limit, category) - .await? - } else { - self.memory.recall(query, limit, category).await? - }; - - if entries.is_empty() { - return Ok(ToolResult { - success: true, - output: "No matching memories found.".to_string(), - error: None, - }); - } - - let formatted = entries - .iter() - .map(|e| { - format!( - "- {} [{}] [importance: {:.1}]: {}", - e.key, - e.category.as_str(), - e.importance, - e.content - ) - }) - .collect::>() - .join("\n"); - - Ok(ToolResult { - success: true, - output: format!("Found {} memories:\n{}", entries.len(), formatted), - error: None, - }) - } -} diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs deleted file mode 100644 index 22aa5ee..0000000 --- a/src/tools/memory_store.rs +++ /dev/null @@ -1,90 +0,0 @@ -use super::traits::{Tool, ToolResult}; -use async_trait::async_trait; -use serde_json::json; -use std::sync::Arc; - -use crate::memory::{MemoryCategory, MemoryManager}; - -pub struct MemoryStoreTool { - memory: Arc, -} - -impl MemoryStoreTool { - pub fn new(memory: Arc) -> Self { - Self { memory } - } -} - -#[async_trait] -impl Tool for MemoryStoreTool { - fn name(&self) -> &str { - "memory_store" - } - - fn description(&self) -> &str { - "Store a fact, preference, or insight into long-term memory. \ - Use this when the user shares important information you should remember. \ - Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \ - and the full content to remember." - } - - fn read_only(&self) -> bool { - false - } - - fn parameters_schema(&self) -> serde_json::Value { - json!({ - "type": "object", - "properties": { - "key": { - "type": "string", - "description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key." - }, - "content": { - "type": "string", - "description": "The full content of the memory entry." - }, - "category": { - "type": "string", - "enum": ["knowledge", "timeline"], - "description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries." - }, - "importance": { - "type": "number", - "description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info." - } - }, - "required": ["key", "content"] - }) - } - - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { - let key = args - .get("key") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?; - - let content = args - .get("content") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?; - - let category = args - .get("category") - .and_then(|v| v.as_str()) - .and_then(MemoryCategory::from_str) - .unwrap_or(MemoryCategory::Knowledge); - - let importance = args.get("importance").and_then(|v| v.as_f64()); - - self.memory - .store(key, content, category, None, importance) - .await?; - - Ok(ToolResult { - success: true, - output: format!("Memory stored: {}", key), - error: None, - }) - } -} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index bd1b66c..59dd62c 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -7,9 +7,7 @@ pub mod file_read; pub mod file_write; pub mod get_skill; pub mod http_request; -pub mod memory_forget; -pub mod memory_recall; -pub mod memory_store; +pub mod memory; pub mod registry; pub mod schema; pub mod send_message; @@ -24,9 +22,7 @@ pub use file_read::FileReadTool; pub use file_write::FileWriteTool; pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; -pub use memory_forget::MemoryForgetTool; -pub use memory_recall::MemoryRecallTool; -pub use memory_store::MemoryStoreTool; +pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool}; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use send_message::SendMessageTool; @@ -42,7 +38,7 @@ use crate::skills::SkillsLoader; /// once the available channel names are known. pub fn create_default_tools( skills_loader: Arc, - memory: Option>, + memory: Arc, ) -> ToolRegistry { let registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); @@ -59,12 +55,9 @@ pub fn create_default_tools( registry.register(WebFetchTool::new(50_000, 30)); registry.register(GetSkillTool::new(skills_loader)); - // Register memory tools if memory system is available - if let Some(mm) = memory { - registry.register(MemoryStoreTool::new(mm.clone())); - registry.register(MemoryRecallTool::new(mm.clone())); - registry.register(MemoryForgetTool::new(mm.clone())); - } + registry.register(MemoryStoreTool::new(memory.clone())); + registry.register(MemoryRecallTool::new(memory.clone())); + registry.register(MemoryForgetTool::new(memory.clone())); registry }