PicoBot/src/storage/memory.rs
xiaoxixi cb1140e9be refactor(memory): Timeline 按 session 隔离,拆分知识/摘要检索工具
- storage/memory: search_memories 和 search_memories_by_time 增加 session_id 过滤参数
- memory/manager: recall/recall_by_time 透传 session_id
- tools: MemoryStoreTool/MemoryRecallTool 锁定 Knowledge 类别,移除 category 参数
- tools: 新增 TimelineRecallTool 用于检索会话摘要,支持可选 session_id 过滤
- tools: 输出格式化增加 session 信息显示
- tests: 新增 test_session_id_filter 验证会话级过滤
2026-05-10 13:35:21 +08:00

274 lines
9.1 KiB
Rust

use sqlx::Row;
use std::sync::OnceLock;
use jieba_rs::Jieba;
use crate::memory::{MemoryCategory, MemoryEntry};
use super::StorageError;
fn jieba() -> &'static Jieba {
static INSTANCE: OnceLock<Jieba> = OnceLock::new();
INSTANCE.get_or_init(Jieba::new)
}
impl super::Storage {
/// Store or update a memory entry (upsert by key).
pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> {
let category_str = entry.category.as_str();
sqlx::query(
r#"
INSERT INTO memories (id, key, content, category, importance, session_id, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
importance = excluded.importance,
session_id = excluded.session_id,
updated_at = excluded.updated_at
"#,
)
.bind(&entry.id)
.bind(&entry.key)
.bind(&entry.content)
.bind(category_str)
.bind(entry.importance)
.bind(&entry.session_id)
.bind(&entry.created_at)
.bind(&entry.updated_at)
.execute(self.pool())
.await?;
Ok(())
}
/// Delete a memory entry by key.
pub async fn delete_memory(&self, key: &str) -> Result<(), StorageError> {
sqlx::query("DELETE FROM memories WHERE key = ?")
.bind(key)
.execute(self.pool())
.await?;
Ok(())
}
/// Search memories by keyword using FTS5.
/// Falls back to LIKE query if FTS5 returns no results.
pub async fn search_memories(
&self,
query: &str,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryEntry>, StorageError> {
// Build FTS5 query: segment with jieba, wrap each term in quotes, join with OR
let fts_query = jieba()
.cut(query, true)
.into_iter()
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
.map(|w| format!("\"{}\"", w.replace('"', "")))
.collect::<Vec<_>>()
.join(" OR ");
let category_filter = category.map(|c| c.as_str());
// Try FTS5 first
let rows = sqlx::query(
r#"
SELECT m.id, m.key, m.content, m.category, m.importance,
m.session_id, m.created_at, m.updated_at
FROM memory_fts f
JOIN memories m ON f.rowid = m.rowid
WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?) AND (? IS NULL OR m.session_id = ?)
ORDER BY rank
LIMIT ?
"#,
)
.bind(&fts_query)
.bind(category_filter)
.bind(category_filter)
.bind(session_id)
.bind(session_id)
.bind(limit as i64)
.fetch_all(self.pool())
.await?;
let mut entries = parse_memory_rows(&rows)?;
// Fallback to term-based LIKE query if FTS5 returned nothing
if entries.is_empty() {
let terms: Vec<String> = jieba()
.cut(query, true)
.into_iter()
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
.map(|w| w.replace(['%', '_'], ""))
.collect();
if !terms.is_empty() {
let like_clauses = terms
.iter()
.map(|_| "(key LIKE ? OR content LIKE ?)")
.collect::<Vec<_>>()
.join(" OR ");
let sql = format!(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE ({})
AND (? IS NULL OR category = ?)
AND (? IS NULL OR session_id = ?)
ORDER BY importance DESC, updated_at DESC
LIMIT ?
"#,
like_clauses
);
let mut query_builder = sqlx::query(&sql);
for term in &terms {
let pattern = format!("%{}%", term);
query_builder = query_builder.bind(pattern.clone()).bind(pattern);
}
query_builder = query_builder
.bind(category_filter)
.bind(category_filter)
.bind(session_id)
.bind(session_id)
.bind(limit as i64);
let rows = query_builder.fetch_all(self.pool()).await?;
entries = parse_memory_rows(&rows)?;
}
}
Ok(entries)
}
/// Retrieve memories within a time range, optionally filtered by keyword query.
pub async fn search_memories_by_time(
&self,
since: i64,
until: i64,
query: Option<&str>,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryEntry>, StorageError> {
let category_filter = category.map(|c| c.as_str());
let since_dt = chrono::DateTime::from_timestamp_millis(since)
.unwrap_or_default()
.to_rfc3339();
let until_dt = chrono::DateTime::from_timestamp_millis(until)
.unwrap_or_default()
.to_rfc3339();
let rows = if let Some(q) = query {
let terms: Vec<String> = jieba()
.cut(q, true)
.into_iter()
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
.map(|w| w.replace(['%', '_'], ""))
.collect();
if terms.is_empty() {
return Ok(Vec::new());
}
let like_clauses = terms
.iter()
.map(|_| "(key LIKE ? OR content LIKE ?)")
.collect::<Vec<_>>()
.join(" OR ");
let sql = format!(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE ({})
AND created_at >= ? AND created_at <= ?
AND (? IS NULL OR category = ?)
AND (? IS NULL OR session_id = ?)
ORDER BY created_at DESC
LIMIT ?
"#,
like_clauses
);
let mut query_builder = sqlx::query(&sql);
for term in &terms {
let pattern = format!("%{}%", term);
query_builder = query_builder.bind(pattern.clone()).bind(pattern);
}
query_builder = query_builder
.bind(&since_dt)
.bind(&until_dt)
.bind(category_filter)
.bind(category_filter)
.bind(session_id)
.bind(session_id)
.bind(limit as i64);
query_builder.fetch_all(self.pool()).await?
} else {
sqlx::query(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE created_at >= ? AND created_at <= ?
AND (? IS NULL OR category = ?)
AND (? IS NULL OR session_id = ?)
ORDER BY created_at DESC
LIMIT ?
"#,
)
.bind(&since_dt)
.bind(&until_dt)
.bind(category_filter)
.bind(category_filter)
.bind(session_id)
.bind(session_id)
.bind(limit as i64)
.fetch_all(self.pool())
.await?
};
parse_memory_rows(&rows)
}
/// Delete old timeline entries beyond retention period.
pub async fn cleanup_old_timelines(&self, retention_days: u64) -> Result<u64, StorageError> {
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
let cutoff_str = cutoff.to_rfc3339();
let result = sqlx::query(
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
)
.bind(&cutoff_str)
.execute(self.pool())
.await?;
Ok(result.rows_affected())
}
}
fn parse_memory_rows(
rows: &[sqlx::sqlite::SqliteRow],
) -> Result<Vec<MemoryEntry>, StorageError> {
rows.iter()
.map(|row| {
Ok(MemoryEntry {
id: row.try_get("id")?,
key: row.try_get("key")?,
content: row.try_get("content")?,
category: MemoryCategory::from_str(&row.try_get::<String, _>("category")?)
.unwrap_or(MemoryCategory::Knowledge),
importance: row.try_get::<f64, _>("importance")?,
session_id: row.try_get::<Option<String>, _>("session_id")?,
created_at: row.try_get("created_at")?,
updated_at: row.try_get("updated_at")?,
})
})
.collect()
}