From 81dcc67932a72cb8509ff569d4dccf38b0546cfe Mon Sep 17 00:00:00 2001 From: xiaoski Date: Fri, 24 Apr 2026 10:04:21 +0800 Subject: [PATCH] =?UTF-8?q?compressor=E4=B8=8D=E5=86=8D=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E7=8B=AC=E7=AB=8B=E7=9A=84LLMProvider=E5=AE=9E=E4=BE=8B?= =?UTF-8?q?=E3=80=82=E5=87=8F=E5=B0=91=E5=BC=80=E9=94=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent_loop.rs | 32 ++++++++++++++-- src/agent/context_compressor.rs | 67 ++++++++++++++++++++++++--------- src/gateway/session.rs | 17 +++++++-- src/gateway/ws.rs | 2 +- 4 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 9db10f9..8e0c9a8 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -218,7 +218,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message { /// AgentLoop - Stateless agent that processes messages with tool calling support. /// History is managed externally by SessionManager. pub struct AgentLoop { - provider: Box, + provider: Arc, tools: Arc, observer: Option>, max_iterations: usize, @@ -231,32 +231,58 @@ pub struct AgentProcessResult { } impl AgentLoop { + /// Create a new AgentLoop with a provider created from config. pub fn new(provider_config: LLMProviderConfig) -> Result { let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { - provider, + provider: Arc::from(provider), tools: Arc::new(ToolRegistry::new()), observer: None, max_iterations, }) } + /// Create a new AgentLoop with provider created from config and given tools. pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; Ok(Self { - provider, + provider: Arc::from(provider), tools, observer: None, max_iterations, }) } + /// Create a new AgentLoop with an existing shared provider. + pub fn with_provider(provider: Arc, max_iterations: usize) -> Self { + Self { + provider, + tools: Arc::new(ToolRegistry::new()), + observer: None, + max_iterations, + } + } + + /// Create a new AgentLoop with an existing shared provider and given tools. + pub fn with_provider_and_tools( + provider: Arc, + tools: Arc, + max_iterations: usize, + ) -> Self { + Self { + provider, + tools, + observer: None, + max_iterations, + } + } + /// Set an observer for tracking events. pub fn with_observer(mut self, observer: Arc) -> Self { self.observer = Some(observer); diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index d53224c..e2161fb 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -1,6 +1,7 @@ +use std::sync::Arc; + use crate::bus::ChatMessage; -use crate::config::LLMProviderConfig; -use crate::providers::{create_provider, ChatCompletionRequest, Message}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message}; use crate::agent::AgentError; @@ -46,24 +47,32 @@ pub struct ContextCompressor { context_window: usize, /// Threshold ratio to trigger compression (50% of context window) threshold_ratio: f64, + /// Shared LLM provider for summarization + provider: Arc, } impl ContextCompressor { - /// Create a new compressor with the given context window size. - pub fn new(context_window: usize) -> Self { + /// Create a new compressor with the given provider and context window size. + pub fn new(provider: Arc, context_window: usize) -> Self { Self { config: ContextCompressionConfig::default(), context_window, threshold_ratio: 0.5, + provider, } } /// Create with custom configuration. - pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self { + pub fn with_config( + provider: Arc, + context_window: usize, + config: ContextCompressionConfig, + ) -> Self { Self { config, context_window, threshold_ratio: 0.5, + provider, } } @@ -97,7 +106,6 @@ impl ContextCompressor { pub async fn compress_if_needed( &self, history: Vec, - provider_config: &LLMProviderConfig, ) -> Result, AgentError> { // Check if compression is needed let tokens = estimate_tokens(&history); @@ -149,7 +157,7 @@ impl ContextCompressor { "Compression pass" ); - match self.compress_once(¤t_history, provider_config).await { + match self.compress_once(¤t_history).await { Ok(Some(compressed)) => { current_history = compressed; } @@ -178,7 +186,6 @@ impl ContextCompressor { async fn compress_once( &self, history: &[ChatMessage], - provider_config: &LLMProviderConfig, ) -> Result>, AgentError> { if history.len() <= self.config.protect_first_n + self.config.protect_last_n { return Ok(None); @@ -214,7 +221,7 @@ impl ContextCompressor { if between_start < between_end { let between = &history[between_start..between_end]; - let summary = self.summarize_segment(between, provider_config).await?; + let summary = self.summarize_segment(between).await?; // Add summary as a special user message new_messages.push(ChatMessage::user(format!( @@ -245,7 +252,6 @@ impl ContextCompressor { async fn summarize_segment( &self, messages: &[ChatMessage], - provider_config: &LLMProviderConfig, ) -> Result { if messages.is_empty() { return Ok(String::new()); @@ -304,10 +310,6 @@ Be concise, aim for {} characters or less. self.config.summary_max_chars, transcript ); - // Create provider and call LLM - let provider = create_provider(provider_config.clone()) - .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; - let request = ChatCompletionRequest { messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], temperature: Some(0.3), @@ -315,7 +317,7 @@ Be concise, aim for {} characters or less. tools: None, }; - match provider.chat(request).await { + match (*self.provider).chat(request).await { Ok(response) => Ok(response.content), Err(e) => { // Fallback: just truncate the transcript @@ -329,6 +331,37 @@ Be concise, aim for {} characters or less. #[cfg(test)] mod tests { use super::*; + use crate::providers::ChatCompletionResponse; + use async_trait::async_trait; + + /// Mock provider for testing - panics if actually used for LLM calls + struct MockProvider; + + #[async_trait] + impl LLMProvider for MockProvider { + async fn chat( + &self, + _request: ChatCompletionRequest, + ) -> Result> { + panic!("MockProvider.chat() called - not expected in test") + } + + fn ptype(&self) -> &str { + "mock" + } + + fn name(&self) -> &str { + "mock" + } + + fn model_id(&self) -> &str { + "mock" + } + } + + fn mock_provider() -> Arc { + Arc::new(MockProvider) + } #[test] fn test_estimate_tokens() { @@ -352,7 +385,7 @@ mod tests { tool_result_trim_chars: 50, ..Default::default() }; - let compressor = ContextCompressor::with_config(100_000, config); + let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config); let mut messages = vec![ ChatMessage::user("Hello"), @@ -366,7 +399,7 @@ mod tests { #[test] fn test_threshold() { - let compressor = ContextCompressor::new(128_000); + let compressor = ContextCompressor::new(mock_provider(), 128_000); assert_eq!(compressor.threshold(), 64_000); } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index b5bf8f3..d76b81c 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -7,6 +7,7 @@ use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::protocol::WsOutbound; +use crate::providers::{create_provider, LLMProvider}; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, @@ -22,6 +23,7 @@ pub struct Session { chat_histories: HashMap>, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, + provider: Arc, tools: Arc, compressor: ContextCompressor, store: Arc, @@ -35,14 +37,19 @@ impl Session { tools: Arc, store: Arc, ) -> Result { + let provider_box = create_provider(provider_config.clone()) + .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; + let provider: Arc = Arc::from(provider_box); + Ok(Self { id: Uuid::new_v4(), channel_name, chat_histories: HashMap::new(), user_tx, provider_config: provider_config.clone(), + provider: provider.clone(), tools, - compressor: ContextCompressor::new(provider_config.token_limit), + compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit), store, }) } @@ -179,7 +186,11 @@ impl Session { /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent(&self) -> Result { - AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone()) + Ok(AgentLoop::with_provider_and_tools( + self.provider.clone(), + self.tools.clone(), + self.provider_config.max_tool_iterations, + )) } } @@ -429,7 +440,7 @@ impl SessionManager { // 压缩历史(如果需要) let history = session_guard.compressor - .compress_if_needed(history, &session_guard.provider_config) + .compress_if_needed(history) .await?; // 创建 agent 并处理 diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 7c545ff..5d46fa1 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -170,7 +170,7 @@ async fn handle_inbound( let raw_history = session_guard.get_or_create_history(&chat_id).clone(); let history = match session_guard .compressor() - .compress_if_needed(raw_history, session_guard.provider_config()) + .compress_if_needed(raw_history) .await { Ok(history) => history,