compressor不再使用独立的LLMProvider实例。减少开销

This commit is contained in:
xiaoski 2026-04-24 10:04:21 +08:00
parent 393d980742
commit 81dcc67932
4 changed files with 94 additions and 24 deletions

View File

@ -218,7 +218,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
/// AgentLoop - Stateless agent that processes messages with tool calling support. /// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager. /// History is managed externally by SessionManager.
pub struct AgentLoop { pub struct AgentLoop {
provider: Box<dyn LLMProvider>, provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
observer: Option<Arc<dyn Observer>>, observer: Option<Arc<dyn Observer>>,
max_iterations: usize, max_iterations: usize,
@ -231,32 +231,58 @@ pub struct AgentProcessResult {
} }
impl AgentLoop { impl AgentLoop {
/// Create a new AgentLoop with a provider created from config.
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> { pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config) let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?; .map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self { Ok(Self {
provider, provider: Arc::from(provider),
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
observer: None, observer: None,
max_iterations, max_iterations,
}) })
} }
/// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> { pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config) let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?; .map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self { Ok(Self {
provider, provider: Arc::from(provider),
tools, tools,
observer: None, observer: None,
max_iterations, max_iterations,
}) })
} }
/// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize) -> Self {
Self {
provider,
tools: Arc::new(ToolRegistry::new()),
observer: None,
max_iterations,
}
}
/// Create a new AgentLoop with an existing shared provider and given tools.
pub fn with_provider_and_tools(
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
max_iterations: usize,
) -> Self {
Self {
provider,
tools,
observer: None,
max_iterations,
}
}
/// Set an observer for tracking events. /// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self { pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer); self.observer = Some(observer);

View File

