feat: 更新 GetCurrentSessionCommandHandler,添加系统提示词提供者支持并优化令牌估算逻辑

This commit is contained in:
oudecheng 2026-05-19 18:29:32 +08:00
parent 2d5b6168cc
commit 90228a4d49
3 changed files with 165 additions and 32 deletions

View File

@ -8,21 +8,54 @@ use crate::text::{char_count, take_prefix_chars};
use crate::agent::{AgentError, AgentRuntimeConfig};
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
const CJK_CHARS_PER_TOKEN: f64 = 2.0;
const OTHER_CHARS_PER_TOKEN: f64 = 4.0;
const JSON_OVERHEAD_PER_MESSAGE: usize = 50;
/// Token estimation using JSON serialization (matches actual request size)
/// Check if a character is CJK (Chinese, Japanese, Korean)
fn is_cjk_char(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
'\u{3040}'..='\u{309F}' | // Hiragana
'\u{30A0}'..='\u{30FF}' | // Katakana
'\u{AC00}'..='\u{D7AF}' | // Korean Hangul
'\u{3400}'..='\u{4DBF}' | // CJK Extension A
'\u{20000}'..='\u{2A6DF}' // CJK Extension B
)
}
/// Token estimation using weighted character counting based on language
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
// Serialize to JSON to match actual request format sent to LLM
let serialized_len = serde_json::to_string(messages)
.map(|s| s.len())
.unwrap_or_else(|_| {
// Fallback: use content length if serialization fails
messages.iter().map(|m| m.content.len()).sum()
});
let mut cjk_count = 0usize;
let mut other_count = 0usize;
let mut media_refs_count = 0usize;
// Apply safety margin for token estimation
((serialized_len / TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64
for msg in messages {
// Count content characters
for c in msg.content.chars() {
if is_cjk_char(c) {
cjk_count += 1;
} else {
other_count += 1;
}
}
// Count media references
media_refs_count += msg.media_refs.len();
}
// Weighted token calculation: CJK chars need more tokens per character
let content_tokens = (cjk_count as f64 / CJK_CHARS_PER_TOKEN)
+ (other_count as f64 / OTHER_CHARS_PER_TOKEN);
// JSON serialization overhead for message structure (fields, brackets, etc.)
let json_overhead = messages.len() * JSON_OVERHEAD_PER_MESSAGE;
// Media references add to JSON size (each path is a string in the array)
let media_overhead = media_refs_count * 20; // Each media ref adds ~20 chars to JSON
// Apply safety multiplier
((content_tokens + json_overhead as f64 + media_overhead as f64)
* TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
}
@ -490,17 +523,77 @@ mod tests {
];
let tokens = estimate_tokens(&messages);
// JSON serialization includes: id, role, content, timestamp, etc.
// With 3 messages, the JSON overhead is significant
// Serialized JSON is typically 300-500 chars for 3 simple messages
// 500 / 4 * 1.2 = ~150 tokens
// Content: "Hello" (5) + "Hi there!" (9) + "How are you?" (12) = 26 chars
// English: 26 / 4 = 6.5 tokens for content
// JSON overhead: 3 * 50 = 150
// Total before multiplier: 156.5
// After 1.2x: ~188 tokens
assert!(
tokens > 50 && tokens < 300,
"Expected ~100-200 tokens (JSON overhead), got {}",
tokens > 100 && tokens < 300,
"Expected ~150-250 tokens for English content, got {}",
tokens
);
}
#[test]
fn test_estimate_tokens_chinese_content() {
let messages = vec![
ChatMessage::user("你好,这是一个中文测试"),
ChatMessage::assistant("这是一个中文回复消息"),
];
let tokens = estimate_tokens(&messages);
// Content: ~20 CJK chars, CJK uses 2 chars/token = ~10 tokens for content
// JSON overhead: 2 * 50 = 100
// Total before multiplier: ~110
// After 1.2x: ~132 tokens
assert!(
tokens > 80 && tokens < 200,
"Expected ~100-180 tokens for Chinese content, got {}",
tokens
);
}
#[test]
fn test_estimate_tokens_mixed_content() {
let messages = vec![
ChatMessage::user("Hello 世界 this is 测试"),
];
let tokens = estimate_tokens(&messages);
// Content: 18 English chars + 4 CJK chars
// English: 18 / 4 = 4.5, CJK: 4 / 2 = 2, content tokens = 6.5
// JSON overhead: 1 * 50 = 50
// Total before multiplier: 56.5
// After 1.2x: ~68 tokens
assert!(
tokens > 40 && tokens < 120,
"Expected ~50-100 tokens for mixed content, got {}",
tokens
);
}
#[test]
fn test_chinese_tokens_higher_than_english() {
// Use more characters to make the content difference significant
// compared to JSON overhead (50 tokens per message)
let english = vec![ChatMessage::user(&"abcdefghij".repeat(20))]; // 200 English chars
let chinese = vec![ChatMessage::user(&"这是一个测试消息字".repeat(20))]; // 200 CJK chars (10 chars * 20)
let english_tokens = estimate_tokens(&english);
let chinese_tokens = estimate_tokens(&chinese);
// 200 English chars: 200/4 = 50 content tokens
// 200 CJK chars: 200/2 = 100 content tokens
// With same JSON overhead, Chinese should use ~1.5x tokens
assert!(
chinese_tokens > english_tokens * 130 / 100, // At least 1.3x
"Chinese (200 chars) should use more tokens than English (200 chars): EN={} CN={}",
english_tokens,
chinese_tokens
);
}
#[test]
fn test_estimate_tokens_includes_image_media_refs() {
let temp_dir = tempfile::tempdir().unwrap();
@ -519,7 +612,9 @@ mod tests {
#[test]
fn test_should_compress() {
let compressor = ContextCompressor::new(20);
let messages = vec![ChatMessage::user(&"x".repeat(200))];
// Need more content to trigger compression with new weighted calculation
// 200 English chars / 4 = 50 tokens, plus overhead
let messages = vec![ChatMessage::user(&"x".repeat(400))];
assert!(compressor.should_compress(&messages));
}

View File

@ -1,4 +1,5 @@
use crate::agent::context_compressor::estimate_tokens;
use crate::agent::{SystemPromptContext, SystemPromptProvider};
use crate::command::context::CommandContext;
use crate::command::handler::{CommandHandler, CommandMetadata};
use crate::command::handlers::get_messages_from_session;
@ -13,6 +14,7 @@ use std::sync::Arc;
pub struct GetCurrentSessionCommandHandler {
store: Arc<SessionStore>,
session_manager: Option<SessionManager>,
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
}
impl GetCurrentSessionCommandHandler {
@ -20,6 +22,7 @@ impl GetCurrentSessionCommandHandler {
Self {
store,
session_manager: None,
system_prompt_provider: None,
}
}
@ -27,6 +30,11 @@ impl GetCurrentSessionCommandHandler {
self.session_manager = Some(session_manager);
self
}
pub fn with_system_prompt_provider(mut self, provider: Arc<dyn SystemPromptProvider>) -> Self {
self.system_prompt_provider = Some(provider);
self
}
}
#[async_trait]
@ -79,17 +87,41 @@ async fn handle_get_current_session(
).await?;
let actual_message_count = messages.len();
let estimated_tokens = estimate_tokens(&messages);
let message_tokens = estimate_tokens(&messages);
// Calculate system prompt tokens if provider is available
let system_prompt_tokens = if let Some(ref provider) = handler.system_prompt_provider {
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
let system_prompt_context = SystemPromptContext {
session_id: ctx.session_id.clone(),
chat_id: chat_id.to_string(),
user_message_count,
};
provider.build(&system_prompt_context)
.map(|sp| {
use crate::bus::ChatMessage;
let system_msg = ChatMessage::system(&sp.content);
estimate_tokens(&[system_msg])
})
.unwrap_or(0)
} else {
0
};
let total_tokens = system_prompt_tokens + message_tokens;
let last_active = format_time_ago(topic.last_active_at);
let created_at = format_time_ago(topic.created_at);
let message = format!(
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{}\n Created: {}\n Last Active: {}",
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{} (系统提示词: ~{}, 用户消息: ~{})\n Created: {}\n Last Active: {}",
topic.id,
topic.title,
actual_message_count,
estimated_tokens,
total_tokens,
system_prompt_tokens,
message_tokens,
created_at,
last_active
);
@ -99,7 +131,9 @@ async fn handle_get_current_session(
.with_metadata("topic_id", &topic.id)
.with_metadata("title", &topic.title)
.with_metadata("message_count", &actual_message_count.to_string())
.with_metadata("estimated_tokens", &estimated_tokens.to_string()))
.with_metadata("estimated_tokens", &total_tokens.to_string())
.with_metadata("system_prompt_tokens", &system_prompt_tokens.to_string())
.with_metadata("message_tokens", &message_tokens.to_string()))
}
fn format_time_ago(timestamp_ms: i64) -> String {

View File

@ -55,25 +55,29 @@ impl InboundProcessor {
.with_session_manager(session_manager.clone());
command_router.register(Box::new(switch_handler));
// 注册 get_current 处理器
command_router.register(Box::new(GetCurrentSessionCommandHandler::new(
store.clone(),
).with_session_manager(session_manager.clone())));
// 注册 load_session 处理器
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
// 注册 save_session 处理器
// 创建 system_prompt_provider用于 save_session, save_topic, get_current
let skills = session_manager.skills();
let prompt_repository = session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0, // save_session 不需要 reinject 逻辑
0, // 不需要 reinject 逻辑
provider_config.clone(),
prompt_repository,
)),
Box::new(SkillPromptProvider::new(skills)),
]));
// 注册 get_current 处理器
command_router.register(Box::new(
GetCurrentSessionCommandHandler::new(store.clone())
.with_session_manager(session_manager.clone())
.with_system_prompt_provider(system_prompt_provider.clone())
));
// 注册 load_session 处理器
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
// 注册 save_session 处理器
command_router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(),
system_prompt_provider.clone(),