feat: add memory system with FTS5 search and context compression integration
This commit is contained in:
parent
618016ac43
commit
c602a0695d
@ -341,7 +341,7 @@ impl AgentLoop {
|
|||||||
// Build and inject system prompt if not present
|
// Build and inject system prompt if not present
|
||||||
let has_system = messages.first().map_or(false, |m| m.role == "system");
|
let has_system = messages.first().map_or(false, |m| m.role == "system");
|
||||||
if !has_system {
|
if !has_system {
|
||||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None);
|
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None);
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||||
messages.insert(0, ChatMessage::system(system_prompt));
|
messages.insert(0, ChatMessage::system(system_prompt));
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
|
use crate::memory::MemoryManager;
|
||||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
@ -49,6 +50,11 @@ pub struct ContextCompressor {
|
|||||||
threshold_ratio: f64,
|
threshold_ratio: f64,
|
||||||
/// Shared LLM provider for summarization
|
/// Shared LLM provider for summarization
|
||||||
provider: Arc<dyn LLMProvider>,
|
provider: Arc<dyn LLMProvider>,
|
||||||
|
/// Memory manager handle (optional). When set, compressed
|
||||||
|
/// context summaries are persisted as timeline memory entries.
|
||||||
|
memory: Option<Arc<MemoryManager>>,
|
||||||
|
/// Current session ID for timeline memory writes.
|
||||||
|
session_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextCompressor {
|
impl ContextCompressor {
|
||||||
@ -59,6 +65,8 @@ impl ContextCompressor {
|
|||||||
context_window,
|
context_window,
|
||||||
threshold_ratio: 0.5,
|
threshold_ratio: 0.5,
|
||||||
provider,
|
provider,
|
||||||
|
memory: None,
|
||||||
|
session_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,9 +81,22 @@ impl ContextCompressor {
|
|||||||
context_window,
|
context_window,
|
||||||
threshold_ratio: 0.5,
|
threshold_ratio: 0.5,
|
||||||
provider,
|
provider,
|
||||||
|
memory: None,
|
||||||
|
session_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Attach a memory manager to persist compressed summaries.
|
||||||
|
pub fn with_memory(mut self, memory: Arc<MemoryManager>) -> Self {
|
||||||
|
self.memory = Some(memory);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the current session ID for timeline writes.
|
||||||
|
pub fn set_session_id(&mut self, id: Option<String>) {
|
||||||
|
self.session_id = id;
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the compression threshold in tokens.
|
/// Get the compression threshold in tokens.
|
||||||
fn threshold(&self) -> usize {
|
fn threshold(&self) -> usize {
|
||||||
(self.context_window as f64 * self.threshold_ratio) as usize
|
(self.context_window as f64 * self.threshold_ratio) as usize
|
||||||
@ -218,6 +239,27 @@ impl ContextCompressor {
|
|||||||
let between = &history[between_start..between_end];
|
let between = &history[between_start..between_end];
|
||||||
let summary = self.summarize_segment(between).await?;
|
let summary = self.summarize_segment(between).await?;
|
||||||
|
|
||||||
|
// Persist compressed summary as timeline memory entry
|
||||||
|
if let Some(ref mm) = self.memory {
|
||||||
|
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||||
|
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||||
|
ts, between.len(), summary);
|
||||||
|
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||||
|
let mm_clone = mm.clone();
|
||||||
|
let sid = self.session_id.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = mm_clone.store(
|
||||||
|
&key,
|
||||||
|
&timeline_content,
|
||||||
|
crate::memory::MemoryCategory::Timeline,
|
||||||
|
sid.as_deref(),
|
||||||
|
Some(0.3),
|
||||||
|
).await {
|
||||||
|
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Add summary as a special user message
|
// Add summary as a special user message
|
||||||
new_messages.push(ChatMessage::user(format!(
|
new_messages.push(ChatMessage::user(format!(
|
||||||
"[Context Summary]\n\n{}",
|
"[Context Summary]\n\n{}",
|
||||||
|
|||||||
@ -19,6 +19,8 @@ pub struct PromptContext<'a> {
|
|||||||
pub model_name: &'a str,
|
pub model_name: &'a str,
|
||||||
pub tools: &'a ToolRegistry,
|
pub tools: &'a ToolRegistry,
|
||||||
pub session_id: Option<&'a str>,
|
pub session_id: Option<&'a str>,
|
||||||
|
/// Pre-fetched memory context string to inject.
|
||||||
|
pub memory_context: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for system prompt sections.
|
/// Trait for system prompt sections.
|
||||||
@ -43,6 +45,7 @@ impl SystemPromptBuilder {
|
|||||||
Box::new(SafetySection),
|
Box::new(SafetySection),
|
||||||
Box::new(WorkspaceSection),
|
Box::new(WorkspaceSection),
|
||||||
Box::new(UserProfileSection),
|
Box::new(UserProfileSection),
|
||||||
|
Box::new(MemorySection),
|
||||||
Box::new(DateTimeSection),
|
Box::new(DateTimeSection),
|
||||||
Box::new(RuntimeSection),
|
Box::new(RuntimeSection),
|
||||||
Box::new(CrossChannelSection),
|
Box::new(CrossChannelSection),
|
||||||
@ -284,6 +287,24 @@ impl PromptSection for RuntimeSection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Injects relevant knowledge memories into the system prompt.
|
||||||
|
pub struct MemorySection;
|
||||||
|
|
||||||
|
impl PromptSection for MemorySection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||||
|
match ctx.memory_context {
|
||||||
|
Some(context) if !context.is_empty() => {
|
||||||
|
format!("## 记忆上下文\n\n{}", context)
|
||||||
|
}
|
||||||
|
_ => String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// === Helper Functions ===
|
// === Helper Functions ===
|
||||||
|
|
||||||
/// Get user config directory (~/.picobot/).
|
/// Get user config directory (~/.picobot/).
|
||||||
@ -321,12 +342,19 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build a complete system prompt with default configuration.
|
/// Build a complete system prompt with default configuration.
|
||||||
pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry, session_id: Option<&str>) -> String {
|
pub fn build_system_prompt(
|
||||||
|
workspace_dir: &Path,
|
||||||
|
model_name: &str,
|
||||||
|
tools: &ToolRegistry,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
memory_context: Option<&str>,
|
||||||
|
) -> String {
|
||||||
let ctx = PromptContext {
|
let ctx = PromptContext {
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
tools,
|
tools,
|
||||||
session_id,
|
session_id,
|
||||||
|
memory_context,
|
||||||
};
|
};
|
||||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||||
}
|
}
|
||||||
@ -346,6 +374,7 @@ mod tests {
|
|||||||
model_name: "test-model",
|
model_name: "test-model",
|
||||||
tools: &tools,
|
tools: &tools,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
|
memory_context: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||||
@ -375,9 +404,44 @@ mod tests {
|
|||||||
let temp_dir = std::env::temp_dir();
|
let temp_dir = std::env::temp_dir();
|
||||||
let tools = ToolRegistry::new();
|
let tools = ToolRegistry::new();
|
||||||
|
|
||||||
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None);
|
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None, None);
|
||||||
|
|
||||||
assert!(!prompt.is_empty());
|
assert!(!prompt.is_empty());
|
||||||
assert!(prompt.contains("test-model"));
|
assert!(prompt.contains("test-model"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_memory_section_with_context() {
|
||||||
|
let temp_dir = std::env::temp_dir();
|
||||||
|
let tools = ToolRegistry::new();
|
||||||
|
|
||||||
|
let ctx = PromptContext {
|
||||||
|
workspace_dir: &temp_dir,
|
||||||
|
model_name: "test",
|
||||||
|
tools: &tools,
|
||||||
|
session_id: None,
|
||||||
|
memory_context: Some("- user_pref: Prefers Rust"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||||
|
assert!(prompt.contains("## 记忆上下文"));
|
||||||
|
assert!(prompt.contains("Prefers Rust"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_memory_section_without_context() {
|
||||||
|
let temp_dir = std::env::temp_dir();
|
||||||
|
let tools = ToolRegistry::new();
|
||||||
|
|
||||||
|
let ctx = PromptContext {
|
||||||
|
workspace_dir: &temp_dir,
|
||||||
|
model_name: "test",
|
||||||
|
tools: &tools,
|
||||||
|
session_id: None,
|
||||||
|
memory_context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||||
|
assert!(!prompt.contains("## 记忆上下文"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -49,6 +49,8 @@ pub struct Config {
|
|||||||
pub client: ClientConfig,
|
pub client: ClientConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub channels: HashMap<String, FeishuChannelConfig>,
|
pub channels: HashMap<String, FeishuChannelConfig>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub memory: MemoryConfig,
|
||||||
#[serde(default = "default_workspace_dir")]
|
#[serde(default = "default_workspace_dir")]
|
||||||
pub workspace_dir: String,
|
pub workspace_dir: String,
|
||||||
}
|
}
|
||||||
@ -218,6 +220,40 @@ impl Default for ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||||
|
pub struct MemoryConfig {
|
||||||
|
/// Master switch for the memory system.
|
||||||
|
#[serde(default)]
|
||||||
|
pub enabled: bool,
|
||||||
|
/// Provider name for consolidation LLM calls (key in `providers`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub consolidation_provider: Option<String>,
|
||||||
|
/// Model name for consolidation LLM calls (key in `models`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub consolidation_model: Option<String>,
|
||||||
|
/// Max knowledge entries injected into system prompt per turn.
|
||||||
|
#[serde(default = "default_recall_limit")]
|
||||||
|
pub recall_limit: usize,
|
||||||
|
/// Number of turns without consolidation before forcing one.
|
||||||
|
#[serde(default = "default_consolidation_turn_threshold")]
|
||||||
|
pub consolidation_turn_threshold: usize,
|
||||||
|
/// Idle minutes before triggering consolidation (for async channels).
|
||||||
|
#[serde(default = "default_idle_consolidation_minutes")]
|
||||||
|
pub idle_consolidation_minutes: u64,
|
||||||
|
/// Days before timeline entries are auto-cleaned.
|
||||||
|
#[serde(default = "default_timeline_retention_days")]
|
||||||
|
pub timeline_retention_days: u64,
|
||||||
|
/// Consecutive consolidation failures before degrading to raw archive.
|
||||||
|
#[serde(default = "default_max_failures_before_degrade")]
|
||||||
|
pub max_failures_before_degrade: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_recall_limit() -> usize { 5 }
|
||||||
|
fn default_consolidation_turn_threshold() -> usize { 3 }
|
||||||
|
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
||||||
|
fn default_timeline_retention_days() -> u64 { 90 }
|
||||||
|
fn default_max_failures_before_degrade() -> usize { 3 }
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LLMProviderConfig {
|
pub struct LLMProviderConfig {
|
||||||
pub provider_type: String,
|
pub provider_type: String,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ use crate::channels::{ChannelManager, CliChatChannel};
|
|||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
||||||
use crate::logging;
|
use crate::logging;
|
||||||
|
use crate::memory::MemoryManager;
|
||||||
use crate::session::SessionManager;
|
use crate::session::SessionManager;
|
||||||
use crate::scheduler::Scheduler;
|
use crate::scheduler::Scheduler;
|
||||||
|
|
||||||
@ -55,11 +56,26 @@ impl GatewayState {
|
|||||||
);
|
);
|
||||||
tracing::info!("Session storage: {}", db_path.display());
|
tracing::info!("Session storage: {}", db_path.display());
|
||||||
|
|
||||||
|
// Initialize MemoryManager if memory system is enabled
|
||||||
|
let memory_manager = if config.memory.enabled {
|
||||||
|
let mm = Arc::new(MemoryManager::new(storage.clone()));
|
||||||
|
tracing::info!("Memory system enabled");
|
||||||
|
Some(mm)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
||||||
let bus = MessageBus::new(100);
|
let bus = MessageBus::new(100);
|
||||||
|
|
||||||
// Create SessionManager with bus injection
|
// Create SessionManager with bus injection
|
||||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone(), bus.clone())?;
|
let session_manager = SessionManager::new(
|
||||||
|
session_ttl_hours,
|
||||||
|
provider_config.clone(),
|
||||||
|
storage.clone(),
|
||||||
|
bus.clone(),
|
||||||
|
memory_manager,
|
||||||
|
)?;
|
||||||
let session_manager = Arc::new(session_manager);
|
let session_manager = Arc::new(session_manager);
|
||||||
|
|
||||||
// Start background cleanup task (default 60 minutes)
|
// Start background cleanup task (default 60 minutes)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ pub mod client;
|
|||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
pub mod channels;
|
pub mod channels;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
pub mod memory;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
pub mod skills;
|
pub mod skills;
|
||||||
|
|||||||
199
src/memory/mod.rs
Normal file
199
src/memory/mod.rs
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
pub mod types;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::storage::Storage;
|
||||||
|
pub use types::{ConsolidationFact, ConsolidationResult, MemoryCategory, MemoryEntry};
|
||||||
|
|
||||||
|
/// MemoryManager provides high-level memory operations.
|
||||||
|
/// Wraps the Storage SQLite layer with semantic methods.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MemoryManager {
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryManager {
|
||||||
|
pub fn new(storage: Arc<Storage>) -> Self {
|
||||||
|
Self { storage }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store or update a memory entry. Generates timestamp and UUID.
|
||||||
|
pub async fn store(
|
||||||
|
&self,
|
||||||
|
key: &str,
|
||||||
|
content: &str,
|
||||||
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
importance: Option<f64>,
|
||||||
|
) -> Result<(), crate::storage::StorageError> {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let entry = MemoryEntry {
|
||||||
|
id: Uuid::new_v4().to_string(),
|
||||||
|
key: key.to_string(),
|
||||||
|
content: content.to_string(),
|
||||||
|
category,
|
||||||
|
importance: importance.unwrap_or(0.5),
|
||||||
|
session_id: session_id.map(|s| s.to_string()),
|
||||||
|
created_at: now.clone(),
|
||||||
|
updated_at: now,
|
||||||
|
};
|
||||||
|
self.storage.upsert_memory(&entry).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search memories by keyword query. Returns entries sorted by relevance.
|
||||||
|
pub async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
category: Option<MemoryCategory>,
|
||||||
|
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
||||||
|
self.storage
|
||||||
|
.search_memories(query, category.as_ref(), limit)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search memories by time range (Unix milliseconds).
|
||||||
|
pub async fn recall_by_time(
|
||||||
|
&self,
|
||||||
|
since: i64,
|
||||||
|
until: i64,
|
||||||
|
limit: usize,
|
||||||
|
category: Option<MemoryCategory>,
|
||||||
|
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
||||||
|
self.storage
|
||||||
|
.search_memories_by_time(since, until, category.as_ref(), limit)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a memory entry by key.
|
||||||
|
pub async fn forget(&self, key: &str) -> Result<(), crate::storage::StorageError> {
|
||||||
|
self.storage.delete_memory(key).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the memory system has any entries (for testing/health check).
|
||||||
|
pub async fn is_empty(&self) -> Result<bool, crate::storage::StorageError> {
|
||||||
|
self.recall("*", 1, None).await.map(|r| r.is_empty())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
async fn setup_memory_manager() -> (Arc<MemoryManager>, tempfile::TempDir) {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let db_path = dir.path().join("test.db");
|
||||||
|
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||||
|
let mm = Arc::new(MemoryManager::new(storage));
|
||||||
|
(mm, dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_and_recall() {
|
||||||
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
|
mm.store(
|
||||||
|
"test_key",
|
||||||
|
"This is a test memory",
|
||||||
|
MemoryCategory::Knowledge,
|
||||||
|
None,
|
||||||
|
Some(0.8),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mm.recall("test memory", 10, None).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "test_key");
|
||||||
|
assert_eq!(results[0].content, "This is a test memory");
|
||||||
|
assert!((results[0].importance - 0.8).abs() < 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upsert_overwrites() {
|
||||||
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
|
mm.store(
|
||||||
|
"dup_key",
|
||||||
|
"original",
|
||||||
|
MemoryCategory::Knowledge,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mm.store(
|
||||||
|
"dup_key",
|
||||||
|
"updated",
|
||||||
|
MemoryCategory::Knowledge,
|
||||||
|
None,
|
||||||
|
Some(0.9),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mm.recall("updated", 10, None).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].content, "updated");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_forget() {
|
||||||
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
|
mm.store(
|
||||||
|
"to_delete",
|
||||||
|
"will be deleted",
|
||||||
|
MemoryCategory::Knowledge,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mm.forget("to_delete").await.unwrap();
|
||||||
|
|
||||||
|
let results = mm.recall("deleted", 10, None).await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_category_filter() {
|
||||||
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
|
mm.store(
|
||||||
|
"knowledge_1",
|
||||||
|
"fact content",
|
||||||
|
MemoryCategory::Knowledge,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mm.store(
|
||||||
|
"timeline_1",
|
||||||
|
"summary content",
|
||||||
|
MemoryCategory::Timeline,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let know_results = mm
|
||||||
|
.recall("content", 10, Some(MemoryCategory::Knowledge))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(know_results.len(), 1);
|
||||||
|
assert_eq!(know_results[0].key, "knowledge_1");
|
||||||
|
|
||||||
|
let time_results = mm
|
||||||
|
.recall("content", 10, Some(MemoryCategory::Timeline))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(time_results.len(), 1);
|
||||||
|
assert_eq!(time_results[0].key, "timeline_1");
|
||||||
|
}
|
||||||
|
}
|
||||||
90
src/memory/types.rs
Normal file
90
src/memory/types.rs
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Memory categories for the memory system.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum MemoryCategory {
|
||||||
|
/// Long-term facts, preferences, patterns, insights (merged fact+insight).
|
||||||
|
Knowledge,
|
||||||
|
/// Conversation summaries produced by context compression.
|
||||||
|
Timeline,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryCategory {
|
||||||
|
pub fn as_str(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Self::Knowledge => "knowledge",
|
||||||
|
Self::Timeline => "timeline",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_str(s: &str) -> Option<Self> {
|
||||||
|
match s {
|
||||||
|
"knowledge" => Some(Self::Knowledge),
|
||||||
|
"timeline" => Some(Self::Timeline),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single memory entry stored in SQLite.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MemoryEntry {
|
||||||
|
pub id: String,
|
||||||
|
/// Semantic identifier (e.g. "user_prefers_rust").
|
||||||
|
pub key: String,
|
||||||
|
/// The memory content.
|
||||||
|
pub content: String,
|
||||||
|
/// Category: knowledge or timeline.
|
||||||
|
pub category: MemoryCategory,
|
||||||
|
/// Importance score 0.0–1.0 (default 0.5).
|
||||||
|
pub importance: f64,
|
||||||
|
/// Associated session ID (optional).
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
/// RFC 3339 creation timestamp.
|
||||||
|
pub created_at: String,
|
||||||
|
/// RFC 3339 last update timestamp.
|
||||||
|
pub updated_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result from an LLM consolidation call.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ConsolidationResult {
|
||||||
|
/// New or updated knowledge entries extracted from conversation.
|
||||||
|
pub facts: Vec<ConsolidationFact>,
|
||||||
|
/// Summary entry for timeline (formatted as "[YYYY-MM-DD HH:MM] text...").
|
||||||
|
pub timeline: Option<String>,
|
||||||
|
/// Keys of existing memories that should be invalidated.
|
||||||
|
pub invalidations: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ConsolidationFact {
|
||||||
|
pub key: String,
|
||||||
|
pub content: String,
|
||||||
|
pub importance: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_memory_category_as_str() {
|
||||||
|
assert_eq!(MemoryCategory::Knowledge.as_str(), "knowledge");
|
||||||
|
assert_eq!(MemoryCategory::Timeline.as_str(), "timeline");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_memory_category_from_str() {
|
||||||
|
assert_eq!(
|
||||||
|
MemoryCategory::from_str("knowledge"),
|
||||||
|
Some(MemoryCategory::Knowledge)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
MemoryCategory::from_str("timeline"),
|
||||||
|
Some(MemoryCategory::Timeline)
|
||||||
|
);
|
||||||
|
assert_eq!(MemoryCategory::from_str("invalid"), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -55,6 +55,10 @@ pub struct Session {
|
|||||||
|
|
||||||
storage: Option<StdArc<Storage>>,
|
storage: Option<StdArc<Storage>>,
|
||||||
routing_info: String,
|
routing_info: String,
|
||||||
|
/// Timestamp (Unix ms) of the last consolidation.
|
||||||
|
/// Messages before this time have been compressed into memory.
|
||||||
|
pub last_consolidated_at: Option<i64>,
|
||||||
|
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@ -65,6 +69,7 @@ impl Session {
|
|||||||
storage: Option<StdArc<Storage>>,
|
storage: Option<StdArc<Storage>>,
|
||||||
routing_info: String,
|
routing_info: String,
|
||||||
title: String,
|
title: String,
|
||||||
|
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let mut provider_box = create_provider(provider_config.clone())
|
let mut provider_box = create_provider(provider_config.clone())
|
||||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||||
@ -78,6 +83,12 @@ impl Session {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config);
|
||||||
|
if let Some(ref mm) = memory_manager {
|
||||||
|
compressor = compressor.with_memory(mm.clone());
|
||||||
|
compressor.set_session_id(Some(id.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -92,9 +103,11 @@ impl Session {
|
|||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
provider: provider.clone(),
|
provider: provider.clone(),
|
||||||
tools,
|
tools,
|
||||||
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
|
compressor,
|
||||||
storage,
|
storage,
|
||||||
routing_info,
|
routing_info,
|
||||||
|
last_consolidated_at: None,
|
||||||
|
memory_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,6 +117,7 @@ impl Session {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
storage: StdArc<Storage>,
|
storage: StdArc<Storage>,
|
||||||
|
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let session_meta = storage.get_session(&id.to_string()).await
|
let session_meta = storage.get_session(&id.to_string()).await
|
||||||
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||||
@ -121,6 +135,12 @@ impl Session {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config);
|
||||||
|
if let Some(ref mm) = memory_manager {
|
||||||
|
compressor = compressor.with_memory(mm.clone());
|
||||||
|
compressor.set_session_id(Some(id.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
// Convert MessageMeta to ChatMessage
|
// Convert MessageMeta to ChatMessage
|
||||||
// Clear tool_call_id/tool_name — they're not valid across API sessions
|
// Clear tool_call_id/tool_name — they're not valid across API sessions
|
||||||
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
||||||
@ -152,9 +172,11 @@ impl Session {
|
|||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
provider: provider.clone(),
|
provider: provider.clone(),
|
||||||
tools,
|
tools,
|
||||||
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
|
compressor,
|
||||||
storage: Some(storage),
|
storage: Some(storage),
|
||||||
routing_info: session_meta.routing_info.unwrap_or_default(),
|
routing_info: session_meta.routing_info.unwrap_or_default(),
|
||||||
|
last_consolidated_at: session_meta.last_consolidated_at,
|
||||||
|
memory_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,6 +304,7 @@ impl Session {
|
|||||||
Some(self.routing_info.clone())
|
Some(self.routing_info.clone())
|
||||||
},
|
},
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: self.last_consolidated_at,
|
||||||
};
|
};
|
||||||
storage.upsert_session(&meta).await?;
|
storage.upsert_session(&meta).await?;
|
||||||
}
|
}
|
||||||
@ -372,13 +395,14 @@ impl Session {
|
|||||||
Ok(self.create_agent()?.with_notify(notify_tx))
|
Ok(self.create_agent()?.with_notify(notify_tx))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills)
|
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills + memory)
|
||||||
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
|
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
|
||||||
let base_prompt = build_system_prompt(
|
let base_prompt = build_system_prompt(
|
||||||
&self.provider_config.workspace_dir,
|
&self.provider_config.workspace_dir,
|
||||||
&self.provider_config.model_id,
|
&self.provider_config.model_id,
|
||||||
&self.tools,
|
&self.tools,
|
||||||
Some(&self.id.to_string()),
|
Some(&self.id.to_string()),
|
||||||
|
memory_context,
|
||||||
);
|
);
|
||||||
|
|
||||||
if skills_prompt.trim().is_empty() {
|
if skills_prompt.trim().is_empty() {
|
||||||
@ -560,6 +584,7 @@ pub struct SessionManager {
|
|||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
bus: Arc<MessageBus>,
|
bus: Arc<MessageBus>,
|
||||||
current_source_session: Arc<Mutex<Option<String>>>,
|
current_source_session: Arc<Mutex<Option<String>>>,
|
||||||
|
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SessionManagerInner {
|
struct SessionManagerInner {
|
||||||
@ -647,12 +672,13 @@ impl SessionManager {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
bus: Arc<MessageBus>,
|
bus: Arc<MessageBus>,
|
||||||
|
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let skills_loader = SkillsLoader::new();
|
let skills_loader = SkillsLoader::new();
|
||||||
skills_loader.load_skills();
|
skills_loader.load_skills();
|
||||||
let skills_loader = Arc::new(skills_loader);
|
let skills_loader = Arc::new(skills_loader);
|
||||||
|
|
||||||
let tools = Arc::new(create_default_tools(skills_loader.clone()));
|
let tools = Arc::new(create_default_tools(skills_loader.clone(), memory_manager.clone()));
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||||||
@ -667,6 +693,7 @@ impl SessionManager {
|
|||||||
storage,
|
storage,
|
||||||
bus,
|
bus,
|
||||||
current_source_session: Arc::new(Mutex::new(None)),
|
current_source_session: Arc::new(Mutex::new(None)),
|
||||||
|
memory_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -800,7 +827,7 @@ impl SessionManager {
|
|||||||
|
|
||||||
// Build the same system prompt that would be injected to the model
|
// Build the same system prompt that would be injected to the model
|
||||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||||
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
|
let system_prompt = session_guard.build_system_prompt(&skills_prompt, None);
|
||||||
|
|
||||||
let filepath = session_guard.dump_to_file(&system_prompt)
|
let filepath = session_guard.dump_to_file(&system_prompt)
|
||||||
.map_err(|e| AgentError::Other(format!("Failed to save dump: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("Failed to save dump: {}", e)))?;
|
||||||
@ -879,6 +906,7 @@ impl SessionManager {
|
|||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: if routing_info.is_empty() { None } else { Some(routing_info.clone()) },
|
routing_info: if routing_info.is_empty() { None } else { Some(routing_info.clone()) },
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
self.storage.upsert_session(&meta).await
|
self.storage.upsert_session(&meta).await
|
||||||
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
|
||||||
@ -890,6 +918,7 @@ impl SessionManager {
|
|||||||
Some(self.storage.clone()),
|
Some(self.storage.clone()),
|
||||||
routing_info,
|
routing_info,
|
||||||
title.clone(),
|
title.clone(),
|
||||||
|
self.memory_manager.clone(),
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
let arc = Arc::new(Mutex::new(session));
|
let arc = Arc::new(Mutex::new(session));
|
||||||
@ -921,6 +950,7 @@ impl SessionManager {
|
|||||||
self.provider_config.clone(),
|
self.provider_config.clone(),
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
self.storage.clone(),
|
self.storage.clone(),
|
||||||
|
self.memory_manager.clone(),
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
let arc = Arc::new(Mutex::new(session));
|
let arc = Arc::new(Mutex::new(session));
|
||||||
@ -944,6 +974,7 @@ impl SessionManager {
|
|||||||
Some(self.storage.clone()),
|
Some(self.storage.clone()),
|
||||||
String::new(),
|
String::new(),
|
||||||
format!("新对话"),
|
format!("新对话"),
|
||||||
|
self.memory_manager.clone(),
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
let arc = Arc::new(Mutex::new(session));
|
let arc = Arc::new(Mutex::new(session));
|
||||||
@ -1220,9 +1251,28 @@ impl SessionManager {
|
|||||||
// Build skills prompt
|
// Build skills prompt
|
||||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||||
|
|
||||||
|
// Fetch memory context
|
||||||
|
let memory_context = if let Some(ref mm) = self.memory_manager {
|
||||||
|
match mm.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
|
||||||
|
Ok(entries) if !entries.is_empty() => {
|
||||||
|
Some(entries.iter()
|
||||||
|
.map(|e| format!("- {}: {}", e.key, e.content))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n"))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Failed to fetch memory context");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// Build combined system prompt and inject at position 0
|
// Build combined system prompt and inject at position 0
|
||||||
// This ensures AgentLoop.process() sees a system message and doesn't inject its own
|
// This ensures AgentLoop.process() sees a system message and doesn't inject its own
|
||||||
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
|
let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref());
|
||||||
history.insert(0, ChatMessage::system(system_prompt));
|
history.insert(0, ChatMessage::system(system_prompt));
|
||||||
|
|
||||||
let history = session_guard.compressor
|
let history = session_guard.compressor
|
||||||
@ -1324,7 +1374,7 @@ impl SessionManager {
|
|||||||
let mut history = session_guard.get_history().to_vec();
|
let mut history = session_guard.get_history().to_vec();
|
||||||
|
|
||||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||||
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
|
let system_prompt = session_guard.build_system_prompt(&skills_prompt, None);
|
||||||
let cron_context = format!(
|
let cron_context = format!(
|
||||||
"\n\n## 定时任务执行\n\n\
|
"\n\n## 定时任务执行\n\n\
|
||||||
你正在执行定时任务「{}」({})。\n\
|
你正在执行定时任务「{}」({})。\n\
|
||||||
|
|||||||
183
src/storage/memory.rs
Normal file
183
src/storage/memory.rs
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
use sqlx::Row;
|
||||||
|
|
||||||
|
use crate::memory::{MemoryCategory, MemoryEntry};
|
||||||
|
|
||||||
|
use super::StorageError;
|
||||||
|
|
||||||
|
impl super::Storage {
|
||||||
|
/// Store or update a memory entry (upsert by key).
|
||||||
|
pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> {
|
||||||
|
let category_str = entry.category.as_str();
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO memories (id, key, content, category, importance, session_id, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(key) DO UPDATE SET
|
||||||
|
content = excluded.content,
|
||||||
|
category = excluded.category,
|
||||||
|
importance = excluded.importance,
|
||||||
|
session_id = excluded.session_id,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&entry.id)
|
||||||
|
.bind(&entry.key)
|
||||||
|
.bind(&entry.content)
|
||||||
|
.bind(category_str)
|
||||||
|
.bind(entry.importance)
|
||||||
|
.bind(&entry.session_id)
|
||||||
|
.bind(&entry.created_at)
|
||||||
|
.bind(&entry.updated_at)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a memory entry by key.
|
||||||
|
pub async fn delete_memory(&self, key: &str) -> Result<(), StorageError> {
|
||||||
|
sqlx::query("DELETE FROM memories WHERE key = ?")
|
||||||
|
.bind(key)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search memories by keyword using FTS5.
|
||||||
|
/// Falls back to LIKE query if FTS5 returns no results.
|
||||||
|
pub async fn search_memories(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
|
// Build FTS5 query: wrap each word in quotes and join with OR
|
||||||
|
let fts_query = query
|
||||||
|
.split_whitespace()
|
||||||
|
.map(|w| format!("\"{}\"", w.replace('"', "")))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" OR ");
|
||||||
|
|
||||||
|
let category_filter = category.map(|c| c.as_str());
|
||||||
|
|
||||||
|
// Try FTS5 first
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT m.id, m.key, m.content, m.category, m.importance,
|
||||||
|
m.session_id, m.created_at, m.updated_at
|
||||||
|
FROM memory_fts f
|
||||||
|
JOIN memories m ON f.rowid = m.rowid
|
||||||
|
WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?)
|
||||||
|
ORDER BY rank
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&fts_query)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut entries = parse_memory_rows(&rows)?;
|
||||||
|
|
||||||
|
// Fallback to LIKE if FTS5 returned nothing
|
||||||
|
if entries.is_empty() {
|
||||||
|
let like_pattern = format!("%{}%", query.replace('%', "").replace('_', ""));
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, key, content, category, importance,
|
||||||
|
session_id, created_at, updated_at
|
||||||
|
FROM memories
|
||||||
|
WHERE (key LIKE ? OR content LIKE ?)
|
||||||
|
AND (? IS NULL OR category = ?)
|
||||||
|
ORDER BY importance DESC, updated_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&like_pattern)
|
||||||
|
.bind(&like_pattern)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
entries = parse_memory_rows(&rows)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve memories within a time range.
|
||||||
|
pub async fn search_memories_by_time(
|
||||||
|
&self,
|
||||||
|
since: i64,
|
||||||
|
until: i64,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
|
let category_filter = category.map(|c| c.as_str());
|
||||||
|
let since_dt = chrono::DateTime::from_timestamp_millis(since)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_rfc3339();
|
||||||
|
let until_dt = chrono::DateTime::from_timestamp_millis(until)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_rfc3339();
|
||||||
|
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, key, content, category, importance,
|
||||||
|
session_id, created_at, updated_at
|
||||||
|
FROM memories
|
||||||
|
WHERE created_at >= ? AND created_at <= ?
|
||||||
|
AND (? IS NULL OR category = ?)
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&since_dt)
|
||||||
|
.bind(&until_dt)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
parse_memory_rows(&rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete old timeline entries beyond retention period.
|
||||||
|
pub async fn cleanup_old_timelines(&self, retention_days: u64) -> Result<u64, StorageError> {
|
||||||
|
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
||||||
|
let cutoff_str = cutoff.to_rfc3339();
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
|
||||||
|
)
|
||||||
|
.bind(&cutoff_str)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(result.rows_affected())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_memory_rows(
|
||||||
|
rows: &[sqlx::sqlite::SqliteRow],
|
||||||
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
|
rows.iter()
|
||||||
|
.map(|row| {
|
||||||
|
Ok(MemoryEntry {
|
||||||
|
id: row.try_get("id")?,
|
||||||
|
key: row.try_get("key")?,
|
||||||
|
content: row.try_get("content")?,
|
||||||
|
category: MemoryCategory::from_str(&row.try_get::<String, _>("category")?)
|
||||||
|
.unwrap_or(MemoryCategory::Knowledge),
|
||||||
|
importance: row.try_get::<f64, _>("importance")?,
|
||||||
|
session_id: row.try_get::<Option<String>, _>("session_id")?,
|
||||||
|
created_at: row.try_get("created_at")?,
|
||||||
|
updated_at: row.try_get("updated_at")?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
@ -1,4 +1,5 @@
|
|||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod memory;
|
||||||
pub mod message;
|
pub mod message;
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
@ -40,6 +41,7 @@ impl Storage {
|
|||||||
message_count INTEGER DEFAULT 0,
|
message_count INTEGER DEFAULT 0,
|
||||||
routing_info TEXT,
|
routing_info TEXT,
|
||||||
deleted_at INTEGER,
|
deleted_at INTEGER,
|
||||||
|
last_consolidated_at INTEGER,
|
||||||
UNIQUE(channel, chat_id, dialog_id)
|
UNIQUE(channel, chat_id, dialog_id)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
@ -94,6 +96,47 @@ impl Storage {
|
|||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
key TEXT NOT NULL UNIQUE,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL DEFAULT 'knowledge',
|
||||||
|
importance REAL NOT NULL DEFAULT 0.5,
|
||||||
|
session_id TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// FTS5 virtual table for full-text search on memories
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(
|
||||||
|
key,
|
||||||
|
content,
|
||||||
|
content=memories,
|
||||||
|
content_rowid=rowid
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Migration: add last_consolidated_at column if not exists
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
ALTER TABLE sessions ADD COLUMN last_consolidated_at INTEGER
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE IF NOT EXISTS llm_calls (
|
CREATE TABLE IF NOT EXISTS llm_calls (
|
||||||
@ -229,14 +272,15 @@ impl Storage {
|
|||||||
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
|
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at)
|
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT(id) DO UPDATE SET
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
title = excluded.title,
|
title = excluded.title,
|
||||||
last_active_at = excluded.last_active_at,
|
last_active_at = excluded.last_active_at,
|
||||||
message_count = excluded.message_count,
|
message_count = excluded.message_count,
|
||||||
routing_info = excluded.routing_info,
|
routing_info = excluded.routing_info,
|
||||||
deleted_at = excluded.deleted_at
|
deleted_at = excluded.deleted_at,
|
||||||
|
last_consolidated_at = excluded.last_consolidated_at
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(&meta.id)
|
.bind(&meta.id)
|
||||||
@ -249,6 +293,7 @@ impl Storage {
|
|||||||
.bind(meta.message_count)
|
.bind(meta.message_count)
|
||||||
.bind(&meta.routing_info)
|
.bind(&meta.routing_info)
|
||||||
.bind(meta.deleted_at)
|
.bind(meta.deleted_at)
|
||||||
|
.bind(meta.last_consolidated_at)
|
||||||
.execute(self.pool())
|
.execute(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -258,7 +303,7 @@ impl Storage {
|
|||||||
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
||||||
let row = sqlx::query(
|
let row = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||||
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
@ -278,6 +323,7 @@ impl Storage {
|
|||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,7 +335,7 @@ impl Storage {
|
|||||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
let rows = sqlx::query(
|
let rows = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
@ -315,6 +361,7 @@ impl Storage {
|
|||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
@ -362,7 +409,7 @@ impl Storage {
|
|||||||
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
|
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
|
||||||
let row = sqlx::query(
|
let row = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
@ -387,6 +434,7 @@ impl Storage {
|
|||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
})),
|
})),
|
||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
}
|
}
|
||||||
@ -471,7 +519,7 @@ impl Storage {
|
|||||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
let rows = sqlx::query(
|
let rows = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
@ -495,6 +543,7 @@ impl Storage {
|
|||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
@ -599,6 +648,7 @@ mod tests {
|
|||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
@ -633,6 +683,7 @@ mod tests {
|
|||||||
message_count: i,
|
message_count: i,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
}
|
}
|
||||||
@ -658,6 +709,7 @@ mod tests {
|
|||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
@ -683,6 +735,7 @@ mod tests {
|
|||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
storage.upsert_session(&session_meta).await.unwrap();
|
storage.upsert_session(&session_meta).await.unwrap();
|
||||||
|
|
||||||
@ -723,6 +776,7 @@ mod tests {
|
|||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
|||||||
@ -12,4 +12,5 @@ pub struct SessionMeta {
|
|||||||
pub message_count: i64,
|
pub message_count: i64,
|
||||||
pub routing_info: Option<String>,
|
pub routing_info: Option<String>,
|
||||||
pub deleted_at: Option<i64>,
|
pub deleted_at: Option<i64>,
|
||||||
|
pub last_consolidated_at: Option<i64>,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -263,6 +263,7 @@ mod tests {
|
|||||||
message_count: i * 5,
|
message_count: i * 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
}
|
}
|
||||||
@ -296,6 +297,7 @@ mod tests {
|
|||||||
message_count: 3,
|
message_count: 3,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
};
|
};
|
||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
|||||||
60
src/tools/memory_forget.rs
Normal file
60
src/tools/memory_forget.rs
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
use super::traits::{Tool, ToolResult};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::memory::MemoryManager;
|
||||||
|
|
||||||
|
pub struct MemoryForgetTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryForgetTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MemoryForgetTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory_forget"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Delete a memory entry by its key. Use this when information is outdated, \
|
||||||
|
incorrect, or the user asks to forget something."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The key of the memory entry to delete."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["key"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let key = args
|
||||||
|
.get("key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||||
|
|
||||||
|
self.memory.forget(key).await?;
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Memory deleted: {}", key),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
118
src/tools/memory_recall.rs
Normal file
118
src/tools/memory_recall.rs
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
use super::traits::{Tool, ToolResult};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::memory::{MemoryCategory, MemoryManager};
|
||||||
|
|
||||||
|
pub struct MemoryRecallTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryRecallTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MemoryRecallTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory_recall"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Search and retrieve entries from long-term memory. \
|
||||||
|
Use this to recall previously stored facts, preferences, or conversation history. \
|
||||||
|
Supports keyword search and optional time-range filtering."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query — keywords to match against memory keys and content."
|
||||||
|
},
|
||||||
|
"category": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["knowledge", "timeline"],
|
||||||
|
"description": "Filter by memory category. Omit to search all categories."
|
||||||
|
},
|
||||||
|
"since": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Start of time range (Unix milliseconds)."
|
||||||
|
},
|
||||||
|
"until": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "End of time range (Unix milliseconds)."
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Max results to return (default 10)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let query = args
|
||||||
|
.get("query")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
||||||
|
|
||||||
|
let category = args
|
||||||
|
.get("category")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(MemoryCategory::from_str);
|
||||||
|
|
||||||
|
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
||||||
|
|
||||||
|
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
||||||
|
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
|
||||||
|
let until = args
|
||||||
|
.get("until")
|
||||||
|
.and_then(|v| v.as_i64())
|
||||||
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
|
self.memory
|
||||||
|
.recall_by_time(since, until, limit, category)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
self.memory.recall(query, limit, category).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if entries.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "No matching memories found.".to_string(),
|
||||||
|
error: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let formatted = entries
|
||||||
|
.iter()
|
||||||
|
.map(|e| {
|
||||||
|
format!(
|
||||||
|
"- {} [{}] [importance: {:.1}]: {}",
|
||||||
|
e.key,
|
||||||
|
e.category.as_str(),
|
||||||
|
e.importance,
|
||||||
|
e.content
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Found {} memories:\n{}", entries.len(), formatted),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
90
src/tools/memory_store.rs
Normal file
90
src/tools/memory_store.rs
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
use super::traits::{Tool, ToolResult};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::memory::{MemoryCategory, MemoryManager};
|
||||||
|
|
||||||
|
pub struct MemoryStoreTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryStoreTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MemoryStoreTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory_store"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Store a fact, preference, or insight into long-term memory. \
|
||||||
|
Use this when the user shares important information you should remember. \
|
||||||
|
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
||||||
|
and the full content to remember."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The full content of the memory entry."
|
||||||
|
},
|
||||||
|
"category": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["knowledge", "timeline"],
|
||||||
|
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
|
||||||
|
},
|
||||||
|
"importance": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["key", "content"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let key = args
|
||||||
|
.get("key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||||
|
|
||||||
|
let content = args
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
||||||
|
|
||||||
|
let category = args
|
||||||
|
.get("category")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(MemoryCategory::from_str)
|
||||||
|
.unwrap_or(MemoryCategory::Knowledge);
|
||||||
|
|
||||||
|
let importance = args.get("importance").and_then(|v| v.as_f64());
|
||||||
|
|
||||||
|
self.memory
|
||||||
|
.store(key, content, category, None, importance)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Memory stored: {}", key),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -7,6 +7,9 @@ pub mod file_read;
|
|||||||
pub mod file_write;
|
pub mod file_write;
|
||||||
pub mod get_skill;
|
pub mod get_skill;
|
||||||
pub mod http_request;
|
pub mod http_request;
|
||||||
|
pub mod memory_forget;
|
||||||
|
pub mod memory_recall;
|
||||||
|
pub mod memory_store;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod send_message;
|
pub mod send_message;
|
||||||
@ -21,6 +24,9 @@ pub use file_read::FileReadTool;
|
|||||||
pub use file_write::FileWriteTool;
|
pub use file_write::FileWriteTool;
|
||||||
pub use get_skill::GetSkillTool;
|
pub use get_skill::GetSkillTool;
|
||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
|
pub use memory_forget::MemoryForgetTool;
|
||||||
|
pub use memory_recall::MemoryRecallTool;
|
||||||
|
pub use memory_store::MemoryStoreTool;
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use send_message::SendMessageTool;
|
pub use send_message::SendMessageTool;
|
||||||
@ -28,12 +34,16 @@ pub use traits::{OutboundMessenger, Tool, ToolResult};
|
|||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use crate::memory::MemoryManager;
|
||||||
use crate::skills::SkillsLoader;
|
use crate::skills::SkillsLoader;
|
||||||
|
|
||||||
/// Create the base tool registry (without send_message).
|
/// Create the base tool registry (without send_message).
|
||||||
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
||||||
/// once the available channel names are known.
|
/// once the available channel names are known.
|
||||||
pub fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
|
pub fn create_default_tools(
|
||||||
|
skills_loader: Arc<SkillsLoader>,
|
||||||
|
memory: Option<Arc<MemoryManager>>,
|
||||||
|
) -> ToolRegistry {
|
||||||
let registry = ToolRegistry::new();
|
let registry = ToolRegistry::new();
|
||||||
registry.register(CalculatorTool::new());
|
registry.register(CalculatorTool::new());
|
||||||
registry.register(FileReadTool::new());
|
registry.register(FileReadTool::new());
|
||||||
@ -48,5 +58,13 @@ pub fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
|
|||||||
));
|
));
|
||||||
registry.register(WebFetchTool::new(50_000, 30));
|
registry.register(WebFetchTool::new(50_000, 30));
|
||||||
registry.register(GetSkillTool::new(skills_loader));
|
registry.register(GetSkillTool::new(skills_loader));
|
||||||
|
|
||||||
|
// Register memory tools if memory system is available
|
||||||
|
if let Some(mm) = memory {
|
||||||
|
registry.register(MemoryStoreTool::new(mm.clone()));
|
||||||
|
registry.register(MemoryRecallTool::new(mm.clone()));
|
||||||
|
registry.register(MemoryForgetTool::new(mm.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user