diff --git a/config.json b/config.json index 0f78f6f..580b0c9 100644 --- a/config.json +++ b/config.json @@ -18,7 +18,6 @@ "provider": "default", "model": "default", "tool_result_max_chars": 20000, - "context_summary_max_chars": 20000, "context_tool_result_trim_chars": 2000 } }, diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index 7464b8a..a4e79cb 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -5,7 +5,7 @@ use crate::bus::{ SYSTEM_CONTEXT_SCHEDULED_PROMPT, }; use crate::config::LLMProviderConfig; -use crate::providers::{create_provider, ChatCompletionRequest, Message}; +use crate::providers::{create_provider, ChatCompletionRequest, LLMProvider, Message}; use crate::text::{char_count, take_prefix_chars}; use crate::agent::AgentError; @@ -62,6 +62,194 @@ pub struct ContextCompressor { } impl ContextCompressor { + fn summary_char_budget_for_context_window(context_window: usize) -> usize { + const SUMMARY_RATIO: f64 = 0.1; + const CHARS_PER_TOKEN: f64 = 2.5; + const MIN_SUMMARY_CHARS: usize = 1_500; + const MAX_SUMMARY_CHARS: usize = 50_000; + + ((context_window as f64 * SUMMARY_RATIO * CHARS_PER_TOKEN) as usize) + .clamp(MIN_SUMMARY_CHARS, MAX_SUMMARY_CHARS) + } + + fn format_transcript_entry(message: &ChatMessage) -> String { + let role = match message.role.as_str() { + "assistant" => "Assistant", + "tool" => "Tool", + _ => message.role.as_str(), + }; + let name = message + .tool_name + .as_ref() + .map(|n| format!(" ({})", n)) + .unwrap_or_default(); + format!("{}: {}{}", role, message.content, name) + } + + fn build_transcript(messages: &[ChatMessage]) -> String { + messages + .iter() + .map(Self::format_transcript_entry) + .collect::>() + .join("\n\n") + } + + fn split_text_chunks(text: &str, max_chars: usize) -> Vec { + if text.is_empty() { + return Vec::new(); + } + + let chunk_size = max_chars.max(1); + let chars: Vec = text.chars().collect(); + chars + .chunks(chunk_size) + .map(|chunk| chunk.iter().collect()) + .collect() + } + + fn chunk_messages_for_summary(messages: &[ChatMessage], max_chars: usize) -> Vec { + if messages.is_empty() { + return Vec::new(); + } + + let chunk_limit = max_chars.max(1); + let mut chunks = Vec::new(); + let mut current = String::new(); + + for entry in messages.iter().map(Self::format_transcript_entry) { + let separator = if current.is_empty() { "" } else { "\n\n" }; + let candidate = format!("{}{}{}", current, separator, entry); + + if !current.is_empty() && char_count(&candidate) > chunk_limit { + chunks.push(current); + current = String::new(); + } + + if char_count(&entry) > chunk_limit { + if !current.is_empty() { + chunks.push(current); + current = String::new(); + } + chunks.extend(Self::split_text_chunks(&entry, chunk_limit)); + continue; + } + + if current.is_empty() { + current = entry; + } else { + current.push_str("\n\n"); + current.push_str(&entry); + } + } + + if !current.is_empty() { + chunks.push(current); + } + + chunks + } + + fn build_summary_prompt(transcript: &str, target_chars: usize) -> String { + format!( + r#"You are a conversation compaction engine. Summarize the following conversation segment. + +PRESERVE: +- Each user question or request in full or as a near-verbatim restatement +- All identifiers (UUIDs, hashes, file paths, URLs) +- Actions taken (tool calls, file operations, commands) +- Key information obtained (results, data, errors) +- Decisions and user preferences +- Current task status + +OMIT: +- Reproducing full tool output verbatim unless it is essential +- Repeated greetings or filler + +Do not assume tool content was pre-trimmed. You may receive long tool outputs; keep the important results, errors, and artifacts. + +Be concise, aim for {} characters or less. + +--- + +{} + +"#, + target_chars, transcript + ) + } + + async fn summarize_transcript( + &self, + provider: &dyn LLMProvider, + transcript: &str, + target_chars: usize, + ) -> Result { + let request = ChatCompletionRequest { + messages: vec![ + Message::system("You are a helpful assistant."), + Message::user(Self::build_summary_prompt(transcript, target_chars)), + ], + temperature: Some(0.3), + max_tokens: Some(1000), + tools: None, + }; + + let response = provider + .chat(request) + .await + .map_err(|e| AgentError::LlmError(e.to_string()))?; + Ok(response.content) + } + + async fn summarize_chunked_transcript( + &self, + provider: &dyn LLMProvider, + messages: &[ChatMessage], + transcript: &str, + ) -> Result { + let target_chars = self.config.summary_max_chars.max(1); + let mut layer = Self::chunk_messages_for_summary(messages, target_chars); + + if layer.is_empty() { + layer.push(transcript.to_string()); + } + + for _ in 0..6 { + if layer.len() == 1 && char_count(&layer[0]) <= target_chars { + return self + .summarize_transcript(provider, &layer[0], target_chars) + .await; + } + + let per_chunk_target = (target_chars / layer.len().max(1)).max(500).min(target_chars); + let mut summaries = Vec::with_capacity(layer.len()); + for chunk in &layer { + summaries.push( + self.summarize_transcript(provider, chunk, per_chunk_target) + .await?, + ); + } + + if summaries.len() == 1 { + let summary = summaries.pop().unwrap_or_default(); + if char_count(&summary) <= target_chars { + return Ok(summary); + } + layer = Self::split_text_chunks(&summary, target_chars); + continue; + } + + let merged = summaries.join("\n\n"); + if char_count(&merged) <= target_chars { + return self.summarize_transcript(provider, &merged, target_chars).await; + } + + layer = Self::split_text_chunks(&merged, target_chars); + } + + Ok(take_prefix_chars(transcript, target_chars)) + } + /// Create a new compressor with the given context window size. pub fn new(context_window: usize) -> Self { Self { @@ -73,9 +261,9 @@ impl ContextCompressor { pub fn from_provider_config(provider_config: &LLMProviderConfig) -> Self { Self::with_config( - provider_config.token_limit, + provider_config.context_window_tokens(), ContextCompressionConfig { - summary_max_chars: provider_config.context_summary_max_chars, + summary_max_chars: provider_config.context_summary_char_budget(), ..ContextCompressionConfig::default() }, ) @@ -236,77 +424,21 @@ impl ContextCompressor { return Ok(String::new()); } - // Build transcript for summarization - let transcript = messages - .iter() - .map(|m| { - let role = match m.role.as_str() { - "assistant" => "Assistant", - "tool" => "Tool", - _ => m.role.as_str(), - }; - let name = m.tool_name - .as_ref() - .map(|n| format!(" ({})", n)) - .unwrap_or_default(); - format!("{}: {}{}", role, m.content, name) - }) - .collect::>() - .join("\n\n"); - - // Truncate transcript if too long - let transcript = if char_count(&transcript) > self.config.summary_max_chars { - format!( - "{}...\n\n[Transcript truncated - {} characters removed]", - take_prefix_chars(&transcript, self.config.summary_max_chars), - char_count(&transcript).saturating_sub(self.config.summary_max_chars) - ) - } else { - transcript - }; - - let prompt = format!( - r#"You are a conversation compaction engine. Summarize the following conversation segment. - -PRESERVE: -- Each user question or request in full or as a near-verbatim restatement -- All identifiers (UUIDs, hashes, file paths, URLs) -- Actions taken (tool calls, file operations, commands) -- Key information obtained (results, data, errors) -- Decisions and user preferences -- Current task status - -OMIT: -- Reproducing full tool output verbatim unless it is essential -- Repeated greetings or filler - -Do not assume tool content was pre-trimmed. You may receive long tool outputs; keep the important results, errors, and artifacts. - -Be concise, aim for {} characters or less. - ---- - -{} - -"#, - self.config.summary_max_chars, transcript - ); - - // Create provider and call LLM let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; + let transcript = Self::build_transcript(messages); - let request = ChatCompletionRequest { - messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], - temperature: Some(0.3), - max_tokens: Some(1000), - tools: None, + let result = if char_count(&transcript) <= self.config.summary_max_chars { + self.summarize_transcript(provider.as_ref(), &transcript, self.config.summary_max_chars) + .await + } else { + self.summarize_chunked_transcript(provider.as_ref(), messages, &transcript) + .await }; - match provider.chat(request).await { - Ok(response) => Ok(response.content), + match result { + Ok(summary) => Ok(summary), Err(e) => { - // Fallback: just truncate the transcript tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript"); Ok(take_prefix_chars(&transcript, self.config.summary_max_chars)) } @@ -384,4 +516,39 @@ mod tests { let compressor = ContextCompressor::new(128_000); assert_eq!(compressor.threshold(), 64_000); } + + #[test] + fn test_summary_char_budget_for_context_window_scales_and_clamps() { + assert_eq!(ContextCompressor::summary_char_budget_for_context_window(4_096), 1_500); + assert_eq!(ContextCompressor::summary_char_budget_for_context_window(65_536), 16_384); + assert_eq!(ContextCompressor::summary_char_budget_for_context_window(128_000), 32_000); + assert_eq!(ContextCompressor::summary_char_budget_for_context_window(400_000), 50_000); + } + + #[test] + fn test_chunk_messages_for_summary_keeps_message_boundaries_when_possible() { + let messages = vec![ + ChatMessage::user("alpha"), + ChatMessage::assistant("beta"), + ChatMessage::user("gamma"), + ]; + + let chunks = ContextCompressor::chunk_messages_for_summary(&messages, 30); + + assert_eq!(chunks.len(), 2); + assert!(chunks.iter().all(|chunk| char_count(chunk) <= 30)); + assert_eq!(chunks[0], "user: alpha\n\nAssistant: beta"); + assert_eq!(chunks[1], "user: gamma"); + } + + #[test] + fn test_chunk_messages_for_summary_splits_oversized_message() { + let messages = vec![ChatMessage::user(&"x".repeat(25))]; + + let chunks = ContextCompressor::chunk_messages_for_summary(&messages, 10); + + assert!(chunks.len() > 1); + assert!(chunks.iter().all(|chunk| char_count(chunk) <= 10)); + assert_eq!(chunks.concat(), "user: xxxxxxxxxxxxxxxxxxxxxxxxx"); + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index fec8ca5..6a890e0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -162,12 +162,8 @@ pub struct AgentConfig { pub model: String, #[serde(default = "default_max_tool_iterations")] pub max_tool_iterations: usize, - #[serde(default = "default_token_limit")] - pub token_limit: usize, #[serde(default = "default_tool_result_max_chars")] pub tool_result_max_chars: usize, - #[serde(default = "default_context_summary_max_chars")] - pub context_summary_max_chars: usize, #[serde(default = "default_context_tool_result_trim_chars")] pub context_tool_result_trim_chars: usize, } @@ -176,18 +172,10 @@ fn default_max_tool_iterations() -> usize { 100 } -fn default_token_limit() -> usize { - 128_000 -} - fn default_tool_result_max_chars() -> usize { 20_000 } -fn default_context_summary_max_chars() -> usize { - 20_000 -} - fn default_context_tool_result_trim_chars() -> usize { 2_000 } @@ -524,12 +512,28 @@ pub struct LLMProviderConfig { pub max_tokens: Option, pub model_extra: HashMap, pub max_tool_iterations: usize, - pub token_limit: usize, pub tool_result_max_chars: usize, - pub context_summary_max_chars: usize, pub context_tool_result_trim_chars: usize, } +impl LLMProviderConfig { + pub fn context_window_tokens(&self) -> usize { + self.max_tokens + .map(|value| value as usize) + .unwrap_or(128_000) + } + + pub fn context_summary_char_budget(&self) -> usize { + const SUMMARY_RATIO: f64 = 0.1; + const CHARS_PER_TOKEN: f64 = 2.5; + const MIN_SUMMARY_CHARS: usize = 1_500; + const MAX_SUMMARY_CHARS: usize = 50_000; + + ((self.context_window_tokens() as f64 * SUMMARY_RATIO * CHARS_PER_TOKEN) as usize) + .clamp(MIN_SUMMARY_CHARS, MAX_SUMMARY_CHARS) + } +} + fn get_default_config_path() -> PathBuf { let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); home.join(".picobot").join("config.json") @@ -590,9 +594,7 @@ impl Config { max_tokens: model.max_tokens, model_extra: model.extra.clone(), max_tool_iterations: agent.max_tool_iterations, - token_limit: agent.token_limit, tool_result_max_chars: agent.tool_result_max_chars, - context_summary_max_chars: agent.context_summary_max_chars, context_tool_result_trim_chars: agent.context_tool_result_trim_chars, }) } @@ -730,10 +732,11 @@ mod tests { assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.model_id, "qwen-plus"); assert_eq!(provider_config.temperature, Some(0.0)); + assert_eq!(provider_config.max_tokens, None); assert_eq!(provider_config.llm_timeout_secs, 120); assert_eq!(provider_config.tool_result_max_chars, 20_000); - assert_eq!(provider_config.context_summary_max_chars, 20_000); assert_eq!(provider_config.context_tool_result_trim_chars, 2_000); + assert_eq!(provider_config.context_summary_char_budget(), 32_000); } #[test] @@ -957,7 +960,6 @@ mod tests { let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.agents["default"].max_tool_iterations, 100); assert_eq!(config.agents["default"].tool_result_max_chars, 20_000); - assert_eq!(config.agents["default"].context_summary_max_chars, 20_000); assert_eq!(config.agents["default"].context_tool_result_trim_chars, 2_000); } @@ -985,7 +987,6 @@ mod tests { "provider": "aliyun", "model": "qwen-plus", "tool_result_max_chars": 1234, - "context_summary_max_chars": 2345, "context_tool_result_trim_chars": 3456 } } @@ -998,11 +999,47 @@ mod tests { let provider_config = config.get_provider_config("default").unwrap(); assert_eq!(agent.tool_result_max_chars, 1234); - assert_eq!(agent.context_summary_max_chars, 2345); assert_eq!(agent.context_tool_result_trim_chars, 3456); assert_eq!(provider_config.tool_result_max_chars, 1234); - assert_eq!(provider_config.context_summary_max_chars, 2345); assert_eq!(provider_config.context_tool_result_trim_chars, 3456); + assert_eq!(provider_config.context_summary_char_budget(), 32_000); + } + + #[test] + fn test_provider_config_summary_budget_scales_with_model_max_tokens() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus", + "max_tokens": 4096 + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let provider_config = config.get_provider_config("default").unwrap(); + + assert_eq!(provider_config.context_window_tokens(), 4096); + assert_eq!(provider_config.context_summary_char_budget(), 1_500); } #[test] diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 0adb229..6b24a68 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1618,9 +1618,7 @@ mod tests { max_tokens: Some(32), model_extra: HashMap::new(), max_tool_iterations: 1, - token_limit: 4096, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, } } @@ -1638,9 +1636,7 @@ mod tests { max_tokens: Some(32), model_extra: HashMap::new(), max_tool_iterations: 1, - token_limit: 4096, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, } } @@ -1846,10 +1842,8 @@ mod tests { max_tokens: Some(32), model_extra: HashMap::new(), max_tool_iterations: 1, - token_limit: 4096, llm_timeout_secs: 30, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -1886,10 +1880,8 @@ mod tests { max_tokens: Some(32), model_extra: HashMap::new(), max_tool_iterations: 1, - token_limit: 4096, llm_timeout_secs: 30, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let planner_provider = LLMProviderConfig { @@ -1972,10 +1964,8 @@ mod tests { json!(mock_response_content), )]), max_tool_iterations: 1, - token_limit: 4096, llm_timeout_secs: 30, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -2063,10 +2053,8 @@ mod tests { json!(mock_response_content), )]), max_tool_iterations: 1, - token_limit: 4096, llm_timeout_secs: 30, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; @@ -2125,10 +2113,8 @@ mod tests { json!(mock_response_content), )]), max_tool_iterations: 1, - token_limit: 4096, llm_timeout_secs: 30, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index ac90fd9..6b634ef 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -848,10 +848,8 @@ mod tests { temperature: Some(0.0), max_tokens: None, model_extra: HashMap::new(), - token_limit: 4096, max_tool_iterations: 4, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( @@ -900,10 +898,8 @@ mod tests { temperature: Some(0.0), max_tokens: None, model_extra: HashMap::new(), - token_limit: 4096, max_tool_iterations: 4, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new(