Compare commits

..

2 Commits

23 changed files with 660 additions and 436 deletions

View File

@ -36,3 +36,4 @@ textwrap = "0.16"
chrono = "0.4" chrono = "0.4"
hostname = "0.3" hostname = "0.3"
sqlx = { version = "0.8", features = ["sqlite", "macros", "chrono", "runtime-tokio"] } sqlx = { version = "0.8", features = ["sqlite", "macros", "chrono", "runtime-tokio"] }
jieba-rs = "0.9"

View File

@ -63,5 +63,14 @@
"reaction_emoji": "Typing" "reaction_emoji": "Typing"
} }
}, },
"memory": {
"consolidation_provider": null,
"consolidation_model": null,
"recall_limit": 5,
"consolidation_turn_threshold": 10,
"idle_consolidation_minutes": 10,
"timeline_retention_days": 90,
"max_failures_before_degrade": 3
},
"workspace_dir": "~/.picobot/workspace" "workspace_dir": "~/.picobot/workspace"
} }

View File

@ -73,7 +73,7 @@ fn truncate_tool_result(output: &str) -> String {
// Even after removing suffix, still too long - take from beginning // Even after removing suffix, still too long - take from beginning
format!( format!(
"{}...\n\n[Output truncated - {} characters removed]", "{}...\n\n[Output truncated - {} characters removed]",
&output[..MAX_TOOL_RESULT_CHARS - 100], &output[..output.ceil_char_boundary(MAX_TOOL_RESULT_CHARS - 100)],
output.len() - MAX_TOOL_RESULT_CHARS + 100 output.len() - MAX_TOOL_RESULT_CHARS + 100
) )
} else { } else {
@ -81,7 +81,7 @@ fn truncate_tool_result(output: &str) -> String {
format!( format!(
"...\n\n[Output truncated - {} characters removed]\n\n{}", "...\n\n[Output truncated - {} characters removed]\n\n{}",
truncated_start_len, truncated_start_len,
&output[truncated_start_len..] &output[output.floor_char_boundary(truncated_start_len)..]
) )
} }
} }

View File

