use std::sync::Arc; use crate::bus::ChatMessage; use crate::memory::MemoryManager; use crate::providers::{ChatCompletionRequest, LLMProvider, Message}; use crate::agent::AgentError; /// Token estimation using ~4 chars/token heuristic with 1.2x safety margin. pub fn estimate_tokens(messages: &[ChatMessage]) -> usize { let raw: usize = messages .iter() .map(|m| m.content.len().div_ceil(4) + 4) .sum(); (raw as f64 * 1.2) as usize } /// Configuration for context compression. #[derive(Debug, Clone)] pub struct ContextCompressionConfig { /// Protect first N messages (system prompt, etc.) pub protect_first_n: usize, /// Protect last N messages (recent context) pub protect_last_n: usize, /// Maximum compression passes pub max_passes: u32, /// Maximum characters in summary pub summary_max_chars: usize, /// Characters to keep when trimming tool results pub tool_result_trim_chars: usize, } impl Default for ContextCompressionConfig { fn default() -> Self { Self { protect_first_n: 1, protect_last_n: 4, max_passes: 3, summary_max_chars: 4000, tool_result_trim_chars: 2000, } } } /// Context compressor that reduces message history when it exceeds token limits. pub struct ContextCompressor { config: ContextCompressionConfig, context_window: usize, /// Threshold ratio to trigger compression (50% of context window) threshold_ratio: f64, /// Shared LLM provider for summarization provider: Arc, /// Memory manager handle. Compressed context summaries are persisted /// as timeline memory entries. memory: Arc, /// Current session ID for timeline memory writes. session_id: Option, } impl ContextCompressor { /// Create a new compressor with the given provider, context window size, and memory manager. pub fn new( provider: Arc, context_window: usize, memory: Arc, ) -> Self { Self { config: ContextCompressionConfig::default(), context_window, threshold_ratio: 0.5, provider, memory, session_id: None, } } /// Create with custom configuration. pub fn with_config( provider: Arc, context_window: usize, config: ContextCompressionConfig, memory: Arc, ) -> Self { Self { config, context_window, threshold_ratio: 0.5, provider, memory, session_id: None, } } /// Set the current session ID for timeline writes. pub fn set_session_id(&mut self, id: Option) { self.session_id = id; } /// Get the compression threshold in tokens. fn threshold(&self) -> usize { (self.context_window as f64 * self.threshold_ratio) as usize } /// Fast-path: trim oversized tool results without LLM call. /// Returns the number of messages modified. fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize { let limit = self.config.tool_result_trim_chars; let mut modified = 0; for msg in messages.iter_mut() { if msg.role == "tool" && msg.content.len() > limit { let removed = msg.content.len() - limit; msg.content = format!( "{}...\n\n[Output truncated - {} characters removed]", &msg.content[..msg.content.ceil_char_boundary(limit)], removed ); modified += 1; } } modified } /// Main entry point - compresses history if over threshold. pub async fn compress_if_needed( &self, history: Vec, ) -> Result, AgentError> { // Check if compression is needed let tokens = estimate_tokens(&history); if tokens <= self.threshold() { return Ok(history); } #[cfg(debug_assertions)] tracing::debug!( tokens = tokens, threshold = self.threshold(), msg_count = history.len(), "Starting context compression" ); // Fast trim pass first let trimmed = self.fast_trim_tool_results(&mut history.clone()); if trimmed > 0 { let tokens_after = estimate_tokens(&history); #[cfg(debug_assertions)] tracing::debug!( trimmed_messages = trimmed, tokens_after = tokens_after, "Fast trim completed" ); if tokens_after <= self.threshold() { return Ok(history); } } // LLM summarization pass let mut current_history = history; for pass in 0..self.config.max_passes { let tokens = estimate_tokens(¤t_history); if tokens <= self.threshold() { break; } #[cfg(debug_assertions)] tracing::debug!( pass = pass + 1, tokens = tokens, "Compression pass" ); match self.compress_once(¤t_history).await { Ok(Some(compressed)) => { current_history = compressed; } Ok(None) => { // No more compressible content break; } Err(e) => { tracing::warn!(error = %e, "Compression pass failed, using current history"); break; } } } #[cfg(debug_assertions)] tracing::debug!( final_tokens = estimate_tokens(¤t_history), final_msg_count = current_history.len(), "Context compression completed" ); Ok(current_history) } /// Single compression pass - summarize middle messages between user turns. /// Returns Some(compressed) if compression happened, None if nothing to compress. async fn compress_once( &self, history: &[ChatMessage], ) -> Result>, AgentError> { if history.len() <= self.config.protect_first_n + self.config.protect_last_n { return Ok(None); } // Find user message indices (excluding protected first messages) let user_indices: Vec = history .iter() .enumerate() .skip(self.config.protect_first_n) .filter(|(_, m)| m.role == "user") .map(|(i, _)| i) .collect(); // Need at least one user message and content between users to compress if user_indices.len() < 2 { return Ok(None); } // Build segments: user -> (assistant turns) -> next user // We'll summarize the assistant turns between consecutive user messages let mut new_messages = history[..=user_indices[0]].to_vec(); for i in 0..user_indices.len() - 1 { let user_idx = user_indices[i]; let next_user_idx = user_indices[i + 1]; new_messages.push(history[user_idx].clone()); // Check if there's assistant content between these two user messages let between_start = user_idx + 1; let between_end = next_user_idx; if between_start < between_end { let between = &history[between_start..between_end]; let summary = self.summarize_segment(between).await?; // Persist compressed summary as timeline memory entry let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", ts, between.len(), summary); let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); let mm = self.memory.clone(); let sid = self.session_id.clone(); tokio::spawn(async move { if let Err(e) = mm.store( &key, &timeline_content, crate::memory::MemoryCategory::Timeline, sid.as_deref(), Some(0.3), ).await { tracing::warn!(error = %e, "Failed to store compressed context as timeline"); } }); // Add summary as a special user message new_messages.push(ChatMessage::user(format!( "[Context Summary]\n\n{}", summary ))); } } // Add last user and everything after (protected) let last_user_idx = user_indices[user_indices.len() - 1]; if last_user_idx < history.len() - 1 { // Add everything from last user onwards (protected) for i in last_user_idx..history.len() { new_messages.push(history[i].clone()); } } // If nothing changed, return None if new_messages.len() == history.len() { return Ok(None); } Ok(Some(new_messages)) } /// Summarize a segment of messages using LLM. async fn summarize_segment( &self, messages: &[ChatMessage], ) -> Result { if messages.is_empty() { return Ok(String::new()); } // Build transcript for summarization let transcript = messages .iter() .map(|m| { let role = match m.role.as_str() { "assistant" => "Assistant", "tool" => "Tool", _ => m.role.as_str(), }; let name = m.tool_name .as_ref() .map(|n| format!(" ({})", n)) .unwrap_or_default(); format!("{}: {}{}", role, m.content, name) }) .collect::>() .join("\n\n"); // Truncate transcript if too long let transcript = if transcript.len() > self.config.summary_max_chars { format!( "{}...\n\n[Transcript truncated - {} characters removed]", &transcript[..transcript.ceil_char_boundary(self.config.summary_max_chars)], transcript.len() - self.config.summary_max_chars ) } else { transcript }; let prompt = format!( r#"You are a conversation compaction engine. Summarize the following conversation segment. PRESERVE: - All identifiers (UUIDs, hashes, file paths, URLs) - Actions taken (tool calls, file operations, commands) - Key information obtained (results, data, errors) - Decisions and user preferences - Current task status OMIT: - Verbose tool output (keep key results only) - Repeated greetings or filler Be concise, aim for {} characters or less. --- {} "#, self.config.summary_max_chars, transcript ); let request = ChatCompletionRequest { messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], temperature: Some(0.3), max_tokens: Some(1000), tools: None, }; match (*self.provider).chat(request).await { Ok(response) => Ok(response.content), Err(e) => { // Fallback: just truncate the transcript tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript"); Ok(transcript[..transcript.ceil_char_boundary(2000)].to_string()) } } } } #[cfg(test)] mod tests { use super::*; use crate::providers::ChatCompletionResponse; use async_trait::async_trait; use std::sync::OnceLock; /// Mock provider for testing - panics if actually used for LLM calls struct MockProvider; #[async_trait] impl LLMProvider for MockProvider { async fn chat( &self, _request: ChatCompletionRequest, ) -> Result> { panic!("MockProvider.chat() called - not expected in test") } fn ptype(&self) -> &str { "mock" } fn name(&self) -> &str { "mock" } fn model_id(&self) -> &str { "mock" } } fn mock_provider() -> Arc { Arc::new(MockProvider) } fn test_memory_manager() -> Arc { static MM: OnceLock> = OnceLock::new(); MM.get_or_init(|| { let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); Arc::new(MemoryManager::new(storage, "test".into(), "test".into())) }) }).clone() } #[test] fn test_estimate_tokens() { let messages = vec![ ChatMessage::user("Hello"), ChatMessage::assistant("Hi there!"), ChatMessage::user("How are you?"), ]; let tokens = estimate_tokens(&messages); // "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6 // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // raw = 19, with 1.2x = ~23 assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); } #[test] fn test_fast_trim() { let config = ContextCompressionConfig { tool_result_trim_chars: 50, ..Default::default() }; let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); let mut messages = vec![ ChatMessage::user("Hello"), ChatMessage::tool("call1", "bash", &"x".repeat(200)), ]; let modified = compressor.fast_trim_tool_results(&mut messages); assert_eq!(modified, 1); assert!(messages[1].content.len() < 100); } #[test] fn test_threshold() { let compressor = ContextCompressor::new(mock_provider(), 128_000, test_memory_manager()); assert_eq!(compressor.threshold(), 64_000); } }