PicoBot/src/tools/memory.rs

263 lines
8.0 KiB
Rust

use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::memory::{MemoryCategory, MemoryManager};
// ── MemoryStoreTool ──────────────────────────────────────────────────
pub struct MemoryStoreTool {
memory: Arc<MemoryManager>,
}
impl MemoryStoreTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryStoreTool {
fn name(&self) -> &str {
"memory_store"
}
fn description(&self) -> &str {
"Store a fact, preference, or insight into long-term memory. \
Use this when the user shares important information you should remember. \
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
and the full content to remember."
}
fn read_only(&self) -> bool {
false
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
},
"content": {
"type": "string",
"description": "The full content of the memory entry."
},
"category": {
"type": "string",
"enum": ["knowledge", "timeline"],
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
},
"importance": {
"type": "number",
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
}
},
"required": ["key", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
let category = args
.get("category")
.and_then(|v| v.as_str())
.and_then(MemoryCategory::from_str)
.unwrap_or(MemoryCategory::Knowledge);
let importance = args.get("importance").and_then(|v| v.as_f64());
self.memory
.store(key, content, category, None, importance)
.await?;
Ok(ToolResult {
success: true,
output: format!("Memory stored: {}", key),
error: None,
})
}
}
// ── MemoryRecallTool ─────────────────────────────────────────────────
pub struct MemoryRecallTool {
memory: Arc<MemoryManager>,
}
impl MemoryRecallTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryRecallTool {
fn name(&self) -> &str {
"memory_recall"
}
fn description(&self) -> &str {
"Search and retrieve entries from long-term memory. \
Use this to recall previously stored facts, preferences, or conversation history. \
Supports keyword search and optional time-range filtering."
}
fn read_only(&self) -> bool {
true
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query — keywords to match against memory keys and content."
},
"category": {
"type": "string",
"enum": ["knowledge", "timeline"],
"description": "Filter by memory category. Omit to search all categories."
},
"since": {
"type": "integer",
"description": "Start of time range (Unix milliseconds)."
},
"until": {
"type": "integer",
"description": "End of time range (Unix milliseconds)."
},
"limit": {
"type": "integer",
"description": "Max results to return (default 10)."
}
},
"required": ["query"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
let category = args
.get("category")
.and_then(|v| v.as_str())
.and_then(MemoryCategory::from_str);
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
let entries = if args.get("since").is_some() || args.get("until").is_some() {
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
let until = args
.get("until")
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(since, until, limit, category)
.await?
} else {
self.memory.recall(query, limit, category).await?
};
if entries.is_empty() {
return Ok(ToolResult {
success: true,
output: "No matching memories found.".to_string(),
error: None,
});
}
let formatted = entries
.iter()
.map(|e| {
format!(
"- {} [{}] [importance: {:.1}]: {}",
e.key,
e.category.as_str(),
e.importance,
e.content
)
})
.collect::<Vec<_>>()
.join("\n");
Ok(ToolResult {
success: true,
output: format!("Found {} memories:\n{}", entries.len(), formatted),
error: None,
})
}
}
// ── MemoryForgetTool ─────────────────────────────────────────────────
pub struct MemoryForgetTool {
memory: Arc<MemoryManager>,
}
impl MemoryForgetTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryForgetTool {
fn name(&self) -> &str {
"memory_forget"
}
fn description(&self) -> &str {
"Delete a memory entry by its key. Use this when information is outdated, \
incorrect, or the user asks to forget something."
}
fn read_only(&self) -> bool {
false
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "The key of the memory entry to delete."
}
},
"required": ["key"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
self.memory.forget(key).await?;
Ok(ToolResult {
success: true,
output: format!("Memory deleted: {}", key),
error: None,
})
}
}