记忆工具合并文件,改成记忆系统总是开启。
This commit is contained in:
parent
c602a0695d
commit
2617558a27
@ -63,5 +63,14 @@
|
||||
"reaction_emoji": "Typing"
|
||||
}
|
||||
},
|
||||
"memory": {
|
||||
"consolidation_provider": null,
|
||||
"consolidation_model": null,
|
||||
"recall_limit": 5,
|
||||
"consolidation_turn_threshold": 10,
|
||||
"idle_consolidation_minutes": 10,
|
||||
"timeline_retention_days": 90,
|
||||
"max_failures_before_degrade": 3
|
||||
},
|
||||
"workspace_dir": "~/.picobot/workspace"
|
||||
}
|
||||
|
||||
@ -50,22 +50,26 @@ pub struct ContextCompressor {
|
||||
threshold_ratio: f64,
|
||||
/// Shared LLM provider for summarization
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
/// Memory manager handle (optional). When set, compressed
|
||||
/// context summaries are persisted as timeline memory entries.
|
||||
memory: Option<Arc<MemoryManager>>,
|
||||
/// Memory manager handle. Compressed context summaries are persisted
|
||||
/// as timeline memory entries.
|
||||
memory: Arc<MemoryManager>,
|
||||
/// Current session ID for timeline memory writes.
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
/// Create a new compressor with the given provider and context window size.
|
||||
pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self {
|
||||
/// Create a new compressor with the given provider, context window size, and memory manager.
|
||||
pub fn new(
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
context_window: usize,
|
||||
memory: Arc<MemoryManager>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config: ContextCompressionConfig::default(),
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
provider,
|
||||
memory: None,
|
||||
memory,
|
||||
session_id: None,
|
||||
}
|
||||
}
|
||||
@ -75,23 +79,18 @@ impl ContextCompressor {
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
context_window: usize,
|
||||
config: ContextCompressionConfig,
|
||||
memory: Arc<MemoryManager>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
provider,
|
||||
memory: None,
|
||||
memory,
|
||||
session_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach a memory manager to persist compressed summaries.
|
||||
pub fn with_memory(mut self, memory: Arc<MemoryManager>) -> Self {
|
||||
self.memory = Some(memory);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the current session ID for timeline writes.
|
||||
pub fn set_session_id(&mut self, id: Option<String>) {
|
||||
self.session_id = id;
|
||||
@ -240,25 +239,23 @@ impl ContextCompressor {
|
||||
let summary = self.summarize_segment(between).await?;
|
||||
|
||||
// Persist compressed summary as timeline memory entry
|
||||
if let Some(ref mm) = self.memory {
|
||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||
ts, between.len(), summary);
|
||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||
let mm_clone = mm.clone();
|
||||
let sid = self.session_id.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = mm_clone.store(
|
||||
&key,
|
||||
&timeline_content,
|
||||
crate::memory::MemoryCategory::Timeline,
|
||||
sid.as_deref(),
|
||||
Some(0.3),
|
||||
).await {
|
||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||
}
|
||||
});
|
||||
}
|
||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||
ts, between.len(), summary);
|
||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||
let mm = self.memory.clone();
|
||||
let sid = self.session_id.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = mm.store(
|
||||
&key,
|
||||
&timeline_content,
|
||||
crate::memory::MemoryCategory::Timeline,
|
||||
sid.as_deref(),
|
||||
Some(0.3),
|
||||
).await {
|
||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||
}
|
||||
});
|
||||
|
||||
// Add summary as a special user message
|
||||
new_messages.push(ChatMessage::user(format!(
|
||||
@ -370,6 +367,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::providers::ChatCompletionResponse;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// Mock provider for testing - panics if actually used for LLM calls
|
||||
struct MockProvider;
|
||||
@ -400,6 +398,18 @@ mod tests {
|
||||
Arc::new(MockProvider)
|
||||
}
|
||||
|
||||
fn test_memory_manager() -> Arc<MemoryManager> {
|
||||
static MM: OnceLock<Arc<MemoryManager>> = OnceLock::new();
|
||||
MM.get_or_init(|| {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
rt.block_on(async {
|
||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
||||
})
|
||||
}).clone()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens() {
|
||||
let messages = vec![
|
||||
@ -422,7 +432,7 @@ mod tests {
|
||||
tool_result_trim_chars: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config);
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
@ -436,7 +446,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_threshold() {
|
||||
let compressor = ContextCompressor::new(mock_provider(), 128_000);
|
||||
let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager());
|
||||
assert_eq!(compressor.threshold(), 64_000);
|
||||
}
|
||||
}
|
||||
|
||||
@ -220,15 +220,14 @@ impl Default for ClientConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MemoryConfig {
|
||||
/// Master switch for the memory system.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Provider name for consolidation LLM calls (key in `providers`).
|
||||
/// If not set, falls back to the main agent's provider.
|
||||
#[serde(default)]
|
||||
pub consolidation_provider: Option<String>,
|
||||
/// Model name for consolidation LLM calls (key in `models`).
|
||||
/// If not set, falls back to the main agent's model.
|
||||
#[serde(default)]
|
||||
pub consolidation_model: Option<String>,
|
||||
/// Max knowledge entries injected into system prompt per turn.
|
||||
@ -248,8 +247,34 @@ pub struct MemoryConfig {
|
||||
pub max_failures_before_degrade: usize,
|
||||
}
|
||||
|
||||
impl Default for MemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
consolidation_provider: None,
|
||||
consolidation_model: None,
|
||||
recall_limit: 5,
|
||||
consolidation_turn_threshold: 10,
|
||||
idle_consolidation_minutes: 10,
|
||||
timeline_retention_days: 90,
|
||||
max_failures_before_degrade: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryConfig {
|
||||
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
||||
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
||||
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
|
||||
}
|
||||
|
||||
/// Resolve consolidation model name, falling back to the main agent's model.
|
||||
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
||||
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_recall_limit() -> usize { 5 }
|
||||
fn default_consolidation_turn_threshold() -> usize { 3 }
|
||||
fn default_consolidation_turn_threshold() -> usize { 10 }
|
||||
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
||||
fn default_timeline_retention_days() -> u64 { 90 }
|
||||
fn default_max_failures_before_degrade() -> usize { 3 }
|
||||
|
||||
@ -56,14 +56,23 @@ impl GatewayState {
|
||||
);
|
||||
tracing::info!("Session storage: {}", db_path.display());
|
||||
|
||||
// Initialize MemoryManager if memory system is enabled
|
||||
let memory_manager = if config.memory.enabled {
|
||||
let mm = Arc::new(MemoryManager::new(storage.clone()));
|
||||
tracing::info!("Memory system enabled");
|
||||
Some(mm)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Resolve consolidation provider/model with fallback to main agent config
|
||||
let consolidation_provider = config
|
||||
.memory
|
||||
.resolve_consolidation_provider(&provider_config.name);
|
||||
let consolidation_model = config
|
||||
.memory
|
||||
.resolve_consolidation_model(&provider_config.model_id);
|
||||
let memory_manager = Arc::new(MemoryManager::new(
|
||||
storage.clone(),
|
||||
consolidation_provider,
|
||||
consolidation_model,
|
||||
));
|
||||
tracing::info!(
|
||||
consolidation_provider = %memory_manager.consolidation_provider,
|
||||
consolidation_model = %memory_manager.consolidation_model,
|
||||
"Memory system initialized"
|
||||
);
|
||||
|
||||
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
||||
let bus = MessageBus::new(100);
|
||||
|
||||
@ -11,11 +11,21 @@ pub use types::{ConsolidationFact, ConsolidationResult, MemoryCategory, MemoryEn
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryManager {
|
||||
storage: Arc<Storage>,
|
||||
pub consolidation_provider: String,
|
||||
pub consolidation_model: String,
|
||||
}
|
||||
|
||||
impl MemoryManager {
|
||||
pub fn new(storage: Arc<Storage>) -> Self {
|
||||
Self { storage }
|
||||
pub fn new(
|
||||
storage: Arc<Storage>,
|
||||
consolidation_provider: String,
|
||||
consolidation_model: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
consolidation_provider,
|
||||
consolidation_model,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store or update a memory entry. Generates timestamp and UUID.
|
||||
@ -87,7 +97,7 @@ mod tests {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||
let mm = Arc::new(MemoryManager::new(storage));
|
||||
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
|
||||
(mm, dir)
|
||||
}
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ pub struct Session {
|
||||
/// Timestamp (Unix ms) of the last consolidation.
|
||||
/// Messages before this time have been compressed into memory.
|
||||
pub last_consolidated_at: Option<i64>,
|
||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@ -69,7 +69,7 @@ impl Session {
|
||||
storage: Option<StdArc<Storage>>,
|
||||
routing_info: String,
|
||||
title: String,
|
||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let mut provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
@ -83,11 +83,8 @@ impl Session {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config);
|
||||
if let Some(ref mm) = memory_manager {
|
||||
compressor = compressor.with_memory(mm.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
}
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
|
||||
@ -117,7 +114,7 @@ impl Session {
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
storage: StdArc<Storage>,
|
||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let session_meta = storage.get_session(&id.to_string()).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||
@ -135,11 +132,8 @@ impl Session {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config);
|
||||
if let Some(ref mm) = memory_manager {
|
||||
compressor = compressor.with_memory(mm.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
}
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
|
||||
// Convert MessageMeta to ChatMessage
|
||||
// Clear tool_call_id/tool_name — they're not valid across API sessions
|
||||
@ -584,7 +578,7 @@ pub struct SessionManager {
|
||||
storage: Arc<Storage>,
|
||||
bus: Arc<MessageBus>,
|
||||
current_source_session: Arc<Mutex<Option<String>>>,
|
||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
}
|
||||
|
||||
struct SessionManagerInner {
|
||||
@ -672,7 +666,7 @@ impl SessionManager {
|
||||
provider_config: LLMProviderConfig,
|
||||
storage: Arc<Storage>,
|
||||
bus: Arc<MessageBus>,
|
||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let skills_loader = SkillsLoader::new();
|
||||
skills_loader.load_skills();
|
||||
@ -1252,22 +1246,18 @@ impl SessionManager {
|
||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||
|
||||
// Fetch memory context
|
||||
let memory_context = if let Some(ref mm) = self.memory_manager {
|
||||
match mm.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
|
||||
Ok(entries) if !entries.is_empty() => {
|
||||
Some(entries.iter()
|
||||
.map(|e| format!("- {}: {}", e.key, e.content))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to fetch memory context");
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
let memory_context = match self.memory_manager.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
|
||||
Ok(entries) if !entries.is_empty() => {
|
||||
Some(entries.iter()
|
||||
.map(|e| format!("- {}: {}", e.key, e.content))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to fetch memory context");
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Build combined system prompt and inject at position 0
|
||||
|
||||
262
src/tools/memory.rs
Normal file
262
src/tools/memory.rs
Normal file
@ -0,0 +1,262 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::memory::{MemoryCategory, MemoryManager};
|
||||
|
||||
// ── MemoryStoreTool ──────────────────────────────────────────────────
|
||||
|
||||
pub struct MemoryStoreTool {
|
||||
memory: Arc<MemoryManager>,
|
||||
}
|
||||
|
||||
impl MemoryStoreTool {
|
||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryStoreTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_store"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Store a fact, preference, or insight into long-term memory. \
|
||||
Use this when the user shares important information you should remember. \
|
||||
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
||||
and the full content to remember."
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The full content of the memory entry."
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["knowledge", "timeline"],
|
||||
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
|
||||
},
|
||||
"importance": {
|
||||
"type": "number",
|
||||
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
||||
}
|
||||
},
|
||||
"required": ["key", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||
|
||||
let content = args
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
||||
|
||||
let category = args
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(MemoryCategory::from_str)
|
||||
.unwrap_or(MemoryCategory::Knowledge);
|
||||
|
||||
let importance = args.get("importance").and_then(|v| v.as_f64());
|
||||
|
||||
self.memory
|
||||
.store(key, content, category, None, importance)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Memory stored: {}", key),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── MemoryRecallTool ─────────────────────────────────────────────────
|
||||
|
||||
pub struct MemoryRecallTool {
|
||||
memory: Arc<MemoryManager>,
|
||||
}
|
||||
|
||||
impl MemoryRecallTool {
|
||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryRecallTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_recall"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search and retrieve entries from long-term memory. \
|
||||
Use this to recall previously stored facts, preferences, or conversation history. \
|
||||
Supports keyword search and optional time-range filtering."
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query — keywords to match against memory keys and content."
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["knowledge", "timeline"],
|
||||
"description": "Filter by memory category. Omit to search all categories."
|
||||
},
|
||||
"since": {
|
||||
"type": "integer",
|
||||
"description": "Start of time range (Unix milliseconds)."
|
||||
},
|
||||
"until": {
|
||||
"type": "integer",
|
||||
"description": "End of time range (Unix milliseconds)."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default 10)."
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
||||
|
||||
let category = args
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(MemoryCategory::from_str);
|
||||
|
||||
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
||||
|
||||
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
||||
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
|
||||
let until = args
|
||||
.get("until")
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||
self.memory
|
||||
.recall_by_time(since, until, limit, category)
|
||||
.await?
|
||||
} else {
|
||||
self.memory.recall(query, limit, category).await?
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No matching memories found.".to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let formatted = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
format!(
|
||||
"- {} [{}] [importance: {:.1}]: {}",
|
||||
e.key,
|
||||
e.category.as_str(),
|
||||
e.importance,
|
||||
e.content
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Found {} memories:\n{}", entries.len(), formatted),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── MemoryForgetTool ─────────────────────────────────────────────────
|
||||
|
||||
pub struct MemoryForgetTool {
|
||||
memory: Arc<MemoryManager>,
|
||||
}
|
||||
|
||||
impl MemoryForgetTool {
|
||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryForgetTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_forget"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Delete a memory entry by its key. Use this when information is outdated, \
|
||||
incorrect, or the user asks to forget something."
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "The key of the memory entry to delete."
|
||||
}
|
||||
},
|
||||
"required": ["key"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||
|
||||
self.memory.forget(key).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Memory deleted: {}", key),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,60 +0,0 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::memory::MemoryManager;
|
||||
|
||||
pub struct MemoryForgetTool {
|
||||
memory: Arc<MemoryManager>,
|
||||
}
|
||||
|
||||
impl MemoryForgetTool {
|
||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryForgetTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_forget"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Delete a memory entry by its key. Use this when information is outdated, \
|
||||
incorrect, or the user asks to forget something."
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "The key of the memory entry to delete."
|
||||
}
|
||||
},
|
||||
"required": ["key"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||
|
||||
self.memory.forget(key).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Memory deleted: {}", key),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,118 +0,0 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::memory::{MemoryCategory, MemoryManager};
|
||||
|
||||
pub struct MemoryRecallTool {
|
||||
memory: Arc<MemoryManager>,
|
||||
}
|
||||
|
||||
impl MemoryRecallTool {
|
||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryRecallTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_recall"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search and retrieve entries from long-term memory. \
|
||||
Use this to recall previously stored facts, preferences, or conversation history. \
|
||||
Supports keyword search and optional time-range filtering."
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query — keywords to match against memory keys and content."
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["knowledge", "timeline"],
|
||||
"description": "Filter by memory category. Omit to search all categories."
|
||||
},
|
||||
"since": {
|
||||
"type": "integer",
|
||||
"description": "Start of time range (Unix milliseconds)."
|
||||
},
|
||||
"until": {
|
||||
"type": "integer",
|
||||
"description": "End of time range (Unix milliseconds)."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default 10)."
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
||||
|
||||
let category = args
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(MemoryCategory::from_str);
|
||||
|
||||
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
||||
|
||||
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
||||
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
|
||||
let until = args
|
||||
.get("until")
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||
self.memory
|
||||
.recall_by_time(since, until, limit, category)
|
||||
.await?
|
||||
} else {
|
||||
self.memory.recall(query, limit, category).await?
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No matching memories found.".to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let formatted = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
format!(
|
||||
"- {} [{}] [importance: {:.1}]: {}",
|
||||
e.key,
|
||||
e.category.as_str(),
|
||||
e.importance,
|
||||
e.content
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Found {} memories:\n{}", entries.len(), formatted),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,90 +0,0 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::memory::{MemoryCategory, MemoryManager};
|
||||
|
||||
pub struct MemoryStoreTool {
|
||||
memory: Arc<MemoryManager>,
|
||||
}
|
||||
|
||||
impl MemoryStoreTool {
|
||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryStoreTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_store"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Store a fact, preference, or insight into long-term memory. \
|
||||
Use this when the user shares important information you should remember. \
|
||||
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
||||
and the full content to remember."
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The full content of the memory entry."
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["knowledge", "timeline"],
|
||||
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
|
||||
},
|
||||
"importance": {
|
||||
"type": "number",
|
||||
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
||||
}
|
||||
},
|
||||
"required": ["key", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||
|
||||
let content = args
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
||||
|
||||
let category = args
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(MemoryCategory::from_str)
|
||||
.unwrap_or(MemoryCategory::Knowledge);
|
||||
|
||||
let importance = args.get("importance").and_then(|v| v.as_f64());
|
||||
|
||||
self.memory
|
||||
.store(key, content, category, None, importance)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Memory stored: {}", key),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -7,9 +7,7 @@ pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod get_skill;
|
||||
pub mod http_request;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod memory;
|
||||
pub mod registry;
|
||||
pub mod schema;
|
||||
pub mod send_message;
|
||||
@ -24,9 +22,7 @@ pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use get_skill::GetSkillTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool};
|
||||
pub use registry::ToolRegistry;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use send_message::SendMessageTool;
|
||||
@ -42,7 +38,7 @@ use crate::skills::SkillsLoader;
|
||||
/// once the available channel names are known.
|
||||
pub fn create_default_tools(
|
||||
skills_loader: Arc<SkillsLoader>,
|
||||
memory: Option<Arc<MemoryManager>>,
|
||||
memory: Arc<MemoryManager>,
|
||||
) -> ToolRegistry {
|
||||
let registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
@ -59,12 +55,9 @@ pub fn create_default_tools(
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
registry.register(GetSkillTool::new(skills_loader));
|
||||
|
||||
// Register memory tools if memory system is available
|
||||
if let Some(mm) = memory {
|
||||
registry.register(MemoryStoreTool::new(mm.clone()));
|
||||
registry.register(MemoryRecallTool::new(mm.clone()));
|
||||
registry.register(MemoryForgetTool::new(mm.clone()));
|
||||
}
|
||||
registry.register(MemoryStoreTool::new(memory.clone()));
|
||||
registry.register(MemoryRecallTool::new(memory.clone()));
|
||||
registry.register(MemoryForgetTool::new(memory.clone()));
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user