@ -1,6 +1,7 @@
use std::sync::Arc;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig; use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
use crate::providers::{create_provider, ChatCompletionRequest, Message};
use crate::agent::AgentError; use crate::agent::AgentError;
@ -46,24 +47,32 @@ pub struct ContextCompressor {
context_window: usize, context_window: usize,
/// Threshold ratio to trigger compression (50% of context window) /// Threshold ratio to trigger compression (50% of context window)
threshold_ratio: f64, threshold_ratio: f64,
/// Shared LLM provider for summarization
provider: Arc<dyn LLMProvider>,
} }
impl ContextCompressor { impl ContextCompressor {
/// Create a new compressor with the given context window size. /// Create a new compressor with the given provider and context window size.
pub fn new(context_window: usize) -> Self { pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self {
Self { Self {
config: ContextCompressionConfig::default(), config: ContextCompressionConfig::default(),
context_window, context_window,
threshold_ratio: 0.5, threshold_ratio: 0.5,
provider,
} }
} }
/// Create with custom configuration. /// Create with custom configuration.
pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self { pub fn with_config(
provider: Arc<dyn LLMProvider>,
context_window: usize,
config: ContextCompressionConfig,
) -> Self {
Self { Self {
config, config,
context_window, context_window,
threshold_ratio: 0.5, threshold_ratio: 0.5,
provider,
} }
} }
@ -97,7 +106,6 @@ impl ContextCompressor {
pub async fn compress_if_needed( pub async fn compress_if_needed(
&self, &self,
history: Vec<ChatMessage>, history: Vec<ChatMessage>,
provider_config: &LLMProviderConfig,
) -> Result<Vec<ChatMessage>, AgentError> { ) -> Result<Vec<ChatMessage>, AgentError> {
// Check if compression is needed // Check if compression is needed
let tokens = estimate_tokens(&history); let tokens = estimate_tokens(&history);
@ -149,7 +157,7 @@ impl ContextCompressor {
"Compression pass" "Compression pass"
); );
match self.compress_once(&current_history, provider_config).await { match self.compress_once(&current_history).await {
Ok(Some(compressed)) => { Ok(Some(compressed)) => {
current_history = compressed; current_history = compressed;
} }
@ -178,7 +186,6 @@ impl ContextCompressor {
async fn compress_once( async fn compress_once(
&self, &self,
history: &[ChatMessage], history: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<Option<Vec<ChatMessage>>, AgentError> { ) -> Result<Option<Vec<ChatMessage>>, AgentError> {
if history.len() <= self.config.protect_first_n + self.config.protect_last_n { if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
return Ok(None); return Ok(None);
@ -214,7 +221,7 @@ impl ContextCompressor {
if between_start < between_end { if between_start < between_end {
let between = &history[between_start..between_end]; let between = &history[between_start..between_end];
let summary = self.summarize_segment(between, provider_config).await?; let summary = self.summarize_segment(between).await?;
// Add summary as a special user message // Add summary as a special user message
new_messages.push(ChatMessage::user(format!( new_messages.push(ChatMessage::user(format!(
@ -245,7 +252,6 @@ impl ContextCompressor {
async fn summarize_segment( async fn summarize_segment(
&self, &self,
messages: &[ChatMessage], messages: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<String, AgentError> { ) -> Result<String, AgentError> {
if messages.is_empty() { if messages.is_empty() {
return Ok(String::new()); return Ok(String::new());
@ -304,10 +310,6 @@ Be concise, aim for {} characters or less.
self.config.summary_max_chars, transcript self.config.summary_max_chars, transcript
); );
// Create provider and call LLM
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
temperature: Some(0.3), temperature: Some(0.3),
@ -315,7 +317,7 @@ Be concise, aim for {} characters or less.
tools: None, tools: None,
}; };
match provider.chat(request).await { match (*self.provider).chat(request).await {
Ok(response) => Ok(response.content), Ok(response) => Ok(response.content),
Err(e) => { Err(e) => {
// Fallback: just truncate the transcript // Fallback: just truncate the transcript
@ -329,6 +331,37 @@ Be concise, aim for {} characters or less.
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::providers::ChatCompletionResponse;
use async_trait::async_trait;
/// Mock provider for testing - panics if actually used for LLM calls
struct MockProvider;
#[async_trait]
impl LLMProvider for MockProvider {
async fn chat(
&self,
_request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
panic!("MockProvider.chat() called - not expected in test")
}
fn ptype(&self) -> &str {
"mock"
}
fn name(&self) -> &str {
"mock"
}
fn model_id(&self) -> &str {
"mock"
}
}
fn mock_provider() -> Arc<dyn LLMProvider> {
Arc::new(MockProvider)
}
#[test] #[test]
fn test_estimate_tokens() { fn test_estimate_tokens() {
@ -352,7 +385,7 @@ mod tests {
tool_result_trim_chars: 50, tool_result_trim_chars: 50,
..Default::default() ..Default::default()
}; };
let compressor = ContextCompressor::with_config(100_000, config); let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config);
let mut messages = vec![ let mut messages = vec![
ChatMessage::user("Hello"), ChatMessage::user("Hello"),
@ -366,7 +399,7 @@ mod tests {
#[test] #[test]
fn test_threshold() { fn test_threshold() {
let compressor = ContextCompressor::new(128_000); let compressor = ContextCompressor::new(mock_provider(), 128_000);
assert_eq!(compressor.threshold(), 64_000); assert_eq!(compressor.threshold(), 64_000);
} }
} }

View File

@ -7,6 +7,7 @@ use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider};
use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{ use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
@ -22,6 +23,7 @@ pub struct Session {
chat_histories: HashMap<String, Vec<ChatMessage>>, chat_histories: HashMap<String, Vec<ChatMessage>>,
pub user_tx: mpsc::Sender<WsOutbound>, pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
compressor: ContextCompressor, compressor: ContextCompressor,
store: Arc<SessionStore>, store: Arc<SessionStore>,
@ -35,14 +37,19 @@ impl Session {
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
store: Arc<SessionStore>, store: Arc<SessionStore>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
Ok(Self { Ok(Self {
id: Uuid::new_v4(), id: Uuid::new_v4(),
channel_name, channel_name,
chat_histories: HashMap::new(), chat_histories: HashMap::new(),
user_tx, user_tx,
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
provider: provider.clone(),
tools, tools,
compressor: ContextCompressor::new(provider_config.token_limit), compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
store, store,
}) })
} }
@ -179,7 +186,11 @@ impl Session {
/// 创建一个临时的 AgentLoop 实例来处理消息 /// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> { pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone()) Ok(AgentLoop::with_provider_and_tools(
self.provider.clone(),
self.tools.clone(),
self.provider_config.max_tool_iterations,
))
} }
} }
@ -429,7 +440,7 @@ impl SessionManager {
// 压缩历史(如果需要) // 压缩历史(如果需要)
let history = session_guard.compressor let history = session_guard.compressor
.compress_if_needed(history, &session_guard.provider_config) .compress_if_needed(history)
.await?; .await?;
// 创建 agent 并处理 // 创建 agent 并处理

View File

@ -170,7 +170,7 @@ async fn handle_inbound(
let raw_history = session_guard.get_or_create_history(&chat_id).clone(); let raw_history = session_guard.get_or_create_history(&chat_id).clone();
let history = match session_guard let history = match session_guard
.compressor() .compressor()
.compress_if_needed(raw_history, session_guard.provider_config()) .compress_if_needed(raw_history)
.await .await
{ {
Ok(history) => history, Ok(history) => history,