feat: 更新 GetCurrentSessionCommandHandler,添加系统提示词提供者支持并优化令牌估算逻辑
This commit is contained in:
parent
2d5b6168cc
commit
90228a4d49
@ -8,21 +8,54 @@ use crate::text::{char_count, take_prefix_chars};
|
||||
|
||||
use crate::agent::{AgentError, AgentRuntimeConfig};
|
||||
|
||||
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
||||
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||||
const CJK_CHARS_PER_TOKEN: f64 = 2.0;
|
||||
const OTHER_CHARS_PER_TOKEN: f64 = 4.0;
|
||||
const JSON_OVERHEAD_PER_MESSAGE: usize = 50;
|
||||
|
||||
/// Token estimation using JSON serialization (matches actual request size)
|
||||
/// Check if a character is CJK (Chinese, Japanese, Korean)
|
||||
fn is_cjk_char(c: char) -> bool {
|
||||
matches!(c,
|
||||
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
|
||||
'\u{3040}'..='\u{309F}' | // Hiragana
|
||||
'\u{30A0}'..='\u{30FF}' | // Katakana
|
||||
'\u{AC00}'..='\u{D7AF}' | // Korean Hangul
|
||||
'\u{3400}'..='\u{4DBF}' | // CJK Extension A
|
||||
'\u{20000}'..='\u{2A6DF}' // CJK Extension B
|
||||
)
|
||||
}
|
||||
|
||||
/// Token estimation using weighted character counting based on language
|
||||
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||
// Serialize to JSON to match actual request format sent to LLM
|
||||
let serialized_len = serde_json::to_string(messages)
|
||||
.map(|s| s.len())
|
||||
.unwrap_or_else(|_| {
|
||||
// Fallback: use content length if serialization fails
|
||||
messages.iter().map(|m| m.content.len()).sum()
|
||||
});
|
||||
let mut cjk_count = 0usize;
|
||||
let mut other_count = 0usize;
|
||||
let mut media_refs_count = 0usize;
|
||||
|
||||
// Apply safety margin for token estimation
|
||||
((serialized_len / TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64
|
||||
for msg in messages {
|
||||
// Count content characters
|
||||
for c in msg.content.chars() {
|
||||
if is_cjk_char(c) {
|
||||
cjk_count += 1;
|
||||
} else {
|
||||
other_count += 1;
|
||||
}
|
||||
}
|
||||
// Count media references
|
||||
media_refs_count += msg.media_refs.len();
|
||||
}
|
||||
|
||||
// Weighted token calculation: CJK chars need more tokens per character
|
||||
let content_tokens = (cjk_count as f64 / CJK_CHARS_PER_TOKEN)
|
||||
+ (other_count as f64 / OTHER_CHARS_PER_TOKEN);
|
||||
|
||||
// JSON serialization overhead for message structure (fields, brackets, etc.)
|
||||
let json_overhead = messages.len() * JSON_OVERHEAD_PER_MESSAGE;
|
||||
|
||||
// Media references add to JSON size (each path is a string in the array)
|
||||
let media_overhead = media_refs_count * 20; // Each media ref adds ~20 chars to JSON
|
||||
|
||||
// Apply safety multiplier
|
||||
((content_tokens + json_overhead as f64 + media_overhead as f64)
|
||||
* TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
|
||||
}
|
||||
|
||||
@ -490,17 +523,77 @@ mod tests {
|
||||
];
|
||||
|
||||
let tokens = estimate_tokens(&messages);
|
||||
// JSON serialization includes: id, role, content, timestamp, etc.
|
||||
// With 3 messages, the JSON overhead is significant
|
||||
// Serialized JSON is typically 300-500 chars for 3 simple messages
|
||||
// 500 / 4 * 1.2 = ~150 tokens
|
||||
// Content: "Hello" (5) + "Hi there!" (9) + "How are you?" (12) = 26 chars
|
||||
// English: 26 / 4 = 6.5 tokens for content
|
||||
// JSON overhead: 3 * 50 = 150
|
||||
// Total before multiplier: 156.5
|
||||
// After 1.2x: ~188 tokens
|
||||
assert!(
|
||||
tokens > 50 && tokens < 300,
|
||||
"Expected ~100-200 tokens (JSON overhead), got {}",
|
||||
tokens > 100 && tokens < 300,
|
||||
"Expected ~150-250 tokens for English content, got {}",
|
||||
tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens_chinese_content() {
|
||||
let messages = vec![
|
||||
ChatMessage::user("你好,这是一个中文测试"),
|
||||
ChatMessage::assistant("这是一个中文回复消息"),
|
||||
];
|
||||
|
||||
let tokens = estimate_tokens(&messages);
|
||||
// Content: ~20 CJK chars, CJK uses 2 chars/token = ~10 tokens for content
|
||||
// JSON overhead: 2 * 50 = 100
|
||||
// Total before multiplier: ~110
|
||||
// After 1.2x: ~132 tokens
|
||||
assert!(
|
||||
tokens > 80 && tokens < 200,
|
||||
"Expected ~100-180 tokens for Chinese content, got {}",
|
||||
tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens_mixed_content() {
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello 世界 this is 测试"),
|
||||
];
|
||||
|
||||
let tokens = estimate_tokens(&messages);
|
||||
// Content: 18 English chars + 4 CJK chars
|
||||
// English: 18 / 4 = 4.5, CJK: 4 / 2 = 2, content tokens = 6.5
|
||||
// JSON overhead: 1 * 50 = 50
|
||||
// Total before multiplier: 56.5
|
||||
// After 1.2x: ~68 tokens
|
||||
assert!(
|
||||
tokens > 40 && tokens < 120,
|
||||
"Expected ~50-100 tokens for mixed content, got {}",
|
||||
tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chinese_tokens_higher_than_english() {
|
||||
// Use more characters to make the content difference significant
|
||||
// compared to JSON overhead (50 tokens per message)
|
||||
let english = vec![ChatMessage::user(&"abcdefghij".repeat(20))]; // 200 English chars
|
||||
let chinese = vec![ChatMessage::user(&"这是一个测试消息字".repeat(20))]; // 200 CJK chars (10 chars * 20)
|
||||
|
||||
let english_tokens = estimate_tokens(&english);
|
||||
let chinese_tokens = estimate_tokens(&chinese);
|
||||
|
||||
// 200 English chars: 200/4 = 50 content tokens
|
||||
// 200 CJK chars: 200/2 = 100 content tokens
|
||||
// With same JSON overhead, Chinese should use ~1.5x tokens
|
||||
assert!(
|
||||
chinese_tokens > english_tokens * 130 / 100, // At least 1.3x
|
||||
"Chinese (200 chars) should use more tokens than English (200 chars): EN={} CN={}",
|
||||
english_tokens,
|
||||
chinese_tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens_includes_image_media_refs() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
@ -519,7 +612,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_should_compress() {
|
||||
let compressor = ContextCompressor::new(20);
|
||||
let messages = vec![ChatMessage::user(&"x".repeat(200))];
|
||||
// Need more content to trigger compression with new weighted calculation
|
||||
// 200 English chars / 4 = 50 tokens, plus overhead
|
||||
let messages = vec![ChatMessage::user(&"x".repeat(400))];
|
||||
assert!(compressor.should_compress(&messages));
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
use crate::agent::context_compressor::estimate_tokens;
|
||||
use crate::agent::{SystemPromptContext, SystemPromptProvider};
|
||||
use crate::command::context::CommandContext;
|
||||
use crate::command::handler::{CommandHandler, CommandMetadata};
|
||||
use crate::command::handlers::get_messages_from_session;
|
||||
@ -13,6 +14,7 @@ use std::sync::Arc;
|
||||
pub struct GetCurrentSessionCommandHandler {
|
||||
store: Arc<SessionStore>,
|
||||
session_manager: Option<SessionManager>,
|
||||
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
|
||||
}
|
||||
|
||||
impl GetCurrentSessionCommandHandler {
|
||||
@ -20,6 +22,7 @@ impl GetCurrentSessionCommandHandler {
|
||||
Self {
|
||||
store,
|
||||
session_manager: None,
|
||||
system_prompt_provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,6 +30,11 @@ impl GetCurrentSessionCommandHandler {
|
||||
self.session_manager = Some(session_manager);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_system_prompt_provider(mut self, provider: Arc<dyn SystemPromptProvider>) -> Self {
|
||||
self.system_prompt_provider = Some(provider);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@ -79,17 +87,41 @@ async fn handle_get_current_session(
|
||||
).await?;
|
||||
|
||||
let actual_message_count = messages.len();
|
||||
let estimated_tokens = estimate_tokens(&messages);
|
||||
let message_tokens = estimate_tokens(&messages);
|
||||
|
||||
// Calculate system prompt tokens if provider is available
|
||||
let system_prompt_tokens = if let Some(ref provider) = handler.system_prompt_provider {
|
||||
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||||
let system_prompt_context = SystemPromptContext {
|
||||
session_id: ctx.session_id.clone(),
|
||||
chat_id: chat_id.to_string(),
|
||||
user_message_count,
|
||||
};
|
||||
|
||||
provider.build(&system_prompt_context)
|
||||
.map(|sp| {
|
||||
use crate::bus::ChatMessage;
|
||||
let system_msg = ChatMessage::system(&sp.content);
|
||||
estimate_tokens(&[system_msg])
|
||||
})
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let total_tokens = system_prompt_tokens + message_tokens;
|
||||
|
||||
let last_active = format_time_ago(topic.last_active_at);
|
||||
let created_at = format_time_ago(topic.created_at);
|
||||
|
||||
let message = format!(
|
||||
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{}\n Created: {}\n Last Active: {}",
|
||||
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{} (系统提示词: ~{}, 用户消息: ~{})\n Created: {}\n Last Active: {}",
|
||||
topic.id,
|
||||
topic.title,
|
||||
actual_message_count,
|
||||
estimated_tokens,
|
||||
total_tokens,
|
||||
system_prompt_tokens,
|
||||
message_tokens,
|
||||
created_at,
|
||||
last_active
|
||||
);
|
||||
@ -99,7 +131,9 @@ async fn handle_get_current_session(
|
||||
.with_metadata("topic_id", &topic.id)
|
||||
.with_metadata("title", &topic.title)
|
||||
.with_metadata("message_count", &actual_message_count.to_string())
|
||||
.with_metadata("estimated_tokens", &estimated_tokens.to_string()))
|
||||
.with_metadata("estimated_tokens", &total_tokens.to_string())
|
||||
.with_metadata("system_prompt_tokens", &system_prompt_tokens.to_string())
|
||||
.with_metadata("message_tokens", &message_tokens.to_string()))
|
||||
}
|
||||
|
||||
fn format_time_ago(timestamp_ms: i64) -> String {
|
||||
|
||||
@ -55,25 +55,29 @@ impl InboundProcessor {
|
||||
.with_session_manager(session_manager.clone());
|
||||
command_router.register(Box::new(switch_handler));
|
||||
|
||||
// 注册 get_current 处理器
|
||||
command_router.register(Box::new(GetCurrentSessionCommandHandler::new(
|
||||
store.clone(),
|
||||
).with_session_manager(session_manager.clone())));
|
||||
|
||||
// 注册 load_session 处理器
|
||||
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
|
||||
|
||||
// 注册 save_session 处理器
|
||||
// 创建 system_prompt_provider(用于 save_session, save_topic, get_current)
|
||||
let skills = session_manager.skills();
|
||||
let prompt_repository = session_manager.store().clone();
|
||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||
Box::new(AgentPromptProvider::new(
|
||||
0, // save_session 不需要 reinject 逻辑
|
||||
0, // 不需要 reinject 逻辑
|
||||
provider_config.clone(),
|
||||
prompt_repository,
|
||||
)),
|
||||
Box::new(SkillPromptProvider::new(skills)),
|
||||
]));
|
||||
|
||||
// 注册 get_current 处理器
|
||||
command_router.register(Box::new(
|
||||
GetCurrentSessionCommandHandler::new(store.clone())
|
||||
.with_session_manager(session_manager.clone())
|
||||
.with_system_prompt_provider(system_prompt_provider.clone())
|
||||
));
|
||||
|
||||
// 注册 load_session 处理器
|
||||
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
|
||||
|
||||
// 注册 save_session 处理器
|
||||
command_router.register(Box::new(SaveSessionCommandHandler::new(
|
||||
store.clone(),
|
||||
system_prompt_provider.clone(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user