use std::sync::Arc; use crate::bus::ChatMessage; use crate::memory::MemoryManager; use crate::providers::{ChatCompletionRequest, LLMProvider, Message}; use crate::agent::AgentError; /// Token estimation using ~4 chars/token heuristic with 1.2x safety margin. pub fn estimate_tokens(messages: &[ChatMessage]) -> usize { let raw: usize = messages .iter() .map(|m| m.content.len().div_ceil(4) + 4) .sum(); (raw as f64 * 1.2) as usize } /// Extract the first number found within `max_len` characters of the start of `s`. /// Used by `parse_context_limit_from_error` to find token limits in error messages. fn find_number_nearby(s: &str, max_len: usize) -> Option<&str> { let end = s.len().min(max_len); let slice = &s[..end]; let start = slice.find(|c: char| c.is_ascii_digit())?; let end = slice[start..] .find(|c: char| !c.is_ascii_digit()) .map(|p| start + p) .unwrap_or(end); Some(&slice[start..end]) } /// Configuration for context compression. #[derive(Debug, Clone)] pub struct ContextCompressionConfig { /// Protect first N messages (system prompt, etc.) pub protect_first_n: usize, /// Protect last N messages (recent context) pub protect_last_n: usize, /// Maximum compression passes pub max_passes: u32, /// Maximum characters in summary pub summary_max_chars: usize, /// Characters to keep when trimming tool results pub tool_result_trim_chars: usize, } impl Default for ContextCompressionConfig { fn default() -> Self { Self { protect_first_n: 1, protect_last_n: 4, max_passes: 3, summary_max_chars: 4000, tool_result_trim_chars: 2000, } } } /// Context compressor that reduces message history when it exceeds token limits. pub struct ContextCompressor { config: ContextCompressionConfig, context_window: usize, /// Threshold ratio to trigger compression (50% of context window) threshold_ratio: f64, /// Shared LLM provider for summarization provider: Arc, /// Memory manager handle. Compressed context summaries are persisted /// as timeline memory entries. memory: Arc, /// Current session ID for timeline memory writes. session_id: Option, } impl ContextCompressor { /// Create a new compressor with the given provider, context window size, and memory manager. pub fn new( provider: Arc, context_window: usize, memory: Arc, ) -> Self { Self { config: ContextCompressionConfig::default(), context_window, threshold_ratio: 0.5, provider, memory, session_id: None, } } /// Create with custom configuration. pub fn with_config( provider: Arc, context_window: usize, config: ContextCompressionConfig, memory: Arc, ) -> Self { Self { config, context_window, threshold_ratio: 0.5, provider, memory, session_id: None, } } /// Set the current session ID for timeline writes. pub fn set_session_id(&mut self, id: Option) { self.session_id = id; } /// Update the context window size (e.g., after parsing actual limit from LLM error). pub fn set_context_window(&mut self, window: usize) { self.context_window = window; } /// Always true — memory is always available (memory system is always on). pub fn has_memory(&self) -> bool { true } /// Get the compression threshold in tokens. pub fn threshold(&self) -> usize { (self.context_window as f64 * self.threshold_ratio) as usize } /// Fast-path: trim oversized tool results without LLM call. /// Returns the number of messages modified. fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize { let limit = self.config.tool_result_trim_chars; let mut modified = 0; for msg in messages.iter_mut() { if msg.role == "tool" && msg.content.len() > limit { let removed = msg.content.len() - limit; msg.content = format!( "{}...\n\n[Output truncated - {} characters removed]", &msg.content[..msg.content.ceil_char_boundary(limit)], removed ); modified += 1; } } modified } /// Remove orphan tool results whose declaring tool_calls have been compressed away. /// Scans for tool messages with no preceding assistant tool_call, and removes them. pub fn repair_tool_pairs(messages: &mut Vec) { let mut declared: std::collections::HashSet = std::collections::HashSet::new(); let mut i = 0; while i < messages.len() { if messages[i].role == "assistant" { if let Some(ref tool_calls) = messages[i].tool_calls { for tc in tool_calls { declared.insert(tc.id.clone()); } } } else if messages[i].role == "tool" { if let Some(ref tid) = messages[i].tool_call_id { if !declared.contains(tid.as_str()) { messages.remove(i); continue; } } } i += 1; } } /// Main entry point - compresses history if over threshold. pub async fn compress_if_needed( &self, mut history: Vec, ) -> Result, AgentError> { // Check if compression is needed let tokens = estimate_tokens(&history); if tokens <= self.threshold() { return Ok(history); } #[cfg(debug_assertions)] tracing::debug!( tokens = tokens, threshold = self.threshold(), msg_count = history.len(), "Starting context compression" ); // Fast trim pass first — modify history in place let trimmed = self.fast_trim_tool_results(&mut history); let tokens_after = estimate_tokens(&history); if trimmed > 0 { #[cfg(debug_assertions)] tracing::debug!( trimmed_messages = trimmed, tokens_after = tokens_after, "Fast trim completed" ); } if tokens_after <= self.threshold() { return Ok(history); } // LLM summarization pass let mut current_history = history; for pass in 0..self.config.max_passes { let tokens = estimate_tokens(¤t_history); if tokens <= self.threshold() { break; } #[cfg(debug_assertions)] tracing::debug!( pass = pass + 1, tokens = tokens, "Compression pass" ); match self.compress_once(¤t_history).await { Ok(Some(compressed)) => { current_history = compressed; } Ok(None) => { // No more compressible content break; } Err(e) => { tracing::warn!(error = %e, "Compression pass failed, using current history"); break; } } } // Hard safety net: if still dangerously high after all passes, // fall back to head+tail truncation so the LLM call doesn't overflow. let final_tokens = estimate_tokens(¤t_history); let danger_threshold = (self.context_window as f64 * 0.9) as usize; if final_tokens > danger_threshold && current_history.len() > self.config.protect_first_n + self.config.protect_last_n { let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec(); let tail_start = current_history.len() - self.config.protect_last_n; let tail: Vec<_> = current_history[tail_start..].to_vec(); let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n; let mut truncated = head; truncated.push(ChatMessage::user(format!( "[Context truncation — {} earlier messages dropped due to token limit]\n\ Previous context could not be fully compressed. Continuing with most recent context.", dropped ))); truncated.extend(tail); tracing::warn!( final_tokens = final_tokens, danger = danger_threshold, dropped_msgs = dropped, "Hard truncation fallback applied" ); current_history = truncated; } #[cfg(debug_assertions)] tracing::debug!( final_tokens = estimate_tokens(¤t_history), final_msg_count = current_history.len(), "Context compression completed" ); Ok(current_history) } /// Try to extract the actual context token limit from an LLM error message. /// Recognizes patterns from OpenAI, Anthropic, and llama.cpp-style errors. pub fn parse_context_limit_from_error(msg: &str) -> Option { let lower = msg.to_lowercase(); // Common patterns: "maximum context length is 128000", "context window of 131072", // "128000 token context", "available context size (8448 tokens)", "> 128000 maximum" let markers = [ "maximum context length", "context window", "context length", "available context size", ]; for marker in &markers { if let Some(pos) = lower.find(marker) { let after = &lower[pos + marker.len()..]; // Look for a number in the vicinity (up to 10 chars after marker) if let Some(num_str) = find_number_nearby(after, 50) { if let Ok(n) = num_str.parse::() { if (1024..=10_000_000).contains(&n) { return Some(n); } } } } } // Also try: "XXXX token context" or "XXXX limit" if let Some(num_str) = find_number_nearby(&lower, lower.len()) { if let Ok(n) = num_str.parse::() { if (1024..=10_000_000).contains(&n) && (lower.contains("token") || lower.contains("context") || lower.contains("limit")) { return Some(n); } } } None } /// Single compression pass - summarize middle messages between user turns. /// Returns Some(compressed) if compression happened, None if nothing to compress. async fn compress_once( &self, history: &[ChatMessage], ) -> Result>, AgentError> { if history.len() <= self.config.protect_first_n + self.config.protect_last_n { return Ok(None); } // Find user message indices (excluding protected first messages) let user_indices: Vec = history .iter() .enumerate() .skip(self.config.protect_first_n) .filter(|(_, m)| m.role == "user") .map(|(i, _)| i) .collect(); // Need at least one user message and content between users to compress if user_indices.len() < 2 { return Ok(None); } // Build segments: user -> (assistant turns) -> next user // We'll summarize the assistant turns between consecutive user messages let mut new_messages = history[..user_indices[0]].to_vec(); for i in 0..user_indices.len() - 1 { let user_idx = user_indices[i]; let next_user_idx = user_indices[i + 1]; new_messages.push(history[user_idx].clone()); // Check if there's assistant content between these two user messages let between_start = user_idx + 1; let between_end = next_user_idx; if between_start < between_end { let between = &history[between_start..between_end]; let summary = self.summarize_segment(between).await?; // Persist compressed summary as timeline memory entry let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", ts, between.len(), summary); let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); let mm = self.memory.clone(); let sid = self.session_id.clone(); tokio::spawn(async move { if let Err(e) = mm.store( &key, &timeline_content, crate::memory::MemoryCategory::Timeline, sid.as_deref(), Some(0.3), ).await { tracing::warn!(error = %e, "Failed to store compressed context as timeline"); } }); // Add summary as a special user message new_messages.push(ChatMessage::user(format!( "[Context Summary]\n\n{}", summary ))); } } // Add last user and everything after (protected) let last_user_idx = user_indices[user_indices.len() - 1]; for i in last_user_idx..history.len() { new_messages.push(history[i].clone()); } // Remove orphan tool results whose declaring tool_calls were compressed away Self::repair_tool_pairs(&mut new_messages); // If nothing changed, return None if new_messages.len() == history.len() { return Ok(None); } Ok(Some(new_messages)) } /// Summarize a segment of messages using LLM. async fn summarize_segment( &self, messages: &[ChatMessage], ) -> Result { if messages.is_empty() { return Ok(String::new()); } // Build transcript for summarization let transcript = messages .iter() .map(|m| { let role = match m.role.as_str() { "assistant" => "Assistant", "tool" => "Tool", _ => m.role.as_str(), }; let name = m.tool_name .as_ref() .map(|n| format!(" ({})", n)) .unwrap_or_default(); format!("{}: {}{}", role, m.content, name) }) .collect::>() .join("\n\n"); // Truncate transcript if too long let transcript = if transcript.len() > self.config.summary_max_chars { format!( "{}...\n\n[Transcript truncated - {} characters removed]", &transcript[..transcript.ceil_char_boundary(self.config.summary_max_chars)], transcript.len() - self.config.summary_max_chars ) } else { transcript }; let prompt = format!( r#"You are a conversation compaction engine. Summarize the following conversation segment. PRESERVE: - All identifiers (UUIDs, hashes, file paths, URLs) - Actions taken (tool calls, file operations, commands) - Key information obtained (results, data, errors) - Decisions and user preferences - Current task status OMIT: - Verbose tool output (keep key results only) - Repeated greetings or filler Be concise, aim for {} characters or less. --- {} "#, self.config.summary_max_chars, transcript ); let request = ChatCompletionRequest { messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], temperature: Some(0.3), max_tokens: Some(1000), tools: None, }; match (*self.provider).chat(request).await { Ok(response) => Ok(response.content), Err(e) => { // Fallback: just truncate the transcript tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript"); Ok(transcript[..transcript.ceil_char_boundary(2000)].to_string()) } } } } #[cfg(test)] mod tests { use super::*; use crate::memory::MemoryManager; use crate::providers::ChatCompletionResponse; use crate::providers::Usage; use async_trait::async_trait; use std::sync::Arc; use std::sync::OnceLock; /// Mock provider for testing - panics if actually used for LLM calls struct MockProvider; #[async_trait] impl LLMProvider for MockProvider { async fn chat( &self, _request: ChatCompletionRequest, ) -> Result> { panic!("MockProvider.chat() called - not expected in test") } fn ptype(&self) -> &str { "mock" } fn name(&self) -> &str { "mock" } fn model_id(&self) -> &str { "mock" } } fn mock_provider() -> Arc { Arc::new(MockProvider) } /// Mock summarizer that returns a simple summary — used when compress_once /// needs to call the LLM for summarization. struct MockSummarizer; #[async_trait] impl LLMProvider for MockSummarizer { async fn chat( &self, _request: ChatCompletionRequest, ) -> Result> { Ok(ChatCompletionResponse { id: "mock".into(), model: "mock".into(), content: "[summarized]".into(), tool_calls: vec![], usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, }) } fn ptype(&self) -> &str { "mock" } fn name(&self) -> &str { "mock" } fn model_id(&self) -> &str { "mock" } } fn mock_summarizer() -> Arc { Arc::new(MockSummarizer) } fn test_memory_manager() -> Arc { static MM: OnceLock> = OnceLock::new(); MM.get_or_init(|| { let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); Arc::new(MemoryManager::new(storage, "test".into(), "test".into())) }) }).clone() } #[test] fn test_estimate_tokens() { let messages = vec![ ChatMessage::user("Hello"), ChatMessage::assistant("Hi there!"), ChatMessage::user("How are you?"), ]; let tokens = estimate_tokens(&messages); // "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6 // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // raw = 19, with 1.2x = ~23 assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); } #[test] fn test_fast_trim() { let config = ContextCompressionConfig { tool_result_trim_chars: 50, ..Default::default() }; let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); let mut messages = vec![ ChatMessage::user("Hello"), ChatMessage::tool("call1", "bash", &"x".repeat(200)), ]; let modified = compressor.fast_trim_tool_results(&mut messages); assert_eq!(modified, 1); assert!(messages[1].content.len() < 100); } #[test] fn test_threshold() { let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager()); assert_eq!(compressor.threshold(), 64_000); } #[tokio::test] async fn test_compress_if_needed_fast_trims_tool_results() { // context_window=200 → threshold=100. // user "Hi" (~6 raw), tool(3000 x's) → ~760 raw*1.2=912 > 100 → triggers compression. // fast_trim to 50 chars should bring tokens well under 100. let tmp = std::env::temp_dir().join(format!("picobot_ctx_trim_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let config = ContextCompressionConfig { tool_result_trim_chars: 50, protect_first_n: 0, protect_last_n: 10, max_passes: 0, ..Default::default() }; let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm); let messages = vec![ ChatMessage::user("Hi"), ChatMessage::tool("call1", "bash", &"x".repeat(3000)), ]; let result = compressor.compress_if_needed(messages).await.unwrap(); let tool_msg = result.iter().find(|m| m.role == "tool").unwrap(); assert!( tool_msg.content.len() < 3000, "tool result should be trimmed, got {} chars", tool_msg.content.len() ); assert!( tool_msg.content.contains("[Output truncated"), "trim marker missing from: {}", tool_msg.content ); let _ = std::fs::remove_file(&tmp); } #[tokio::test] async fn test_compress_once_no_duplicate_and_no_lost_user() { // Verifies two boundary bugs in compress_once: // - B2A (L230): first user message duplicated when protect_first_n > 0 // - B2B (L275): last user message lost when it is the final history message // // context_window=200 → threshold=100. Large tool outputs force LLM summarization. let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let config = ContextCompressionConfig { tool_result_trim_chars: 2000, protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated protect_last_n: 2, max_passes: 1, ..Default::default() }; let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm); // History: 9 messages, last message is user Q4. // user_indices (skip 1) = [1, 3, 6, 8] // B2A: init history[..=1] includes Q1, then loop i=0 pushes Q1 again → duplicate // B2B: last_user_idx=8, 8 < 8 → false → Q4 not pushed → lost let big = "x".repeat(3000); let messages = vec![ ChatMessage::system("You are a helper."), // 0: protected ChatMessage::user("Q1"), // 1: first user ChatMessage::tool("t1", "bash", &big), // 2 ChatMessage::user("Q2"), // 3 ChatMessage::assistant("thinking"), // 4 ChatMessage::tool("t2", "bash", &big), // 5 ChatMessage::user("Q3"), // 6 ChatMessage::assistant("thinking"), // 7 ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers ]; let result = compressor.compress_if_needed(messages).await.unwrap(); // B2A: "Q1" must appear exactly once let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count(); assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count); // B2B: "Q4" must NOT be lost let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count(); assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count); let _ = std::fs::remove_file(&tmp); } #[tokio::test] async fn test_compress_hard_truncation_fallback() { // When LLM compression fails (or max_passes=0) and tokens are still // above 90% of context_window, a head+tail truncation kicks in. let tmp = std::env::temp_dir().join(format!("picobot_ctx_trunc_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let config = ContextCompressionConfig { tool_result_trim_chars: 500, // trim reduces but not enough protect_first_n: 1, protect_last_n: 2, max_passes: 0, // no LLM summarization → will exceed danger ..Default::default() }; // context_window=100, danger_threshold=90. // Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387. // Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90. let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm); let big = "x".repeat(3000); let messages = vec![ ChatMessage::system("sys"), ChatMessage::user("Q1"), ChatMessage::tool("t1", "bash", &big), ChatMessage::user("Q2"), ChatMessage::tool("t2", "bash", &big), ChatMessage::user("Q3"), ChatMessage::tool("t3", "bash", &big), ]; let result = compressor.compress_if_needed(messages).await.unwrap(); // After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len()); // Truncation notice should be present let has_notice = result.iter().any(|m| m.content.contains("Context truncation")); assert!(has_notice, "hard truncation notice missing"); let _ = std::fs::remove_file(&tmp); } #[test] fn test_repair_tool_pairs_removes_orphans() { use crate::providers::ToolCall; // Simulate compressed output: summary replaced assistant(tool_call: tc1), // leaving tool(tc1) as an orphan. Legitimate tool(tc2) should be kept. let mut messages = vec![ ChatMessage::user("Q1"), ChatMessage::user("[Context Summary]\n\nsummary of previous turn"), ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared ChatMessage::assistant("done"), // declares tc2 ChatMessage::tool("tc2", "bash", "legitimate result"), // legit ]; // Set tool_call_id on tool messages and tool_calls on assistant messages[2].tool_call_id = Some("tc1".into()); messages[4].tool_call_id = Some("tc2".into()); messages[3].tool_calls = Some(vec![ToolCall { id: "tc2".into(), name: "bash".into(), arguments: serde_json::json!({"cmd": "echo ok"}), }]); ContextCompressor::repair_tool_pairs(&mut messages); // orphan should be removed; legitimate should stay assert_eq!(messages.len(), 4); assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into()))); assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into()))); } #[test] fn test_parse_context_limit_from_error() { // OpenAI: "maximum context length is 128000" assert_eq!( ContextCompressor::parse_context_limit_from_error( "This model's maximum context length is 128000 tokens." ), Some(128000) ); // Anthropic: "context window of 200000" assert_eq!( ContextCompressor::parse_context_limit_from_error( "Your request exceeds the context window of 200000." ), Some(200000) ); // llama.cpp: "available context size (8448 tokens)" assert_eq!( ContextCompressor::parse_context_limit_from_error( "context size exceeded, available context size (8448 tokens)" ), Some(8448) ); // Non-context error should return None assert_eq!( ContextCompressor::parse_context_limit_from_error("Internal server error"), None ); // Numbers too small should be rejected assert_eq!( ContextCompressor::parse_context_limit_from_error("context length is 500"), None ); } }