compressor不再使用独立的LLMProvider实例。减少开销
This commit is contained in:
parent
393d980742
commit
81dcc67932
@ -218,7 +218,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
||||
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
||||
/// History is managed externally by SessionManager.
|
||||
pub struct AgentLoop {
|
||||
provider: Box<dyn LLMProvider>,
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
observer: Option<Arc<dyn Observer>>,
|
||||
max_iterations: usize,
|
||||
@ -231,32 +231,58 @@ pub struct AgentProcessResult {
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
/// Create a new AgentLoop with a provider created from config.
|
||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
let provider = create_provider(provider_config)
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
provider: Arc::from(provider),
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
observer: None,
|
||||
max_iterations,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with provider created from config and given tools.
|
||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
let provider = create_provider(provider_config)
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
provider: Arc::from(provider),
|
||||
tools,
|
||||
observer: None,
|
||||
max_iterations,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with an existing shared provider.
|
||||
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
observer: None,
|
||||
max_iterations,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with an existing shared provider and given tools.
|
||||
pub fn with_provider_and_tools(
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
max_iterations: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tools,
|
||||
observer: None,
|
||||
max_iterations,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set an observer for tracking events.
|
||||
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
||||
self.observer = Some(observer);
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
|
||||
@ -46,24 +47,32 @@ pub struct ContextCompressor {
|
||||
context_window: usize,
|
||||
/// Threshold ratio to trigger compression (50% of context window)
|
||||
threshold_ratio: f64,
|
||||
/// Shared LLM provider for summarization
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
/// Create a new compressor with the given context window size.
|
||||
pub fn new(context_window: usize) -> Self {
|
||||
/// Create a new compressor with the given provider and context window size.
|
||||
pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self {
|
||||
Self {
|
||||
config: ContextCompressionConfig::default(),
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
provider,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom configuration.
|
||||
pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self {
|
||||
pub fn with_config(
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
context_window: usize,
|
||||
config: ContextCompressionConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
provider,
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,7 +106,6 @@ impl ContextCompressor {
|
||||
pub async fn compress_if_needed(
|
||||
&self,
|
||||
history: Vec<ChatMessage>,
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Vec<ChatMessage>, AgentError> {
|
||||
// Check if compression is needed
|
||||
let tokens = estimate_tokens(&history);
|
||||
@ -149,7 +157,7 @@ impl ContextCompressor {
|
||||
"Compression pass"
|
||||
);
|
||||
|
||||
match self.compress_once(¤t_history, provider_config).await {
|
||||
match self.compress_once(¤t_history).await {
|
||||
Ok(Some(compressed)) => {
|
||||
current_history = compressed;
|
||||
}
|
||||
@ -178,7 +186,6 @@ impl ContextCompressor {
|
||||
async fn compress_once(
|
||||
&self,
|
||||
history: &[ChatMessage],
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
|
||||
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
|
||||
return Ok(None);
|
||||
@ -214,7 +221,7 @@ impl ContextCompressor {
|
||||
|
||||
if between_start < between_end {
|
||||
let between = &history[between_start..between_end];
|
||||
let summary = self.summarize_segment(between, provider_config).await?;
|
||||
let summary = self.summarize_segment(between).await?;
|
||||
|
||||
// Add summary as a special user message
|
||||
new_messages.push(ChatMessage::user(format!(
|
||||
@ -245,7 +252,6 @@ impl ContextCompressor {
|
||||
async fn summarize_segment(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<String, AgentError> {
|
||||
if messages.is_empty() {
|
||||
return Ok(String::new());
|
||||
@ -304,10 +310,6 @@ Be concise, aim for {} characters or less.
|
||||
self.config.summary_max_chars, transcript
|
||||
);
|
||||
|
||||
// Create provider and call LLM
|
||||
let provider = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
||||
temperature: Some(0.3),
|
||||
@ -315,7 +317,7 @@ Be concise, aim for {} characters or less.
|
||||
tools: None,
|
||||
};
|
||||
|
||||
match provider.chat(request).await {
|
||||
match (*self.provider).chat(request).await {
|
||||
Ok(response) => Ok(response.content),
|
||||
Err(e) => {
|
||||
// Fallback: just truncate the transcript
|
||||
@ -329,6 +331,37 @@ Be concise, aim for {} characters or less.
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::providers::ChatCompletionResponse;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Mock provider for testing - panics if actually used for LLM calls
|
||||
struct MockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for MockProvider {
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
panic!("MockProvider.chat() called - not expected in test")
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn model_id(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_provider() -> Arc<dyn LLMProvider> {
|
||||
Arc::new(MockProvider)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens() {
|
||||
@ -352,7 +385,7 @@ mod tests {
|
||||
tool_result_trim_chars: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(100_000, config);
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config);
|
||||
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
@ -366,7 +399,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_threshold() {
|
||||
let compressor = ContextCompressor::new(128_000);
|
||||
let compressor = ContextCompressor::new(mock_provider(), 128_000);
|
||||
assert_eq!(compressor.threshold(), 64_000);
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::providers::{create_provider, LLMProvider};
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
@ -22,6 +23,7 @@ pub struct Session {
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
compressor: ContextCompressor,
|
||||
store: Arc<SessionStore>,
|
||||
@ -35,14 +37,19 @@ impl Session {
|
||||
tools: Arc<ToolRegistry>,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
channel_name,
|
||||
chat_histories: HashMap::new(),
|
||||
user_tx,
|
||||
provider_config: provider_config.clone(),
|
||||
provider: provider.clone(),
|
||||
tools,
|
||||
compressor: ContextCompressor::new(provider_config.token_limit),
|
||||
compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
|
||||
store,
|
||||
})
|
||||
}
|
||||
@ -179,7 +186,11 @@ impl Session {
|
||||
|
||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||||
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
|
||||
Ok(AgentLoop::with_provider_and_tools(
|
||||
self.provider.clone(),
|
||||
self.tools.clone(),
|
||||
self.provider_config.max_tool_iterations,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -429,7 +440,7 @@ impl SessionManager {
|
||||
|
||||
// 压缩历史(如果需要)
|
||||
let history = session_guard.compressor
|
||||
.compress_if_needed(history, &session_guard.provider_config)
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
|
||||
// 创建 agent 并处理
|
||||
|
||||
@ -170,7 +170,7 @@ async fn handle_inbound(
|
||||
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
let history = match session_guard
|
||||
.compressor()
|
||||
.compress_if_needed(raw_history, session_guard.provider_config())
|
||||
.compress_if_needed(raw_history)
|
||||
.await
|
||||
{
|
||||
Ok(history) => history,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user