feat: 更新令牌估算逻辑,使用 JSON 序列化以匹配实际请求大小,并在当前话题处理器中添加令牌估算

This commit is contained in:
oudecheng 2026-05-19 17:05:04 +08:00
parent 4ec4e2b993
commit 3c2650824c
2 changed files with 29 additions and 31 deletions

View File

@ -11,31 +11,19 @@ use crate::agent::{AgentError, AgentRuntimeConfig};
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4; const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2; const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
/// Token estimation using ~4 chars/token heuristic with 1.2x safety margin. /// Token estimation using JSON serialization (matches actual request size)
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize { pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
let raw: usize = messages // Serialize to JSON to match actual request format sent to LLM
.iter() let serialized_len = serde_json::to_string(messages)
.map(|message| { .map(|s| s.len())
message .unwrap_or_else(|_| {
.content // Fallback: use content length if serialization fails
.len() messages.iter().map(|m| m.content.len()).sum()
.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN) });
+ estimate_image_tokens(&message.media_refs)
+ 4
})
.sum();
(raw as f64 * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
}
fn estimate_image_tokens(media_refs: &[String]) -> usize { // Apply safety margin for token estimation
media_refs ((serialized_len / TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64
.iter() * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
.filter_map(|path| std::fs::metadata(path).ok())
.map(|metadata| {
let base64_chars = metadata.len().saturating_mul(4).div_ceil(3) as usize;
base64_chars.div_ceil(TOKEN_ESTIMATE_CHARS_PER_TOKEN)
})
.sum()
} }
/// Configuration for context compression. /// Configuration for context compression.
@ -502,13 +490,13 @@ mod tests {
]; ];
let tokens = estimate_tokens(&messages); let tokens = estimate_tokens(&messages);
// "Hello" (5) -> ceil(5/4)+4 = 2+4 = 6 // JSON serialization includes: id, role, content, timestamp, etc.
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // With 3 messages, the JSON overhead is significant
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // Serialized JSON is typically 300-500 chars for 3 simple messages
// raw = 19, with 1.2x = ~23 // 500 / 4 * 1.2 = ~150 tokens
assert!( assert!(
tokens > 18 && tokens < 30, tokens > 50 && tokens < 300,
"Expected ~23 tokens, got {}", "Expected ~100-200 tokens (JSON overhead), got {}",
tokens tokens
); );
} }

View File

@ -1,3 +1,4 @@
use crate::agent::context_compressor::estimate_tokens;
use crate::command::context::CommandContext; use crate::command::context::CommandContext;
use crate::command::handler::{CommandHandler, CommandMetadata}; use crate::command::handler::{CommandHandler, CommandMetadata};
use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::response::{CommandError, CommandResponse, MessageKind};
@ -56,14 +57,22 @@ async fn handle_get_current_session(
.map_err(|e| CommandError::new("GET_TOPIC_ERROR", e.to_string()))? .map_err(|e| CommandError::new("GET_TOPIC_ERROR", e.to_string()))?
.ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", topic_id)))?; .ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", topic_id)))?;
// Load messages and estimate tokens
let messages = handler
.store
.load_messages_for_topic(topic_id)
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?;
let estimated_tokens = estimate_tokens(&messages);
let last_active = format_time_ago(topic.last_active_at); let last_active = format_time_ago(topic.last_active_at);
let created_at = format_time_ago(topic.created_at); let created_at = format_time_ago(topic.created_at);
let message = format!( let message = format!(
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Created: {}\n Last Active: {}", "Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{}\n Created: {}\n Last Active: {}",
topic.id, topic.id,
topic.title, topic.title,
topic.message_count, topic.message_count,
estimated_tokens,
created_at, created_at,
last_active last_active
); );
@ -72,7 +81,8 @@ async fn handle_get_current_session(
.with_message(MessageKind::Notification, &message) .with_message(MessageKind::Notification, &message)
.with_metadata("topic_id", &topic.id) .with_metadata("topic_id", &topic.id)
.with_metadata("title", &topic.title) .with_metadata("title", &topic.title)
.with_metadata("message_count", &topic.message_count.to_string())) .with_metadata("message_count", &topic.message_count.to_string())
.with_metadata("estimated_tokens", &estimated_tokens.to_string()))
} }
fn format_time_ago(timestamp_ms: i64) -> String { fn format_time_ago(timestamp_ms: i64) -> String {