diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index a164489..2d21028 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -8,21 +8,54 @@ use crate::text::{char_count, take_prefix_chars}; use crate::agent::{AgentError, AgentRuntimeConfig}; -const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4; const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2; +const CJK_CHARS_PER_TOKEN: f64 = 2.0; +const OTHER_CHARS_PER_TOKEN: f64 = 4.0; +const JSON_OVERHEAD_PER_MESSAGE: usize = 50; -/// Token estimation using JSON serialization (matches actual request size) +/// Check if a character is CJK (Chinese, Japanese, Korean) +fn is_cjk_char(c: char) -> bool { + matches!(c, + '\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs + '\u{3040}'..='\u{309F}' | // Hiragana + '\u{30A0}'..='\u{30FF}' | // Katakana + '\u{AC00}'..='\u{D7AF}' | // Korean Hangul + '\u{3400}'..='\u{4DBF}' | // CJK Extension A + '\u{20000}'..='\u{2A6DF}' // CJK Extension B + ) +} + +/// Token estimation using weighted character counting based on language pub fn estimate_tokens(messages: &[ChatMessage]) -> usize { - // Serialize to JSON to match actual request format sent to LLM - let serialized_len = serde_json::to_string(messages) - .map(|s| s.len()) - .unwrap_or_else(|_| { - // Fallback: use content length if serialization fails - messages.iter().map(|m| m.content.len()).sum() - }); + let mut cjk_count = 0usize; + let mut other_count = 0usize; + let mut media_refs_count = 0usize; - // Apply safety margin for token estimation - ((serialized_len / TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64 + for msg in messages { + // Count content characters + for c in msg.content.chars() { + if is_cjk_char(c) { + cjk_count += 1; + } else { + other_count += 1; + } + } + // Count media references + media_refs_count += msg.media_refs.len(); + } + + // Weighted token calculation: CJK chars need more tokens per character + let content_tokens = (cjk_count as f64 / CJK_CHARS_PER_TOKEN) + + (other_count as f64 / OTHER_CHARS_PER_TOKEN); + + // JSON serialization overhead for message structure (fields, brackets, etc.) + let json_overhead = messages.len() * JSON_OVERHEAD_PER_MESSAGE; + + // Media references add to JSON size (each path is a string in the array) + let media_overhead = media_refs_count * 20; // Each media ref adds ~20 chars to JSON + + // Apply safety multiplier + ((content_tokens + json_overhead as f64 + media_overhead as f64) * TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize } @@ -490,17 +523,77 @@ mod tests { ]; let tokens = estimate_tokens(&messages); - // JSON serialization includes: id, role, content, timestamp, etc. - // With 3 messages, the JSON overhead is significant - // Serialized JSON is typically 300-500 chars for 3 simple messages - // 500 / 4 * 1.2 = ~150 tokens + // Content: "Hello" (5) + "Hi there!" (9) + "How are you?" (12) = 26 chars + // English: 26 / 4 = 6.5 tokens for content + // JSON overhead: 3 * 50 = 150 + // Total before multiplier: 156.5 + // After 1.2x: ~188 tokens assert!( - tokens > 50 && tokens < 300, - "Expected ~100-200 tokens (JSON overhead), got {}", + tokens > 100 && tokens < 300, + "Expected ~150-250 tokens for English content, got {}", tokens ); } + #[test] + fn test_estimate_tokens_chinese_content() { + let messages = vec![ + ChatMessage::user("你好,这是一个中文测试"), + ChatMessage::assistant("这是一个中文回复消息"), + ]; + + let tokens = estimate_tokens(&messages); + // Content: ~20 CJK chars, CJK uses 2 chars/token = ~10 tokens for content + // JSON overhead: 2 * 50 = 100 + // Total before multiplier: ~110 + // After 1.2x: ~132 tokens + assert!( + tokens > 80 && tokens < 200, + "Expected ~100-180 tokens for Chinese content, got {}", + tokens + ); + } + + #[test] + fn test_estimate_tokens_mixed_content() { + let messages = vec![ + ChatMessage::user("Hello 世界 this is 测试"), + ]; + + let tokens = estimate_tokens(&messages); + // Content: 18 English chars + 4 CJK chars + // English: 18 / 4 = 4.5, CJK: 4 / 2 = 2, content tokens = 6.5 + // JSON overhead: 1 * 50 = 50 + // Total before multiplier: 56.5 + // After 1.2x: ~68 tokens + assert!( + tokens > 40 && tokens < 120, + "Expected ~50-100 tokens for mixed content, got {}", + tokens + ); + } + + #[test] + fn test_chinese_tokens_higher_than_english() { + // Use more characters to make the content difference significant + // compared to JSON overhead (50 tokens per message) + let english = vec![ChatMessage::user(&"abcdefghij".repeat(20))]; // 200 English chars + let chinese = vec![ChatMessage::user(&"这是一个测试消息字".repeat(20))]; // 200 CJK chars (10 chars * 20) + + let english_tokens = estimate_tokens(&english); + let chinese_tokens = estimate_tokens(&chinese); + + // 200 English chars: 200/4 = 50 content tokens + // 200 CJK chars: 200/2 = 100 content tokens + // With same JSON overhead, Chinese should use ~1.5x tokens + assert!( + chinese_tokens > english_tokens * 130 / 100, // At least 1.3x + "Chinese (200 chars) should use more tokens than English (200 chars): EN={} CN={}", + english_tokens, + chinese_tokens + ); + } + #[test] fn test_estimate_tokens_includes_image_media_refs() { let temp_dir = tempfile::tempdir().unwrap(); @@ -519,7 +612,9 @@ mod tests { #[test] fn test_should_compress() { let compressor = ContextCompressor::new(20); - let messages = vec![ChatMessage::user(&"x".repeat(200))]; + // Need more content to trigger compression with new weighted calculation + // 200 English chars / 4 = 50 tokens, plus overhead + let messages = vec![ChatMessage::user(&"x".repeat(400))]; assert!(compressor.should_compress(&messages)); } diff --git a/src/command/handlers/get_current.rs b/src/command/handlers/get_current.rs index 23ff79a..9b44fa0 100644 --- a/src/command/handlers/get_current.rs +++ b/src/command/handlers/get_current.rs @@ -1,4 +1,5 @@ use crate::agent::context_compressor::estimate_tokens; +use crate::agent::{SystemPromptContext, SystemPromptProvider}; use crate::command::context::CommandContext; use crate::command::handler::{CommandHandler, CommandMetadata}; use crate::command::handlers::get_messages_from_session; @@ -13,6 +14,7 @@ use std::sync::Arc; pub struct GetCurrentSessionCommandHandler { store: Arc, session_manager: Option, + system_prompt_provider: Option>, } impl GetCurrentSessionCommandHandler { @@ -20,6 +22,7 @@ impl GetCurrentSessionCommandHandler { Self { store, session_manager: None, + system_prompt_provider: None, } } @@ -27,6 +30,11 @@ impl GetCurrentSessionCommandHandler { self.session_manager = Some(session_manager); self } + + pub fn with_system_prompt_provider(mut self, provider: Arc) -> Self { + self.system_prompt_provider = Some(provider); + self + } } #[async_trait] @@ -79,17 +87,41 @@ async fn handle_get_current_session( ).await?; let actual_message_count = messages.len(); - let estimated_tokens = estimate_tokens(&messages); + let message_tokens = estimate_tokens(&messages); + + // Calculate system prompt tokens if provider is available + let system_prompt_tokens = if let Some(ref provider) = handler.system_prompt_provider { + let user_message_count = messages.iter().filter(|m| m.role == "user").count(); + let system_prompt_context = SystemPromptContext { + session_id: ctx.session_id.clone(), + chat_id: chat_id.to_string(), + user_message_count, + }; + + provider.build(&system_prompt_context) + .map(|sp| { + use crate::bus::ChatMessage; + let system_msg = ChatMessage::system(&sp.content); + estimate_tokens(&[system_msg]) + }) + .unwrap_or(0) + } else { + 0 + }; + + let total_tokens = system_prompt_tokens + message_tokens; let last_active = format_time_ago(topic.last_active_at); let created_at = format_time_ago(topic.created_at); let message = format!( - "Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{}\n Created: {}\n Last Active: {}", + "Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{} (系统提示词: ~{}, 用户消息: ~{})\n Created: {}\n Last Active: {}", topic.id, topic.title, actual_message_count, - estimated_tokens, + total_tokens, + system_prompt_tokens, + message_tokens, created_at, last_active ); @@ -99,7 +131,9 @@ async fn handle_get_current_session( .with_metadata("topic_id", &topic.id) .with_metadata("title", &topic.title) .with_metadata("message_count", &actual_message_count.to_string()) - .with_metadata("estimated_tokens", &estimated_tokens.to_string())) + .with_metadata("estimated_tokens", &total_tokens.to_string()) + .with_metadata("system_prompt_tokens", &system_prompt_tokens.to_string()) + .with_metadata("message_tokens", &message_tokens.to_string())) } fn format_time_ago(timestamp_ms: i64) -> String { diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 430b2c0..5424268 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -55,25 +55,29 @@ impl InboundProcessor { .with_session_manager(session_manager.clone()); command_router.register(Box::new(switch_handler)); - // 注册 get_current 处理器 - command_router.register(Box::new(GetCurrentSessionCommandHandler::new( - store.clone(), - ).with_session_manager(session_manager.clone()))); - - // 注册 load_session 处理器 - command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone()))); - - // 注册 save_session 处理器 + // 创建 system_prompt_provider(用于 save_session, save_topic, get_current) let skills = session_manager.skills(); let prompt_repository = session_manager.store().clone(); let system_prompt_provider: Arc = Arc::new(CompositeSystemPromptProvider::new(vec![ Box::new(AgentPromptProvider::new( - 0, // save_session 不需要 reinject 逻辑 + 0, // 不需要 reinject 逻辑 provider_config.clone(), prompt_repository, )), Box::new(SkillPromptProvider::new(skills)), ])); + + // 注册 get_current 处理器 + command_router.register(Box::new( + GetCurrentSessionCommandHandler::new(store.clone()) + .with_session_manager(session_manager.clone()) + .with_system_prompt_provider(system_prompt_provider.clone()) + )); + + // 注册 load_session 处理器 + command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone()))); + + // 注册 save_session 处理器 command_router.register(Box::new(SaveSessionCommandHandler::new( store.clone(), system_prompt_provider.clone(),