Compare commits
2 Commits
c602a0695d
...
5d62141658
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d62141658 | |||
| 2617558a27 |
@ -36,3 +36,4 @@ textwrap = "0.16"
|
|||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
hostname = "0.3"
|
hostname = "0.3"
|
||||||
sqlx = { version = "0.8", features = ["sqlite", "macros", "chrono", "runtime-tokio"] }
|
sqlx = { version = "0.8", features = ["sqlite", "macros", "chrono", "runtime-tokio"] }
|
||||||
|
jieba-rs = "0.9"
|
||||||
|
|||||||
@ -63,5 +63,14 @@
|
|||||||
"reaction_emoji": "Typing"
|
"reaction_emoji": "Typing"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"memory": {
|
||||||
|
"consolidation_provider": null,
|
||||||
|
"consolidation_model": null,
|
||||||
|
"recall_limit": 5,
|
||||||
|
"consolidation_turn_threshold": 10,
|
||||||
|
"idle_consolidation_minutes": 10,
|
||||||
|
"timeline_retention_days": 90,
|
||||||
|
"max_failures_before_degrade": 3
|
||||||
|
},
|
||||||
"workspace_dir": "~/.picobot/workspace"
|
"workspace_dir": "~/.picobot/workspace"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -73,7 +73,7 @@ fn truncate_tool_result(output: &str) -> String {
|
|||||||
// Even after removing suffix, still too long - take from beginning
|
// Even after removing suffix, still too long - take from beginning
|
||||||
format!(
|
format!(
|
||||||
"{}...\n\n[Output truncated - {} characters removed]",
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
&output[..MAX_TOOL_RESULT_CHARS - 100],
|
&output[..output.ceil_char_boundary(MAX_TOOL_RESULT_CHARS - 100)],
|
||||||
output.len() - MAX_TOOL_RESULT_CHARS + 100
|
output.len() - MAX_TOOL_RESULT_CHARS + 100
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@ -81,7 +81,7 @@ fn truncate_tool_result(output: &str) -> String {
|
|||||||
format!(
|
format!(
|
||||||
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
||||||
truncated_start_len,
|
truncated_start_len,
|
||||||
&output[truncated_start_len..]
|
&output[output.floor_char_boundary(truncated_start_len)..]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,22 +50,26 @@ pub struct ContextCompressor {
|
|||||||
threshold_ratio: f64,
|
threshold_ratio: f64,
|
||||||
/// Shared LLM provider for summarization
|
/// Shared LLM provider for summarization
|
||||||
provider: Arc<dyn LLMProvider>,
|
provider: Arc<dyn LLMProvider>,
|
||||||
/// Memory manager handle (optional). When set, compressed
|
/// Memory manager handle. Compressed context summaries are persisted
|
||||||
/// context summaries are persisted as timeline memory entries.
|
/// as timeline memory entries.
|
||||||
memory: Option<Arc<MemoryManager>>,
|
memory: Arc<MemoryManager>,
|
||||||
/// Current session ID for timeline memory writes.
|
/// Current session ID for timeline memory writes.
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextCompressor {
|
impl ContextCompressor {
|
||||||
/// Create a new compressor with the given provider and context window size.
|
/// Create a new compressor with the given provider, context window size, and memory manager.
|
||||||
pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self {
|
pub fn new(
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
|
context_window: usize,
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config: ContextCompressionConfig::default(),
|
config: ContextCompressionConfig::default(),
|
||||||
context_window,
|
context_window,
|
||||||
threshold_ratio: 0.5,
|
threshold_ratio: 0.5,
|
||||||
provider,
|
provider,
|
||||||
memory: None,
|
memory,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,23 +79,18 @@ impl ContextCompressor {
|
|||||||
provider: Arc<dyn LLMProvider>,
|
provider: Arc<dyn LLMProvider>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
config: ContextCompressionConfig,
|
config: ContextCompressionConfig,
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
context_window,
|
context_window,
|
||||||
threshold_ratio: 0.5,
|
threshold_ratio: 0.5,
|
||||||
provider,
|
provider,
|
||||||
memory: None,
|
memory,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Attach a memory manager to persist compressed summaries.
|
|
||||||
pub fn with_memory(mut self, memory: Arc<MemoryManager>) -> Self {
|
|
||||||
self.memory = Some(memory);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the current session ID for timeline writes.
|
/// Set the current session ID for timeline writes.
|
||||||
pub fn set_session_id(&mut self, id: Option<String>) {
|
pub fn set_session_id(&mut self, id: Option<String>) {
|
||||||
self.session_id = id;
|
self.session_id = id;
|
||||||
@ -113,7 +112,7 @@ impl ContextCompressor {
|
|||||||
let removed = msg.content.len() - limit;
|
let removed = msg.content.len() - limit;
|
||||||
msg.content = format!(
|
msg.content = format!(
|
||||||
"{}...\n\n[Output truncated - {} characters removed]",
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
&msg.content[..limit.min(msg.content.len())],
|
&msg.content[..msg.content.ceil_char_boundary(limit)],
|
||||||
removed
|
removed
|
||||||
);
|
);
|
||||||
modified += 1;
|
modified += 1;
|
||||||
@ -240,25 +239,23 @@ impl ContextCompressor {
|
|||||||
let summary = self.summarize_segment(between).await?;
|
let summary = self.summarize_segment(between).await?;
|
||||||
|
|
||||||
// Persist compressed summary as timeline memory entry
|
// Persist compressed summary as timeline memory entry
|
||||||
if let Some(ref mm) = self.memory {
|
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||||
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
ts, between.len(), summary);
|
||||||
ts, between.len(), summary);
|
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
let mm = self.memory.clone();
|
||||||
let mm_clone = mm.clone();
|
let sid = self.session_id.clone();
|
||||||
let sid = self.session_id.clone();
|
tokio::spawn(async move {
|
||||||
tokio::spawn(async move {
|
if let Err(e) = mm.store(
|
||||||
if let Err(e) = mm_clone.store(
|
&key,
|
||||||
&key,
|
&timeline_content,
|
||||||
&timeline_content,
|
crate::memory::MemoryCategory::Timeline,
|
||||||
crate::memory::MemoryCategory::Timeline,
|
sid.as_deref(),
|
||||||
sid.as_deref(),
|
Some(0.3),
|
||||||
Some(0.3),
|
).await {
|
||||||
).await {
|
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
}
|
||||||
}
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add summary as a special user message
|
// Add summary as a special user message
|
||||||
new_messages.push(ChatMessage::user(format!(
|
new_messages.push(ChatMessage::user(format!(
|
||||||
@ -316,7 +313,7 @@ impl ContextCompressor {
|
|||||||
let transcript = if transcript.len() > self.config.summary_max_chars {
|
let transcript = if transcript.len() > self.config.summary_max_chars {
|
||||||
format!(
|
format!(
|
||||||
"{}...\n\n[Transcript truncated - {} characters removed]",
|
"{}...\n\n[Transcript truncated - {} characters removed]",
|
||||||
&transcript[..self.config.summary_max_chars],
|
&transcript[..transcript.ceil_char_boundary(self.config.summary_max_chars)],
|
||||||
transcript.len() - self.config.summary_max_chars
|
transcript.len() - self.config.summary_max_chars
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@ -359,7 +356,7 @@ Be concise, aim for {} characters or less.
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback: just truncate the transcript
|
// Fallback: just truncate the transcript
|
||||||
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
||||||
Ok(transcript[..transcript.len().min(2000)].to_string())
|
Ok(transcript[..transcript.ceil_char_boundary(2000)].to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -370,6 +367,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::providers::ChatCompletionResponse;
|
use crate::providers::ChatCompletionResponse;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
|
||||||
/// Mock provider for testing - panics if actually used for LLM calls
|
/// Mock provider for testing - panics if actually used for LLM calls
|
||||||
struct MockProvider;
|
struct MockProvider;
|
||||||
@ -400,6 +398,18 @@ mod tests {
|
|||||||
Arc::new(MockProvider)
|
Arc::new(MockProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_memory_manager() -> Arc<MemoryManager> {
|
||||||
|
static MM: OnceLock<Arc<MemoryManager>> = OnceLock::new();
|
||||||
|
MM.get_or_init(|| {
|
||||||
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||||
|
rt.block_on(async {
|
||||||
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||||
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
|
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
||||||
|
})
|
||||||
|
}).clone()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_estimate_tokens() {
|
fn test_estimate_tokens() {
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
@ -422,7 +432,7 @@ mod tests {
|
|||||||
tool_result_trim_chars: 50,
|
tool_result_trim_chars: 50,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config);
|
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||||
|
|
||||||
let mut messages = vec![
|
let mut messages = vec![
|
||||||
ChatMessage::user("Hello"),
|
ChatMessage::user("Hello"),
|
||||||
@ -436,7 +446,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_threshold() {
|
fn test_threshold() {
|
||||||
let compressor = ContextCompressor::new(mock_provider(), 128_000);
|
let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager());
|
||||||
assert_eq!(compressor.threshold(), 64_000);
|
assert_eq!(compressor.threshold(), 64_000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -243,10 +243,9 @@ impl PromptSection for CrossChannelSection {
|
|||||||
- dialog_id: 对话标识,同一 chat 下可以有多个 dialog
|
- dialog_id: 对话标识,同一 chat 下可以有多个 dialog
|
||||||
|
|
||||||
{}### 跨会话消息
|
{}### 跨会话消息
|
||||||
对话历史中可能出现带有 `[message from X to Y]` 前缀的 assistant 消息,
|
对话历史中可能出现带有 `[message from X]` 前缀的 assistant 消息,
|
||||||
表示此消息由 send_message 工具从别处发送过来。
|
表示此消息由 send_message 工具从别处发送过来。
|
||||||
- X: 来源标识,可能是会话 ID、工具名或其他标识字符串;未指定时为 "unknown"
|
- X: 来源标识,可能是会话 ID、工具名或其他标识字符串;未指定时为 "unknown"
|
||||||
- Y: 目标会话的完整 session ID (<channel>:<chat_id>:<dialog_id>)
|
|
||||||
|
|
||||||
收到此类消息时一般不需要主动处理,只需知晓。如果用户问及相关信息,
|
收到此类消息时一般不需要主动处理,只需知晓。如果用户问及相关信息,
|
||||||
可以尝试从来源处获取更多详情。
|
可以尝试从来源处获取更多详情。
|
||||||
|
|||||||
@ -778,7 +778,7 @@ impl FeishuChannel {
|
|||||||
|
|
||||||
let payload_content = if msg_type == "text" {
|
let payload_content = if msg_type == "text" {
|
||||||
let truncated = if content.len() > MAX_TEXT_LENGTH {
|
let truncated = if content.len() > MAX_TEXT_LENGTH {
|
||||||
format!("{}...\n\n[Content truncated due to length limit]", &content[..MAX_TEXT_LENGTH])
|
format!("{}...\n\n[Content truncated due to length limit]", &content[..content.ceil_char_boundary(MAX_TEXT_LENGTH)])
|
||||||
} else {
|
} else {
|
||||||
content.to_string()
|
content.to_string()
|
||||||
};
|
};
|
||||||
@ -788,7 +788,7 @@ impl FeishuChannel {
|
|||||||
// But we still need to check length
|
// But we still need to check length
|
||||||
if content.len() > MAX_TEXT_LENGTH {
|
if content.len() > MAX_TEXT_LENGTH {
|
||||||
// Fallback to truncated text for post as well
|
// Fallback to truncated text for post as well
|
||||||
serde_json::json!({ "text": format!("{}...\n\n[Content truncated due to length limit]", &content[..MAX_TEXT_LENGTH]) }).to_string()
|
serde_json::json!({ "text": format!("{}...\n\n[Content truncated due to length limit]", &content[..content.ceil_char_boundary(MAX_TEXT_LENGTH)]) }).to_string()
|
||||||
} else {
|
} else {
|
||||||
content.to_string()
|
content.to_string()
|
||||||
}
|
}
|
||||||
@ -2136,7 +2136,7 @@ impl Channel for FeishuChannel {
|
|||||||
if !msg.content.is_empty() {
|
if !msg.content.is_empty() {
|
||||||
const MAX_TEXT_LENGTH: usize = 60_000;
|
const MAX_TEXT_LENGTH: usize = 60_000;
|
||||||
let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH {
|
let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH {
|
||||||
format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..MAX_TEXT_LENGTH])
|
format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..msg.content.ceil_char_boundary(MAX_TEXT_LENGTH)])
|
||||||
} else {
|
} else {
|
||||||
msg.content.clone()
|
msg.content.clone()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -220,15 +220,14 @@ impl Default for ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct MemoryConfig {
|
pub struct MemoryConfig {
|
||||||
/// Master switch for the memory system.
|
|
||||||
#[serde(default)]
|
|
||||||
pub enabled: bool,
|
|
||||||
/// Provider name for consolidation LLM calls (key in `providers`).
|
/// Provider name for consolidation LLM calls (key in `providers`).
|
||||||
|
/// If not set, falls back to the main agent's provider.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub consolidation_provider: Option<String>,
|
pub consolidation_provider: Option<String>,
|
||||||
/// Model name for consolidation LLM calls (key in `models`).
|
/// Model name for consolidation LLM calls (key in `models`).
|
||||||
|
/// If not set, falls back to the main agent's model.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub consolidation_model: Option<String>,
|
pub consolidation_model: Option<String>,
|
||||||
/// Max knowledge entries injected into system prompt per turn.
|
/// Max knowledge entries injected into system prompt per turn.
|
||||||
@ -248,8 +247,34 @@ pub struct MemoryConfig {
|
|||||||
pub max_failures_before_degrade: usize,
|
pub max_failures_before_degrade: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for MemoryConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
consolidation_provider: None,
|
||||||
|
consolidation_model: None,
|
||||||
|
recall_limit: 5,
|
||||||
|
consolidation_turn_threshold: 10,
|
||||||
|
idle_consolidation_minutes: 10,
|
||||||
|
timeline_retention_days: 90,
|
||||||
|
max_failures_before_degrade: 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryConfig {
|
||||||
|
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
||||||
|
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
||||||
|
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve consolidation model name, falling back to the main agent's model.
|
||||||
|
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
||||||
|
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn default_recall_limit() -> usize { 5 }
|
fn default_recall_limit() -> usize { 5 }
|
||||||
fn default_consolidation_turn_threshold() -> usize { 3 }
|
fn default_consolidation_turn_threshold() -> usize { 10 }
|
||||||
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
||||||
fn default_timeline_retention_days() -> u64 { 90 }
|
fn default_timeline_retention_days() -> u64 { 90 }
|
||||||
fn default_max_failures_before_degrade() -> usize { 3 }
|
fn default_max_failures_before_degrade() -> usize { 3 }
|
||||||
|
|||||||
@ -56,14 +56,23 @@ impl GatewayState {
|
|||||||
);
|
);
|
||||||
tracing::info!("Session storage: {}", db_path.display());
|
tracing::info!("Session storage: {}", db_path.display());
|
||||||
|
|
||||||
// Initialize MemoryManager if memory system is enabled
|
// Resolve consolidation provider/model with fallback to main agent config
|
||||||
let memory_manager = if config.memory.enabled {
|
let consolidation_provider = config
|
||||||
let mm = Arc::new(MemoryManager::new(storage.clone()));
|
.memory
|
||||||
tracing::info!("Memory system enabled");
|
.resolve_consolidation_provider(&provider_config.name);
|
||||||
Some(mm)
|
let consolidation_model = config
|
||||||
} else {
|
.memory
|
||||||
None
|
.resolve_consolidation_model(&provider_config.model_id);
|
||||||
};
|
let memory_manager = Arc::new(MemoryManager::new(
|
||||||
|
storage.clone(),
|
||||||
|
consolidation_provider,
|
||||||
|
consolidation_model,
|
||||||
|
));
|
||||||
|
tracing::info!(
|
||||||
|
consolidation_provider = %memory_manager.consolidation_provider,
|
||||||
|
consolidation_model = %memory_manager.consolidation_model,
|
||||||
|
"Memory system initialized"
|
||||||
|
);
|
||||||
|
|
||||||
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
||||||
let bus = MessageBus::new(100);
|
let bus = MessageBus::new(100);
|
||||||
|
|||||||
@ -11,11 +11,21 @@ pub use types::{ConsolidationFact, ConsolidationResult, MemoryCategory, MemoryEn
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MemoryManager {
|
pub struct MemoryManager {
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
|
pub consolidation_provider: String,
|
||||||
|
pub consolidation_model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MemoryManager {
|
impl MemoryManager {
|
||||||
pub fn new(storage: Arc<Storage>) -> Self {
|
pub fn new(
|
||||||
Self { storage }
|
storage: Arc<Storage>,
|
||||||
|
consolidation_provider: String,
|
||||||
|
consolidation_model: String,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
storage,
|
||||||
|
consolidation_provider,
|
||||||
|
consolidation_model,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store or update a memory entry. Generates timestamp and UUID.
|
/// Store or update a memory entry. Generates timestamp and UUID.
|
||||||
@ -58,11 +68,12 @@ impl MemoryManager {
|
|||||||
&self,
|
&self,
|
||||||
since: i64,
|
since: i64,
|
||||||
until: i64,
|
until: i64,
|
||||||
|
query: Option<&str>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
category: Option<MemoryCategory>,
|
category: Option<MemoryCategory>,
|
||||||
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
||||||
self.storage
|
self.storage
|
||||||
.search_memories_by_time(since, until, category.as_ref(), limit)
|
.search_memories_by_time(since, until, query, category.as_ref(), limit)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +98,7 @@ mod tests {
|
|||||||
let dir = tempdir().unwrap();
|
let dir = tempdir().unwrap();
|
||||||
let db_path = dir.path().join("test.db");
|
let db_path = dir.path().join("test.db");
|
||||||
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||||
let mm = Arc::new(MemoryManager::new(storage));
|
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
|
||||||
(mm, dir)
|
(mm, dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -202,7 +202,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
} else {
|
} else {
|
||||||
let mut blocks = convert_content_blocks(&m.content);
|
let mut blocks = convert_content_blocks(&m.content);
|
||||||
// Append tool_use blocks from assistant messages with tool calls
|
// Append tool_use blocks from assistant messages with tool calls
|
||||||
if let Some(ref tool_calls) = m.tool_calls {
|
if let Some(tool_calls) = m.tool_calls.as_ref().filter(|c| !c.is_empty()) {
|
||||||
for tc in tool_calls {
|
for tc in tool_calls {
|
||||||
blocks.push(serde_json::json!({
|
blocks.push(serde_json::json!({
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
|
|||||||
@ -77,7 +77,7 @@ impl OpenAIProvider {
|
|||||||
"tool_call_id": m.tool_call_id,
|
"tool_call_id": m.tool_call_id,
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
})
|
})
|
||||||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
} else if m.role == "assistant" && m.tool_calls.as_ref().map_or(false, |c| !c.is_empty()) {
|
||||||
json!({
|
json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(&m.content),
|
"content": convert_content_blocks(&m.content),
|
||||||
|
|||||||
@ -147,7 +147,7 @@ impl Scheduler {
|
|||||||
let _ = self.bus.publish_outbound(outbound).await;
|
let _ = self.bus.publish_outbound(outbound).await;
|
||||||
|
|
||||||
let output_truncated = if output.len() > 8000 {
|
let output_truncated = if output.len() > 8000 {
|
||||||
format!("{}...[truncated]", &output[..8000])
|
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
||||||
} else {
|
} else {
|
||||||
output.clone()
|
output.clone()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -58,7 +58,7 @@ pub struct Session {
|
|||||||
/// Timestamp (Unix ms) of the last consolidation.
|
/// Timestamp (Unix ms) of the last consolidation.
|
||||||
/// Messages before this time have been compressed into memory.
|
/// Messages before this time have been compressed into memory.
|
||||||
pub last_consolidated_at: Option<i64>,
|
pub last_consolidated_at: Option<i64>,
|
||||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@ -69,7 +69,7 @@ impl Session {
|
|||||||
storage: Option<StdArc<Storage>>,
|
storage: Option<StdArc<Storage>>,
|
||||||
routing_info: String,
|
routing_info: String,
|
||||||
title: String,
|
title: String,
|
||||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let mut provider_box = create_provider(provider_config.clone())
|
let mut provider_box = create_provider(provider_config.clone())
|
||||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||||
@ -83,11 +83,8 @@ impl Session {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config);
|
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||||
if let Some(ref mm) = memory_manager {
|
compressor.set_session_id(Some(id.to_string()));
|
||||||
compressor = compressor.with_memory(mm.clone());
|
|
||||||
compressor.set_session_id(Some(id.to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
|
||||||
@ -117,7 +114,7 @@ impl Session {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
storage: StdArc<Storage>,
|
storage: StdArc<Storage>,
|
||||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let session_meta = storage.get_session(&id.to_string()).await
|
let session_meta = storage.get_session(&id.to_string()).await
|
||||||
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||||
@ -135,28 +132,28 @@ impl Session {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config);
|
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||||
if let Some(ref mm) = memory_manager {
|
compressor.set_session_id(Some(id.to_string()));
|
||||||
compressor = compressor.with_memory(mm.clone());
|
|
||||||
compressor.set_session_id(Some(id.to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert MessageMeta to ChatMessage
|
// Convert MessageMeta to ChatMessage, then repair damaged tool call chains
|
||||||
// Clear tool_call_id/tool_name — they're not valid across API sessions
|
let mut chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
||||||
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
|
||||||
ChatMessage {
|
ChatMessage {
|
||||||
id: m.id,
|
id: m.id,
|
||||||
role: m.role,
|
role: m.role,
|
||||||
content: m.content,
|
content: m.content,
|
||||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||||
timestamp: m.created_at,
|
timestamp: m.created_at,
|
||||||
tool_call_id: None,
|
tool_call_id: m.tool_call_id,
|
||||||
tool_name: None,
|
tool_name: m.tool_name,
|
||||||
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
|
tool_calls: m.tool_calls
|
||||||
|
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||||
|
.filter(|v| !v.is_empty()),
|
||||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||||
}
|
}
|
||||||
}).collect();
|
}).collect();
|
||||||
|
|
||||||
|
repair_tool_call_chains(&mut chat_messages);
|
||||||
|
|
||||||
let seq_counter = chat_messages.len() as i64 + 1;
|
let seq_counter = chat_messages.len() as i64 + 1;
|
||||||
let total_message_count = chat_messages.len() as i64;
|
let total_message_count = chat_messages.len() as i64;
|
||||||
|
|
||||||
@ -211,7 +208,7 @@ impl Session {
|
|||||||
},
|
},
|
||||||
tool_call_id: message.tool_call_id.clone(),
|
tool_call_id: message.tool_call_id.clone(),
|
||||||
tool_name: message.tool_name.clone(),
|
tool_name: message.tool_name.clone(),
|
||||||
tool_calls: message.tool_calls.as_ref().map(|tc| serde_json::to_string(tc).unwrap_or_default()),
|
tool_calls: message.tool_calls.as_ref().and_then(|tc| serde_json::to_string(tc).ok()),
|
||||||
source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()),
|
source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()),
|
||||||
created_at: now,
|
created_at: now,
|
||||||
};
|
};
|
||||||
@ -574,6 +571,67 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Repair damaged tool call chains after restoring from storage.
|
||||||
|
/// Handles cases where the gateway crashed mid-loop, leaving assistant
|
||||||
|
/// tool_calls without corresponding tool result messages.
|
||||||
|
fn repair_tool_call_chains(messages: &mut Vec<ChatMessage>) {
|
||||||
|
let mut i = 0;
|
||||||
|
while i < messages.len() {
|
||||||
|
let calls = match &messages[i].tool_calls {
|
||||||
|
Some(calls) if !calls.is_empty() => calls.clone(),
|
||||||
|
_ => {
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if messages[i].role != "assistant" {
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect expected tool call IDs
|
||||||
|
let expected_ids: std::collections::HashSet<&str> = calls.iter().map(|c| c.id.as_str()).collect();
|
||||||
|
let expected_count = expected_ids.len();
|
||||||
|
|
||||||
|
// Check following messages for matching tool results (same tool_call_id)
|
||||||
|
let mut found = 0;
|
||||||
|
let mut j = i + 1;
|
||||||
|
while j < messages.len() && found < expected_count {
|
||||||
|
if messages[j].role == "tool" {
|
||||||
|
if let Some(ref tc_id) = messages[j].tool_call_id {
|
||||||
|
if expected_ids.contains(tc_id.as_str()) {
|
||||||
|
found += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if messages[j].role == "user" || messages[j].role == "assistant" {
|
||||||
|
// Next user/assistant message — stop scanning, chain is broken
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if found < expected_count {
|
||||||
|
// Incomplete chain: remove tool_calls and add interruption note
|
||||||
|
tracing::warn!(
|
||||||
|
found,
|
||||||
|
expected = expected_count,
|
||||||
|
"Repairing incomplete tool call chain — gateway restart likely interrupted execution"
|
||||||
|
);
|
||||||
|
let old_content = std::mem::take(&mut messages[i].content);
|
||||||
|
messages[i].content = format!(
|
||||||
|
"{}\n\n[Tool calls ({}): {} — execution interrupted by gateway restart]",
|
||||||
|
old_content,
|
||||||
|
expected_count,
|
||||||
|
calls.iter().map(|c| c.name.as_str()).collect::<Vec<_>>().join(", ")
|
||||||
|
);
|
||||||
|
messages[i].tool_calls = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SessionManager {
|
pub struct SessionManager {
|
||||||
@ -584,7 +642,7 @@ pub struct SessionManager {
|
|||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
bus: Arc<MessageBus>,
|
bus: Arc<MessageBus>,
|
||||||
current_source_session: Arc<Mutex<Option<String>>>,
|
current_source_session: Arc<Mutex<Option<String>>>,
|
||||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SessionManagerInner {
|
struct SessionManagerInner {
|
||||||
@ -672,7 +730,7 @@ impl SessionManager {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
bus: Arc<MessageBus>,
|
bus: Arc<MessageBus>,
|
||||||
memory_manager: Option<Arc<crate::memory::MemoryManager>>,
|
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let skills_loader = SkillsLoader::new();
|
let skills_loader = SkillsLoader::new();
|
||||||
skills_loader.load_skills();
|
skills_loader.load_skills();
|
||||||
@ -1252,22 +1310,18 @@ impl SessionManager {
|
|||||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||||
|
|
||||||
// Fetch memory context
|
// Fetch memory context
|
||||||
let memory_context = if let Some(ref mm) = self.memory_manager {
|
let memory_context = match self.memory_manager.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
|
||||||
match mm.recall(&content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
|
Ok(entries) if !entries.is_empty() => {
|
||||||
Ok(entries) if !entries.is_empty() => {
|
Some(entries.iter()
|
||||||
Some(entries.iter()
|
.map(|e| format!("- {}: {}", e.key, e.content))
|
||||||
.map(|e| format!("- {}: {}", e.key, e.content))
|
.collect::<Vec<_>>()
|
||||||
.collect::<Vec<_>>()
|
.join("\n"))
|
||||||
.join("\n"))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(error = %e, "Failed to fetch memory context");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
}
|
||||||
} else {
|
Err(e) => {
|
||||||
None
|
tracing::warn!(error = %e, "Failed to fetch memory context");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build combined system prompt and inject at position 0
|
// Build combined system prompt and inject at position 0
|
||||||
@ -1409,11 +1463,9 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let raw_response = result.final_response.content;
|
let raw_response = result.final_response.content;
|
||||||
|
|
||||||
let target_id = unified_id.to_string();
|
|
||||||
let prefix = format!(
|
let prefix = format!(
|
||||||
"[message from cron:{}({}) to {}]\n",
|
"[message from cron:{}({})]\n",
|
||||||
job_name, job_id, target_id
|
job_name, job_id
|
||||||
);
|
);
|
||||||
let prefixed_response = format!("{}{}", prefix, raw_response);
|
let prefixed_response = format!("{}{}", prefix, raw_response);
|
||||||
|
|
||||||
@ -1491,11 +1543,10 @@ impl OutboundMessenger for SessionManager {
|
|||||||
(sid, session)
|
(sid, session)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build message prefix: [message from <origin> to <channel:chat_id:dialog_id>]
|
// Build message prefix: [message from <origin>]
|
||||||
let target_id = target_sid.to_string();
|
|
||||||
let origin = source.from_session.as_deref().unwrap_or("unknown");
|
let origin = source.from_session.as_deref().unwrap_or("unknown");
|
||||||
let origin_id = source.from_session.clone();
|
let origin_id = source.from_session.clone();
|
||||||
let prefix = format!("[message from {} to {}] ", origin, target_id);
|
let prefix = format!("[message from {}] ", origin);
|
||||||
let marked_content = format!("{}\n{}", prefix, content);
|
let marked_content = format!("{}\n{}", prefix, content);
|
||||||
|
|
||||||
// Write source-tagged assistant message to target session history
|
// Write source-tagged assistant message to target session history
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
use sqlx::Row;
|
use sqlx::Row;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
|
||||||
|
use jieba_rs::Jieba;
|
||||||
|
|
||||||
use crate::memory::{MemoryCategory, MemoryEntry};
|
use crate::memory::{MemoryCategory, MemoryEntry};
|
||||||
|
|
||||||
use super::StorageError;
|
use super::StorageError;
|
||||||
|
|
||||||
|
fn jieba() -> &'static Jieba {
|
||||||
|
static INSTANCE: OnceLock<Jieba> = OnceLock::new();
|
||||||
|
INSTANCE.get_or_init(Jieba::new)
|
||||||
|
}
|
||||||
|
|
||||||
impl super::Storage {
|
impl super::Storage {
|
||||||
/// Store or update a memory entry (upsert by key).
|
/// Store or update a memory entry (upsert by key).
|
||||||
pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> {
|
pub async fn upsert_memory(&self, entry: &MemoryEntry) -> Result<(), StorageError> {
|
||||||
@ -50,9 +58,11 @@ impl super::Storage {
|
|||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
// Build FTS5 query: wrap each word in quotes and join with OR
|
// Build FTS5 query: segment with jieba, wrap each term in quotes, join with OR
|
||||||
let fts_query = query
|
let fts_query = jieba()
|
||||||
.split_whitespace()
|
.cut(query, true)
|
||||||
|
.into_iter()
|
||||||
|
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
|
||||||
.map(|w| format!("\"{}\"", w.replace('"', "")))
|
.map(|w| format!("\"{}\"", w.replace('"', "")))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(" OR ");
|
.join(" OR ");
|
||||||
@ -80,39 +90,59 @@ impl super::Storage {
|
|||||||
|
|
||||||
let mut entries = parse_memory_rows(&rows)?;
|
let mut entries = parse_memory_rows(&rows)?;
|
||||||
|
|
||||||
// Fallback to LIKE if FTS5 returned nothing
|
// Fallback to term-based LIKE query if FTS5 returned nothing
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
let like_pattern = format!("%{}%", query.replace('%', "").replace('_', ""));
|
let terms: Vec<String> = jieba()
|
||||||
let rows = sqlx::query(
|
.cut(query, true)
|
||||||
r#"
|
.into_iter()
|
||||||
SELECT id, key, content, category, importance,
|
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
|
||||||
session_id, created_at, updated_at
|
.map(|w| w.replace('%', "").replace('_', ""))
|
||||||
FROM memories
|
.collect();
|
||||||
WHERE (key LIKE ? OR content LIKE ?)
|
|
||||||
AND (? IS NULL OR category = ?)
|
|
||||||
ORDER BY importance DESC, updated_at DESC
|
|
||||||
LIMIT ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&like_pattern)
|
|
||||||
.bind(&like_pattern)
|
|
||||||
.bind(category_filter)
|
|
||||||
.bind(category_filter)
|
|
||||||
.bind(limit as i64)
|
|
||||||
.fetch_all(self.pool())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
entries = parse_memory_rows(&rows)?;
|
if !terms.is_empty() {
|
||||||
|
let like_clauses = terms
|
||||||
|
.iter()
|
||||||
|
.map(|_| "(key LIKE ? OR content LIKE ?)")
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" OR ");
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
r#"
|
||||||
|
SELECT id, key, content, category, importance,
|
||||||
|
session_id, created_at, updated_at
|
||||||
|
FROM memories
|
||||||
|
WHERE ({})
|
||||||
|
AND (? IS NULL OR category = ?)
|
||||||
|
ORDER BY importance DESC, updated_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
like_clauses
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut query_builder = sqlx::query(&sql);
|
||||||
|
for term in &terms {
|
||||||
|
let pattern = format!("%{}%", term);
|
||||||
|
query_builder = query_builder.bind(pattern.clone()).bind(pattern);
|
||||||
|
}
|
||||||
|
query_builder = query_builder
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(limit as i64);
|
||||||
|
|
||||||
|
let rows = query_builder.fetch_all(self.pool()).await?;
|
||||||
|
entries = parse_memory_rows(&rows)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(entries)
|
Ok(entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve memories within a time range.
|
/// Retrieve memories within a time range, optionally filtered by keyword query.
|
||||||
pub async fn search_memories_by_time(
|
pub async fn search_memories_by_time(
|
||||||
&self,
|
&self,
|
||||||
since: i64,
|
since: i64,
|
||||||
until: i64,
|
until: i64,
|
||||||
|
query: Option<&str>,
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
@ -124,24 +154,71 @@ impl super::Storage {
|
|||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.to_rfc3339();
|
.to_rfc3339();
|
||||||
|
|
||||||
let rows = sqlx::query(
|
let rows = if let Some(q) = query {
|
||||||
r#"
|
let terms: Vec<String> = jieba()
|
||||||
SELECT id, key, content, category, importance,
|
.cut(q, true)
|
||||||
session_id, created_at, updated_at
|
.into_iter()
|
||||||
FROM memories
|
.filter(|w| w.len() > 1 || w.bytes().any(|b| b > 127))
|
||||||
WHERE created_at >= ? AND created_at <= ?
|
.map(|w| w.replace('%', "").replace('_', ""))
|
||||||
AND (? IS NULL OR category = ?)
|
.collect();
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT ?
|
if terms.is_empty() {
|
||||||
"#,
|
return Ok(Vec::new());
|
||||||
)
|
}
|
||||||
.bind(&since_dt)
|
|
||||||
.bind(&until_dt)
|
let like_clauses = terms
|
||||||
.bind(category_filter)
|
.iter()
|
||||||
.bind(category_filter)
|
.map(|_| "(key LIKE ? OR content LIKE ?)")
|
||||||
.bind(limit as i64)
|
.collect::<Vec<_>>()
|
||||||
.fetch_all(self.pool())
|
.join(" OR ");
|
||||||
.await?;
|
|
||||||
|
let sql = format!(
|
||||||
|
r#"
|
||||||
|
SELECT id, key, content, category, importance,
|
||||||
|
session_id, created_at, updated_at
|
||||||
|
FROM memories
|
||||||
|
WHERE ({})
|
||||||
|
AND created_at >= ? AND created_at <= ?
|
||||||
|
AND (? IS NULL OR category = ?)
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
like_clauses
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut query_builder = sqlx::query(&sql);
|
||||||
|
for term in &terms {
|
||||||
|
let pattern = format!("%{}%", term);
|
||||||
|
query_builder = query_builder.bind(pattern.clone()).bind(pattern);
|
||||||
|
}
|
||||||
|
query_builder = query_builder
|
||||||
|
.bind(&since_dt)
|
||||||
|
.bind(&until_dt)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(limit as i64);
|
||||||
|
|
||||||
|
query_builder.fetch_all(self.pool()).await?
|
||||||
|
} else {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, key, content, category, importance,
|
||||||
|
session_id, created_at, updated_at
|
||||||
|
FROM memories
|
||||||
|
WHERE created_at >= ? AND created_at <= ?
|
||||||
|
AND (? IS NULL OR category = ?)
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&since_dt)
|
||||||
|
.bind(&until_dt)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(category_filter)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
parse_memory_rows(&rows)
|
parse_memory_rows(&rows)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -127,6 +127,48 @@ impl Storage {
|
|||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Triggers to keep FTS5 index in sync with memories table
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
||||||
|
INSERT INTO memory_fts(rowid, key, content) VALUES (new.rowid, new.key, new.content);
|
||||||
|
END
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
||||||
|
INSERT INTO memory_fts(memory_fts, rowid, key, content)
|
||||||
|
VALUES ('delete', old.rowid, old.key, old.content);
|
||||||
|
END
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
||||||
|
INSERT INTO memory_fts(memory_fts, rowid, key, content)
|
||||||
|
VALUES ('delete', old.rowid, old.key, old.content);
|
||||||
|
INSERT INTO memory_fts(rowid, key, content)
|
||||||
|
VALUES (new.rowid, new.key, new.content);
|
||||||
|
END
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Rebuild FTS5 index for any existing records
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Migration: add last_consolidated_at column if not exists
|
// Migration: add last_consolidated_at column if not exists
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
|
|||||||
@ -68,9 +68,9 @@ impl BashTool {
|
|||||||
let half = MAX_OUTPUT_CHARS / 2;
|
let half = MAX_OUTPUT_CHARS / 2;
|
||||||
format!(
|
format!(
|
||||||
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
||||||
&output[..half],
|
&output[..output.ceil_char_boundary(half)],
|
||||||
output.len() - MAX_OUTPUT_CHARS,
|
output.len() - MAX_OUTPUT_CHARS,
|
||||||
&output[output.len() - half..]
|
&output[output.floor_char_boundary(output.len() - half)..]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -101,7 +101,7 @@ impl HttpRequestTool {
|
|||||||
if text.len() > self.max_response_size {
|
if text.len() > self.max_response_size {
|
||||||
format!(
|
format!(
|
||||||
"{}\n\n... [Response truncated due to size limit] ...",
|
"{}\n\n... [Response truncated due to size limit] ...",
|
||||||
&text[..self.max_response_size]
|
&text[..text.ceil_char_boundary(self.max_response_size)]
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
text.to_string()
|
text.to_string()
|
||||||
|
|||||||
265
src/tools/memory.rs
Normal file
265
src/tools/memory.rs
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
use super::traits::{Tool, ToolResult};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::memory::{MemoryCategory, MemoryManager};
|
||||||
|
|
||||||
|
// ── MemoryStoreTool ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct MemoryStoreTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryStoreTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MemoryStoreTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory_store"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Store a fact, preference, or insight into long-term memory. \
|
||||||
|
Use this when the user shares important information you should remember. \
|
||||||
|
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
||||||
|
and the full content to remember."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The full content of the memory entry."
|
||||||
|
},
|
||||||
|
"category": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["knowledge", "timeline"],
|
||||||
|
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
|
||||||
|
},
|
||||||
|
"importance": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["key", "content"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let key = args
|
||||||
|
.get("key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||||
|
|
||||||
|
let content = args
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
||||||
|
|
||||||
|
let category = args
|
||||||
|
.get("category")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(MemoryCategory::from_str)
|
||||||
|
.unwrap_or(MemoryCategory::Knowledge);
|
||||||
|
|
||||||
|
let importance = args.get("importance").and_then(|v| v.as_f64());
|
||||||
|
|
||||||
|
self.memory
|
||||||
|
.store(key, content, category, None, importance)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Memory stored: {}", key),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── MemoryRecallTool ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct MemoryRecallTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryRecallTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MemoryRecallTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory_recall"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Search and retrieve entries from long-term memory using keyword matching. \
|
||||||
|
Use this to recall previously stored facts, preferences, or conversation history. \
|
||||||
|
IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence). \
|
||||||
|
Use multiple synonymous or related terms to increase recall. \
|
||||||
|
Example: instead of 'what is the user location', use 'user location address city residence'. \
|
||||||
|
Supports optional time-range filtering via since/until (Unix ms)."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Space-separated KEYWORDS for memory search (NOT a natural language question). Use multiple related terms for better recall, e.g. 'address city location residence'."
|
||||||
|
},
|
||||||
|
"category": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["knowledge", "timeline"],
|
||||||
|
"description": "Filter by memory category. Omit to search all categories."
|
||||||
|
},
|
||||||
|
"since": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Start of time range (Unix milliseconds)."
|
||||||
|
},
|
||||||
|
"until": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "End of time range (Unix milliseconds)."
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Max results to return (default 10)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let query = args
|
||||||
|
.get("query")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
||||||
|
|
||||||
|
let category = args
|
||||||
|
.get("category")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(MemoryCategory::from_str);
|
||||||
|
|
||||||
|
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
||||||
|
|
||||||
|
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
||||||
|
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
|
||||||
|
let until = args
|
||||||
|
.get("until")
|
||||||
|
.and_then(|v| v.as_i64())
|
||||||
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
|
self.memory
|
||||||
|
.recall_by_time(since, until, Some(query), limit, category)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
self.memory.recall(query, limit, category).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if entries.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "No matching memories found.".to_string(),
|
||||||
|
error: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let formatted = entries
|
||||||
|
.iter()
|
||||||
|
.map(|e| {
|
||||||
|
format!(
|
||||||
|
"- {} [{}] [importance: {:.1}]: {}",
|
||||||
|
e.key,
|
||||||
|
e.category.as_str(),
|
||||||
|
e.importance,
|
||||||
|
e.content
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Found {} memories:\n{}", entries.len(), formatted),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── MemoryForgetTool ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct MemoryForgetTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryForgetTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MemoryForgetTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"memory_forget"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Delete a memory entry by its key. Use this when information is outdated, \
|
||||||
|
incorrect, or the user asks to forget something."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The key of the memory entry to delete."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["key"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let key = args
|
||||||
|
.get("key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
||||||
|
|
||||||
|
self.memory.forget(key).await?;
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Memory deleted: {}", key),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,60 +0,0 @@
|
|||||||
use super::traits::{Tool, ToolResult};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::memory::MemoryManager;
|
|
||||||
|
|
||||||
pub struct MemoryForgetTool {
|
|
||||||
memory: Arc<MemoryManager>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MemoryForgetTool {
|
|
||||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
|
||||||
Self { memory }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for MemoryForgetTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"memory_forget"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Delete a memory entry by its key. Use this when information is outdated, \
|
|
||||||
incorrect, or the user asks to forget something."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_only(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"key": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The key of the memory entry to delete."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["key"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let key = args
|
|
||||||
.get("key")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
|
||||||
|
|
||||||
self.memory.forget(key).await?;
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: format!("Memory deleted: {}", key),
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,118 +0,0 @@
|
|||||||
use super::traits::{Tool, ToolResult};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::memory::{MemoryCategory, MemoryManager};
|
|
||||||
|
|
||||||
pub struct MemoryRecallTool {
|
|
||||||
memory: Arc<MemoryManager>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MemoryRecallTool {
|
|
||||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
|
||||||
Self { memory }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for MemoryRecallTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"memory_recall"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Search and retrieve entries from long-term memory. \
|
|
||||||
Use this to recall previously stored facts, preferences, or conversation history. \
|
|
||||||
Supports keyword search and optional time-range filtering."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_only(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Search query — keywords to match against memory keys and content."
|
|
||||||
},
|
|
||||||
"category": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["knowledge", "timeline"],
|
|
||||||
"description": "Filter by memory category. Omit to search all categories."
|
|
||||||
},
|
|
||||||
"since": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Start of time range (Unix milliseconds)."
|
|
||||||
},
|
|
||||||
"until": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "End of time range (Unix milliseconds)."
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Max results to return (default 10)."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["query"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let query = args
|
|
||||||
.get("query")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
|
||||||
|
|
||||||
let category = args
|
|
||||||
.get("category")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.and_then(MemoryCategory::from_str);
|
|
||||||
|
|
||||||
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
|
||||||
|
|
||||||
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
|
||||||
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
|
|
||||||
let until = args
|
|
||||||
.get("until")
|
|
||||||
.and_then(|v| v.as_i64())
|
|
||||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
|
||||||
self.memory
|
|
||||||
.recall_by_time(since, until, limit, category)
|
|
||||||
.await?
|
|
||||||
} else {
|
|
||||||
self.memory.recall(query, limit, category).await?
|
|
||||||
};
|
|
||||||
|
|
||||||
if entries.is_empty() {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: "No matching memories found.".to_string(),
|
|
||||||
error: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let formatted = entries
|
|
||||||
.iter()
|
|
||||||
.map(|e| {
|
|
||||||
format!(
|
|
||||||
"- {} [{}] [importance: {:.1}]: {}",
|
|
||||||
e.key,
|
|
||||||
e.category.as_str(),
|
|
||||||
e.importance,
|
|
||||||
e.content
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: format!("Found {} memories:\n{}", entries.len(), formatted),
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,90 +0,0 @@
|
|||||||
use super::traits::{Tool, ToolResult};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::memory::{MemoryCategory, MemoryManager};
|
|
||||||
|
|
||||||
pub struct MemoryStoreTool {
|
|
||||||
memory: Arc<MemoryManager>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MemoryStoreTool {
|
|
||||||
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
|
||||||
Self { memory }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for MemoryStoreTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"memory_store"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Store a fact, preference, or insight into long-term memory. \
|
|
||||||
Use this when the user shares important information you should remember. \
|
|
||||||
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
|
||||||
and the full content to remember."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_only(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"key": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Semantic identifier for this memory (e.g., 'user_language_pref'). Unique key."
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The full content of the memory entry."
|
|
||||||
},
|
|
||||||
"category": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["knowledge", "timeline"],
|
|
||||||
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
|
|
||||||
},
|
|
||||||
"importance": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["key", "content"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let key = args
|
|
||||||
.get("key")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: key"))?;
|
|
||||||
|
|
||||||
let content = args
|
|
||||||
.get("content")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
|
||||||
|
|
||||||
let category = args
|
|
||||||
.get("category")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.and_then(MemoryCategory::from_str)
|
|
||||||
.unwrap_or(MemoryCategory::Knowledge);
|
|
||||||
|
|
||||||
let importance = args.get("importance").and_then(|v| v.as_f64());
|
|
||||||
|
|
||||||
self.memory
|
|
||||||
.store(key, content, category, None, importance)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: format!("Memory stored: {}", key),
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -7,9 +7,7 @@ pub mod file_read;
|
|||||||
pub mod file_write;
|
pub mod file_write;
|
||||||
pub mod get_skill;
|
pub mod get_skill;
|
||||||
pub mod http_request;
|
pub mod http_request;
|
||||||
pub mod memory_forget;
|
pub mod memory;
|
||||||
pub mod memory_recall;
|
|
||||||
pub mod memory_store;
|
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod send_message;
|
pub mod send_message;
|
||||||
@ -24,9 +22,7 @@ pub use file_read::FileReadTool;
|
|||||||
pub use file_write::FileWriteTool;
|
pub use file_write::FileWriteTool;
|
||||||
pub use get_skill::GetSkillTool;
|
pub use get_skill::GetSkillTool;
|
||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
pub use memory_forget::MemoryForgetTool;
|
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool};
|
||||||
pub use memory_recall::MemoryRecallTool;
|
|
||||||
pub use memory_store::MemoryStoreTool;
|
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use send_message::SendMessageTool;
|
pub use send_message::SendMessageTool;
|
||||||
@ -42,7 +38,7 @@ use crate::skills::SkillsLoader;
|
|||||||
/// once the available channel names are known.
|
/// once the available channel names are known.
|
||||||
pub fn create_default_tools(
|
pub fn create_default_tools(
|
||||||
skills_loader: Arc<SkillsLoader>,
|
skills_loader: Arc<SkillsLoader>,
|
||||||
memory: Option<Arc<MemoryManager>>,
|
memory: Arc<MemoryManager>,
|
||||||
) -> ToolRegistry {
|
) -> ToolRegistry {
|
||||||
let registry = ToolRegistry::new();
|
let registry = ToolRegistry::new();
|
||||||
registry.register(CalculatorTool::new());
|
registry.register(CalculatorTool::new());
|
||||||
@ -59,12 +55,9 @@ pub fn create_default_tools(
|
|||||||
registry.register(WebFetchTool::new(50_000, 30));
|
registry.register(WebFetchTool::new(50_000, 30));
|
||||||
registry.register(GetSkillTool::new(skills_loader));
|
registry.register(GetSkillTool::new(skills_loader));
|
||||||
|
|
||||||
// Register memory tools if memory system is available
|
registry.register(MemoryStoreTool::new(memory.clone()));
|
||||||
if let Some(mm) = memory {
|
registry.register(MemoryRecallTool::new(memory.clone()));
|
||||||
registry.register(MemoryStoreTool::new(mm.clone()));
|
registry.register(MemoryForgetTool::new(memory.clone()));
|
||||||
registry.register(MemoryRecallTool::new(mm.clone()));
|
|
||||||
registry.register(MemoryForgetTool::new(mm.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
|
|||||||
@ -53,7 +53,7 @@ impl WebFetchTool {
|
|||||||
if text.len() > self.max_response_size {
|
if text.len() > self.max_response_size {
|
||||||
format!(
|
format!(
|
||||||
"{}\n\n... [Response truncated due to size limit] ...",
|
"{}\n\n... [Response truncated due to size limit] ...",
|
||||||
&text[..self.max_response_size]
|
&text[..text.ceil_char_boundary(self.max_response_size)]
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
text.to_string()
|
text.to_string()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user