feat: 更新 GetCurrentSessionCommandHandler,添加系统提示词提供者支持并优化令牌估算逻辑
This commit is contained in:
parent
2d5b6168cc
commit
90228a4d49
@ -8,21 +8,54 @@ use crate::text::{char_count, take_prefix_chars};
|
|||||||
|
|
||||||
use crate::agent::{AgentError, AgentRuntimeConfig};
|
use crate::agent::{AgentError, AgentRuntimeConfig};
|
||||||
|
|
||||||
const TOKEN_ESTIMATE_CHARS_PER_TOKEN: usize = 4;
|
|
||||||
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
const TOKEN_ESTIMATE_SAFETY_MULTIPLIER: f64 = 1.2;
|
||||||
|
const CJK_CHARS_PER_TOKEN: f64 = 2.0;
|
||||||
|
const OTHER_CHARS_PER_TOKEN: f64 = 4.0;
|
||||||
|
const JSON_OVERHEAD_PER_MESSAGE: usize = 50;
|
||||||
|
|
||||||
/// Token estimation using JSON serialization (matches actual request size)
|
/// Check if a character is CJK (Chinese, Japanese, Korean)
|
||||||
|
fn is_cjk_char(c: char) -> bool {
|
||||||
|
matches!(c,
|
||||||
|
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
|
||||||
|
'\u{3040}'..='\u{309F}' | // Hiragana
|
||||||
|
'\u{30A0}'..='\u{30FF}' | // Katakana
|
||||||
|
'\u{AC00}'..='\u{D7AF}' | // Korean Hangul
|
||||||
|
'\u{3400}'..='\u{4DBF}' | // CJK Extension A
|
||||||
|
'\u{20000}'..='\u{2A6DF}' // CJK Extension B
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token estimation using weighted character counting based on language
|
||||||
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||||
// Serialize to JSON to match actual request format sent to LLM
|
let mut cjk_count = 0usize;
|
||||||
let serialized_len = serde_json::to_string(messages)
|
let mut other_count = 0usize;
|
||||||
.map(|s| s.len())
|
let mut media_refs_count = 0usize;
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
// Fallback: use content length if serialization fails
|
|
||||||
messages.iter().map(|m| m.content.len()).sum()
|
|
||||||
});
|
|
||||||
|
|
||||||
// Apply safety margin for token estimation
|
for msg in messages {
|
||||||
((serialized_len / TOKEN_ESTIMATE_CHARS_PER_TOKEN) as f64
|
// Count content characters
|
||||||
|
for c in msg.content.chars() {
|
||||||
|
if is_cjk_char(c) {
|
||||||
|
cjk_count += 1;
|
||||||
|
} else {
|
||||||
|
other_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Count media references
|
||||||
|
media_refs_count += msg.media_refs.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Weighted token calculation: CJK chars need more tokens per character
|
||||||
|
let content_tokens = (cjk_count as f64 / CJK_CHARS_PER_TOKEN)
|
||||||
|
+ (other_count as f64 / OTHER_CHARS_PER_TOKEN);
|
||||||
|
|
||||||
|
// JSON serialization overhead for message structure (fields, brackets, etc.)
|
||||||
|
let json_overhead = messages.len() * JSON_OVERHEAD_PER_MESSAGE;
|
||||||
|
|
||||||
|
// Media references add to JSON size (each path is a string in the array)
|
||||||
|
let media_overhead = media_refs_count * 20; // Each media ref adds ~20 chars to JSON
|
||||||
|
|
||||||
|
// Apply safety multiplier
|
||||||
|
((content_tokens + json_overhead as f64 + media_overhead as f64)
|
||||||
* TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
|
* TOKEN_ESTIMATE_SAFETY_MULTIPLIER) as usize
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,17 +523,77 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let tokens = estimate_tokens(&messages);
|
let tokens = estimate_tokens(&messages);
|
||||||
// JSON serialization includes: id, role, content, timestamp, etc.
|
// Content: "Hello" (5) + "Hi there!" (9) + "How are you?" (12) = 26 chars
|
||||||
// With 3 messages, the JSON overhead is significant
|
// English: 26 / 4 = 6.5 tokens for content
|
||||||
// Serialized JSON is typically 300-500 chars for 3 simple messages
|
// JSON overhead: 3 * 50 = 150
|
||||||
// 500 / 4 * 1.2 = ~150 tokens
|
// Total before multiplier: 156.5
|
||||||
|
// After 1.2x: ~188 tokens
|
||||||
assert!(
|
assert!(
|
||||||
tokens > 50 && tokens < 300,
|
tokens > 100 && tokens < 300,
|
||||||
"Expected ~100-200 tokens (JSON overhead), got {}",
|
"Expected ~150-250 tokens for English content, got {}",
|
||||||
tokens
|
tokens
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_tokens_chinese_content() {
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::user("你好,这是一个中文测试"),
|
||||||
|
ChatMessage::assistant("这是一个中文回复消息"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let tokens = estimate_tokens(&messages);
|
||||||
|
// Content: ~20 CJK chars, CJK uses 2 chars/token = ~10 tokens for content
|
||||||
|
// JSON overhead: 2 * 50 = 100
|
||||||
|
// Total before multiplier: ~110
|
||||||
|
// After 1.2x: ~132 tokens
|
||||||
|
assert!(
|
||||||
|
tokens > 80 && tokens < 200,
|
||||||
|
"Expected ~100-180 tokens for Chinese content, got {}",
|
||||||
|
tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_tokens_mixed_content() {
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::user("Hello 世界 this is 测试"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let tokens = estimate_tokens(&messages);
|
||||||
|
// Content: 18 English chars + 4 CJK chars
|
||||||
|
// English: 18 / 4 = 4.5, CJK: 4 / 2 = 2, content tokens = 6.5
|
||||||
|
// JSON overhead: 1 * 50 = 50
|
||||||
|
// Total before multiplier: 56.5
|
||||||
|
// After 1.2x: ~68 tokens
|
||||||
|
assert!(
|
||||||
|
tokens > 40 && tokens < 120,
|
||||||
|
"Expected ~50-100 tokens for mixed content, got {}",
|
||||||
|
tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chinese_tokens_higher_than_english() {
|
||||||
|
// Use more characters to make the content difference significant
|
||||||
|
// compared to JSON overhead (50 tokens per message)
|
||||||
|
let english = vec![ChatMessage::user(&"abcdefghij".repeat(20))]; // 200 English chars
|
||||||
|
let chinese = vec![ChatMessage::user(&"这是一个测试消息字".repeat(20))]; // 200 CJK chars (10 chars * 20)
|
||||||
|
|
||||||
|
let english_tokens = estimate_tokens(&english);
|
||||||
|
let chinese_tokens = estimate_tokens(&chinese);
|
||||||
|
|
||||||
|
// 200 English chars: 200/4 = 50 content tokens
|
||||||
|
// 200 CJK chars: 200/2 = 100 content tokens
|
||||||
|
// With same JSON overhead, Chinese should use ~1.5x tokens
|
||||||
|
assert!(
|
||||||
|
chinese_tokens > english_tokens * 130 / 100, // At least 1.3x
|
||||||
|
"Chinese (200 chars) should use more tokens than English (200 chars): EN={} CN={}",
|
||||||
|
english_tokens,
|
||||||
|
chinese_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_estimate_tokens_includes_image_media_refs() {
|
fn test_estimate_tokens_includes_image_media_refs() {
|
||||||
let temp_dir = tempfile::tempdir().unwrap();
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
@ -519,7 +612,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_should_compress() {
|
fn test_should_compress() {
|
||||||
let compressor = ContextCompressor::new(20);
|
let compressor = ContextCompressor::new(20);
|
||||||
let messages = vec![ChatMessage::user(&"x".repeat(200))];
|
// Need more content to trigger compression with new weighted calculation
|
||||||
|
// 200 English chars / 4 = 50 tokens, plus overhead
|
||||||
|
let messages = vec![ChatMessage::user(&"x".repeat(400))];
|
||||||
assert!(compressor.should_compress(&messages));
|
assert!(compressor.should_compress(&messages));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
use crate::agent::context_compressor::estimate_tokens;
|
use crate::agent::context_compressor::estimate_tokens;
|
||||||
|
use crate::agent::{SystemPromptContext, SystemPromptProvider};
|
||||||
use crate::command::context::CommandContext;
|
use crate::command::context::CommandContext;
|
||||||
use crate::command::handler::{CommandHandler, CommandMetadata};
|
use crate::command::handler::{CommandHandler, CommandMetadata};
|
||||||
use crate::command::handlers::get_messages_from_session;
|
use crate::command::handlers::get_messages_from_session;
|
||||||
@ -13,6 +14,7 @@ use std::sync::Arc;
|
|||||||
pub struct GetCurrentSessionCommandHandler {
|
pub struct GetCurrentSessionCommandHandler {
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
session_manager: Option<SessionManager>,
|
session_manager: Option<SessionManager>,
|
||||||
|
system_prompt_provider: Option<Arc<dyn SystemPromptProvider>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GetCurrentSessionCommandHandler {
|
impl GetCurrentSessionCommandHandler {
|
||||||
@ -20,6 +22,7 @@ impl GetCurrentSessionCommandHandler {
|
|||||||
Self {
|
Self {
|
||||||
store,
|
store,
|
||||||
session_manager: None,
|
session_manager: None,
|
||||||
|
system_prompt_provider: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,6 +30,11 @@ impl GetCurrentSessionCommandHandler {
|
|||||||
self.session_manager = Some(session_manager);
|
self.session_manager = Some(session_manager);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_system_prompt_provider(mut self, provider: Arc<dyn SystemPromptProvider>) -> Self {
|
||||||
|
self.system_prompt_provider = Some(provider);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -79,17 +87,41 @@ async fn handle_get_current_session(
|
|||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
let actual_message_count = messages.len();
|
let actual_message_count = messages.len();
|
||||||
let estimated_tokens = estimate_tokens(&messages);
|
let message_tokens = estimate_tokens(&messages);
|
||||||
|
|
||||||
|
// Calculate system prompt tokens if provider is available
|
||||||
|
let system_prompt_tokens = if let Some(ref provider) = handler.system_prompt_provider {
|
||||||
|
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||||||
|
let system_prompt_context = SystemPromptContext {
|
||||||
|
session_id: ctx.session_id.clone(),
|
||||||
|
chat_id: chat_id.to_string(),
|
||||||
|
user_message_count,
|
||||||
|
};
|
||||||
|
|
||||||
|
provider.build(&system_prompt_context)
|
||||||
|
.map(|sp| {
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
let system_msg = ChatMessage::system(&sp.content);
|
||||||
|
estimate_tokens(&[system_msg])
|
||||||
|
})
|
||||||
|
.unwrap_or(0)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let total_tokens = system_prompt_tokens + message_tokens;
|
||||||
|
|
||||||
let last_active = format_time_ago(topic.last_active_at);
|
let last_active = format_time_ago(topic.last_active_at);
|
||||||
let created_at = format_time_ago(topic.created_at);
|
let created_at = format_time_ago(topic.created_at);
|
||||||
|
|
||||||
let message = format!(
|
let message = format!(
|
||||||
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{}\n Created: {}\n Last Active: {}",
|
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{} (系统提示词: ~{}, 用户消息: ~{})\n Created: {}\n Last Active: {}",
|
||||||
topic.id,
|
topic.id,
|
||||||
topic.title,
|
topic.title,
|
||||||
actual_message_count,
|
actual_message_count,
|
||||||
estimated_tokens,
|
total_tokens,
|
||||||
|
system_prompt_tokens,
|
||||||
|
message_tokens,
|
||||||
created_at,
|
created_at,
|
||||||
last_active
|
last_active
|
||||||
);
|
);
|
||||||
@ -99,7 +131,9 @@ async fn handle_get_current_session(
|
|||||||
.with_metadata("topic_id", &topic.id)
|
.with_metadata("topic_id", &topic.id)
|
||||||
.with_metadata("title", &topic.title)
|
.with_metadata("title", &topic.title)
|
||||||
.with_metadata("message_count", &actual_message_count.to_string())
|
.with_metadata("message_count", &actual_message_count.to_string())
|
||||||
.with_metadata("estimated_tokens", &estimated_tokens.to_string()))
|
.with_metadata("estimated_tokens", &total_tokens.to_string())
|
||||||
|
.with_metadata("system_prompt_tokens", &system_prompt_tokens.to_string())
|
||||||
|
.with_metadata("message_tokens", &message_tokens.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_time_ago(timestamp_ms: i64) -> String {
|
fn format_time_ago(timestamp_ms: i64) -> String {
|
||||||
|
|||||||
@ -55,25 +55,29 @@ impl InboundProcessor {
|
|||||||
.with_session_manager(session_manager.clone());
|
.with_session_manager(session_manager.clone());
|
||||||
command_router.register(Box::new(switch_handler));
|
command_router.register(Box::new(switch_handler));
|
||||||
|
|
||||||
// 注册 get_current 处理器
|
// 创建 system_prompt_provider(用于 save_session, save_topic, get_current)
|
||||||
command_router.register(Box::new(GetCurrentSessionCommandHandler::new(
|
|
||||||
store.clone(),
|
|
||||||
).with_session_manager(session_manager.clone())));
|
|
||||||
|
|
||||||
// 注册 load_session 处理器
|
|
||||||
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
|
|
||||||
|
|
||||||
// 注册 save_session 处理器
|
|
||||||
let skills = session_manager.skills();
|
let skills = session_manager.skills();
|
||||||
let prompt_repository = session_manager.store().clone();
|
let prompt_repository = session_manager.store().clone();
|
||||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||||
Box::new(AgentPromptProvider::new(
|
Box::new(AgentPromptProvider::new(
|
||||||
0, // save_session 不需要 reinject 逻辑
|
0, // 不需要 reinject 逻辑
|
||||||
provider_config.clone(),
|
provider_config.clone(),
|
||||||
prompt_repository,
|
prompt_repository,
|
||||||
)),
|
)),
|
||||||
Box::new(SkillPromptProvider::new(skills)),
|
Box::new(SkillPromptProvider::new(skills)),
|
||||||
]));
|
]));
|
||||||
|
|
||||||
|
// 注册 get_current 处理器
|
||||||
|
command_router.register(Box::new(
|
||||||
|
GetCurrentSessionCommandHandler::new(store.clone())
|
||||||
|
.with_session_manager(session_manager.clone())
|
||||||
|
.with_system_prompt_provider(system_prompt_provider.clone())
|
||||||
|
));
|
||||||
|
|
||||||
|
// 注册 load_session 处理器
|
||||||
|
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
|
||||||
|
|
||||||
|
// 注册 save_session 处理器
|
||||||
command_router.register(Box::new(SaveSessionCommandHandler::new(
|
command_router.register(Box::new(SaveSessionCommandHandler::new(
|
||||||
store.clone(),
|
store.clone(),
|
||||||
system_prompt_provider.clone(),
|
system_prompt_provider.clone(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user