compressor不再使用独立的LLMProvider实例。减少开销

This commit is contained in:
xiaoski 2026-04-24 10:04:21 +08:00
parent 393d980742
commit 81dcc67932
4 changed files with 94 additions and 24 deletions

View File

@ -218,7 +218,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
observer: Option<Arc<dyn Observer>>,
max_iterations: usize,
@ -231,32 +231,58 @@ pub struct AgentProcessResult {
}
impl AgentLoop {
/// Create a new AgentLoop with a provider created from config.
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
provider: Arc::from(provider),
tools: Arc::new(ToolRegistry::new()),
observer: None,
max_iterations,
})
}
/// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
provider: Arc::from(provider),
tools,
observer: None,
max_iterations,
})
}
/// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize) -> Self {
Self {
provider,
tools: Arc::new(ToolRegistry::new()),
observer: None,
max_iterations,
}
}
/// Create a new AgentLoop with an existing shared provider and given tools.
pub fn with_provider_and_tools(
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
max_iterations: usize,
) -> Self {
Self {
provider,
tools,
observer: None,
max_iterations,
}
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);

View File

@ -1,6 +1,7 @@
use std::sync::Arc;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, ChatCompletionRequest, Message};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
use crate::agent::AgentError;
@ -46,24 +47,32 @@ pub struct ContextCompressor {
context_window: usize,
/// Threshold ratio to trigger compression (50% of context window)
threshold_ratio: f64,
/// Shared LLM provider for summarization
provider: Arc<dyn LLMProvider>,
}
impl ContextCompressor {
/// Create a new compressor with the given context window size.
pub fn new(context_window: usize) -> Self {
/// Create a new compressor with the given provider and context window size.
pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self {
Self {
config: ContextCompressionConfig::default(),
context_window,
threshold_ratio: 0.5,
provider,
}
}
/// Create with custom configuration.
pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self {
pub fn with_config(
provider: Arc<dyn LLMProvider>,
context_window: usize,
config: ContextCompressionConfig,
) -> Self {
Self {
config,
context_window,
threshold_ratio: 0.5,
provider,
}
}
@ -97,7 +106,6 @@ impl ContextCompressor {
pub async fn compress_if_needed(
&self,
history: Vec<ChatMessage>,
provider_config: &LLMProviderConfig,
) -> Result<Vec<ChatMessage>, AgentError> {
// Check if compression is needed
let tokens = estimate_tokens(&history);
@ -149,7 +157,7 @@ impl ContextCompressor {
"Compression pass"
);
match self.compress_once(&current_history, provider_config).await {
match self.compress_once(&current_history).await {
Ok(Some(compressed)) => {
current_history = compressed;
}
@ -178,7 +186,6 @@ impl ContextCompressor {
async fn compress_once(
&self,
history: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
return Ok(None);
@ -214,7 +221,7 @@ impl ContextCompressor {
if between_start < between_end {
let between = &history[between_start..between_end];
let summary = self.summarize_segment(between, provider_config).await?;
let summary = self.summarize_segment(between).await?;
// Add summary as a special user message
new_messages.push(ChatMessage::user(format!(
@ -245,7 +252,6 @@ impl ContextCompressor {
async fn summarize_segment(
&self,
messages: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<String, AgentError> {
if messages.is_empty() {
return Ok(String::new());
@ -304,10 +310,6 @@ Be concise, aim for {} characters or less.
self.config.summary_max_chars, transcript
);
// Create provider and call LLM
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
let request = ChatCompletionRequest {
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
temperature: Some(0.3),
@ -315,7 +317,7 @@ Be concise, aim for {} characters or less.
tools: None,
};
match provider.chat(request).await {
match (*self.provider).chat(request).await {
Ok(response) => Ok(response.content),
Err(e) => {
// Fallback: just truncate the transcript
@ -329,6 +331,37 @@ Be concise, aim for {} characters or less.
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::ChatCompletionResponse;
use async_trait::async_trait;
/// Mock provider for testing - panics if actually used for LLM calls
struct MockProvider;
#[async_trait]
impl LLMProvider for MockProvider {
async fn chat(
&self,
_request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
panic!("MockProvider.chat() called - not expected in test")
}
fn ptype(&self) -> &str {
"mock"
}
fn name(&self) -> &str {
"mock"
}
fn model_id(&self) -> &str {
"mock"
}
}
fn mock_provider() -> Arc<dyn LLMProvider> {
Arc::new(MockProvider)
}
#[test]
fn test_estimate_tokens() {
@ -352,7 +385,7 @@ mod tests {
tool_result_trim_chars: 50,
..Default::default()
};
let compressor = ContextCompressor::with_config(100_000, config);
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config);
let mut messages = vec![
ChatMessage::user("Hello"),
@ -366,7 +399,7 @@ mod tests {
#[test]
fn test_threshold() {
let compressor = ContextCompressor::new(128_000);
let compressor = ContextCompressor::new(mock_provider(), 128_000);
assert_eq!(compressor.threshold(), 64_000);
}
}

View File

@ -7,6 +7,7 @@ use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider};
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
@ -22,6 +23,7 @@ pub struct Session {
chat_histories: HashMap<String, Vec<ChatMessage>>,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
store: Arc<SessionStore>,
@ -35,14 +37,19 @@ impl Session {
tools: Arc<ToolRegistry>,
store: Arc<SessionStore>,
) -> Result<Self, AgentError> {
let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
Ok(Self {
id: Uuid::new_v4(),
channel_name,
chat_histories: HashMap::new(),
user_tx,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor: ContextCompressor::new(provider_config.token_limit),
compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
store,
})
}
@ -179,7 +186,11 @@ impl Session {
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
Ok(AgentLoop::with_provider_and_tools(
self.provider.clone(),
self.tools.clone(),
self.provider_config.max_tool_iterations,
))
}
}
@ -429,7 +440,7 @@ impl SessionManager {
// 压缩历史(如果需要)
let history = session_guard.compressor
.compress_if_needed(history, &session_guard.provider_config)
.compress_if_needed(history)
.await?;
// 创建 agent 并处理

View File

@ -170,7 +170,7 @@ async fn handle_inbound(
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
let history = match session_guard
.compressor()
.compress_if_needed(raw_history, session_guard.provider_config())
.compress_if_needed(raw_history)
.await
{
Ok(history) => history,