1046 lines
37 KiB
Rust
1046 lines
37 KiB
Rust
use std::sync::Arc;
|
|
|
|
use crate::bus::ChatMessage;
|
|
use crate::memory::MemoryManager;
|
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
|
|
|
|
use crate::agent::AgentError;
|
|
|
|
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin.
|
|
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
|
let raw: usize = messages
|
|
.iter()
|
|
.map(|m| m.content.len().div_ceil(4) + 4)
|
|
.sum();
|
|
(raw as f64 * 1.2) as usize
|
|
}
|
|
|
|
/// Extract the first number found within `max_len` characters of the start of `s`.
|
|
/// Used by `parse_context_limit_from_error` to find token limits in error messages.
|
|
fn find_number_nearby(s: &str, max_len: usize) -> Option<&str> {
|
|
let end = s.len().min(max_len);
|
|
let slice = &s[..end];
|
|
let start = slice.find(|c: char| c.is_ascii_digit())?;
|
|
let end = slice[start..]
|
|
.find(|c: char| !c.is_ascii_digit())
|
|
.map(|p| start + p)
|
|
.unwrap_or(end);
|
|
Some(&slice[start..end])
|
|
}
|
|
|
|
/// Configuration for context compression.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ContextCompressionConfig {
|
|
/// Protect first N messages (system prompt, etc.)
|
|
pub protect_first_n: usize,
|
|
/// Protect last N messages (recent context)
|
|
pub protect_last_n: usize,
|
|
/// Maximum compression passes
|
|
pub max_passes: u32,
|
|
/// Maximum characters in summary
|
|
pub summary_max_chars: usize,
|
|
/// Characters to keep when trimming tool results
|
|
pub tool_result_trim_chars: usize,
|
|
}
|
|
|
|
impl Default for ContextCompressionConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
protect_first_n: 1,
|
|
protect_last_n: 4,
|
|
max_passes: 3,
|
|
summary_max_chars: 4000,
|
|
tool_result_trim_chars: 2000,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Context compressor that reduces message history when it exceeds token limits.
|
|
pub struct ContextCompressor {
|
|
config: ContextCompressionConfig,
|
|
context_window: usize,
|
|
/// Threshold ratio to trigger compression (70% of context window)
|
|
threshold_ratio: f64,
|
|
/// Shared LLM provider for summarization
|
|
provider: Arc<dyn LLMProvider>,
|
|
/// Memory manager handle. Compressed context summaries are persisted
|
|
/// as timeline memory entries.
|
|
memory: Arc<MemoryManager>,
|
|
/// Current session ID for timeline memory writes.
|
|
session_id: Option<String>,
|
|
/// Message count sent in the last LLM call (used to split known/new history).
|
|
last_sent_message_count: Option<usize>,
|
|
/// Real total_tokens from the last API response.
|
|
last_api_total_tokens: Option<u32>,
|
|
}
|
|
|
|
/// Result of context compression.
|
|
pub struct CompressionResult {
|
|
pub history: Vec<ChatMessage>,
|
|
pub created_timelines: bool,
|
|
}
|
|
|
|
/// Token budget state snapshot for diagnostics.
|
|
pub struct TokenInfo {
|
|
pub context_window: usize,
|
|
pub threshold: usize,
|
|
pub estimated_tokens: usize,
|
|
pub last_api_tokens: Option<u32>,
|
|
pub cache_active: bool,
|
|
}
|
|
|
|
impl ContextCompressor {
|
|
/// Create a new compressor with the given provider, context window size, and memory manager.
|
|
pub fn new(
|
|
provider: Arc<dyn LLMProvider>,
|
|
context_window: usize,
|
|
memory: Arc<MemoryManager>,
|
|
) -> Self {
|
|
Self {
|
|
config: ContextCompressionConfig::default(),
|
|
context_window,
|
|
threshold_ratio: 0.7,
|
|
provider,
|
|
memory,
|
|
session_id: None,
|
|
last_sent_message_count: None,
|
|
last_api_total_tokens: None,
|
|
}
|
|
}
|
|
|
|
/// Create with custom configuration.
|
|
pub fn with_config(
|
|
provider: Arc<dyn LLMProvider>,
|
|
context_window: usize,
|
|
config: ContextCompressionConfig,
|
|
memory: Arc<MemoryManager>,
|
|
) -> Self {
|
|
Self {
|
|
config,
|
|
context_window,
|
|
threshold_ratio: 0.7,
|
|
provider,
|
|
memory,
|
|
session_id: None,
|
|
last_sent_message_count: None,
|
|
last_api_total_tokens: None,
|
|
}
|
|
}
|
|
|
|
/// Set the current session ID for timeline writes.
|
|
pub fn set_session_id(&mut self, id: Option<String>) {
|
|
self.session_id = id;
|
|
}
|
|
|
|
/// Update the context window size (e.g., after parsing actual limit from LLM error).
|
|
pub fn set_context_window(&mut self, window: usize) {
|
|
self.context_window = window;
|
|
}
|
|
|
|
/// Record the API's reported token usage from the last completed turn.
|
|
/// `msg_count`: number of messages sent to LLM in that call.
|
|
/// `tokens`: `total_tokens` from the API response.
|
|
pub fn set_last_api_info(&mut self, msg_count: usize, tokens: Option<u32>) {
|
|
self.last_sent_message_count = Some(msg_count);
|
|
self.last_api_total_tokens = tokens;
|
|
}
|
|
|
|
/// Invalidate the cached API token info — called after compression modifies messages.
|
|
fn invalidate_token_cache(&mut self) {
|
|
self.last_sent_message_count = None;
|
|
self.last_api_total_tokens = None;
|
|
}
|
|
|
|
/// Hybrid token estimation: API-reported tokens for known history +
|
|
/// char/4 estimate for new messages since last API call.
|
|
fn token_estimate_with_history(&self, messages: &[ChatMessage]) -> usize {
|
|
match (self.last_api_total_tokens, self.last_sent_message_count) {
|
|
(Some(known), Some(known_count)) if messages.len() > known_count => {
|
|
let delta = &messages[known_count..];
|
|
known as usize + estimate_tokens(delta)
|
|
}
|
|
(Some(known), _) => known as usize,
|
|
_ => estimate_tokens(messages),
|
|
}
|
|
}
|
|
|
|
/// Always true — memory is always available (memory system is always on).
|
|
pub fn has_memory(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
/// Get a snapshot of the current token budget state for diagnostics.
|
|
pub fn token_info(&self, messages: &[ChatMessage]) -> TokenInfo {
|
|
TokenInfo {
|
|
context_window: self.context_window,
|
|
threshold: self.threshold(),
|
|
estimated_tokens: self.token_estimate_with_history(messages),
|
|
last_api_tokens: self.last_api_total_tokens,
|
|
cache_active: self.last_api_total_tokens.is_some(),
|
|
}
|
|
}
|
|
|
|
/// Get the compression threshold in tokens.
|
|
pub fn threshold(&self) -> usize {
|
|
(self.context_window as f64 * self.threshold_ratio) as usize
|
|
}
|
|
|
|
/// Fast-path: trim oversized tool results without LLM call.
|
|
/// Old tool results (outside of `protect_tail` zone) are replaced with a
|
|
/// concise placeholder; recent results are truncated to `tool_result_trim_chars`.
|
|
/// Returns the number of messages modified.
|
|
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage], protect_tail: usize) -> usize {
|
|
let limit = self.config.tool_result_trim_chars;
|
|
let tail_start = messages.len().saturating_sub(protect_tail);
|
|
let mut modified = 0;
|
|
|
|
for (i, msg) in messages.iter_mut().enumerate() {
|
|
if msg.role != "tool" || msg.content.len() <= limit {
|
|
continue;
|
|
}
|
|
if i < tail_start {
|
|
let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
|
|
let chars = msg.content.len();
|
|
msg.content = format!(
|
|
"[Tool output ({}) — {} chars, omitted from context]",
|
|
tool_name, chars
|
|
);
|
|
} else {
|
|
let removed = msg.content.len() - limit;
|
|
msg.content = format!(
|
|
"{}...\n\n[Output truncated - {} characters removed]",
|
|
&msg.content[..msg.content.ceil_char_boundary(limit)],
|
|
removed
|
|
);
|
|
}
|
|
modified += 1;
|
|
}
|
|
|
|
modified
|
|
}
|
|
|
|
/// Repair tool call chains after compression.
|
|
/// Phase 1: remove orphan tool results whose declaring tool_calls are missing.
|
|
/// Phase 2: strip tool_calls from assistants whose results are missing.
|
|
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
|
|
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
|
|
let mut i = 0;
|
|
while i < messages.len() {
|
|
if messages[i].role == "assistant" {
|
|
if let Some(ref tool_calls) = messages[i].tool_calls {
|
|
for tc in tool_calls {
|
|
declared.insert(tc.id.clone());
|
|
}
|
|
}
|
|
} else if messages[i].role == "tool"
|
|
&& let Some(ref tid) = messages[i].tool_call_id
|
|
&& !declared.contains(tid.as_str())
|
|
{
|
|
messages.remove(i);
|
|
continue;
|
|
}
|
|
i += 1;
|
|
}
|
|
|
|
let broken: Vec<usize> = messages
|
|
.iter()
|
|
.enumerate()
|
|
.filter_map(|(idx, msg)| {
|
|
if msg.role == "assistant"
|
|
&& let Some(ref tcs) = msg.tool_calls
|
|
&& !tcs.is_empty()
|
|
{
|
|
let all_present = tcs.iter().all(|tc| {
|
|
messages.iter().any(|m| {
|
|
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
|
})
|
|
});
|
|
if !all_present { Some(idx) } else { None }
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
for idx in broken {
|
|
let msg = &mut messages[idx];
|
|
let tcs = msg.tool_calls.take().unwrap_or_default();
|
|
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
|
msg.content = format!(
|
|
"{}\n\n[Tool calls ({}) — results are no longer available]",
|
|
msg.content,
|
|
names.join(", ")
|
|
);
|
|
}
|
|
}
|
|
|
|
/// Main entry point - compresses history if over threshold.
|
|
pub async fn compress_if_needed(
|
|
&mut self,
|
|
mut history: Vec<ChatMessage>,
|
|
) -> Result<CompressionResult, AgentError> {
|
|
// Check if compression is needed
|
|
let tokens = self.token_estimate_with_history(&history);
|
|
if tokens <= self.threshold() {
|
|
return Ok(CompressionResult {
|
|
history,
|
|
created_timelines: false,
|
|
});
|
|
}
|
|
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(
|
|
tokens = tokens,
|
|
threshold = self.threshold(),
|
|
msg_count = history.len(),
|
|
"Starting context compression"
|
|
);
|
|
|
|
// Fast trim pass first — modify history in place
|
|
let trimmed = self.fast_trim_tool_results(&mut history, self.config.protect_last_n);
|
|
let tokens_after = self.token_estimate_with_history(&history);
|
|
if trimmed > 0 {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(
|
|
trimmed_messages = trimmed,
|
|
tokens_after = tokens_after,
|
|
"Fast trim completed"
|
|
);
|
|
}
|
|
if tokens_after <= self.threshold() {
|
|
self.invalidate_token_cache();
|
|
return Ok(CompressionResult {
|
|
history,
|
|
created_timelines: false,
|
|
});
|
|
}
|
|
|
|
// LLM summarization pass
|
|
let mut current_history = history;
|
|
let mut created_timelines = false;
|
|
for pass in 0..self.config.max_passes {
|
|
let tokens = self.token_estimate_with_history(¤t_history);
|
|
if tokens <= self.threshold() {
|
|
break;
|
|
}
|
|
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
|
|
|
|
match self.compress_once(¤t_history).await {
|
|
Ok(Some(compressed)) => {
|
|
current_history = compressed;
|
|
created_timelines = true;
|
|
}
|
|
Ok(None) => {
|
|
// No more compressible content
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Compression pass failed, using current history");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Hard safety net: if still dangerously high after all passes,
|
|
// fall back to head+tail truncation so the LLM call doesn't overflow.
|
|
let final_tokens = self.token_estimate_with_history(¤t_history);
|
|
let danger_threshold = (self.context_window as f64 * 0.9) as usize;
|
|
if final_tokens > danger_threshold
|
|
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n
|
|
{
|
|
let mut tail_start = current_history.len() - self.config.protect_last_n;
|
|
|
|
// Align tail_start backwards to preserve tool chain boundaries:
|
|
// if an assistant with tool_calls has results spanning the cut,
|
|
// include the assistant in the tail.
|
|
if tail_start > 0 && tail_start < current_history.len() {
|
|
let mut scan = tail_start.saturating_sub(1);
|
|
loop {
|
|
let m = ¤t_history[scan];
|
|
if m.role == "assistant" {
|
|
if let Some(tcs) = &m.tool_calls
|
|
&& !tcs.is_empty()
|
|
{
|
|
let has_post = current_history[scan + 1..]
|
|
.iter()
|
|
.filter(|r| r.role == "tool")
|
|
.any(|r| {
|
|
tcs.iter()
|
|
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
|
|
});
|
|
if has_post {
|
|
tail_start = scan;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
if scan == 0 {
|
|
break;
|
|
}
|
|
scan -= 1;
|
|
}
|
|
}
|
|
|
|
// Skip orphan tool messages at the new head-tail boundary
|
|
while tail_start < current_history.len() && current_history[tail_start].role == "tool" {
|
|
tail_start += 1;
|
|
}
|
|
|
|
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
|
|
let tail: Vec<_> = current_history[tail_start..].to_vec();
|
|
let dropped = current_history.len() - self.config.protect_first_n - tail.len();
|
|
|
|
let mut truncated = head;
|
|
truncated.push(ChatMessage::user(format!(
|
|
"[Context truncation — {} earlier messages dropped due to token limit]\n\
|
|
Previous context could not be fully compressed. Continuing with most recent context.",
|
|
dropped
|
|
)));
|
|
truncated.extend(tail);
|
|
|
|
// Strip tool_calls from any assistant in the head whose results
|
|
// were dropped (previously in the middle section).
|
|
for msg in &mut truncated[..self.config.protect_first_n] {
|
|
if msg.role == "assistant" {
|
|
if let Some(ref tcs) = msg.tool_calls
|
|
&& !tcs.is_empty()
|
|
{
|
|
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
|
msg.content = format!(
|
|
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
|
|
msg.content,
|
|
names.join(", ")
|
|
);
|
|
msg.tool_calls = None;
|
|
}
|
|
}
|
|
}
|
|
|
|
Self::repair_tool_pairs(&mut truncated);
|
|
|
|
tracing::warn!(
|
|
final_tokens = final_tokens,
|
|
danger = danger_threshold,
|
|
dropped_msgs = dropped,
|
|
"Hard truncation fallback applied"
|
|
);
|
|
|
|
current_history = truncated;
|
|
}
|
|
|
|
if created_timelines {
|
|
self.invalidate_token_cache();
|
|
}
|
|
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(
|
|
final_tokens = self.token_estimate_with_history(¤t_history),
|
|
final_msg_count = current_history.len(),
|
|
"Context compression completed"
|
|
);
|
|
|
|
Ok(CompressionResult {
|
|
history: current_history,
|
|
created_timelines,
|
|
})
|
|
}
|
|
|
|
/// Try to extract the actual context token limit from an LLM error message.
|
|
/// Recognizes patterns from OpenAI, Anthropic, and llama.cpp-style errors.
|
|
pub fn parse_context_limit_from_error(msg: &str) -> Option<usize> {
|
|
let lower = msg.to_lowercase();
|
|
|
|
// Common patterns: "maximum context length is 128000", "context window of 131072",
|
|
// "128000 token context", "available context size (8448 tokens)", "> 128000 maximum"
|
|
let markers = [
|
|
"maximum context length",
|
|
"context window",
|
|
"context length",
|
|
"available context size",
|
|
];
|
|
|
|
for marker in &markers {
|
|
if let Some(pos) = lower.find(marker) {
|
|
let after = &lower[pos + marker.len()..];
|
|
// Look for a number in the vicinity (up to 10 chars after marker)
|
|
if let Some(num_str) = find_number_nearby(after, 50)
|
|
&& let Ok(n) = num_str.parse::<usize>()
|
|
&& (1024..=10_000_000).contains(&n)
|
|
{
|
|
return Some(n);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Also try: "XXXX token context" or "XXXX limit"
|
|
if let Some(num_str) = find_number_nearby(&lower, lower.len())
|
|
&& let Ok(n) = num_str.parse::<usize>()
|
|
&& (1024..=10_000_000).contains(&n)
|
|
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
|
|
{
|
|
return Some(n);
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// Single compression pass - summarize middle messages between user turns.
|
|
/// Returns Some(compressed) if compression happened, None if nothing to compress.
|
|
async fn compress_once(
|
|
&self,
|
|
history: &[ChatMessage],
|
|
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
|
|
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
|
|
return Ok(None);
|
|
}
|
|
|
|
// Find user message indices (excluding protected first messages)
|
|
let user_indices: Vec<usize> = history
|
|
.iter()
|
|
.enumerate()
|
|
.skip(self.config.protect_first_n)
|
|
.filter(|(_, m)| m.role == "user")
|
|
.map(|(i, _)| i)
|
|
.collect();
|
|
|
|
// Need at least one user message and content between users to compress
|
|
if user_indices.len() < 2 {
|
|
return Ok(None);
|
|
}
|
|
|
|
// Build segments: user -> (assistant turns) -> next user
|
|
// We'll summarize the assistant turns between consecutive user messages
|
|
let mut new_messages = history[..user_indices[0]].to_vec();
|
|
|
|
for i in 0..user_indices.len() - 1 {
|
|
let user_idx = user_indices[i];
|
|
let next_user_idx = user_indices[i + 1];
|
|
|
|
new_messages.push(history[user_idx].clone());
|
|
|
|
// Check if there's assistant content between these two user messages
|
|
let between_start = user_idx + 1;
|
|
let between_end = next_user_idx;
|
|
|
|
if between_start < between_end {
|
|
let between = &history[between_start..between_end];
|
|
let summary = self.summarize_segment(between).await?;
|
|
|
|
// Persist compressed summary as timeline memory entry
|
|
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
|
let timeline_content = format!(
|
|
"[{}] Compressed {} conversation segments:\n{}",
|
|
ts,
|
|
between.len(),
|
|
summary
|
|
);
|
|
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
|
let mm = self.memory.clone();
|
|
let sid = self.session_id.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = mm
|
|
.store(
|
|
&key,
|
|
&timeline_content,
|
|
crate::memory::MemoryCategory::Timeline,
|
|
sid.as_deref(),
|
|
Some(0.3),
|
|
)
|
|
.await
|
|
{
|
|
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
|
}
|
|
});
|
|
|
|
// Add summary as a special user message
|
|
new_messages.push(ChatMessage::user(format!(
|
|
"[Context Summary]\n\n{}",
|
|
summary
|
|
)));
|
|
}
|
|
}
|
|
|
|
// Add last user and everything after (protected)
|
|
let last_user_idx = user_indices[user_indices.len() - 1];
|
|
for i in last_user_idx..history.len() {
|
|
new_messages.push(history[i].clone());
|
|
}
|
|
|
|
// Remove orphan tool results whose declaring tool_calls were compressed away
|
|
Self::repair_tool_pairs(&mut new_messages);
|
|
|
|
// If nothing changed, return None
|
|
if new_messages.len() == history.len() {
|
|
return Ok(None);
|
|
}
|
|
|
|
Ok(Some(new_messages))
|
|
}
|
|
|
|
/// Summarize a segment of messages using LLM.
|
|
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
|
|
if messages.is_empty() {
|
|
return Ok(String::new());
|
|
}
|
|
|
|
// Build transcript for summarization
|
|
let transcript = messages
|
|
.iter()
|
|
.map(|m| {
|
|
let role = match m.role.as_str() {
|
|
"assistant" => "Assistant",
|
|
"tool" => "Tool",
|
|
_ => m.role.as_str(),
|
|
};
|
|
let name = m
|
|
.tool_name
|
|
.as_ref()
|
|
.map(|n| format!(" ({})", n))
|
|
.unwrap_or_default();
|
|
format!("{}: {}{}", role, m.content, name)
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n\n");
|
|
|
|
// Truncate transcript if too long
|
|
let transcript = if transcript.len() > self.config.summary_max_chars {
|
|
format!(
|
|
"{}...\n\n[Transcript truncated - {} characters removed]",
|
|
&transcript[..transcript.ceil_char_boundary(self.config.summary_max_chars)],
|
|
transcript.len() - self.config.summary_max_chars
|
|
)
|
|
} else {
|
|
transcript
|
|
};
|
|
|
|
let prompt = format!(
|
|
r#"You are a conversation compaction engine. Summarize the following conversation segment.
|
|
|
|
PRESERVE:
|
|
- All identifiers (UUIDs, hashes, file paths, URLs)
|
|
- Actions taken (tool calls, file operations, commands)
|
|
- Key information obtained (results, data, errors)
|
|
- Decisions and user preferences
|
|
- Current task status
|
|
|
|
OMIT:
|
|
- Verbose tool output (keep key results only)
|
|
- Repeated greetings or filler
|
|
|
|
Be concise, aim for {} characters or less.
|
|
|
|
---
|
|
|
|
{}
|
|
|
|
"#,
|
|
self.config.summary_max_chars, transcript
|
|
);
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![
|
|
Message::system("You are a helpful assistant."),
|
|
Message::user(&prompt),
|
|
],
|
|
temperature: Some(0.3),
|
|
max_tokens: Some(1000),
|
|
tools: None,
|
|
};
|
|
|
|
match (*self.provider).chat(request).await {
|
|
Ok(response) => Ok(response.content),
|
|
Err(e) => {
|
|
// Fallback: just truncate the transcript
|
|
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
|
Ok(transcript[..transcript.ceil_char_boundary(2000)].to_string())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::memory::MemoryManager;
|
|
use crate::providers::ChatCompletionResponse;
|
|
use crate::providers::Usage;
|
|
use async_trait::async_trait;
|
|
use std::sync::Arc;
|
|
use std::sync::OnceLock;
|
|
|
|
/// Mock provider for testing - panics if actually used for LLM calls
|
|
struct MockProvider;
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for MockProvider {
|
|
async fn chat(
|
|
&self,
|
|
_request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
panic!("MockProvider.chat() called - not expected in test")
|
|
}
|
|
|
|
fn ptype(&self) -> &str {
|
|
"mock"
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
"mock"
|
|
}
|
|
|
|
fn model_id(&self) -> &str {
|
|
"mock"
|
|
}
|
|
}
|
|
|
|
fn mock_provider() -> Arc<dyn LLMProvider> {
|
|
Arc::new(MockProvider)
|
|
}
|
|
|
|
/// Mock summarizer that returns a simple summary — used when compress_once
|
|
/// needs to call the LLM for summarization.
|
|
struct MockSummarizer;
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for MockSummarizer {
|
|
async fn chat(
|
|
&self,
|
|
_request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(ChatCompletionResponse {
|
|
id: "mock".into(),
|
|
model: "mock".into(),
|
|
content: "[summarized]".into(),
|
|
reasoning_content: None,
|
|
tool_calls: vec![],
|
|
usage: Usage {
|
|
prompt_tokens: 0,
|
|
completion_tokens: 0,
|
|
total_tokens: 0,
|
|
cached_tokens: None,
|
|
cache_read_input_tokens: None,
|
|
cache_creation_input_tokens: None,
|
|
},
|
|
})
|
|
}
|
|
|
|
fn ptype(&self) -> &str {
|
|
"mock"
|
|
}
|
|
fn name(&self) -> &str {
|
|
"mock"
|
|
}
|
|
fn model_id(&self) -> &str {
|
|
"mock"
|
|
}
|
|
}
|
|
|
|
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
|
Arc::new(MockSummarizer)
|
|
}
|
|
|
|
fn test_memory_manager() -> Arc<MemoryManager> {
|
|
static MM: OnceLock<Arc<MemoryManager>> = OnceLock::new();
|
|
MM.get_or_init(|| {
|
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
|
rt.block_on(async {
|
|
let tmp = std::env::temp_dir()
|
|
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
|
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
|
})
|
|
})
|
|
.clone()
|
|
}
|
|
|
|
#[test]
|
|
fn test_estimate_tokens() {
|
|
let messages = vec![
|
|
ChatMessage::user("Hello"),
|
|
ChatMessage::assistant("Hi there!"),
|
|
ChatMessage::user("How are you?"),
|
|
];
|
|
|
|
let tokens = estimate_tokens(&messages);
|
|
// "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6
|
|
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
|
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
|
// raw = 19, with 1.2x = ~23
|
|
assert!(
|
|
tokens > 18 && tokens < 30,
|
|
"Expected ~23 tokens, got {}",
|
|
tokens
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_fast_trim() {
|
|
let config = ContextCompressionConfig {
|
|
tool_result_trim_chars: 50,
|
|
..Default::default()
|
|
};
|
|
let compressor =
|
|
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
|
|
|
let mut messages = vec![
|
|
ChatMessage::user("Hello"),
|
|
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
|
|
];
|
|
|
|
let modified = compressor.fast_trim_tool_results(&mut messages, 2);
|
|
assert_eq!(modified, 1);
|
|
assert!(messages[1].content.len() < 100);
|
|
}
|
|
|
|
#[test]
|
|
fn test_threshold() {
|
|
let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager());
|
|
assert_eq!(compressor.threshold(), 89_600);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_compress_if_needed_fast_trims_tool_results() {
|
|
// context_window=200 → threshold=100.
|
|
// user "Hi" (~6 raw), tool(3000 x's) → ~760 raw*1.2=912 > 100 → triggers compression.
|
|
// fast_trim to 50 chars should bring tokens well under 100.
|
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_trim_{}.db", std::process::id()));
|
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
|
|
|
let config = ContextCompressionConfig {
|
|
tool_result_trim_chars: 50,
|
|
protect_first_n: 0,
|
|
protect_last_n: 10,
|
|
max_passes: 0,
|
|
..Default::default()
|
|
};
|
|
let mut compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
|
|
|
|
let messages = vec![
|
|
ChatMessage::user("Hi"),
|
|
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
|
];
|
|
|
|
let result = compressor
|
|
.compress_if_needed(messages)
|
|
.await
|
|
.unwrap()
|
|
.history;
|
|
|
|
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
|
assert!(
|
|
tool_msg.content.len() < 3000,
|
|
"tool result should be trimmed, got {} chars",
|
|
tool_msg.content.len()
|
|
);
|
|
assert!(
|
|
tool_msg.content.contains("[Output truncated"),
|
|
"trim marker missing from: {}",
|
|
tool_msg.content
|
|
);
|
|
|
|
let _ = std::fs::remove_file(&tmp);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_compress_once_no_duplicate_and_no_lost_user() {
|
|
// Verifies two boundary bugs in compress_once:
|
|
// - B2A (L230): first user message duplicated when protect_first_n > 0
|
|
// - B2B (L275): last user message lost when it is the final history message
|
|
//
|
|
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
|
let tmp =
|
|
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
|
|
|
let config = ContextCompressionConfig {
|
|
tool_result_trim_chars: 2000,
|
|
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
|
|
protect_last_n: 2,
|
|
max_passes: 1,
|
|
..Default::default()
|
|
};
|
|
let mut compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
|
|
|
|
// History: 9 messages, last message is user Q4.
|
|
// user_indices (skip 1) = [1, 3, 6, 8]
|
|
// B2A: init history[..=1] includes Q1, then loop i=0 pushes Q1 again → duplicate
|
|
// B2B: last_user_idx=8, 8 < 8 → false → Q4 not pushed → lost
|
|
let big = "x".repeat(3000);
|
|
let messages = vec![
|
|
ChatMessage::system("You are a helper."), // 0: protected
|
|
ChatMessage::user("Q1"), // 1: first user
|
|
ChatMessage::tool("t1", "bash", &big), // 2
|
|
ChatMessage::user("Q2"), // 3
|
|
ChatMessage::assistant("thinking"), // 4
|
|
ChatMessage::tool("t2", "bash", &big), // 5
|
|
ChatMessage::user("Q3"), // 6
|
|
ChatMessage::assistant("thinking"), // 7
|
|
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
|
];
|
|
|
|
let result = compressor
|
|
.compress_if_needed(messages)
|
|
.await
|
|
.unwrap()
|
|
.history;
|
|
|
|
// B2A: "Q1" must appear exactly once
|
|
let q1_count = result
|
|
.iter()
|
|
.filter(|m| m.role == "user" && m.content == "Q1")
|
|
.count();
|
|
assert_eq!(
|
|
q1_count, 1,
|
|
"Q1 should appear exactly once, got {}",
|
|
q1_count
|
|
);
|
|
|
|
// B2B: "Q4" must NOT be lost
|
|
let q4_count = result
|
|
.iter()
|
|
.filter(|m| m.role == "user" && m.content == "Q4")
|
|
.count();
|
|
assert_eq!(
|
|
q4_count, 1,
|
|
"Q4 should appear exactly once (not lost), got {}",
|
|
q4_count
|
|
);
|
|
|
|
let _ = std::fs::remove_file(&tmp);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_compress_hard_truncation_fallback() {
|
|
// When LLM compression fails (or max_passes=0) and tokens are still
|
|
// above 90% of context_window, a head+tail truncation kicks in.
|
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_trunc_{}.db", std::process::id()));
|
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
|
|
|
let config = ContextCompressionConfig {
|
|
tool_result_trim_chars: 500, // trim reduces but not enough
|
|
protect_first_n: 1,
|
|
protect_last_n: 2,
|
|
max_passes: 0, // no LLM summarization → will exceed danger
|
|
..Default::default()
|
|
};
|
|
// context_window=100, danger_threshold=90.
|
|
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
|
|
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
|
|
let mut compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
|
|
|
|
let big = "x".repeat(3000);
|
|
let messages = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user("Q1"),
|
|
ChatMessage::tool("t1", "bash", &big),
|
|
ChatMessage::user("Q2"),
|
|
ChatMessage::tool("t2", "bash", &big),
|
|
ChatMessage::user("Q3"),
|
|
ChatMessage::tool("t3", "bash", &big),
|
|
];
|
|
|
|
let result = compressor
|
|
.compress_if_needed(messages)
|
|
.await
|
|
.unwrap()
|
|
.history;
|
|
|
|
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
|
assert!(
|
|
result.len() < 7,
|
|
"expected truncation reduction, got {} messages",
|
|
result.len()
|
|
);
|
|
|
|
// Truncation notice should be present
|
|
let has_notice = result
|
|
.iter()
|
|
.any(|m| m.content.contains("Context truncation"));
|
|
assert!(has_notice, "hard truncation notice missing");
|
|
|
|
let _ = std::fs::remove_file(&tmp);
|
|
}
|
|
|
|
#[test]
|
|
fn test_repair_tool_pairs_removes_orphans() {
|
|
use crate::providers::ToolCall;
|
|
|
|
// Simulate compressed output: summary replaced assistant(tool_call: tc1),
|
|
// leaving tool(tc1) as an orphan. Legitimate tool(tc2) should be kept.
|
|
let mut messages = vec![
|
|
ChatMessage::user("Q1"),
|
|
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
|
|
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
|
|
ChatMessage::assistant("done"), // declares tc2
|
|
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
|
|
];
|
|
// Set tool_call_id on tool messages and tool_calls on assistant
|
|
messages[2].tool_call_id = Some("tc1".into());
|
|
messages[4].tool_call_id = Some("tc2".into());
|
|
messages[3].tool_calls = Some(vec![ToolCall {
|
|
id: "tc2".into(),
|
|
name: "bash".into(),
|
|
arguments: serde_json::json!({"cmd": "echo ok"}),
|
|
}]);
|
|
|
|
ContextCompressor::repair_tool_pairs(&mut messages);
|
|
|
|
// orphan should be removed; legitimate should stay
|
|
assert_eq!(messages.len(), 4);
|
|
assert!(
|
|
messages
|
|
.iter()
|
|
.all(|m| m.tool_call_id != Some("tc1".into()))
|
|
);
|
|
assert!(
|
|
messages
|
|
.iter()
|
|
.any(|m| m.tool_call_id == Some("tc2".into()))
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_context_limit_from_error() {
|
|
// OpenAI: "maximum context length is 128000"
|
|
assert_eq!(
|
|
ContextCompressor::parse_context_limit_from_error(
|
|
"This model's maximum context length is 128000 tokens."
|
|
),
|
|
Some(128000)
|
|
);
|
|
|
|
// Anthropic: "context window of 200000"
|
|
assert_eq!(
|
|
ContextCompressor::parse_context_limit_from_error(
|
|
"Your request exceeds the context window of 200000."
|
|
),
|
|
Some(200000)
|
|
);
|
|
|
|
// llama.cpp: "available context size (8448 tokens)"
|
|
assert_eq!(
|
|
ContextCompressor::parse_context_limit_from_error(
|
|
"context size exceeded, available context size (8448 tokens)"
|
|
),
|
|
Some(8448)
|
|
);
|
|
|
|
// Non-context error should return None
|
|
assert_eq!(
|
|
ContextCompressor::parse_context_limit_from_error("Internal server error"),
|
|
None
|
|
);
|
|
|
|
// Numbers too small should be rejected
|
|
assert_eq!(
|
|
ContextCompressor::parse_context_limit_from_error("context length is 500"),
|
|
None
|
|
);
|
|
}
|
|
}
|