@ -50,22 +50,26 @@ pub struct ContextCompressor {
threshold_ratio: f64, threshold_ratio: f64,
/// Shared LLM provider for summarization /// Shared LLM provider for summarization
provider: Arc<dyn LLMProvider>, provider: Arc<dyn LLMProvider>,
/// Memory manager handle (optional). When set, compressed /// Memory manager handle. Compressed context summaries are persisted
/// context summaries are persisted as timeline memory entries. /// as timeline memory entries.
memory: Option<Arc<MemoryManager>>, memory: Arc<MemoryManager>,
/// Current session ID for timeline memory writes. /// Current session ID for timeline memory writes.
session_id: Option<String>, session_id: Option<String>,
} }
impl ContextCompressor { impl ContextCompressor {
/// Create a new compressor with the given provider and context window size. /// Create a new compressor with the given provider, context window size, and memory manager.
pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self { pub fn new(
provider: Arc<dyn LLMProvider>,
context_window: usize,
memory: Arc<MemoryManager>,
) -> Self {
Self { Self {
config: ContextCompressionConfig::default(), config: ContextCompressionConfig::default(),
context_window, context_window,
threshold_ratio: 0.5, threshold_ratio: 0.5,
provider, provider,
memory: None, memory,
session_id: None, session_id: None,
} }
} }
@ -75,23 +79,18 @@ impl ContextCompressor {
provider: Arc<dyn LLMProvider>, provider: Arc<dyn LLMProvider>,
context_window: usize, context_window: usize,
config: ContextCompressionConfig, config: ContextCompressionConfig,
memory: Arc<MemoryManager>,
) -> Self { ) -> Self {
Self { Self {
config, config,
context_window, context_window,
threshold_ratio: 0.5, threshold_ratio: 0.5,
provider, provider,
memory: None, memory,
session_id: None, session_id: None,
} }
} }
/// Attach a memory manager to persist compressed summaries.
pub fn with_memory(mut self, memory: Arc<MemoryManager>) -> Self {
self.memory = Some(memory);
self
}
/// Set the current session ID for timeline writes. /// Set the current session ID for timeline writes.
pub fn set_session_id(&mut self, id: Option<String>) { pub fn set_session_id(&mut self, id: Option<String>) {
self.session_id = id; self.session_id = id;
@ -113,7 +112,7 @@ impl ContextCompressor {
let removed = msg.content.len() - limit; let removed = msg.content.len() - limit;
msg.content = format!( msg.content = format!(
"{}...\n\n[Output truncated - {} characters removed]", "{}...\n\n[Output truncated - {} characters removed]",
&msg.content[..limit.min(msg.content.len())], &msg.content[..msg.content.ceil_char_boundary(limit)],
removed removed
); );
modified += 1; modified += 1;
@ -240,25 +239,23 @@ impl ContextCompressor {
let summary = self.summarize_segment(between).await?; let summary = self.summarize_segment(between).await?;
// Persist compressed summary as timeline memory entry // Persist compressed summary as timeline memory entry
if let Some(ref mm) = self.memory { let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", ts, between.len(), summary);
ts, between.len(), summary); let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); let mm = self.memory.clone();
let mm_clone = mm.clone(); let sid = self.session_id.clone();
let sid = self.session_id.clone(); tokio::spawn(async move {
tokio::spawn(async move { if let Err(e) = mm.store(
if let Err(e) = mm_clone.store( &key,
&key, &timeline_content,
&timeline_content, crate::memory::MemoryCategory::Timeline,
crate::memory::MemoryCategory::Timeline, sid.as_deref(),
sid.as_deref(), Some(0.3),
Some(0.3), ).await {
).await { tracing::warn!(error = %e, "Failed to store compressed context as timeline");
tracing::warn!(error = %e, "Failed to store compressed context as timeline"); }
} });
});
}
// Add summary as a special user message // Add summary as a special user message
new_messages.push(ChatMessage::user(format!( new_messages.push(ChatMessage::user(format!(
@ -316,7 +313,7 @@ impl ContextCompressor {
let transcript = if transcript.len() > self.config.summary_max_chars { let transcript = if transcript.len() > self.config.summary_max_chars {
format!( format!(
"{}...\n\n[Transcript truncated - {} characters removed]", "{}...\n\n[Transcript truncated - {} characters removed]",
&transcript[..self.config.summary_max_chars], &transcript[..transcript.ceil_char_boundary(self.config.summary_max_chars)],
transcript.len() - self.config.summary_max_chars transcript.len() - self.config.summary_max_chars
) )
} else { } else {
@ -359,7 +356,7 @@ Be concise, aim for {} characters or less.
Err(e) => { Err(e) => {
// Fallback: just truncate the transcript // Fallback: just truncate the transcript
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript"); tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
Ok(transcript[..transcript.len().min(2000)].to_string()) Ok(transcript[..transcript.ceil_char_boundary(2000)].to_string())
} }
} }
} }
@ -370,6 +367,7 @@ mod tests {
use super::*; use super::*;
use crate::providers::ChatCompletionResponse; use crate::providers::ChatCompletionResponse;
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::OnceLock;
/// Mock provider for testing - panics if actually used for LLM calls /// Mock provider for testing - panics if actually used for LLM calls
struct MockProvider; struct MockProvider;
@ -400,6 +398,18 @@ mod tests {
Arc::new(MockProvider) Arc::new(MockProvider)
} }
fn test_memory_manager() -> Arc<MemoryManager> {
static MM: OnceLock<Arc<MemoryManager>> = OnceLock::new();
MM.get_or_init(|| {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
})
}).clone()
}
#[test] #[test]
fn test_estimate_tokens() { fn test_estimate_tokens() {
let messages = vec![ let messages = vec![
@ -422,7 +432,7 @@ mod tests {
tool_result_trim_chars: 50, tool_result_trim_chars: 50,
..Default::default() ..Default::default()
}; };
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config); let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let mut messages = vec![ let mut messages = vec![
ChatMessage::user("Hello"), ChatMessage::user("Hello"),
@ -436,7 +446,7 @@ mod tests {
#[test] #[test]
fn test_threshold() { fn test_threshold() {
let compressor = ContextCompressor::new(mock_provider(), 128_000); let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager());
assert_eq!(compressor.threshold(), 64_000); assert_eq!(compressor.threshold(), 64_000);
} }
} }

View File

@ -243,10 +243,9 @@ impl PromptSection for CrossChannelSection {
- dialog_id: chat dialog - dialog_id: chat dialog
{}### {}###
`[message from X to Y]` assistant `[message from X]` assistant
send_message send_message
- X: ID "unknown" - X: ID "unknown"
- Y: session ID (<channel>:<chat_id>:<dialog_id>)

View File

@ -778,7 +778,7 @@ impl FeishuChannel {
let payload_content = if msg_type == "text" { let payload_content = if msg_type == "text" {
let truncated = if content.len() > MAX_TEXT_LENGTH { let truncated = if content.len() > MAX_TEXT_LENGTH {
format!("{}...\n\n[Content truncated due to length limit]", &content[..MAX_TEXT_LENGTH]) format!("{}...\n\n[Content truncated due to length limit]", &content[..content.ceil_char_boundary(MAX_TEXT_LENGTH)])
} else { } else {
content.to_string() content.to_string()
}; };
@ -788,7 +788,7 @@ impl FeishuChannel {
// But we still need to check length // But we still need to check length
if content.len() > MAX_TEXT_LENGTH { if content.len() > MAX_TEXT_LENGTH {
// Fallback to truncated text for post as well // Fallback to truncated text for post as well
serde_json::json!({ "text": format!("{}...\n\n[Content truncated due to length limit]", &content[..MAX_TEXT_LENGTH]) }).to_string() serde_json::json!({ "text": format!("{}...\n\n[Content truncated due to length limit]", &content[..content.ceil_char_boundary(MAX_TEXT_LENGTH)]) }).to_string()
} else { } else {
content.to_string() content.to_string()
} }
@ -2136,7 +2136,7 @@ impl Channel for FeishuChannel {
if !msg.content.is_empty() { if !msg.content.is_empty() {
const MAX_TEXT_LENGTH: usize = 60_000; const MAX_TEXT_LENGTH: usize = 60_000;
let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH { let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH {
format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..MAX_TEXT_LENGTH]) format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..msg.content.ceil_char_boundary(MAX_TEXT_LENGTH)])
} else { } else {
msg.content.clone() msg.content.clone()
}; };

View File

@ -220,15 +220,14 @@ impl Default for ClientConfig {
} }
} }
#[derive(Debug, Clone, Default, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MemoryConfig { pub struct MemoryConfig {
/// Master switch for the memory system.
#[serde(default)]
pub enabled: bool,
/// Provider name for consolidation LLM calls (key in `providers`). /// Provider name for consolidation LLM calls (key in `providers`).
/// If not set, falls back to the main agent's provider.
#[serde(default)] #[serde(default)]
pub consolidation_provider: Option<String>, pub consolidation_provider: Option<String>,
/// Model name for consolidation LLM calls (key in `models`). /// Model name for consolidation LLM calls (key in `models`).
/// If not set, falls back to the main agent's model.
#[serde(default)] #[serde(default)]
pub consolidation_model: Option<String>, pub consolidation_model: Option<String>,
/// Max knowledge entries injected into system prompt per turn. /// Max knowledge entries injected into system prompt per turn.
@ -248,8 +247,34 @@ pub struct MemoryConfig {
pub max_failures_before_degrade: usize, pub max_failures_before_degrade: usize,
} }
impl Default for MemoryConfig {
fn default() -> Self {
Self {
consolidation_provider: None,
consolidation_model: None,
recall_limit: 5,
consolidation_turn_threshold: 10,
idle_consolidation_minutes: 10,
timeline_retention_days: 90,
max_failures_before_degrade: 3,
}
}
}
impl MemoryConfig {
/// Resolve consolidation provider name, falling back to the main agent's provider.
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
}
/// Resolve consolidation model name, falling back to the main agent's model.
pub fn resolve_consolidation_model(&self, default: &str) -> String {
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
}
}
fn default_recall_limit() -> usize { 5 } fn default_recall_limit() -> usize { 5 }
fn default_consolidation_turn_threshold() -> usize { 3 } fn default_consolidation_turn_threshold() -> usize { 10 }
fn default_idle_consolidation_minutes() -> u64 { 10 } fn default_idle_consolidation_minutes() -> u64 { 10 }
fn default_timeline_retention_days() -> u64 { 90 } fn default_timeline_retention_days() -> u64 { 90 }
fn default_max_failures_before_degrade() -> usize { 3 } fn default_max_failures_before_degrade() -> usize { 3 }

View File

@ -56,14 +56,23 @@ impl GatewayState {
); );
tracing::info!("Session storage: {}", db_path.display()); tracing::info!("Session storage: {}", db_path.display());
// Initialize MemoryManager if memory system is enabled // Resolve consolidation provider/model with fallback to main agent config
let memory_manager = if config.memory.enabled { let consolidation_provider = config
let mm = Arc::new(MemoryManager::new(storage.clone())); .memory
tracing::info!("Memory system enabled"); .resolve_consolidation_provider(&provider_config.name);
Some(mm) let consolidation_model = config
} else { .memory
None .resolve_consolidation_model(&provider_config.model_id);
}; let memory_manager = Arc::new(MemoryManager::new(
storage.clone(),
consolidation_provider,
consolidation_model,
));
tracing::info!(
consolidation_provider = %memory_manager.consolidation_provider,
consolidation_model = %memory_manager.consolidation_model,
"Memory system initialized"
);
// Create MessageBus first (shared by SessionManager and ChannelManager) // Create MessageBus first (shared by SessionManager and ChannelManager)
let bus = MessageBus::new(100); let bus = MessageBus::new(100);

View File

@ -11,11 +11,21 @@ pub use types::{ConsolidationFact, ConsolidationResult, MemoryCategory, MemoryEn
#[derive(Clone)] #[derive(Clone)]
pub struct MemoryManager { pub struct MemoryManager {
storage: Arc<Storage>, storage: Arc<Storage>,
pub consolidation_provider: String,
pub consolidation_model: String,
} }
impl MemoryManager { impl MemoryManager {
pub fn new(storage: Arc<Storage>) -> Self { pub fn new(
Self { storage } storage: Arc<Storage>,
consolidation_provider: String,
consolidation_model: String,
) -> Self {
Self {
storage,
consolidation_provider,
consolidation_model,
}
} }
/// Store or update a memory entry. Generates timestamp and UUID. /// Store or update a memory entry. Generates timestamp and UUID.
@ -58,11 +68,12 @@ impl MemoryManager {
&self, &self,
since: i64, since: i64,
until: i64, until: i64,
query: Option<&str>,
limit: usize, limit: usize,
category: Option<MemoryCategory>, category: Option<MemoryCategory>,
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> { ) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
self.storage self.storage
.search_memories_by_time(since, until, category.as_ref(), limit) .search_memories_by_time(since, until, query, category.as_ref(), limit)
.await .await
} }
@ -87,7 +98,7 @@ mod tests {
let dir = tempdir().unwrap(); let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db"); let db_path = dir.path().join("test.db");
let storage = Arc::new(Storage::new(&db_path).await.unwrap()); let storage = Arc::new(Storage::new(&db_path).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage)); let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
(mm, dir) (mm, dir)
} }

View File

@ -202,7 +202,7 @@ impl LLMProvider for AnthropicProvider {
} else { } else {
let mut blocks = convert_content_blocks(&m.content); let mut blocks = convert_content_blocks(&m.content);
// Append tool_use blocks from assistant messages with tool calls // Append tool_use blocks from assistant messages with tool calls
if let Some(ref tool_calls) = m.tool_calls { if let Some(tool_calls) = m.tool_calls.as_ref().filter(|c| !c.is_empty()) {
for tc in tool_calls { for tc in tool_calls {
blocks.push(serde_json::json!({ blocks.push(serde_json::json!({
"type": "tool_use", "type": "tool_use",

View File

@ -77,7 +77,7 @@ impl OpenAIProvider {
"tool_call_id": m.tool_call_id, "tool_call_id": m.tool_call_id,
"name": m.name, "name": m.name,
}) })
} else if m.role == "assistant" && m.tool_calls.is_some() { } else if m.role == "assistant" && m.tool_calls.as_ref().map_or(false, |c| !c.is_empty()) {
json!({ json!({
"role": m.role, "role": m.role,
"content": convert_content_blocks(&m.content), "content": convert_content_blocks(&m.content),

View File

@ -147,7 +147,7 @@ impl Scheduler {
let _ = self.bus.publish_outbound(outbound).await; let _ = self.bus.publish_outbound(outbound).await;
let output_truncated = if output.len() > 8000 { let output_truncated = if output.len() > 8000 {
format!("{}...[truncated]", &output[..8000]) format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
} else { } else {
output.clone() output.clone()
}; };

View File

@ -58,7 +58,7 @@ pub struct Session {
/// Timestamp (Unix ms) of the last consolidation. /// Timestamp (Unix ms) of the last consolidation.
/// Messages before this time have been compressed into memory. /// Messages before this time have been compressed into memory.
pub last_consolidated_at: Option<i64>, pub last_consolidated_at: Option<i64>,
memory_manager: Option<Arc<crate::memory::MemoryManager>>, memory_manager: Arc<crate::memory::MemoryManager>,
} }
impl Session { impl Session {
@ -69,7 +69,7 @@ impl Session {
storage: Option<StdArc<Storage>>, storage: Option<StdArc<Storage>>,
routing_info: String, routing_info: String,
title: String, title: String,
memory_manager: Option<Arc<crate::memory::MemoryManager>>, memory_manager: Arc<crate::memory::MemoryManager>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let mut provider_box = create_provider(provider_config.clone()) let mut provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
@ -83,11 +83,8 @@ impl Session {
..Default::default() ..Default::default()
}; };
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config); let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
if let Some(ref mm) = memory_manager { compressor.set_session_id(Some(id.to_string()));
compressor = compressor.with_memory(mm.clone());
compressor.set_session_id(Some(id.to_string()));
}
let now = chrono::Utc::now().timestamp_millis(); let now = chrono::Utc::now().timestamp_millis();
@ -117,7 +114,7 @@ impl Session {
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
storage: StdArc<Storage>, storage: StdArc<Storage>,
memory_manager: Option<Arc<crate::memory::MemoryManager>>, memory_manager: Arc<crate::memory::MemoryManager>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let session_meta = storage.get_session(&id.to_string()).await let session_meta = storage.get_session(&id.to_string()).await
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?; .map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
@ -135,28 +132,28 @@ impl Session {
..Default::default() ..Default::default()
}; };
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config); let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
if let Some(ref mm) = memory_manager { compressor.set_session_id(Some(id.to_string()));
compressor = compressor.with_memory(mm.clone());
compressor.set_session_id(Some(id.to_string()));
}
// Convert MessageMeta to ChatMessage // Convert MessageMeta to ChatMessage, then repair damaged tool call chains
// Clear tool_call_id/tool_name — they're not valid across API sessions let mut chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
ChatMessage { ChatMessage {
id: m.id, id: m.id,
role: m.role, role: m.role,
content: m.content, content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(), media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at, timestamp: m.created_at,
tool_call_id: None, tool_call_id: m.tool_call_id,
tool_name: None, tool_name: m.tool_name,
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()), tool_calls: m.tool_calls
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
.filter(|v| !v.is_empty()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()), source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
} }
}).collect(); }).collect();
repair_tool_call_chains(&mut chat_messages);
let seq_counter = chat_messages.len() as i64 + 1; let seq_counter = chat_messages.len() as i64 + 1;
let total_message_count = chat_messages.len() as i64; let total_message_count = chat_messages.len() as i64;
@ -211,7 +208,7 @@ impl Session {
}, },
tool_call_id: message.tool_call_id.clone(), tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(), tool_name: message.tool_name.clone(),
tool_calls: message.tool_calls.as_ref().map(|tc| serde_json::to_string(tc).unwrap_or_default()), tool_calls: message.tool_calls.as_ref().and_then(|tc| serde_json::to_string(tc).ok()),
source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()), source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()),
created_at: now, created_at: now,
}; };
@ -574,6 +571,67 @@ impl Session {
} }
} }
/// Repair damaged tool call chains after restoring from storage.
/// Handles cases where the gateway crashed mid-loop, leaving assistant
/// tool_calls without corresponding tool result messages.
fn repair_tool_call_chains(messages: &mut Vec<ChatMessage>) {
let mut i = 0;
while i < messages.len() {
let calls = match &messages[i].tool_calls {
Some(calls) if !calls.is_empty() => calls.clone(),
_ => {
i += 1;
continue;
}
};
if messages[i].role != "assistant" {
i += 1;
continue;
}
// Collect expected tool call IDs
let expected_ids: std::collections::HashSet<&str> = calls.iter().map(|c| c.id.as_str()).collect();
let expected_count = expected_ids.len();
// Check following messages for matching tool results (same tool_call_id)
let mut found = 0;
let mut j = i + 1;
while j < messages.len() && found < expected_count {
if messages[j].role == "tool" {
if let Some(ref tc_id) = messages[j].tool_call_id {
if expected_ids.contains(tc_id.as_str()) {
found += 1;
}
}
} else if messages[j].role == "user" || messages[j].role == "assistant" {
// Next user/assistant message — stop scanning, chain is broken
break;
}
j += 1;
}
if found < expected_count {
// Incomplete chain: remove tool_calls and add interruption note
tracing::warn!(
found,
expected = expected_count,
"Repairing incomplete tool call chain — gateway restart likely interrupted execution"
);
let old_content = std::mem::take(&mut messages[i].content);
messages[i].content = format!(
"{}\n\n[Tool calls ({}): {} — execution interrupted by gateway restart]",
old_content,
expected_count,
calls.iter().map(|c| c.name.as_str()).collect::<Vec<_>>().join(", ")
);
messages[i].tool_calls = None;
}
i += 1;
}
}
/// SessionManager 管理所有 Session按 channel_name 路由 /// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)] #[derive(Clone)]
pub struct SessionManager { pub struct SessionManager {
@ -584,7 +642,7 @@ pub struct SessionManager {
storage: Arc<Storage>, storage: Arc<Storage>,
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
current_source_session: Arc<Mutex<Option<String>>>, current_source_session: Arc<Mutex<Option<String>>>,
memory_manager: Option<Arc<crate::memory::MemoryManager>>, memory_manager: Arc<crate::memory::MemoryManager>,
} }
struct SessionManagerInner { struct SessionManagerInner {
@ -672,7 +730,7 @@ impl SessionManager {
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
storage: Arc<Storage>, storage: Arc<Storage>,
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
memory_manager: Option<Arc<crate::memory::MemoryManager>>, memory_manager: Arc<crate::memory::MemoryManager>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let skills_loader = SkillsLoader::new(); let skills_loader = SkillsLoader::new();
skills_loader.load_skills(); skills_loader.load_skills();
@ -1252,22 +1310,18 @@ impl SessionManager {
let skills_prompt = self.skills_loader.build_skills_prompt(); let skills_prompt = self.skills_loader.build_skills_prompt();
// Fetch memory context // Fetch memory context
let memory_context = if let Some(ref mm) = self.memory_manager { let memory_context = match self.memory_manager.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
match mm.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await { Ok(entries) if !entries.is_empty() => {
Ok(entries) if !entries.is_empty() => { Some(entries.iter()
Some(entries.iter() .map(|e| format!("- {}: {}", e.key, e.content))
.map(|e| format!("- {}: {}", e.key, e.content)) .collect::<Vec<_>>()
.collect::<Vec<_>>() .join("\n"))
.join("\n"))
}
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
} }
} else { Err(e) => {
None tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
}; };
// Build combined system prompt and inject at position 0 // Build combined system prompt and inject at position 0
@ -1409,11 +1463,9 @@ impl SessionManager {
} }
let raw_response = result.final_response.content; let raw_response = result.final_response.content;
let target_id = unified_id.to_string();
let prefix = format!( let prefix = format!(
"[message from cron:{}({}) to {}]\n", "[message from cron:{}({})]\n",
job_name, job_id, target_id job_name, job_id
); );
let prefixed_response = format!("{}{}", prefix, raw_response); let prefixed_response = format!("{}{}", prefix, raw_response);
@ -1491,11 +1543,10 @@ impl OutboundMessenger for SessionManager {
(sid, session) (sid, session)
}; };
// Build message prefix: [message from <origin> to <channel:chat_id:dialog_id>] // Build message prefix: [message from <origin>]
let target_id = target_sid.to_string();
let origin = source.from_session.as_deref().unwrap_or("unknown"); let origin = source.from_session.as_deref().unwrap_or("unknown");
let origin_id = source.from_session.clone(); let origin_id = source.from_session.clone();
let prefix = format!("[message from {} to {}] ", origin, target_id); let prefix = format!("[message from {}] ", origin);
let marked_content = format!("{}\n{}", prefix, content); let marked_content = format!("{}\n{}", prefix, content);
// Write source-tagged assistant message to target session history // Write source-tagged assistant message to target session history

View File

@ -1,9 +1,17 @@
use sqlx::Row; use sqlx::Row;
use std::sync::OnceLock;
use jieba_rs::Jieba;
use crate::memory::{MemoryCategory, MemoryEntry}; use crate::memory::{MemoryCategory, MemoryEntry};
use super::StorageError; use super::StorageError;
fn jieba() -> &'static Jieba {
static INSTANCE: OnceLock<Jieba> = OnceLock::new();
INSTANCE.get_or_init(Jieba::new)
}
impl super::Storage { impl super::Storage {
/// Store or update a memory entry (upsert by key). /// Store or update a memory entry (upsert by key).
pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> { pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> {
@ -50,9 +58,11 @@ impl super::Storage {
category: Option<&MemoryCategory>, category: Option<&MemoryCategory>,
limit: usize, limit: usize,
) -> Result<Vec<MemoryEntry>, StorageError> { ) -> Result<Vec<MemoryEntry>, StorageError> {
// Build FTS5 query: wrap each word in quotes and join with OR // Build FTS5 query: segment with jieba, wrap each term in quotes, join with OR
let fts_query = query let fts_query = jieba()
.split_whitespace() .cut(query, true)
.into_iter()
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
.map(|w| format!("\"{}\"", w.replace('"', ""))) .map(|w| format!("\"{}\"", w.replace('"', "")))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(" OR "); .join(" OR ");
@ -80,39 +90,59 @@ impl super::Storage {
let mut entries = parse_memory_rows(&rows)?; let mut entries = parse_memory_rows(&rows)?;
// Fallback to LIKE if FTS5 returned nothing // Fallback to term-based LIKE query if FTS5 returned nothing
if entries.is_empty() { if entries.is_empty() {
let like_pattern = format!("%{}%", query.replace('%', "").replace('_', "")); let terms: Vec<String> = jieba()
let rows = sqlx::query( .cut(query, true)
r#" .into_iter()
SELECT id, key, content, category, importance, .filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
session_id, created_at, updated_at .map(|w| w.replace('%', "").replace('_', ""))
FROM memories .collect();
WHERE (key LIKE ? OR content LIKE ?)
AND (? IS NULL OR category = ?)
ORDER BY importance DESC, updated_at DESC
LIMIT ?
"#,
)
.bind(&like_pattern)
.bind(&like_pattern)
.bind(category_filter)
.bind(category_filter)
.bind(limit as i64)
.fetch_all(self.pool())
.await?;
entries = parse_memory_rows(&rows)?; if !terms.is_empty() {
let like_clauses = terms
.iter()
.map(|_| "(key LIKE ? OR content LIKE ?)")
.collect::<Vec<_>>()
.join(" OR ");
let sql = format!(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE ({})
AND (? IS NULL OR category = ?)
ORDER BY importance DESC, updated_at DESC
LIMIT ?
"#,
like_clauses
);
let mut query_builder = sqlx::query(&sql);
for term in &terms {
let pattern = format!("%{}%", term);
query_builder = query_builder.bind(pattern.clone()).bind(pattern);
}
query_builder = query_builder
.bind(category_filter)
.bind(category_filter)
.bind(limit as i64);
let rows = query_builder.fetch_all(self.pool()).await?;
entries = parse_memory_rows(&rows)?;
}
} }
Ok(entries) Ok(entries)
} }
/// Retrieve memories within a time range. /// Retrieve memories within a time range, optionally filtered by keyword query.
pub async fn search_memories_by_time( pub async fn search_memories_by_time(
&self, &self,
since: i64, since: i64,
until: i64, until: i64,
query: Option<&str>,
category: Option<&MemoryCategory>, category: Option<&MemoryCategory>,
limit: usize, limit: usize,
) -> Result<Vec<MemoryEntry>, StorageError> { ) -> Result<Vec<MemoryEntry>, StorageError> {
@ -124,24 +154,71 @@ impl super::Storage {
.unwrap_or_default() .unwrap_or_default()
.to_rfc3339(); .to_rfc3339();
let rows = sqlx::query( let rows = if let Some(q) = query {
r#" let terms: Vec<String> = jieba()
SELECT id, key, content, category, importance, .cut(q, true)
session_id, created_at, updated_at .into_iter()
FROM memories .filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
WHERE created_at >= ? AND created_at <= ? .map(|w| w.replace('%', "").replace('_', ""))
AND (? IS NULL OR category = ?) .collect();
ORDER BY created_at DESC
LIMIT ? if terms.is_empty() {
"#, return Ok(Vec::new());
) }
.bind(&since_dt)
.bind(&until_dt) let like_clauses = terms
.bind(category_filter) .iter()
.bind(category_filter) .map(|_| "(key LIKE ? OR content LIKE ?)")
.bind(limit as i64) .collect::<Vec<_>>()
.fetch_all(self.pool()) .join(" OR ");
.await?;
let sql = format!(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE ({})
AND created_at >= ? AND created_at <= ?
AND (? IS NULL OR category = ?)
ORDER BY created_at DESC
LIMIT ?
"#,
like_clauses
);
let mut query_builder = sqlx::query(&sql);
for term in &terms {
let pattern = format!("%{}%", term);
query_builder = query_builder.bind(pattern.clone()).bind(pattern);
}
query_builder = query_builder
.bind(&since_dt)
.bind(&until_dt)
.bind(category_filter)
.bind(category_filter)
.bind(limit as i64);
query_builder.fetch_all(self.pool()).await?
} else {
sqlx::query(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE created_at >= ? AND created_at <= ?
AND (? IS NULL OR category = ?)
ORDER BY created_at DESC
LIMIT ?
"#,
)
.bind(&since_dt)
.bind(&until_dt)
.bind(category_filter)
.bind(category_filter)
.bind(limit as i64)
.fetch_all(self.pool())
.await?
};
parse_memory_rows(&rows) parse_memory_rows(&rows)
} }

View File

@ -127,6 +127,48 @@ impl Storage {
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
// Triggers to keep FTS5 index in sync with memories table
sqlx::query(
r#"
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memory_fts(rowid, key, content) VALUES (new.rowid, new.key, new.content);
END
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memory_fts(memory_fts, rowid, key, content)
VALUES ('delete', old.rowid, old.key, old.content);
END
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memory_fts(memory_fts, rowid, key, content)
VALUES ('delete', old.rowid, old.key, old.content);
INSERT INTO memory_fts(rowid, key, content)
VALUES (new.rowid, new.key, new.content);
END
"#,
)
.execute(&self.pool)
.await?;
// Rebuild FTS5 index for any existing records
sqlx::query(
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
)
.execute(&self.pool)
.await?;
// Migration: add last_consolidated_at column if not exists // Migration: add last_consolidated_at column if not exists
sqlx::query( sqlx::query(
r#" r#"

View File

@ -68,9 +68,9 @@ impl BashTool {
let half = MAX_OUTPUT_CHARS / 2; let half = MAX_OUTPUT_CHARS / 2;
format!( format!(
"{}...\n\n(... {} chars truncated ...)\n\n{}", "{}...\n\n(... {} chars truncated ...)\n\n{}",
&output[..half], &output[..output.ceil_char_boundary(half)],
output.len() - MAX_OUTPUT_CHARS, output.len() - MAX_OUTPUT_CHARS,
&output[output.len() - half..] &output[output.floor_char_boundary(output.len() - half)..]
) )
} }
} }

View File

@ -101,7 +101,7 @@ impl HttpRequestTool {
if text.len() > self.max_response_size { if text.len() > self.max_response_size {
format!( format!(
"{}\n\n... [Response truncated due to size limit] ...", "{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size] &text[..text.ceil_char_boundary(self.max_response_size)]
) )
} else { } else {
text.to_string() text.to_string()

265
src/tools/memory.rs Normal file
View File

@ -0,0 +1,265 @@
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::memory::{MemoryCategory, MemoryManager};
// ── MemoryStoreTool ──────────────────────────────────────────────────
pub struct MemoryStoreTool {
memory: Arc<MemoryManager>,
}
impl MemoryStoreTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryStoreTool {
fn name(&self) -> &str {
"memory_store"
}
fn description(&self) -> &str {
"Store a fact, preference, or insight into long-term memory. \
Use this when the user shares important information you should remember. \
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
and the full content to remember."
}
fn read_only(&self) -> bool {
false
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
},
"content": {
"type": "string",
"description": "The full content of the memory entry."
},
"category": {
"type": "string",
"enum": ["knowledge", "timeline"],
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
},
"importance": {
"type": "number",
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
}
},
"required": ["key", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
let category = args
.get("category")
.and_then(|v| v.as_str())
.and_then(MemoryCategory::from_str)
.unwrap_or(MemoryCategory::Knowledge);
let importance = args.get("importance").and_then(|v| v.as_f64());
self.memory
.store(key, content, category, None, importance)
.await?;
Ok(ToolResult {
success: true,
output: format!("Memory stored: {}", key),
error: None,
})
}
}
// ── MemoryRecallTool ─────────────────────────────────────────────────
pub struct MemoryRecallTool {
memory: Arc<MemoryManager>,
}
impl MemoryRecallTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryRecallTool {
fn name(&self) -> &str {
"memory_recall"
}
fn description(&self) -> &str {
"Search and retrieve entries from long-term memory using keyword matching. \
Use this to recall previously stored facts, preferences, or conversation history. \
IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence). \
Use multiple synonymous or related terms to increase recall. \
Example: instead of 'what is the user location', use 'user location address city residence'. \
Supports optional time-range filtering via since/until (Unix ms)."
}
fn read_only(&self) -> bool {
true
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Space-separated KEYWORDS for memory search (NOT a natural language question). Use multiple related terms for better recall, e.g. 'address city location residence'."
},
"category": {
"type": "string",
"enum": ["knowledge", "timeline"],
"description": "Filter by memory category. Omit to search all categories."
},
"since": {
"type": "integer",
"description": "Start of time range (Unix milliseconds)."
},
"until": {
"type": "integer",
"description": "End of time range (Unix milliseconds)."
},
"limit": {
"type": "integer",
"description": "Max results to return (default 10)."
}
},
"required": ["query"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
let category = args
.get("category")
.and_then(|v| v.as_str())
.and_then(MemoryCategory::from_str);
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
let entries = if args.get("since").is_some() || args.get("until").is_some() {
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
let until = args
.get("until")
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(since, until, Some(query), limit, category)
.await?
} else {
self.memory.recall(query, limit, category).await?
};
if entries.is_empty() {
return Ok(ToolResult {
success: true,
output: "No matching memories found.".to_string(),
error: None,
});
}
let formatted = entries
.iter()
.map(|e| {
format!(
"- {} [{}] [importance: {:.1}]: {}",
e.key,
e.category.as_str(),
e.importance,
e.content
)
})
.collect::<Vec<_>>()
.join("\n");
Ok(ToolResult {
success: true,
output: format!("Found {} memories:\n{}", entries.len(), formatted),
error: None,
})
}
}
// ── MemoryForgetTool ─────────────────────────────────────────────────
pub struct MemoryForgetTool {
memory: Arc<MemoryManager>,
}
impl MemoryForgetTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryForgetTool {
fn name(&self) -> &str {
"memory_forget"
}
fn description(&self) -> &str {
"Delete a memory entry by its key. Use this when information is outdated, \
incorrect, or the user asks to forget something."
}
fn read_only(&self) -> bool {
false
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "The key of the memory entry to delete."
}
},
"required": ["key"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
self.memory.forget(key).await?;
Ok(ToolResult {
success: true,
output: format!("Memory deleted: {}", key),
error: None,
})
}
}

View File

@ -1,60 +0,0 @@
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::memory::MemoryManager;
pub struct MemoryForgetTool {
memory: Arc<MemoryManager>,
}
impl MemoryForgetTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryForgetTool {
fn name(&self) -> &str {
"memory_forget"
}
fn description(&self) -> &str {
"Delete a memory entry by its key. Use this when information is outdated, \
incorrect, or the user asks to forget something."
}
fn read_only(&self) -> bool {
false
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "The key of the memory entry to delete."
}
},
"required": ["key"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
self.memory.forget(key).await?;
Ok(ToolResult {
success: true,
output: format!("Memory deleted: {}", key),
error: None,
})
}
}

View File

@ -1,118 +0,0 @@
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::memory::{MemoryCategory, MemoryManager};
pub struct MemoryRecallTool {
memory: Arc<MemoryManager>,
}
impl MemoryRecallTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryRecallTool {
fn name(&self) -> &str {
"memory_recall"
}
fn description(&self) -> &str {
"Search and retrieve entries from long-term memory. \
Use this to recall previously stored facts, preferences, or conversation history. \
Supports keyword search and optional time-range filtering."
}
fn read_only(&self) -> bool {
true
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query — keywords to match against memory keys and content."
},
"category": {
"type": "string",
"enum": ["knowledge", "timeline"],
"description": "Filter by memory category. Omit to search all categories."
},
"since": {
"type": "integer",
"description": "Start of time range (Unix milliseconds)."
},
"until": {
"type": "integer",
"description": "End of time range (Unix milliseconds)."
},
"limit": {
"type": "integer",
"description": "Max results to return (default 10)."
}
},
"required": ["query"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
let category = args
.get("category")
.and_then(|v| v.as_str())
.and_then(MemoryCategory::from_str);
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
let entries = if args.get("since").is_some() || args.get("until").is_some() {
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
let until = args
.get("until")
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(since, until, limit, category)
.await?
} else {
self.memory.recall(query, limit, category).await?
};
if entries.is_empty() {
return Ok(ToolResult {
success: true,
output: "No matching memories found.".to_string(),
error: None,
});
}
let formatted = entries
.iter()
.map(|e| {
format!(
"- {} [{}] [importance: {:.1}]: {}",
e.key,
e.category.as_str(),
e.importance,
e.content
)
})
.collect::<Vec<_>>()
.join("\n");
Ok(ToolResult {
success: true,
output: format!("Found {} memories:\n{}", entries.len(), formatted),
error: None,
})
}
}

View File

@ -1,90 +0,0 @@
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::memory::{MemoryCategory, MemoryManager};
pub struct MemoryStoreTool {
memory: Arc<MemoryManager>,
}
impl MemoryStoreTool {
pub fn new(memory: Arc<MemoryManager>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryStoreTool {
fn name(&self) -> &str {
"memory_store"
}
fn description(&self) -> &str {
"Store a fact, preference, or insight into long-term memory. \
Use this when the user shares important information you should remember. \
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
and the full content to remember."
}
fn read_only(&self) -> bool {
false
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
},
"content": {
"type": "string",
"description": "The full content of the memory entry."
},
"category": {
"type": "string",
"enum": ["knowledge", "timeline"],
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
},
"importance": {
"type": "number",
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
}
},
"required": ["key", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
let category = args
.get("category")
.and_then(|v| v.as_str())
.and_then(MemoryCategory::from_str)
.unwrap_or(MemoryCategory::Knowledge);
let importance = args.get("importance").and_then(|v| v.as_f64());
self.memory
.store(key, content, category, None, importance)
.await?;
Ok(ToolResult {
success: true,
output: format!("Memory stored: {}", key),
error: None,
})
}
}

View File

@ -7,9 +7,7 @@ pub mod file_read;
pub mod file_write; pub mod file_write;
pub mod get_skill; pub mod get_skill;
pub mod http_request; pub mod http_request;
pub mod memory_forget; pub mod memory;
pub mod memory_recall;
pub mod memory_store;
pub mod registry; pub mod registry;
pub mod schema; pub mod schema;
pub mod send_message; pub mod send_message;
@ -24,9 +22,7 @@ pub use file_read::FileReadTool;
pub use file_write::FileWriteTool; pub use file_write::FileWriteTool;
pub use get_skill::GetSkillTool; pub use get_skill::GetSkillTool;
pub use http_request::HttpRequestTool; pub use http_request::HttpRequestTool;
pub use memory_forget::MemoryForgetTool; pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool};
pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool;
pub use registry::ToolRegistry; pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr}; pub use schema::{CleaningStrategy, SchemaCleanr};
pub use send_message::SendMessageTool; pub use send_message::SendMessageTool;
@ -42,7 +38,7 @@ use crate::skills::SkillsLoader;
/// once the available channel names are known. /// once the available channel names are known.
pub fn create_default_tools( pub fn create_default_tools(
skills_loader: Arc<SkillsLoader>, skills_loader: Arc<SkillsLoader>,
memory: Option<Arc<MemoryManager>>, memory: Arc<MemoryManager>,
) -> ToolRegistry { ) -> ToolRegistry {
let registry = ToolRegistry::new(); let registry = ToolRegistry::new();
registry.register(CalculatorTool::new()); registry.register(CalculatorTool::new());
@ -59,12 +55,9 @@ pub fn create_default_tools(
registry.register(WebFetchTool::new(50_000, 30)); registry.register(WebFetchTool::new(50_000, 30));
registry.register(GetSkillTool::new(skills_loader)); registry.register(GetSkillTool::new(skills_loader));
// Register memory tools if memory system is available registry.register(MemoryStoreTool::new(memory.clone()));
if let Some(mm) = memory { registry.register(MemoryRecallTool::new(memory.clone()));
registry.register(MemoryStoreTool::new(mm.clone())); registry.register(MemoryForgetTool::new(memory.clone()));
registry.register(MemoryRecallTool::new(mm.clone()));
registry.register(MemoryForgetTool::new(mm.clone()));
}
registry registry
} }

View File

@ -53,7 +53,7 @@ impl WebFetchTool {
if text.len() > self.max_response_size { if text.len() > self.max_response_size {
format!( format!(
"{}\n\n... [Response truncated due to size limit] ...", "{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size] &text[..text.ceil_char_boundary(self.max_response_size)]
) )
} else { } else {
text.to_string() text.to_string()