Compare commits
2 Commits
488e10dceb
...
cb1140e9be
| Author | SHA1 | Date | |
|---|---|---|---|
| cb1140e9be | |||
| f9ae4b2c69 |
@ -1,3 +1,4 @@
|
|||||||
|
use crate::agent::context_compressor::estimate_tokens;
|
||||||
use crate::agent::system_prompt::build_system_prompt;
|
use crate::agent::system_prompt::build_system_prompt;
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
@ -226,6 +227,7 @@ pub struct AgentLoop {
|
|||||||
max_iterations: usize,
|
max_iterations: usize,
|
||||||
workspace_dir: PathBuf,
|
workspace_dir: PathBuf,
|
||||||
model_name: String,
|
model_name: String,
|
||||||
|
context_window: usize,
|
||||||
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
|
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,6 +251,7 @@ impl AgentLoop {
|
|||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
observer: None,
|
observer: None,
|
||||||
notify_tx: None,
|
notify_tx: None,
|
||||||
|
context_window: 0,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -268,6 +271,7 @@ impl AgentLoop {
|
|||||||
tools,
|
tools,
|
||||||
observer: None,
|
observer: None,
|
||||||
notify_tx: None,
|
notify_tx: None,
|
||||||
|
context_window: 0,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -281,6 +285,7 @@ impl AgentLoop {
|
|||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
observer: None,
|
observer: None,
|
||||||
notify_tx: None,
|
notify_tx: None,
|
||||||
|
context_window: 0,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
@ -300,12 +305,19 @@ impl AgentLoop {
|
|||||||
tools,
|
tools,
|
||||||
observer: None,
|
observer: None,
|
||||||
notify_tx: None,
|
notify_tx: None,
|
||||||
|
context_window: 0,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
workspace_dir,
|
workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the context window size for preemptive trimming.
|
||||||
|
pub fn with_context_window(mut self, window: usize) -> Self {
|
||||||
|
self.context_window = window;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the workspace directory.
|
/// Set the workspace directory.
|
||||||
pub fn with_workspace_dir(mut self, dir: PathBuf) -> Self {
|
pub fn with_workspace_dir(mut self, dir: PathBuf) -> Self {
|
||||||
self.workspace_dir = dir;
|
self.workspace_dir = dir;
|
||||||
@ -323,6 +335,36 @@ impl AgentLoop {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Preemptive trim: truncate old tool results in-place when history is
|
||||||
|
/// approaching the context window limit. Only trims tool messages with
|
||||||
|
/// content > TRIM_CHARS, preserving the most recent KEEP messages.
|
||||||
|
fn preemptive_trim_old_tool_results(
|
||||||
|
&self,
|
||||||
|
messages: &mut [ChatMessage],
|
||||||
|
max_chars: usize,
|
||||||
|
keep_recent: usize,
|
||||||
|
) -> usize {
|
||||||
|
let end = messages.len().saturating_sub(keep_recent);
|
||||||
|
let start = 1; // protect system message at [0] if present
|
||||||
|
let mut modified = 0;
|
||||||
|
for i in start..end {
|
||||||
|
if messages[i].role != "tool" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if messages[i].content.len() <= max_chars {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let removed = messages[i].content.len() - max_chars;
|
||||||
|
messages[i].content = format!(
|
||||||
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
|
&messages[i].content[..messages[i].content.ceil_char_boundary(max_chars)],
|
||||||
|
removed
|
||||||
|
);
|
||||||
|
modified += 1;
|
||||||
|
}
|
||||||
|
modified
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||||||
&self.tools
|
&self.tools
|
||||||
}
|
}
|
||||||
@ -355,6 +397,27 @@ impl AgentLoop {
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, "Agent iteration started");
|
tracing::debug!(iteration, "Agent iteration started");
|
||||||
|
|
||||||
|
// Preemptive context check: trim old tool results if token estimate
|
||||||
|
// exceeds 80% of context window to prevent mid-loop overflow.
|
||||||
|
if self.context_window > 0 {
|
||||||
|
let estimated = estimate_tokens(&messages);
|
||||||
|
let danger = (self.context_window as f64 * 0.8) as usize;
|
||||||
|
if estimated > danger {
|
||||||
|
let trimmed = self.preemptive_trim_old_tool_results(
|
||||||
|
&mut messages, 2000, 4,
|
||||||
|
);
|
||||||
|
if trimmed > 0 {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(
|
||||||
|
estimated,
|
||||||
|
danger,
|
||||||
|
trimmed_msgs = trimmed,
|
||||||
|
"Preemptive tool-result trim applied in loop"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert messages to LLM format
|
// Convert messages to LLM format
|
||||||
let messages_for_llm: Vec<Message> = messages
|
let messages_for_llm: Vec<Message> = messages
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@ -15,6 +15,19 @@ pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
|||||||
(raw as f64 * 1.2) as usize
|
(raw as f64 * 1.2) as usize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract the first number found within `max_len` characters of the start of `s`.
|
||||||
|
/// Used by `parse_context_limit_from_error` to find token limits in error messages.
|
||||||
|
fn find_number_nearby(s: &str, max_len: usize) -> Option<&str> {
|
||||||
|
let end = s.len().min(max_len);
|
||||||
|
let slice = &s[..end];
|
||||||
|
let start = slice.find(|c: char| c.is_ascii_digit())?;
|
||||||
|
let end = slice[start..]
|
||||||
|
.find(|c: char| !c.is_ascii_digit())
|
||||||
|
.map(|p| start + p)
|
||||||
|
.unwrap_or(end);
|
||||||
|
Some(&slice[start..end])
|
||||||
|
}
|
||||||
|
|
||||||
/// Configuration for context compression.
|
/// Configuration for context compression.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ContextCompressionConfig {
|
pub struct ContextCompressionConfig {
|
||||||
@ -96,13 +109,18 @@ impl ContextCompressor {
|
|||||||
self.session_id = id;
|
self.session_id = id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update the context window size (e.g., after parsing actual limit from LLM error).
|
||||||
|
pub fn set_context_window(&mut self, window: usize) {
|
||||||
|
self.context_window = window;
|
||||||
|
}
|
||||||
|
|
||||||
/// Always true — memory is always available (memory system is always on).
|
/// Always true — memory is always available (memory system is always on).
|
||||||
pub fn has_memory(&self) -> bool {
|
pub fn has_memory(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the compression threshold in tokens.
|
/// Get the compression threshold in tokens.
|
||||||
fn threshold(&self) -> usize {
|
pub fn threshold(&self) -> usize {
|
||||||
(self.context_window as f64 * self.threshold_ratio) as usize
|
(self.context_window as f64 * self.threshold_ratio) as usize
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,10 +145,34 @@ impl ContextCompressor {
|
|||||||
modified
|
modified
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Remove orphan tool results whose declaring tool_calls have been compressed away.
|
||||||
|
/// Scans for tool messages with no preceding assistant tool_call, and removes them.
|
||||||
|
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
|
||||||
|
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||||
|
let mut i = 0;
|
||||||
|
while i < messages.len() {
|
||||||
|
if messages[i].role == "assistant" {
|
||||||
|
if let Some(ref tool_calls) = messages[i].tool_calls {
|
||||||
|
for tc in tool_calls {
|
||||||
|
declared.insert(tc.id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if messages[i].role == "tool" {
|
||||||
|
if let Some(ref tid) = messages[i].tool_call_id {
|
||||||
|
if !declared.contains(tid.as_str()) {
|
||||||
|
messages.remove(i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Main entry point - compresses history if over threshold.
|
/// Main entry point - compresses history if over threshold.
|
||||||
pub async fn compress_if_needed(
|
pub async fn compress_if_needed(
|
||||||
&self,
|
&self,
|
||||||
history: Vec<ChatMessage>,
|
mut history: Vec<ChatMessage>,
|
||||||
) -> Result<Vec<ChatMessage>, AgentError> {
|
) -> Result<Vec<ChatMessage>, AgentError> {
|
||||||
// Check if compression is needed
|
// Check if compression is needed
|
||||||
let tokens = estimate_tokens(&history);
|
let tokens = estimate_tokens(&history);
|
||||||
@ -146,19 +188,19 @@ impl ContextCompressor {
|
|||||||
"Starting context compression"
|
"Starting context compression"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Fast trim pass first
|
// Fast trim pass first — modify history in place
|
||||||
let trimmed = self.fast_trim_tool_results(&mut history.clone());
|
let trimmed = self.fast_trim_tool_results(&mut history);
|
||||||
|
let tokens_after = estimate_tokens(&history);
|
||||||
if trimmed > 0 {
|
if trimmed > 0 {
|
||||||
let tokens_after = estimate_tokens(&history);
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
trimmed_messages = trimmed,
|
trimmed_messages = trimmed,
|
||||||
tokens_after = tokens_after,
|
tokens_after = tokens_after,
|
||||||
"Fast trim completed"
|
"Fast trim completed"
|
||||||
);
|
);
|
||||||
if tokens_after <= self.threshold() {
|
}
|
||||||
return Ok(history);
|
if tokens_after <= self.threshold() {
|
||||||
}
|
return Ok(history);
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLM summarization pass
|
// LLM summarization pass
|
||||||
@ -191,6 +233,36 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hard safety net: if still dangerously high after all passes,
|
||||||
|
// fall back to head+tail truncation so the LLM call doesn't overflow.
|
||||||
|
let final_tokens = estimate_tokens(¤t_history);
|
||||||
|
let danger_threshold = (self.context_window as f64 * 0.9) as usize;
|
||||||
|
if final_tokens > danger_threshold
|
||||||
|
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n
|
||||||
|
{
|
||||||
|
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
|
||||||
|
let tail_start = current_history.len() - self.config.protect_last_n;
|
||||||
|
let tail: Vec<_> = current_history[tail_start..].to_vec();
|
||||||
|
let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n;
|
||||||
|
|
||||||
|
let mut truncated = head;
|
||||||
|
truncated.push(ChatMessage::user(format!(
|
||||||
|
"[Context truncation — {} earlier messages dropped due to token limit]\n\
|
||||||
|
Previous context could not be fully compressed. Continuing with most recent context.",
|
||||||
|
dropped
|
||||||
|
)));
|
||||||
|
truncated.extend(tail);
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
final_tokens = final_tokens,
|
||||||
|
danger = danger_threshold,
|
||||||
|
dropped_msgs = dropped,
|
||||||
|
"Hard truncation fallback applied"
|
||||||
|
);
|
||||||
|
|
||||||
|
current_history = truncated;
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
final_tokens = estimate_tokens(¤t_history),
|
final_tokens = estimate_tokens(¤t_history),
|
||||||
@ -201,6 +273,48 @@ impl ContextCompressor {
|
|||||||
Ok(current_history)
|
Ok(current_history)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Try to extract the actual context token limit from an LLM error message.
|
||||||
|
/// Recognizes patterns from OpenAI, Anthropic, and llama.cpp-style errors.
|
||||||
|
pub fn parse_context_limit_from_error(msg: &str) -> Option<usize> {
|
||||||
|
let lower = msg.to_lowercase();
|
||||||
|
|
||||||
|
// Common patterns: "maximum context length is 128000", "context window of 131072",
|
||||||
|
// "128000 token context", "available context size (8448 tokens)", "> 128000 maximum"
|
||||||
|
let markers = [
|
||||||
|
"maximum context length",
|
||||||
|
"context window",
|
||||||
|
"context length",
|
||||||
|
"available context size",
|
||||||
|
];
|
||||||
|
|
||||||
|
for marker in &markers {
|
||||||
|
if let Some(pos) = lower.find(marker) {
|
||||||
|
let after = &lower[pos + marker.len()..];
|
||||||
|
// Look for a number in the vicinity (up to 10 chars after marker)
|
||||||
|
if let Some(num_str) = find_number_nearby(after, 50) {
|
||||||
|
if let Ok(n) = num_str.parse::<usize>() {
|
||||||
|
if (1024..=10_000_000).contains(&n) {
|
||||||
|
return Some(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also try: "XXXX token context" or "XXXX limit"
|
||||||
|
if let Some(num_str) = find_number_nearby(&lower, lower.len()) {
|
||||||
|
if let Ok(n) = num_str.parse::<usize>() {
|
||||||
|
if (1024..=10_000_000).contains(&n)
|
||||||
|
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
|
||||||
|
{
|
||||||
|
return Some(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Single compression pass - summarize middle messages between user turns.
|
/// Single compression pass - summarize middle messages between user turns.
|
||||||
/// Returns Some(compressed) if compression happened, None if nothing to compress.
|
/// Returns Some(compressed) if compression happened, None if nothing to compress.
|
||||||
async fn compress_once(
|
async fn compress_once(
|
||||||
@ -227,7 +341,7 @@ impl ContextCompressor {
|
|||||||
|
|
||||||
// Build segments: user -> (assistant turns) -> next user
|
// Build segments: user -> (assistant turns) -> next user
|
||||||
// We'll summarize the assistant turns between consecutive user messages
|
// We'll summarize the assistant turns between consecutive user messages
|
||||||
let mut new_messages = history[..=user_indices[0]].to_vec();
|
let mut new_messages = history[..user_indices[0]].to_vec();
|
||||||
|
|
||||||
for i in 0..user_indices.len() - 1 {
|
for i in 0..user_indices.len() - 1 {
|
||||||
let user_idx = user_indices[i];
|
let user_idx = user_indices[i];
|
||||||
@ -272,13 +386,13 @@ impl ContextCompressor {
|
|||||||
|
|
||||||
// Add last user and everything after (protected)
|
// Add last user and everything after (protected)
|
||||||
let last_user_idx = user_indices[user_indices.len() - 1];
|
let last_user_idx = user_indices[user_indices.len() - 1];
|
||||||
if last_user_idx < history.len() - 1 {
|
for i in last_user_idx..history.len() {
|
||||||
// Add everything from last user onwards (protected)
|
new_messages.push(history[i].clone());
|
||||||
for i in last_user_idx..history.len() {
|
|
||||||
new_messages.push(history[i].clone());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove orphan tool results whose declaring tool_calls were compressed away
|
||||||
|
Self::repair_tool_pairs(&mut new_messages);
|
||||||
|
|
||||||
// If nothing changed, return None
|
// If nothing changed, return None
|
||||||
if new_messages.len() == history.len() {
|
if new_messages.len() == history.len() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@ -370,8 +484,11 @@ Be concise, aim for {} characters or less.
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::memory::MemoryManager;
|
||||||
use crate::providers::ChatCompletionResponse;
|
use crate::providers::ChatCompletionResponse;
|
||||||
|
use crate::providers::Usage;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::sync::OnceLock;
|
use std::sync::OnceLock;
|
||||||
|
|
||||||
/// Mock provider for testing - panics if actually used for LLM calls
|
/// Mock provider for testing - panics if actually used for LLM calls
|
||||||
@ -403,6 +520,34 @@ mod tests {
|
|||||||
Arc::new(MockProvider)
|
Arc::new(MockProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mock summarizer that returns a simple summary — used when compress_once
|
||||||
|
/// needs to call the LLM for summarization.
|
||||||
|
struct MockSummarizer;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for MockSummarizer {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
_request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: "mock".into(),
|
||||||
|
model: "mock".into(),
|
||||||
|
content: "[summarized]".into(),
|
||||||
|
tool_calls: vec![],
|
||||||
|
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptype(&self) -> &str { "mock" }
|
||||||
|
fn name(&self) -> &str { "mock" }
|
||||||
|
fn model_id(&self) -> &str { "mock" }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
||||||
|
Arc::new(MockSummarizer)
|
||||||
|
}
|
||||||
|
|
||||||
fn test_memory_manager() -> Arc<MemoryManager> {
|
fn test_memory_manager() -> Arc<MemoryManager> {
|
||||||
static MM: OnceLock<Arc<MemoryManager>> = OnceLock::new();
|
static MM: OnceLock<Arc<MemoryManager>> = OnceLock::new();
|
||||||
MM.get_or_init(|| {
|
MM.get_or_init(|| {
|
||||||
@ -454,4 +599,206 @@ mod tests {
|
|||||||
let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager());
|
let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager());
|
||||||
assert_eq!(compressor.threshold(), 64_000);
|
assert_eq!(compressor.threshold(), 64_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_compress_if_needed_fast_trims_tool_results() {
|
||||||
|
// context_window=200 → threshold=100.
|
||||||
|
// user "Hi" (~6 raw), tool(3000 x's) → ~760 raw*1.2=912 > 100 → triggers compression.
|
||||||
|
// fast_trim to 50 chars should bring tokens well under 100.
|
||||||
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_trim_{}.db", std::process::id()));
|
||||||
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||||
|
|
||||||
|
let config = ContextCompressionConfig {
|
||||||
|
tool_result_trim_chars: 50,
|
||||||
|
protect_first_n: 0,
|
||||||
|
protect_last_n: 10,
|
||||||
|
max_passes: 0,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::user("Hi"),
|
||||||
|
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = compressor.compress_if_needed(messages).await.unwrap();
|
||||||
|
|
||||||
|
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
||||||
|
assert!(
|
||||||
|
tool_msg.content.len() < 3000,
|
||||||
|
"tool result should be trimmed, got {} chars",
|
||||||
|
tool_msg.content.len()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tool_msg.content.contains("[Output truncated"),
|
||||||
|
"trim marker missing from: {}",
|
||||||
|
tool_msg.content
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_compress_once_no_duplicate_and_no_lost_user() {
|
||||||
|
// Verifies two boundary bugs in compress_once:
|
||||||
|
// - B2A (L230): first user message duplicated when protect_first_n > 0
|
||||||
|
// - B2B (L275): last user message lost when it is the final history message
|
||||||
|
//
|
||||||
|
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
||||||
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||||
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||||
|
|
||||||
|
let config = ContextCompressionConfig {
|
||||||
|
tool_result_trim_chars: 2000,
|
||||||
|
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
|
||||||
|
protect_last_n: 2,
|
||||||
|
max_passes: 1,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
|
||||||
|
|
||||||
|
// History: 9 messages, last message is user Q4.
|
||||||
|
// user_indices (skip 1) = [1, 3, 6, 8]
|
||||||
|
// B2A: init history[..=1] includes Q1, then loop i=0 pushes Q1 again → duplicate
|
||||||
|
// B2B: last_user_idx=8, 8 < 8 → false → Q4 not pushed → lost
|
||||||
|
let big = "x".repeat(3000);
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("You are a helper."), // 0: protected
|
||||||
|
ChatMessage::user("Q1"), // 1: first user
|
||||||
|
ChatMessage::tool("t1", "bash", &big), // 2
|
||||||
|
ChatMessage::user("Q2"), // 3
|
||||||
|
ChatMessage::assistant("thinking"), // 4
|
||||||
|
ChatMessage::tool("t2", "bash", &big), // 5
|
||||||
|
ChatMessage::user("Q3"), // 6
|
||||||
|
ChatMessage::assistant("thinking"), // 7
|
||||||
|
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = compressor.compress_if_needed(messages).await.unwrap();
|
||||||
|
|
||||||
|
// B2A: "Q1" must appear exactly once
|
||||||
|
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
|
||||||
|
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
|
||||||
|
|
||||||
|
// B2B: "Q4" must NOT be lost
|
||||||
|
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
|
||||||
|
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_compress_hard_truncation_fallback() {
|
||||||
|
// When LLM compression fails (or max_passes=0) and tokens are still
|
||||||
|
// above 90% of context_window, a head+tail truncation kicks in.
|
||||||
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_trunc_{}.db", std::process::id()));
|
||||||
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||||
|
|
||||||
|
let config = ContextCompressionConfig {
|
||||||
|
tool_result_trim_chars: 500, // trim reduces but not enough
|
||||||
|
protect_first_n: 1,
|
||||||
|
protect_last_n: 2,
|
||||||
|
max_passes: 0, // no LLM summarization → will exceed danger
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
// context_window=100, danger_threshold=90.
|
||||||
|
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
|
||||||
|
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
|
||||||
|
let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
|
||||||
|
|
||||||
|
let big = "x".repeat(3000);
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("sys"),
|
||||||
|
ChatMessage::user("Q1"),
|
||||||
|
ChatMessage::tool("t1", "bash", &big),
|
||||||
|
ChatMessage::user("Q2"),
|
||||||
|
ChatMessage::tool("t2", "bash", &big),
|
||||||
|
ChatMessage::user("Q3"),
|
||||||
|
ChatMessage::tool("t3", "bash", &big),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = compressor.compress_if_needed(messages).await.unwrap();
|
||||||
|
|
||||||
|
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
||||||
|
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
|
||||||
|
|
||||||
|
// Truncation notice should be present
|
||||||
|
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
|
||||||
|
assert!(has_notice, "hard truncation notice missing");
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_repair_tool_pairs_removes_orphans() {
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
|
// Simulate compressed output: summary replaced assistant(tool_call: tc1),
|
||||||
|
// leaving tool(tc1) as an orphan. Legitimate tool(tc2) should be kept.
|
||||||
|
let mut messages = vec![
|
||||||
|
ChatMessage::user("Q1"),
|
||||||
|
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
|
||||||
|
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
|
||||||
|
ChatMessage::assistant("done"), // declares tc2
|
||||||
|
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
|
||||||
|
];
|
||||||
|
// Set tool_call_id on tool messages and tool_calls on assistant
|
||||||
|
messages[2].tool_call_id = Some("tc1".into());
|
||||||
|
messages[4].tool_call_id = Some("tc2".into());
|
||||||
|
messages[3].tool_calls = Some(vec![ToolCall {
|
||||||
|
id: "tc2".into(),
|
||||||
|
name: "bash".into(),
|
||||||
|
arguments: serde_json::json!({"cmd": "echo ok"}),
|
||||||
|
}]);
|
||||||
|
|
||||||
|
ContextCompressor::repair_tool_pairs(&mut messages);
|
||||||
|
|
||||||
|
// orphan should be removed; legitimate should stay
|
||||||
|
assert_eq!(messages.len(), 4);
|
||||||
|
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
|
||||||
|
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_context_limit_from_error() {
|
||||||
|
// OpenAI: "maximum context length is 128000"
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::parse_context_limit_from_error(
|
||||||
|
"This model's maximum context length is 128000 tokens."
|
||||||
|
),
|
||||||
|
Some(128000)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Anthropic: "context window of 200000"
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::parse_context_limit_from_error(
|
||||||
|
"Your request exceeds the context window of 200000."
|
||||||
|
),
|
||||||
|
Some(200000)
|
||||||
|
);
|
||||||
|
|
||||||
|
// llama.cpp: "available context size (8448 tokens)"
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::parse_context_limit_from_error(
|
||||||
|
"context size exceeded, available context size (8448 tokens)"
|
||||||
|
),
|
||||||
|
Some(8448)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Non-context error should return None
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::parse_context_limit_from_error("Internal server error"),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
|
||||||
|
// Numbers too small should be rejected
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::parse_context_limit_from_error("context length is 500"),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,5 +3,5 @@ pub mod context_compressor;
|
|||||||
pub mod system_prompt;
|
pub mod system_prompt;
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||||
pub use context_compressor::ContextCompressor;
|
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
||||||
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};
|
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};
|
||||||
|
|||||||
@ -52,18 +52,21 @@ impl MemoryManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Search memories by keyword query. Returns entries sorted by relevance.
|
/// Search memories by keyword query. Returns entries sorted by relevance.
|
||||||
|
/// When `session_id` is provided, results are filtered to that session.
|
||||||
pub async fn recall(
|
pub async fn recall(
|
||||||
&self,
|
&self,
|
||||||
query: &str,
|
query: &str,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
category: Option<MemoryCategory>,
|
category: Option<MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
||||||
self.storage
|
self.storage
|
||||||
.search_memories(query, category.as_ref(), limit)
|
.search_memories(query, category.as_ref(), session_id, limit)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Search memories by time range (Unix milliseconds).
|
/// Search memories by time range (Unix milliseconds).
|
||||||
|
/// When `session_id` is provided, results are filtered to that session.
|
||||||
pub async fn recall_by_time(
|
pub async fn recall_by_time(
|
||||||
&self,
|
&self,
|
||||||
since: i64,
|
since: i64,
|
||||||
@ -71,9 +74,10 @@ impl MemoryManager {
|
|||||||
query: Option<&str>,
|
query: Option<&str>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
category: Option<MemoryCategory>,
|
category: Option<MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
) -> Result<Vec<MemoryEntry>, crate::storage::StorageError> {
|
||||||
self.storage
|
self.storage
|
||||||
.search_memories_by_time(since, until, query, category.as_ref(), limit)
|
.search_memories_by_time(since, until, query, category.as_ref(), session_id, limit)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +88,7 @@ impl MemoryManager {
|
|||||||
|
|
||||||
/// Check if the memory system has any entries (for testing/health check).
|
/// Check if the memory system has any entries (for testing/health check).
|
||||||
pub async fn is_empty(&self) -> Result<bool, crate::storage::StorageError> {
|
pub async fn is_empty(&self) -> Result<bool, crate::storage::StorageError> {
|
||||||
self.recall("*", 1, None).await.map(|r| r.is_empty())
|
self.recall("*", 1, None, None).await.map(|r| r.is_empty())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +120,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mm.recall("test memory", 10, None).await.unwrap();
|
let results = mm.recall("test memory", 10, None, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
assert_eq!(results[0].key, "test_key");
|
assert_eq!(results[0].key, "test_key");
|
||||||
assert_eq!(results[0].content, "This is a test memory");
|
assert_eq!(results[0].content, "This is a test memory");
|
||||||
@ -146,7 +150,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mm.recall("updated", 10, None).await.unwrap();
|
let results = mm.recall("updated", 10, None, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
assert_eq!(results[0].content, "updated");
|
assert_eq!(results[0].content, "updated");
|
||||||
}
|
}
|
||||||
@ -166,7 +170,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
mm.forget("to_delete").await.unwrap();
|
mm.forget("to_delete").await.unwrap();
|
||||||
|
|
||||||
let results = mm.recall("deleted", 10, None).await.unwrap();
|
let results = mm.recall("deleted", 10, None, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,17 +198,60 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let know_results = mm
|
let know_results = mm
|
||||||
.recall("content", 10, Some(MemoryCategory::Knowledge))
|
.recall("content", 10, Some(MemoryCategory::Knowledge), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(know_results.len(), 1);
|
assert_eq!(know_results.len(), 1);
|
||||||
assert_eq!(know_results[0].key, "knowledge_1");
|
assert_eq!(know_results[0].key, "knowledge_1");
|
||||||
|
|
||||||
let time_results = mm
|
let time_results = mm
|
||||||
.recall("content", 10, Some(MemoryCategory::Timeline))
|
.recall("content", 10, Some(MemoryCategory::Timeline), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(time_results.len(), 1);
|
assert_eq!(time_results.len(), 1);
|
||||||
assert_eq!(time_results[0].key, "timeline_1");
|
assert_eq!(time_results[0].key, "timeline_1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_session_id_filter() {
|
||||||
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
|
// Store a timeline entry for session A
|
||||||
|
mm.store(
|
||||||
|
"tl_a",
|
||||||
|
"summary from session A",
|
||||||
|
MemoryCategory::Timeline,
|
||||||
|
Some("chan:chat:dialog_a"),
|
||||||
|
Some(0.5),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Store a timeline entry for session B
|
||||||
|
mm.store(
|
||||||
|
"tl_b",
|
||||||
|
"summary from session B",
|
||||||
|
MemoryCategory::Timeline,
|
||||||
|
Some("chan:chat:dialog_b"),
|
||||||
|
Some(0.5),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Recall without session_id — should get both
|
||||||
|
let all = mm
|
||||||
|
.recall("summary", 10, Some(MemoryCategory::Timeline), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(all.len(), 2);
|
||||||
|
|
||||||
|
// Recall scoped to session A — should get only tl_a
|
||||||
|
let scoped = mm
|
||||||
|
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(scoped.len(), 1);
|
||||||
|
assert_eq!(scoped[0].key, "tl_a");
|
||||||
|
assert_eq!(scoped[0].session_id.as_deref(), Some("chan:chat:dialog_a"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,6 +21,18 @@ use crate::config::LLMProviderConfig;
|
|||||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||||
use crate::agent::system_prompt::build_system_prompt;
|
use crate::agent::system_prompt::build_system_prompt;
|
||||||
use crate::agent::context_compressor::ContextCompressionConfig;
|
use crate::agent::context_compressor::ContextCompressionConfig;
|
||||||
|
|
||||||
|
/// Check if an LLM error message indicates a context window overflow.
|
||||||
|
fn is_context_overflow_error(msg: &str) -> bool {
|
||||||
|
let lower = msg.to_lowercase();
|
||||||
|
lower.contains("context length")
|
||||||
|
|| lower.contains("context window")
|
||||||
|
|| lower.contains("maximum context")
|
||||||
|
|| lower.contains("too many tokens")
|
||||||
|
|| lower.contains("token limit exceeded")
|
||||||
|
|| lower.contains("prompt is too long")
|
||||||
|
|| lower.contains("input is too long")
|
||||||
|
}
|
||||||
use crate::providers::{create_provider, LLMProvider};
|
use crate::providers::{create_provider, LLMProvider};
|
||||||
use crate::session::session_id::UnifiedSessionId;
|
use crate::session::session_id::UnifiedSessionId;
|
||||||
use crate::session::events::DialogInfo;
|
use crate::session::events::DialogInfo;
|
||||||
@ -372,6 +384,11 @@ impl Session {
|
|||||||
&self.compressor
|
&self.compressor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the compressor's current threshold for diagnostics/fallback.
|
||||||
|
pub fn compressor_threshold(&self) -> usize {
|
||||||
|
self.compressor.threshold()
|
||||||
|
}
|
||||||
|
|
||||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||||||
Ok(AgentLoop::with_provider_and_tools(
|
Ok(AgentLoop::with_provider_and_tools(
|
||||||
@ -380,7 +397,7 @@ impl Session {
|
|||||||
self.provider_config.max_tool_iterations,
|
self.provider_config.max_tool_iterations,
|
||||||
self.provider_config.model_id.clone(),
|
self.provider_config.model_id.clone(),
|
||||||
self.provider_config.workspace_dir.clone(),
|
self.provider_config.workspace_dir.clone(),
|
||||||
))
|
).with_context_window(self.provider_config.token_limit))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 创建一个附通知通道的 AgentLoop 实例
|
/// 创建一个附通知通道的 AgentLoop 实例
|
||||||
@ -1305,7 +1322,7 @@ impl SessionManager {
|
|||||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||||
|
|
||||||
// Fetch memory context
|
// Fetch memory context
|
||||||
let memory_context = match self.memory_manager.recall(content, 5, Some(crate::memory::MemoryCategory::Knowledge)).await {
|
let memory_context = match self.memory_manager.recall(content, 5, Some(crate::memory::MemoryCategory::Knowledge), None).await {
|
||||||
Ok(entries) if !entries.is_empty() => {
|
Ok(entries) if !entries.is_empty() => {
|
||||||
Some(entries.iter()
|
Some(entries.iter()
|
||||||
.map(|e| format!("- {}: {}", e.key, e.content))
|
.map(|e| format!("- {}: {}", e.key, e.content))
|
||||||
@ -1319,15 +1336,17 @@ impl SessionManager {
|
|||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build combined system prompt and inject at position 0
|
// Build combined system prompt and inject at position 0 AFTER compression.
|
||||||
// This ensures AgentLoop.process() sees a system message and doesn't inject its own
|
// This ensures AgentLoop.process() sees a system message without it participating
|
||||||
|
// in context compression (system prompt is dynamic and should not be persisted).
|
||||||
let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref());
|
let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref());
|
||||||
history.insert(0, ChatMessage::system(system_prompt));
|
|
||||||
|
|
||||||
let history = session_guard.compressor
|
let mut history = session_guard.compressor
|
||||||
.compress_if_needed(history)
|
.compress_if_needed(history)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
history.insert(0, ChatMessage::system(system_prompt.clone()));
|
||||||
|
|
||||||
// Advance consolidation pointer — future compressions skip already-processed messages
|
// Advance consolidation pointer — future compressions skip already-processed messages
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
session_guard.last_consolidated_at = Some(now);
|
session_guard.last_consolidated_at = Some(now);
|
||||||
@ -1336,7 +1355,28 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
||||||
let result = agent.process(history).await?;
|
|
||||||
|
// Try LLM call; on context overflow, re-compress with tighter window and retry once.
|
||||||
|
let result = match agent.process(history).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(AgentError::LlmError(ref msg))
|
||||||
|
if is_context_overflow_error(msg) =>
|
||||||
|
{
|
||||||
|
let new_window = crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
|
||||||
|
.unwrap_or(session_guard.compressor_threshold());
|
||||||
|
tracing::warn!(
|
||||||
|
new_window,
|
||||||
|
error = %msg,
|
||||||
|
"Context overflow in handle_message — retrying with tighter window"
|
||||||
|
);
|
||||||
|
session_guard.compressor.set_context_window(new_window);
|
||||||
|
let raw = session_guard.get_history().to_vec();
|
||||||
|
let mut retry = session_guard.compressor.compress_if_needed(raw).await?;
|
||||||
|
retry.insert(0, ChatMessage::system(system_prompt));
|
||||||
|
agent.process(retry).await?
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
for msg in result.emitted_messages {
|
for msg in result.emitted_messages {
|
||||||
session_guard.add_message(msg, true).await
|
session_guard.add_message(msg, true).await
|
||||||
@ -1443,12 +1483,15 @@ impl SessionManager {
|
|||||||
job_name, job_id, channel, chat_id
|
job_name, job_id, channel, chat_id
|
||||||
);
|
);
|
||||||
let full_system_prompt = format!("{}{}", system_prompt, cron_context);
|
let full_system_prompt = format!("{}{}", system_prompt, cron_context);
|
||||||
history.insert(0, ChatMessage::system(full_system_prompt));
|
|
||||||
|
|
||||||
let history = session_guard.compressor
|
// Inject system prompt AFTER compression so it doesn't participate
|
||||||
|
// in context compression (system prompt is dynamic and should not be persisted).
|
||||||
|
let mut history = session_guard.compressor
|
||||||
.compress_if_needed(history)
|
.compress_if_needed(history)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
history.insert(0, ChatMessage::system(full_system_prompt));
|
||||||
|
|
||||||
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
||||||
let result = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,7 @@ impl super::Storage {
|
|||||||
&self,
|
&self,
|
||||||
query: &str,
|
query: &str,
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
// Build FTS5 query: segment with jieba, wrap each term in quotes, join with OR
|
// Build FTS5 query: segment with jieba, wrap each term in quotes, join with OR
|
||||||
@ -76,7 +77,7 @@ impl super::Storage {
|
|||||||
m.session_id, m.created_at, m.updated_at
|
m.session_id, m.created_at, m.updated_at
|
||||||
FROM memory_fts f
|
FROM memory_fts f
|
||||||
JOIN memories m ON f.rowid = m.rowid
|
JOIN memories m ON f.rowid = m.rowid
|
||||||
WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?)
|
WHERE memory_fts MATCH ? AND (? IS NULL OR m.category = ?) AND (? IS NULL OR m.session_id = ?)
|
||||||
ORDER BY rank
|
ORDER BY rank
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"#,
|
"#,
|
||||||
@ -84,6 +85,8 @@ impl super::Storage {
|
|||||||
.bind(&fts_query)
|
.bind(&fts_query)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(session_id)
|
||||||
.bind(limit as i64)
|
.bind(limit as i64)
|
||||||
.fetch_all(self.pool())
|
.fetch_all(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
@ -113,6 +116,7 @@ impl super::Storage {
|
|||||||
FROM memories
|
FROM memories
|
||||||
WHERE ({})
|
WHERE ({})
|
||||||
AND (? IS NULL OR category = ?)
|
AND (? IS NULL OR category = ?)
|
||||||
|
AND (? IS NULL OR session_id = ?)
|
||||||
ORDER BY importance DESC, updated_at DESC
|
ORDER BY importance DESC, updated_at DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"#,
|
"#,
|
||||||
@ -127,6 +131,8 @@ impl super::Storage {
|
|||||||
query_builder = query_builder
|
query_builder = query_builder
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(session_id)
|
||||||
.bind(limit as i64);
|
.bind(limit as i64);
|
||||||
|
|
||||||
let rows = query_builder.fetch_all(self.pool()).await?;
|
let rows = query_builder.fetch_all(self.pool()).await?;
|
||||||
@ -144,6 +150,7 @@ impl super::Storage {
|
|||||||
until: i64,
|
until: i64,
|
||||||
query: Option<&str>,
|
query: Option<&str>,
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
let category_filter = category.map(|c| c.as_str());
|
let category_filter = category.map(|c| c.as_str());
|
||||||
@ -180,6 +187,7 @@ impl super::Storage {
|
|||||||
WHERE ({})
|
WHERE ({})
|
||||||
AND created_at >= ? AND created_at <= ?
|
AND created_at >= ? AND created_at <= ?
|
||||||
AND (? IS NULL OR category = ?)
|
AND (? IS NULL OR category = ?)
|
||||||
|
AND (? IS NULL OR session_id = ?)
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"#,
|
"#,
|
||||||
@ -196,6 +204,8 @@ impl super::Storage {
|
|||||||
.bind(&until_dt)
|
.bind(&until_dt)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(session_id)
|
||||||
.bind(limit as i64);
|
.bind(limit as i64);
|
||||||
|
|
||||||
query_builder.fetch_all(self.pool()).await?
|
query_builder.fetch_all(self.pool()).await?
|
||||||
@ -207,6 +217,7 @@ impl super::Storage {
|
|||||||
FROM memories
|
FROM memories
|
||||||
WHERE created_at >= ? AND created_at <= ?
|
WHERE created_at >= ? AND created_at <= ?
|
||||||
AND (? IS NULL OR category = ?)
|
AND (? IS NULL OR category = ?)
|
||||||
|
AND (? IS NULL OR session_id = ?)
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"#,
|
"#,
|
||||||
@ -215,6 +226,8 @@ impl super::Storage {
|
|||||||
.bind(&until_dt)
|
.bind(&until_dt)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
.bind(category_filter)
|
.bind(category_filter)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(session_id)
|
||||||
.bind(limit as i64)
|
.bind(limit as i64)
|
||||||
.fetch_all(self.pool())
|
.fetch_all(self.pool())
|
||||||
.await?
|
.await?
|
||||||
|
|||||||
@ -24,7 +24,7 @@ impl Tool for MemoryStoreTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Store a fact, preference, or insight into long-term memory. \
|
"Store a fact, preference, or insight into long-term knowledge memory. \
|
||||||
Use this when the user shares important information you should remember. \
|
Use this when the user shares important information you should remember. \
|
||||||
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
Provide a descriptive key (e.g., 'user_prefers_python', 'project_auth_approach') \
|
||||||
and the full content to remember."
|
and the full content to remember."
|
||||||
@ -46,11 +46,6 @@ impl Tool for MemoryStoreTool {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The full content of the memory entry."
|
"description": "The full content of the memory entry."
|
||||||
},
|
},
|
||||||
"category": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["knowledge", "timeline"],
|
|
||||||
"description": "Memory category. Use 'knowledge' for facts/preferences/insights, 'timeline' for conversation summaries."
|
|
||||||
},
|
|
||||||
"importance": {
|
"importance": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
"description": "Importance score 0.0-1.0. Higher = more important. Use 0.8+ for critical facts, 0.5 for general info."
|
||||||
@ -71,16 +66,10 @@ impl Tool for MemoryStoreTool {
|
|||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: content"))?;
|
||||||
|
|
||||||
let category = args
|
|
||||||
.get("category")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.and_then(MemoryCategory::from_str)
|
|
||||||
.unwrap_or(MemoryCategory::Knowledge);
|
|
||||||
|
|
||||||
let importance = args.get("importance").and_then(|v| v.as_f64());
|
let importance = args.get("importance").and_then(|v| v.as_f64());
|
||||||
|
|
||||||
self.memory
|
self.memory
|
||||||
.store(key, content, category, None, importance)
|
.store(key, content, MemoryCategory::Knowledge, None, importance)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
@ -110,8 +99,8 @@ impl Tool for MemoryRecallTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Search and retrieve entries from long-term memory using keyword matching. \
|
"Search and retrieve entries from long-term knowledge memory using keyword matching. \
|
||||||
Use this to recall previously stored facts, preferences, or conversation history. \
|
Use this to recall previously stored facts, preferences, or insights. \
|
||||||
IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence). \
|
IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence). \
|
||||||
Use multiple synonymous or related terms to increase recall. \
|
Use multiple synonymous or related terms to increase recall. \
|
||||||
Example: instead of 'what is the user location', use 'user location address city residence'. \
|
Example: instead of 'what is the user location', use 'user location address city residence'. \
|
||||||
@ -130,11 +119,6 @@ impl Tool for MemoryRecallTool {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Space-separated KEYWORDS for memory search (NOT a natural language question). Use multiple related terms for better recall, e.g. 'address city location residence'."
|
"description": "Space-separated KEYWORDS for memory search (NOT a natural language question). Use multiple related terms for better recall, e.g. 'address city location residence'."
|
||||||
},
|
},
|
||||||
"category": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["knowledge", "timeline"],
|
|
||||||
"description": "Filter by memory category. Omit to search all categories."
|
|
||||||
},
|
|
||||||
"since": {
|
"since": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Start of time range (Unix milliseconds)."
|
"description": "Start of time range (Unix milliseconds)."
|
||||||
@ -158,11 +142,6 @@ impl Tool for MemoryRecallTool {
|
|||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
||||||
|
|
||||||
let category = args
|
|
||||||
.get("category")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.and_then(MemoryCategory::from_str);
|
|
||||||
|
|
||||||
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
||||||
|
|
||||||
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
||||||
@ -172,10 +151,10 @@ impl Tool for MemoryRecallTool {
|
|||||||
.and_then(|v| v.as_i64())
|
.and_then(|v| v.as_i64())
|
||||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
self.memory
|
self.memory
|
||||||
.recall_by_time(since, until, Some(query), limit, category)
|
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
|
||||||
.await?
|
.await?
|
||||||
} else {
|
} else {
|
||||||
self.memory.recall(query, limit, category).await?
|
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
|
||||||
};
|
};
|
||||||
|
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
@ -189,10 +168,12 @@ impl Tool for MemoryRecallTool {
|
|||||||
let formatted = entries
|
let formatted = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|e| {
|
.map(|e| {
|
||||||
|
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||||
format!(
|
format!(
|
||||||
"- {} [{}] [importance: {:.1}]: {}",
|
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||||
e.key,
|
e.key,
|
||||||
e.category.as_str(),
|
e.category.as_str(),
|
||||||
|
session,
|
||||||
e.importance,
|
e.importance,
|
||||||
e.content
|
e.content
|
||||||
)
|
)
|
||||||
@ -208,6 +189,119 @@ impl Tool for MemoryRecallTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── TimelineRecallTool ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct TimelineRecallTool {
|
||||||
|
memory: Arc<MemoryManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TimelineRecallTool {
|
||||||
|
pub fn new(memory: Arc<MemoryManager>) -> Self {
|
||||||
|
Self { memory }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for TimelineRecallTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"timeline_recall"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Search and retrieve conversation summaries from timeline memory. \
|
||||||
|
Use this to recall what was discussed in past sessions or earlier in the current session. \
|
||||||
|
Optionally filter by session_id to scope to a specific conversation. \
|
||||||
|
IMPORTANT: query must be a space-separated list of RELEVANT KEYWORDS (not a question or sentence)."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Space-separated KEYWORDS for timeline search (NOT a natural language question). Use multiple related terms for better recall."
|
||||||
|
},
|
||||||
|
"session_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Filter to a specific session (format: channel:chat_id:dialog_id). Omit to search across all sessions."
|
||||||
|
},
|
||||||
|
"since": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Start of time range (Unix milliseconds)."
|
||||||
|
},
|
||||||
|
"until": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "End of time range (Unix milliseconds)."
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Max results to return (default 10)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let query = args
|
||||||
|
.get("query")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?;
|
||||||
|
|
||||||
|
let session_id = args.get("session_id").and_then(|v| v.as_str());
|
||||||
|
|
||||||
|
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
|
||||||
|
|
||||||
|
let entries = if args.get("since").is_some() || args.get("until").is_some() {
|
||||||
|
let since = args.get("since").and_then(|v| v.as_i64()).unwrap_or(0);
|
||||||
|
let until = args
|
||||||
|
.get("until")
|
||||||
|
.and_then(|v| v.as_i64())
|
||||||
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
|
self.memory
|
||||||
|
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if entries.is_empty() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "No matching timeline entries found.".to_string(),
|
||||||
|
error: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let formatted = entries
|
||||||
|
.iter()
|
||||||
|
.map(|e| {
|
||||||
|
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||||
|
format!(
|
||||||
|
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||||
|
e.key,
|
||||||
|
e.category.as_str(),
|
||||||
|
session,
|
||||||
|
e.importance,
|
||||||
|
e.content
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("Found {} timeline entries:\n{}", entries.len(), formatted),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── MemoryForgetTool ─────────────────────────────────────────────────
|
// ── MemoryForgetTool ─────────────────────────────────────────────────
|
||||||
|
|
||||||
pub struct MemoryForgetTool {
|
pub struct MemoryForgetTool {
|
||||||
|
|||||||
@ -22,7 +22,7 @@ pub use file_read::FileReadTool;
|
|||||||
pub use file_write::FileWriteTool;
|
pub use file_write::FileWriteTool;
|
||||||
pub use get_skill::GetSkillTool;
|
pub use get_skill::GetSkillTool;
|
||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool};
|
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use send_message::SendMessageTool;
|
pub use send_message::SendMessageTool;
|
||||||
@ -57,6 +57,7 @@ pub fn create_default_tools(
|
|||||||
|
|
||||||
registry.register(MemoryStoreTool::new(memory.clone()));
|
registry.register(MemoryStoreTool::new(memory.clone()));
|
||||||
registry.register(MemoryRecallTool::new(memory.clone()));
|
registry.register(MemoryRecallTool::new(memory.clone()));
|
||||||
|
registry.register(TimelineRecallTool::new(memory.clone()));
|
||||||
registry.register(MemoryForgetTool::new(memory.clone()));
|
registry.register(MemoryForgetTool::new(memory.clone()));
|
||||||
|
|
||||||
registry
|
registry
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user