compressor不再使用独立的LLMProvider实例。减少开销
This commit is contained in:
parent
393d980742
commit
81dcc67932
@ -218,7 +218,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|||||||
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
/// AgentLoop - Stateless agent that processes messages with tool calling support.
|
||||||
/// History is managed externally by SessionManager.
|
/// History is managed externally by SessionManager.
|
||||||
pub struct AgentLoop {
|
pub struct AgentLoop {
|
||||||
provider: Box<dyn LLMProvider>,
|
provider: Arc<dyn LLMProvider>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
observer: Option<Arc<dyn Observer>>,
|
observer: Option<Arc<dyn Observer>>,
|
||||||
max_iterations: usize,
|
max_iterations: usize,
|
||||||
@ -231,32 +231,58 @@ pub struct AgentProcessResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
|
/// Create a new AgentLoop with a provider created from config.
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
let provider = create_provider(provider_config)
|
let provider = create_provider(provider_config)
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
provider,
|
provider: Arc::from(provider),
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
observer: None,
|
observer: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new AgentLoop with provider created from config and given tools.
|
||||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
let provider = create_provider(provider_config)
|
let provider = create_provider(provider_config)
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
provider,
|
provider: Arc::from(provider),
|
||||||
tools,
|
tools,
|
||||||
observer: None,
|
observer: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new AgentLoop with an existing shared provider.
|
||||||
|
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
|
observer: None,
|
||||||
|
max_iterations,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new AgentLoop with an existing shared provider and given tools.
|
||||||
|
pub fn with_provider_and_tools(
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
max_iterations: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
tools,
|
||||||
|
observer: None,
|
||||||
|
max_iterations,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set an observer for tracking events.
|
/// Set an observer for tracking events.
|
||||||
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
||||||
self.observer = Some(observer);
|
self.observer = Some(observer);
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
|
||||||
use crate::providers::{create_provider, ChatCompletionRequest, Message};
|
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
|
|
||||||
@ -46,24 +47,32 @@ pub struct ContextCompressor {
|
|||||||
context_window: usize,
|
context_window: usize,
|
||||||
/// Threshold ratio to trigger compression (50% of context window)
|
/// Threshold ratio to trigger compression (50% of context window)
|
||||||
threshold_ratio: f64,
|
threshold_ratio: f64,
|
||||||
|
/// Shared LLM provider for summarization
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextCompressor {
|
impl ContextCompressor {
|
||||||
/// Create a new compressor with the given context window size.
|
/// Create a new compressor with the given provider and context window size.
|
||||||
pub fn new(context_window: usize) -> Self {
|
pub fn new(provider: Arc<dyn LLMProvider>, context_window: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config: ContextCompressionConfig::default(),
|
config: ContextCompressionConfig::default(),
|
||||||
context_window,
|
context_window,
|
||||||
threshold_ratio: 0.5,
|
threshold_ratio: 0.5,
|
||||||
|
provider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create with custom configuration.
|
/// Create with custom configuration.
|
||||||
pub fn with_config(context_window: usize, config: ContextCompressionConfig) -> Self {
|
pub fn with_config(
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
|
context_window: usize,
|
||||||
|
config: ContextCompressionConfig,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
context_window,
|
context_window,
|
||||||
threshold_ratio: 0.5,
|
threshold_ratio: 0.5,
|
||||||
|
provider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +106,6 @@ impl ContextCompressor {
|
|||||||
pub async fn compress_if_needed(
|
pub async fn compress_if_needed(
|
||||||
&self,
|
&self,
|
||||||
history: Vec<ChatMessage>,
|
history: Vec<ChatMessage>,
|
||||||
provider_config: &LLMProviderConfig,
|
|
||||||
) -> Result<Vec<ChatMessage>, AgentError> {
|
) -> Result<Vec<ChatMessage>, AgentError> {
|
||||||
// Check if compression is needed
|
// Check if compression is needed
|
||||||
let tokens = estimate_tokens(&history);
|
let tokens = estimate_tokens(&history);
|
||||||
@ -149,7 +157,7 @@ impl ContextCompressor {
|
|||||||
"Compression pass"
|
"Compression pass"
|
||||||
);
|
);
|
||||||
|
|
||||||
match self.compress_once(¤t_history, provider_config).await {
|
match self.compress_once(¤t_history).await {
|
||||||
Ok(Some(compressed)) => {
|
Ok(Some(compressed)) => {
|
||||||
current_history = compressed;
|
current_history = compressed;
|
||||||
}
|
}
|
||||||
@ -178,7 +186,6 @@ impl ContextCompressor {
|
|||||||
async fn compress_once(
|
async fn compress_once(
|
||||||
&self,
|
&self,
|
||||||
history: &[ChatMessage],
|
history: &[ChatMessage],
|
||||||
provider_config: &LLMProviderConfig,
|
|
||||||
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
|
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
|
||||||
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
|
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@ -214,7 +221,7 @@ impl ContextCompressor {
|
|||||||
|
|
||||||
if between_start < between_end {
|
if between_start < between_end {
|
||||||
let between = &history[between_start..between_end];
|
let between = &history[between_start..between_end];
|
||||||
let summary = self.summarize_segment(between, provider_config).await?;
|
let summary = self.summarize_segment(between).await?;
|
||||||
|
|
||||||
// Add summary as a special user message
|
// Add summary as a special user message
|
||||||
new_messages.push(ChatMessage::user(format!(
|
new_messages.push(ChatMessage::user(format!(
|
||||||
@ -245,7 +252,6 @@ impl ContextCompressor {
|
|||||||
async fn summarize_segment(
|
async fn summarize_segment(
|
||||||
&self,
|
&self,
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
provider_config: &LLMProviderConfig,
|
|
||||||
) -> Result<String, AgentError> {
|
) -> Result<String, AgentError> {
|
||||||
if messages.is_empty() {
|
if messages.is_empty() {
|
||||||
return Ok(String::new());
|
return Ok(String::new());
|
||||||
@ -304,10 +310,6 @@ Be concise, aim for {} characters or less.
|
|||||||
self.config.summary_max_chars, transcript
|
self.config.summary_max_chars, transcript
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create provider and call LLM
|
|
||||||
let provider = create_provider(provider_config.clone())
|
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
||||||
temperature: Some(0.3),
|
temperature: Some(0.3),
|
||||||
@ -315,7 +317,7 @@ Be concise, aim for {} characters or less.
|
|||||||
tools: None,
|
tools: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
match provider.chat(request).await {
|
match (*self.provider).chat(request).await {
|
||||||
Ok(response) => Ok(response.content),
|
Ok(response) => Ok(response.content),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback: just truncate the transcript
|
// Fallback: just truncate the transcript
|
||||||
@ -329,6 +331,37 @@ Be concise, aim for {} characters or less.
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::providers::ChatCompletionResponse;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// Mock provider for testing - panics if actually used for LLM calls
|
||||||
|
struct MockProvider;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for MockProvider {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
_request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
panic!("MockProvider.chat() called - not expected in test")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptype(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_id(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mock_provider() -> Arc<dyn LLMProvider> {
|
||||||
|
Arc::new(MockProvider)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_estimate_tokens() {
|
fn test_estimate_tokens() {
|
||||||
@ -352,7 +385,7 @@ mod tests {
|
|||||||
tool_result_trim_chars: 50,
|
tool_result_trim_chars: 50,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let compressor = ContextCompressor::with_config(100_000, config);
|
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config);
|
||||||
|
|
||||||
let mut messages = vec![
|
let mut messages = vec![
|
||||||
ChatMessage::user("Hello"),
|
ChatMessage::user("Hello"),
|
||||||
@ -366,7 +399,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_threshold() {
|
fn test_threshold() {
|
||||||
let compressor = ContextCompressor::new(128_000);
|
let compressor = ContextCompressor::new(mock_provider(), 128_000);
|
||||||
assert_eq!(compressor.threshold(), 64_000);
|
assert_eq!(compressor.threshold(), 64_000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,7 @@ use crate::bus::ChatMessage;
|
|||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
|
use crate::providers::{create_provider, LLMProvider};
|
||||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||||
@ -22,6 +23,7 @@ pub struct Session {
|
|||||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
compressor: ContextCompressor,
|
compressor: ContextCompressor,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
@ -35,14 +37,19 @@ impl Session {
|
|||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
|
let provider_box = create_provider(provider_config.clone())
|
||||||
|
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||||
|
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
channel_name,
|
channel_name,
|
||||||
chat_histories: HashMap::new(),
|
chat_histories: HashMap::new(),
|
||||||
user_tx,
|
user_tx,
|
||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
|
provider: provider.clone(),
|
||||||
tools,
|
tools,
|
||||||
compressor: ContextCompressor::new(provider_config.token_limit),
|
compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
|
||||||
store,
|
store,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -179,7 +186,11 @@ impl Session {
|
|||||||
|
|
||||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||||||
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
|
Ok(AgentLoop::with_provider_and_tools(
|
||||||
|
self.provider.clone(),
|
||||||
|
self.tools.clone(),
|
||||||
|
self.provider_config.max_tool_iterations,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -429,7 +440,7 @@ impl SessionManager {
|
|||||||
|
|
||||||
// 压缩历史(如果需要)
|
// 压缩历史(如果需要)
|
||||||
let history = session_guard.compressor
|
let history = session_guard.compressor
|
||||||
.compress_if_needed(history, &session_guard.provider_config)
|
.compress_if_needed(history)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// 创建 agent 并处理
|
// 创建 agent 并处理
|
||||||
|
|||||||
@ -170,7 +170,7 @@ async fn handle_inbound(
|
|||||||
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
||||||
let history = match session_guard
|
let history = match session_guard
|
||||||
.compressor()
|
.compressor()
|
||||||
.compress_if_needed(raw_history, session_guard.provider_config())
|
.compress_if_needed(raw_history)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(history) => history,
|
Ok(history) => history,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user