From 3df628bd28a97651d8678d534a272f94675bbb50 Mon Sep 17 00:00:00 2001 From: xiaoski Date: Fri, 29 May 2026 16:23:35 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85/info=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent_loop.rs | 21 ++-- src/agent/context_compressor.rs | 186 ++++++++++++++++++++++++++++---- src/session/session.rs | 56 +++++++++- 3 files changed, 233 insertions(+), 30 deletions(-) diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 63ab55c..61356b8 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -228,6 +228,7 @@ pub struct AgentLoop { pub struct AgentProcessResult { pub final_response: ChatMessage, pub emitted_messages: Vec, + pub total_tokens: Option, } impl AgentLoop { @@ -340,8 +341,9 @@ impl AgentLoop { } /// Preemptive trim: truncate old tool results in-place when history is - /// approaching the context window limit. Only trims tool messages with - /// content > TRIM_CHARS, preserving the most recent KEEP messages. + /// approaching the context window limit. Old results (outside of `keep_recent` + /// zone) are replaced with a short placeholder; recent results are truncated + /// to `max_chars`. fn preemptive_trim_old_tool_results( &self, messages: &mut [ChatMessage], @@ -358,11 +360,11 @@ impl AgentLoop { if messages[i].content.len() <= max_chars { continue; } - let removed = messages[i].content.len() - max_chars; + let tool_name = messages[i].tool_name.as_deref().unwrap_or("unknown"); + let chars = messages[i].content.len(); messages[i].content = format!( - "{}...\n\n[Output truncated - {} characters removed]", - &messages[i].content[..messages[i].content.ceil_char_boundary(max_chars)], - removed + "[Tool output ({}) — {} chars, omitted from context]", + tool_name, chars ); modified += 1; } @@ -413,6 +415,7 @@ impl AgentLoop { // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); + let mut accumulated_tokens: u32 = 0; for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] @@ -466,6 +469,8 @@ impl AgentLoop { AgentError::LlmError(e.to_string()) })?; + accumulated_tokens += response.usage.total_tokens; + #[cfg(debug_assertions)] tracing::debug!( iteration, @@ -482,6 +487,7 @@ impl AgentLoop { return Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, + total_tokens: Some(accumulated_tokens), }); } @@ -584,12 +590,14 @@ impl AgentLoop { match (*self.provider).chat(request).await { Ok(response) => { + accumulated_tokens += response.usage.total_tokens; let mut assistant_message = ChatMessage::assistant(response.content); assistant_message.reasoning_content = response.reasoning_content; emitted_messages.push(assistant_message.clone()); Ok(AgentProcessResult { final_response: assistant_message, emitted_messages, + total_tokens: Some(accumulated_tokens), }) } Err(e) => { @@ -602,6 +610,7 @@ impl AgentLoop { Ok(AgentProcessResult { final_response: final_message, emitted_messages, + total_tokens: if accumulated_tokens > 0 { Some(accumulated_tokens) } else { None }, }) } } diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index 578895d..6c20b61 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -68,6 +68,10 @@ pub struct ContextCompressor { memory: Arc, /// Current session ID for timeline memory writes. session_id: Option, + /// Message count sent in the last LLM call (used to split known/new history). + last_sent_message_count: Option, + /// Real total_tokens from the last API response. + last_api_total_tokens: Option, } /// Result of context compression. @@ -76,6 +80,15 @@ pub struct CompressionResult { pub created_timelines: bool, } +/// Token budget state snapshot for diagnostics. +pub struct TokenInfo { + pub context_window: usize, + pub threshold: usize, + pub estimated_tokens: usize, + pub last_api_tokens: Option, + pub cache_active: bool, +} + impl ContextCompressor { /// Create a new compressor with the given provider, context window size, and memory manager. pub fn new( @@ -90,6 +103,8 @@ impl ContextCompressor { provider, memory, session_id: None, + last_sent_message_count: None, + last_api_total_tokens: None, } } @@ -107,6 +122,8 @@ impl ContextCompressor { provider, memory, session_id: None, + last_sent_message_count: None, + last_api_total_tokens: None, } } @@ -120,39 +137,91 @@ impl ContextCompressor { self.context_window = window; } + /// Record the API's reported token usage from the last completed turn. + /// `msg_count`: number of messages sent to LLM in that call. + /// `tokens`: `total_tokens` from the API response. + pub fn set_last_api_info(&mut self, msg_count: usize, tokens: Option) { + self.last_sent_message_count = Some(msg_count); + self.last_api_total_tokens = tokens; + } + + /// Invalidate the cached API token info — called after compression modifies messages. + fn invalidate_token_cache(&mut self) { + self.last_sent_message_count = None; + self.last_api_total_tokens = None; + } + + /// Hybrid token estimation: API-reported tokens for known history + + /// char/4 estimate for new messages since last API call. + fn token_estimate_with_history(&self, messages: &[ChatMessage]) -> usize { + match (self.last_api_total_tokens, self.last_sent_message_count) { + (Some(known), Some(known_count)) if messages.len() > known_count => { + let delta = &messages[known_count..]; + known as usize + estimate_tokens(delta) + } + (Some(known), _) => known as usize, + _ => estimate_tokens(messages), + } + } + /// Always true — memory is always available (memory system is always on). pub fn has_memory(&self) -> bool { true } + /// Get a snapshot of the current token budget state for diagnostics. + pub fn token_info(&self, messages: &[ChatMessage]) -> TokenInfo { + TokenInfo { + context_window: self.context_window, + threshold: self.threshold(), + estimated_tokens: self.token_estimate_with_history(messages), + last_api_tokens: self.last_api_total_tokens, + cache_active: self.last_api_total_tokens.is_some(), + } + } + /// Get the compression threshold in tokens. pub fn threshold(&self) -> usize { (self.context_window as f64 * self.threshold_ratio) as usize } /// Fast-path: trim oversized tool results without LLM call. + /// Old tool results (outside of `protect_tail` zone) are replaced with a + /// concise placeholder; recent results are truncated to `tool_result_trim_chars`. /// Returns the number of messages modified. - fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize { + fn fast_trim_tool_results(&self, messages: &mut [ChatMessage], protect_tail: usize) -> usize { let limit = self.config.tool_result_trim_chars; + let tail_start = messages.len().saturating_sub(protect_tail); let mut modified = 0; - for msg in messages.iter_mut() { - if msg.role == "tool" && msg.content.len() > limit { + for (i, msg) in messages.iter_mut().enumerate() { + if msg.role != "tool" || msg.content.len() <= limit { + continue; + } + if i < tail_start { + let tool_name = msg.tool_name.as_deref().unwrap_or("unknown"); + let chars = msg.content.len(); + msg.content = format!( + "[Tool output ({}) — {} chars, omitted from context]", + tool_name, chars + ); + } else { let removed = msg.content.len() - limit; msg.content = format!( "{}...\n\n[Output truncated - {} characters removed]", &msg.content[..msg.content.ceil_char_boundary(limit)], removed ); - modified += 1; } + modified += 1; } modified } - /// Remove orphan tool results whose declaring tool_calls have been compressed away. - /// Scans for tool messages with no preceding assistant tool_call, and removes them. + /// Repair tool call chains after compression. + /// Phase 1: remove orphan tool results whose declaring tool_calls are missing. + /// Phase 2: strip tool_calls from assistants whose results are missing. pub fn repair_tool_pairs(messages: &mut Vec) { let mut declared: std::collections::HashSet = std::collections::HashSet::new(); let mut i = 0; @@ -171,15 +240,40 @@ impl ContextCompressor { } i += 1; } + + let broken: Vec = messages.iter().enumerate() + .filter_map(|(idx, msg)| { + if msg.role == "assistant" + && let Some(ref tcs) = msg.tool_calls + && !tcs.is_empty() { + let all_present = tcs.iter().all(|tc| { + messages.iter().any(|m| { + m.role == "tool" + && m.tool_call_id.as_deref() == Some(tc.id.as_str()) + }) + }); + if !all_present { Some(idx) } else { None } + } else { None } + }).collect(); + + for idx in broken { + let msg = &mut messages[idx]; + let tcs = msg.tool_calls.take().unwrap_or_default(); + let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); + msg.content = format!( + "{}\n\n[Tool calls ({}) — results are no longer available]", + msg.content, names.join(", ") + ); + } } /// Main entry point - compresses history if over threshold. pub async fn compress_if_needed( - &self, + &mut self, mut history: Vec, ) -> Result { // Check if compression is needed - let tokens = estimate_tokens(&history); + let tokens = self.token_estimate_with_history(&history); if tokens <= self.threshold() { return Ok(CompressionResult { history, created_timelines: false }); } @@ -193,8 +287,8 @@ impl ContextCompressor { ); // Fast trim pass first — modify history in place - let trimmed = self.fast_trim_tool_results(&mut history); - let tokens_after = estimate_tokens(&history); + let trimmed = self.fast_trim_tool_results(&mut history, self.config.protect_last_n); + let tokens_after = self.token_estimate_with_history(&history); if trimmed > 0 { #[cfg(debug_assertions)] tracing::debug!( @@ -204,6 +298,7 @@ impl ContextCompressor { ); } if tokens_after <= self.threshold() { + self.invalidate_token_cache(); return Ok(CompressionResult { history, created_timelines: false }); } @@ -211,7 +306,7 @@ impl ContextCompressor { let mut current_history = history; let mut created_timelines = false; for pass in 0..self.config.max_passes { - let tokens = estimate_tokens(¤t_history); + let tokens = self.token_estimate_with_history(¤t_history); if tokens <= self.threshold() { break; } @@ -241,15 +336,46 @@ impl ContextCompressor { // Hard safety net: if still dangerously high after all passes, // fall back to head+tail truncation so the LLM call doesn't overflow. - let final_tokens = estimate_tokens(¤t_history); + let final_tokens = self.token_estimate_with_history(¤t_history); let danger_threshold = (self.context_window as f64 * 0.9) as usize; if final_tokens > danger_threshold && current_history.len() > self.config.protect_first_n + self.config.protect_last_n { + let mut tail_start = current_history.len() - self.config.protect_last_n; + + // Align tail_start backwards to preserve tool chain boundaries: + // if an assistant with tool_calls has results spanning the cut, + // include the assistant in the tail. + if tail_start > 0 && tail_start < current_history.len() { + let mut scan = tail_start.saturating_sub(1); + loop { + let m = ¤t_history[scan]; + if m.role == "assistant" { + if let Some(tcs) = &m.tool_calls + && !tcs.is_empty() { + let has_post = current_history[scan + 1..] + .iter() + .filter(|r| r.role == "tool") + .any(|r| tcs.iter().any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))); + if has_post { + tail_start = scan; + } + } + break; + } + if scan == 0 { break; } + scan -= 1; + } + } + + // Skip orphan tool messages at the new head-tail boundary + while tail_start < current_history.len() && current_history[tail_start].role == "tool" { + tail_start += 1; + } + let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec(); - let tail_start = current_history.len() - self.config.protect_last_n; let tail: Vec<_> = current_history[tail_start..].to_vec(); - let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n; + let dropped = current_history.len() - self.config.protect_first_n - tail.len(); let mut truncated = head; truncated.push(ChatMessage::user(format!( @@ -259,6 +385,24 @@ impl ContextCompressor { ))); truncated.extend(tail); + // Strip tool_calls from any assistant in the head whose results + // were dropped (previously in the middle section). + for msg in &mut truncated[..self.config.protect_first_n] { + if msg.role == "assistant" { + if let Some(ref tcs) = msg.tool_calls + && !tcs.is_empty() { + let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); + msg.content = format!( + "{}\n\n[Tool calls ({}) — results dropped during truncation]", + msg.content, names.join(", ") + ); + msg.tool_calls = None; + } + } + } + + Self::repair_tool_pairs(&mut truncated); + tracing::warn!( final_tokens = final_tokens, danger = danger_threshold, @@ -269,9 +413,13 @@ impl ContextCompressor { current_history = truncated; } + if created_timelines { + self.invalidate_token_cache(); + } + #[cfg(debug_assertions)] tracing::debug!( - final_tokens = estimate_tokens(¤t_history), + final_tokens = self.token_estimate_with_history(¤t_history), final_msg_count = current_history.len(), "Context compression completed" ); @@ -592,7 +740,7 @@ mod tests { ChatMessage::tool("call1", "bash", &"x".repeat(200)), ]; - let modified = compressor.fast_trim_tool_results(&mut messages); + let modified = compressor.fast_trim_tool_results(&mut messages, 2); assert_eq!(modified, 1); assert!(messages[1].content.len() < 100); } @@ -619,7 +767,7 @@ mod tests { max_passes: 0, ..Default::default() }; - let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm); + let mut compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm); let messages = vec![ ChatMessage::user("Hi"), @@ -661,7 +809,7 @@ mod tests { max_passes: 1, ..Default::default() }; - let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm); + let mut compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm); // History: 9 messages, last message is user Q4. // user_indices (skip 1) = [1, 3, 6, 8] @@ -711,7 +859,7 @@ mod tests { // context_window=100, danger_threshold=90. // Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387. // Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90. - let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm); + let mut compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm); let big = "x".repeat(3000); let messages = vec![ diff --git a/src/session/session.rs b/src/session/session.rs index c2d971a..3b863ec 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -167,6 +167,7 @@ impl Session { compressor.set_session_id(Some(id.to_string())); let mut chat_messages: Vec = Vec::new(); + let mut restored_compressed_at = session_meta.last_compressed_message_at; if let Some(after_ts) = session_meta.last_compressed_message_at { // Load last 4 timelines to detect if there are more than 3 @@ -223,7 +224,7 @@ impl Session { repair_tool_call_chains(&mut tail_msgs); chat_messages.extend(tail_msgs); } else { - // No prior compression — load all messages (existing behavior) + // No prior compression — load all messages let messages = storage.load_messages(&id.to_string(), 0).await .map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?; @@ -247,6 +248,16 @@ impl Session { repair_tool_call_chains(&mut chat_messages); } + // Compress loaded history if it exceeds budget + if !chat_messages.is_empty() { + let result = compressor.compress_if_needed(chat_messages).await + .map_err(|e| AgentError::Other(format!("compression during restore: {}", e)))?; + if result.created_timelines { + restored_compressed_at = Some(chrono::Utc::now().timestamp_millis()); + } + chat_messages = result.history; + } + // seq_counter from actual DB max let max_seq = storage .get_max_message_seq(&id.to_string()) @@ -271,7 +282,7 @@ impl Session { storage: Some(storage), routing_info: session_meta.routing_info.unwrap_or_default(), last_consolidated_at: session_meta.last_consolidated_at, - last_compressed_message_at: session_meta.last_compressed_message_at, + last_compressed_message_at: restored_compressed_at, memory_manager, agent_tx: None, current_cancel: None, @@ -1018,7 +1029,8 @@ impl SessionManager { if let Some(sid) = current_session_id { let session = self.get_or_create_session(sid).await?; let session_guard = session.lock().await; - let message_count = session_guard.get_history().len(); + let history = session_guard.get_history(); + let message_count = history.len(); let session_id_str = session_guard.session_id(); let title = &session_guard.title; let model_name = &session_guard.provider_config.name; @@ -1028,9 +1040,41 @@ impl SessionManager { let last_active_at = chrono::DateTime::from_timestamp_millis(session_guard.last_active_at) .map(|dt| dt.with_timezone(&chrono::Local).format("%Y-%m-%d %H:%M:%S").to_string()) .unwrap_or_default(); + let token_info = session_guard.compressor.token_info(history); + let cache_info = if token_info.cache_active { + format!("API精确: {} tokens", token_info.last_api_tokens.unwrap_or(0)) + } else { + "无API精确缓存".to_string() + }; + let threshold_pct = if token_info.context_window > 0 { + (token_info.threshold as f64 / token_info.context_window as f64 * 100.0) as usize + } else { 0 }; + let usage_pct = if token_info.context_window > 0 { + (token_info.estimated_tokens as f64 / token_info.context_window as f64 * 100.0).min(100.0) as usize + } else { 0 }; + let usage_bar = if token_info.context_window > 0 { + format!("{}/{} tokens ({}%)", token_info.estimated_tokens, token_info.context_window, usage_pct) + } else { + "未设置".to_string() + }; + let compression_status = if token_info.estimated_tokens > token_info.threshold { + "[即将压缩]" + } else { + "[正常]" + }; + let ctx_info = format!( + "[窗口] {} [阈值] {}/{} ({}) [状态] {} {}", + usage_bar, + token_info.threshold, + token_info.context_window, + threshold_pct, + compression_status, + cache_info, + ); Ok((None, format!( - "对话标题: {}\nSession ID: {}\n模型: {}\n用户消息: {} / 总消息: {}\n创建时间: {}\n最后活跃: {}", - title, session_id_str, model_name, session_guard.message_count, message_count, created_at, last_active_at + "对话标题: {}\nSession ID: {}\n模型: {}\n用户消息: {} / 总消息: {}\n创建时间: {}\n最后活跃: {}\n\n上下文: {}", + title, session_id_str, model_name, session_guard.message_count, message_count, + created_at, last_active_at, ctx_info, ))) } else { Ok((None, "No active session.".to_string())) @@ -1777,6 +1821,8 @@ fn spawn_agent_worker( { tracing::warn!("failed to generate title: {}", e); } + let sent_count = guard.messages.len(); + guard.compressor.set_last_api_info(sent_count, result.total_tokens); result.final_response.content };