feat: Enhance ChatMessage with system context and background compaction
- Added `system_context` field to `ChatMessage` for better message context handling. - Introduced constants for system context prompts in `message.rs`. - Updated `Session` to manage background history compaction, including methods to start and finish compaction. - Implemented logic to schedule background compaction after message processing in `SessionManager`. - Enhanced database schema to support new `system_context` field in messages. - Added functionality to compact active history, preserving system messages and summaries. - Updated tests to validate new compaction logic and ensure message integrity. - Removed unused functions and cleaned up code in various modules for better maintainability. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
3792472b83
commit
3045a6b596
@ -1,4 +1,9 @@
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::{
|
||||
ChatMessage,
|
||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
||||
SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||
};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
use crate::text::{char_count, take_prefix_chars};
|
||||
@ -17,26 +22,32 @@ pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||
/// Configuration for context compression.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextCompressionConfig {
|
||||
/// Protect first N messages (system prompt, etc.)
|
||||
pub protect_first_n: usize,
|
||||
/// Protect last N messages (recent context)
|
||||
pub protect_last_n: usize,
|
||||
/// Maximum compression passes
|
||||
pub max_passes: u32,
|
||||
/// Preserve the latest N complete user turns in full.
|
||||
pub retain_last_user_turns: usize,
|
||||
/// Maximum characters in summary
|
||||
pub summary_max_chars: usize,
|
||||
/// Characters to keep when trimming tool results
|
||||
pub tool_result_trim_chars: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct UserTurnRange {
|
||||
start: usize,
|
||||
end_exclusive: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HistoryCompactionPlan {
|
||||
pub preserved_system_messages: Vec<ChatMessage>,
|
||||
pub summary_message: ChatMessage,
|
||||
pub preserved_messages: Vec<ChatMessage>,
|
||||
pub compressed_turns: usize,
|
||||
pub preserved_turns: usize,
|
||||
}
|
||||
|
||||
impl Default for ContextCompressionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
protect_first_n: 1,
|
||||
protect_last_n: 4,
|
||||
max_passes: 3,
|
||||
retain_last_user_turns: 3,
|
||||
summary_max_chars: 20_000,
|
||||
tool_result_trim_chars: 2_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -65,7 +76,6 @@ impl ContextCompressor {
|
||||
provider_config.token_limit,
|
||||
ContextCompressionConfig {
|
||||
summary_max_chars: provider_config.context_summary_max_chars,
|
||||
tool_result_trim_chars: provider_config.context_tool_result_trim_chars,
|
||||
..ContextCompressionConfig::default()
|
||||
},
|
||||
)
|
||||
@ -85,26 +95,88 @@ impl ContextCompressor {
|
||||
(self.context_window as f64 * self.threshold_ratio) as usize
|
||||
}
|
||||
|
||||
/// Fast-path: trim oversized tool results without LLM call.
|
||||
/// Returns the number of messages modified.
|
||||
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize {
|
||||
let limit = self.config.tool_result_trim_chars;
|
||||
let mut modified = 0;
|
||||
|
||||
for msg in messages.iter_mut() {
|
||||
let content_chars = char_count(&msg.content);
|
||||
if msg.role == "tool" && content_chars > limit {
|
||||
let removed = content_chars - limit;
|
||||
msg.content = format!(
|
||||
"{}...\n\n[Output truncated - {} characters removed]",
|
||||
take_prefix_chars(&msg.content, limit),
|
||||
removed
|
||||
);
|
||||
modified += 1;
|
||||
}
|
||||
pub fn should_compress(&self, history: &[ChatMessage]) -> bool {
|
||||
estimate_tokens(history) > self.threshold()
|
||||
}
|
||||
|
||||
modified
|
||||
fn user_turn_ranges(&self, history: &[ChatMessage]) -> Vec<UserTurnRange> {
|
||||
let user_indices: Vec<usize> = history
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, message)| message.role == "user")
|
||||
.map(|(index, _)| index)
|
||||
.collect();
|
||||
|
||||
user_indices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(turn_index, start)| UserTurnRange {
|
||||
start: *start,
|
||||
end_exclusive: user_indices
|
||||
.get(turn_index + 1)
|
||||
.copied()
|
||||
.unwrap_or(history.len()),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn should_preserve_system_message(&self, message: &ChatMessage) -> bool {
|
||||
message.role == "system"
|
||||
&& (message.has_system_context(SYSTEM_CONTEXT_AGENT_PROMPT)
|
||||
|| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
|
||||
}
|
||||
|
||||
fn split_prefix_messages(&self, history: &[ChatMessage]) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
|
||||
let preserved_system_messages = history
|
||||
.iter()
|
||||
.filter(|message| self.should_preserve_system_message(message))
|
||||
.cloned()
|
||||
.collect();
|
||||
let summary_source = history
|
||||
.iter()
|
||||
.filter(|message| !self.should_preserve_system_message(message))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
(preserved_system_messages, summary_source)
|
||||
}
|
||||
|
||||
pub async fn build_compaction_plan(
|
||||
&self,
|
||||
history: &[ChatMessage],
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Option<HistoryCompactionPlan>, AgentError> {
|
||||
if !self.should_compress(history) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let turn_ranges = self.user_turn_ranges(history);
|
||||
if turn_ranges.len() <= self.config.retain_last_user_turns {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let preserved_turn_start = turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
|
||||
if preserved_turn_start == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let (preserved_system_messages, summary_source) =
|
||||
self.split_prefix_messages(&history[..preserved_turn_start]);
|
||||
|
||||
let summary = self
|
||||
.summarize_segment(&summary_source, provider_config)
|
||||
.await?;
|
||||
|
||||
Ok(Some(HistoryCompactionPlan {
|
||||
preserved_system_messages,
|
||||
summary_message: ChatMessage::system_with_context(format!(
|
||||
"[Compressed History]\n\n{}",
|
||||
summary
|
||||
), Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string())),
|
||||
preserved_messages: history[preserved_turn_start..].to_vec(),
|
||||
compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns,
|
||||
preserved_turns: self.config.retain_last_user_turns,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Main entry point - compresses history if over threshold.
|
||||
@ -113,8 +185,6 @@ impl ContextCompressor {
|
||||
history: Vec<ChatMessage>,
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Vec<ChatMessage>, AgentError> {
|
||||
let mut history = history;
|
||||
// Check if compression is needed
|
||||
let tokens = estimate_tokens(&history);
|
||||
if tokens <= self.threshold() {
|
||||
#[cfg(debug_assertions)]
|
||||
@ -134,50 +204,18 @@ impl ContextCompressor {
|
||||
"Starting context compression"
|
||||
);
|
||||
|
||||
// Fast trim pass first
|
||||
let trimmed = self.fast_trim_tool_results(&mut history);
|
||||
if trimmed > 0 {
|
||||
let tokens_after = estimate_tokens(&history);
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
trimmed_messages = trimmed,
|
||||
tokens_after = tokens_after,
|
||||
"Fast trim completed"
|
||||
let current_history = match self.build_compaction_plan(&history, provider_config).await? {
|
||||
Some(plan) => {
|
||||
let mut compressed = Vec::with_capacity(
|
||||
plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1,
|
||||
);
|
||||
if tokens_after <= self.threshold() {
|
||||
return Ok(history);
|
||||
}
|
||||
}
|
||||
|
||||
// LLM summarization pass
|
||||
let mut current_history = history;
|
||||
for pass in 0..self.config.max_passes {
|
||||
let tokens = estimate_tokens(¤t_history);
|
||||
if tokens <= self.threshold() {
|
||||
break;
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
pass = pass + 1,
|
||||
tokens = tokens,
|
||||
"Compression pass"
|
||||
);
|
||||
|
||||
match self.compress_once(¤t_history, provider_config).await {
|
||||
Ok(Some(compressed)) => {
|
||||
current_history = compressed;
|
||||
}
|
||||
Ok(None) => {
|
||||
// No more compressible content
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Compression pass failed, using current history");
|
||||
break;
|
||||
}
|
||||
}
|
||||
compressed.extend(plan.preserved_system_messages);
|
||||
compressed.push(plan.summary_message);
|
||||
compressed.extend(plan.preserved_messages);
|
||||
compressed
|
||||
}
|
||||
None => history,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
final_tokens = estimate_tokens(¤t_history),
|
||||
@ -188,74 +226,6 @@ impl ContextCompressor {
|
||||
Ok(current_history)
|
||||
}
|
||||
|
||||
/// Single compression pass - summarize middle messages between user turns.
|
||||
/// Returns Some(compressed) if compression happened, None if nothing to compress.
|
||||
async fn compress_once(
|
||||
&self,
|
||||
history: &[ChatMessage],
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
|
||||
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Find user message indices (excluding protected first messages)
|
||||
let user_indices: Vec<usize> = history
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(self.config.protect_first_n)
|
||||
.filter(|(_, m)| m.role == "user")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// Need at least one user message and content between users to compress
|
||||
if user_indices.len() < 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build segments: user -> (assistant turns) -> next user
|
||||
// We'll summarize the assistant turns between consecutive user messages
|
||||
let mut new_messages = history[..=user_indices[0]].to_vec();
|
||||
|
||||
for i in 0..user_indices.len() - 1 {
|
||||
let user_idx = user_indices[i];
|
||||
let next_user_idx = user_indices[i + 1];
|
||||
|
||||
new_messages.push(history[user_idx].clone());
|
||||
|
||||
// Check if there's assistant content between these two user messages
|
||||
let between_start = user_idx + 1;
|
||||
let between_end = next_user_idx;
|
||||
|
||||
if between_start < between_end {
|
||||
let between = &history[between_start..between_end];
|
||||
let summary = self.summarize_segment(between, provider_config).await?;
|
||||
|
||||
// Add summary as a special user message
|
||||
new_messages.push(ChatMessage::user(format!(
|
||||
"[Context Summary]\n\n{}",
|
||||
summary
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add last user and everything after (protected)
|
||||
let last_user_idx = user_indices[user_indices.len() - 1];
|
||||
if last_user_idx < history.len() - 1 {
|
||||
// Add everything from last user onwards (protected)
|
||||
for i in last_user_idx..history.len() {
|
||||
new_messages.push(history[i].clone());
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing changed, return None
|
||||
if new_messages.len() == history.len() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(new_messages))
|
||||
}
|
||||
|
||||
/// Summarize a segment of messages using LLM.
|
||||
async fn summarize_segment(
|
||||
&self,
|
||||
@ -299,6 +269,7 @@ impl ContextCompressor {
|
||||
r#"You are a conversation compaction engine. Summarize the following conversation segment.
|
||||
|
||||
PRESERVE:
|
||||
- Each user question or request in full or as a near-verbatim restatement
|
||||
- All identifiers (UUIDs, hashes, file paths, URLs)
|
||||
- Actions taken (tool calls, file operations, commands)
|
||||
- Key information obtained (results, data, errors)
|
||||
@ -306,9 +277,11 @@ PRESERVE:
|
||||
- Current task status
|
||||
|
||||
OMIT:
|
||||
- Verbose tool output (keep key results only)
|
||||
- Reproducing full tool output verbatim unless it is essential
|
||||
- Repeated greetings or filler
|
||||
|
||||
Do not assume tool content was pre-trimmed. You may receive long tool outputs; keep the important results, errors, and artifacts.
|
||||
|
||||
Be concise, aim for {} characters or less.
|
||||
|
||||
---
|
||||
@ -362,43 +335,48 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fast_trim() {
|
||||
let config = ContextCompressionConfig {
|
||||
tool_result_trim_chars: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(100_000, config);
|
||||
fn test_should_compress() {
|
||||
let compressor = ContextCompressor::new(20);
|
||||
let messages = vec![ChatMessage::user(&"x".repeat(200))];
|
||||
assert!(compressor.should_compress(&messages));
|
||||
}
|
||||
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
|
||||
#[test]
|
||||
fn test_user_turn_ranges_follow_user_boundaries() {
|
||||
let compressor = ContextCompressor::new(100_000);
|
||||
let history = vec![
|
||||
ChatMessage::system("system"),
|
||||
ChatMessage::user("u1"),
|
||||
ChatMessage::assistant("a1"),
|
||||
ChatMessage::tool("call-1", "bash", "t1"),
|
||||
ChatMessage::user("u2"),
|
||||
ChatMessage::assistant("a2"),
|
||||
ChatMessage::user("u3"),
|
||||
];
|
||||
|
||||
let modified = compressor.fast_trim_tool_results(&mut messages);
|
||||
assert_eq!(modified, 1);
|
||||
assert!(messages[1].content.len() < 100);
|
||||
let turns = compressor.user_turn_ranges(&history);
|
||||
assert_eq!(turns, vec![
|
||||
UserTurnRange { start: 1, end_exclusive: 4 },
|
||||
UserTurnRange { start: 4, end_exclusive: 6 },
|
||||
UserTurnRange { start: 6, end_exclusive: 7 },
|
||||
]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fast_trim_handles_utf8_char_boundaries() {
|
||||
let config = ContextCompressionConfig {
|
||||
tool_result_trim_chars: 5,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(100_000, config);
|
||||
fn test_split_prefix_messages_preserves_key_system_messages() {
|
||||
let compressor = ContextCompressor::new(50);
|
||||
let prefix = vec![
|
||||
ChatMessage::system_with_context("agent prompt", Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string())),
|
||||
ChatMessage::user("u1"),
|
||||
ChatMessage::assistant("a1"),
|
||||
ChatMessage::system_with_context("scheduled prompt", Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string())),
|
||||
];
|
||||
|
||||
let mut messages = vec![ChatMessage::tool("call1", "bash", &"雅".repeat(20))];
|
||||
|
||||
let modified = compressor.fast_trim_tool_results(&mut messages);
|
||||
assert_eq!(modified, 1);
|
||||
assert!(messages[0].content.contains("Output truncated"));
|
||||
assert!(messages[0].content.is_char_boundary(messages[0].content.len()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_tool_result_trim_chars() {
|
||||
let config = ContextCompressionConfig::default();
|
||||
assert_eq!(config.tool_result_trim_chars, 2_000);
|
||||
let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix);
|
||||
assert_eq!(preserved_system_messages.len(), 2);
|
||||
assert_eq!(summary_source.len(), 2);
|
||||
assert!(preserved_system_messages[0].has_system_context(SYSTEM_CONTEXT_AGENT_PROMPT));
|
||||
assert!(preserved_system_messages[1].has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -3,6 +3,10 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::providers::ToolCall;
|
||||
|
||||
pub const SYSTEM_CONTEXT_AGENT_PROMPT: &str = "agent_prompt";
|
||||
pub const SYSTEM_CONTEXT_SCHEDULED_PROMPT: &str = "scheduled_system_prompt";
|
||||
pub const SYSTEM_CONTEXT_HISTORY_COMPACTION: &str = "history_compaction";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolMessageState {
|
||||
@ -75,6 +79,8 @@ pub struct ChatMessage {
|
||||
pub media_refs: Vec<String>, // Paths to media files for context
|
||||
pub timestamp: i64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_context: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
@ -94,6 +100,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
system_context: None,
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -109,6 +116,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs,
|
||||
timestamp: current_timestamp(),
|
||||
system_context: None,
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -124,6 +132,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
system_context: None,
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -148,6 +157,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
system_context: None,
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -167,12 +177,20 @@ impl ChatMessage {
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::system_with_context(content, None::<String>)
|
||||
}
|
||||
|
||||
pub fn system_with_context(
|
||||
content: impl Into<String>,
|
||||
system_context: impl Into<Option<String>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
system_context: system_context.into(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -197,6 +215,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
system_context: None,
|
||||
reasoning_content: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
tool_name: Some(tool_name.into()),
|
||||
@ -205,6 +224,10 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_system_context(&self, expected: &str) -> bool {
|
||||
self.system_context.as_deref() == Some(expected)
|
||||
}
|
||||
|
||||
pub fn is_assistant_tool_call_message(&self) -> bool {
|
||||
self.role == "assistant"
|
||||
&& self
|
||||
|
||||
@ -2,7 +2,16 @@ pub mod dispatcher;
|
||||
pub mod message;
|
||||
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage};
|
||||
pub use message::{
|
||||
ChatMessage,
|
||||
ContentBlock,
|
||||
InboundMessage,
|
||||
MediaItem,
|
||||
OutboundMessage,
|
||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
||||
SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
@ -7,7 +7,13 @@ use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||
use crate::bus::{
|
||||
ChatMessage,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||
};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
|
||||
use crate::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
@ -399,6 +405,7 @@ pub struct Session {
|
||||
pub channel_name: String,
|
||||
/// 按 chat_id 路由到不同会话历史,支持多用户多会话
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
compression_in_flight: HashSet<String>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
@ -478,6 +485,7 @@ impl Session {
|
||||
id: Uuid::new_v4(),
|
||||
channel_name,
|
||||
chat_histories: HashMap::new(),
|
||||
compression_in_flight: HashSet::new(),
|
||||
user_tx,
|
||||
provider_config: provider_config.clone(),
|
||||
tools,
|
||||
@ -532,7 +540,13 @@ impl Session {
|
||||
> session_record.agent_prompt_reinjection_count
|
||||
{
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
|
||||
self.append_persisted_message(
|
||||
chat_id,
|
||||
ChatMessage::system_with_context(
|
||||
agent_prompt,
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
),
|
||||
)?;
|
||||
self.store
|
||||
.mark_agent_prompt_reinjected(&session_id)
|
||||
.map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?;
|
||||
@ -562,6 +576,7 @@ impl Session {
|
||||
|
||||
pub fn remove_history(&mut self, chat_id: &str) {
|
||||
self.chat_histories.remove(chat_id);
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
}
|
||||
|
||||
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
@ -640,6 +655,7 @@ impl Session {
|
||||
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||
self.chat_histories.clear();
|
||||
self.compression_in_flight.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||
|
||||
@ -666,6 +682,23 @@ impl Session {
|
||||
&self.compressor
|
||||
}
|
||||
|
||||
fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||||
self.compression_in_flight.insert(chat_id.to_string())
|
||||
}
|
||||
|
||||
fn finish_background_compaction(&mut self, chat_id: &str) {
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
}
|
||||
|
||||
fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history = self
|
||||
.store
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if self.skills.is_empty() {
|
||||
return Ok(());
|
||||
@ -729,7 +762,13 @@ impl Session {
|
||||
}
|
||||
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
|
||||
self.append_persisted_message(
|
||||
chat_id,
|
||||
ChatMessage::system_with_context(
|
||||
agent_prompt,
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -830,6 +869,94 @@ pub(crate) fn handle_in_chat_command(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub(crate) async fn schedule_background_history_compaction(
|
||||
session: Arc<Mutex<Session>>,
|
||||
chat_id: impl Into<String>,
|
||||
) -> Result<(), AgentError> {
|
||||
let chat_id = chat_id.into();
|
||||
|
||||
let snapshot = {
|
||||
let mut session_guard = session.lock().await;
|
||||
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
|
||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
let compressor = session_guard.compressor().clone();
|
||||
if !compressor.should_compress(&history) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !session_guard.try_start_background_compaction(&chat_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
(
|
||||
session_guard.store.clone(),
|
||||
session_guard.persistent_session_id(&chat_id),
|
||||
session_record.reset_cutoff_seq,
|
||||
session_record.message_count,
|
||||
history,
|
||||
compressor,
|
||||
session_guard.provider_config().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let (store, session_id, expected_reset_cutoff_seq, snapshot_end_seq, history, compressor, provider_config) = snapshot;
|
||||
let session_for_task = session.clone();
|
||||
let chat_id_for_task = chat_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
|
||||
|
||||
let compaction_result = compressor.build_compaction_plan(&history, &provider_config).await;
|
||||
let mut committed = false;
|
||||
|
||||
match compaction_result {
|
||||
Ok(Some(plan)) => match store.compact_active_history(
|
||||
&session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
&plan.preserved_system_messages,
|
||||
&plan.summary_message,
|
||||
&plan.preserved_messages,
|
||||
) {
|
||||
Ok(true) => {
|
||||
committed = true;
|
||||
tracing::info!(
|
||||
chat_id = %chat_id_for_task,
|
||||
snapshot_end_seq,
|
||||
compressed_turns = plan.compressed_turns,
|
||||
preserved_turns = plan.preserved_turns,
|
||||
"Background history compaction committed"
|
||||
);
|
||||
}
|
||||
Ok(false) => {
|
||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
|
||||
}
|
||||
}
|
||||
|
||||
let mut session_guard = session_for_task.lock().await;
|
||||
if committed {
|
||||
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
|
||||
}
|
||||
}
|
||||
session_guard.finish_background_compaction(&chat_id_for_task);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
impl SessionManager {
|
||||
pub fn new(
|
||||
session_ttl_hours: u64,
|
||||
@ -1224,7 +1351,7 @@ impl SessionManager {
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
|
||||
// 处理消息
|
||||
let (history, compressor, provider_config, agent, user_message_id) = {
|
||||
let (history, agent, user_message_id) = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
session_guard.ensure_persistent_session(chat_id)?;
|
||||
@ -1254,9 +1381,6 @@ impl SessionManager {
|
||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
let compressor = session_guard.compressor().clone();
|
||||
let provider_config = session_guard.provider_config().clone();
|
||||
|
||||
session_guard.record_skill_offer(chat_id)?;
|
||||
|
||||
// 创建 agent 并处理
|
||||
@ -1265,14 +1389,12 @@ impl SessionManager {
|
||||
agent = agent.with_emitted_message_handler(handler);
|
||||
}
|
||||
|
||||
(history, compressor, provider_config, agent, user_message_id)
|
||||
(history, agent, user_message_id)
|
||||
};
|
||||
|
||||
let history = compressor
|
||||
.compress_if_needed(history, &provider_config)
|
||||
.await?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
let mut should_schedule_compaction = false;
|
||||
let response = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
@ -1287,6 +1409,7 @@ impl SessionManager {
|
||||
} else {
|
||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
should_schedule_compaction = true;
|
||||
|
||||
result
|
||||
.emitted_messages
|
||||
@ -1308,6 +1431,12 @@ impl SessionManager {
|
||||
}
|
||||
};
|
||||
|
||||
if should_schedule_compaction {
|
||||
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.to_string()).await {
|
||||
tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
@ -1340,7 +1469,7 @@ impl SessionManager {
|
||||
.unwrap_or_else(|| "scheduler".to_string());
|
||||
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
|
||||
|
||||
let (history, compressor, agent, user_message_id) = {
|
||||
let (history, agent, user_message_id) = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
session_guard.ensure_persistent_session(chat_id)?;
|
||||
@ -1353,7 +1482,13 @@ impl SessionManager {
|
||||
session_guard.ensure_agent_prompt_before_user_message(chat_id)?;
|
||||
|
||||
if let Some(system_prompt) = options.system_prompt.as_deref() {
|
||||
session_guard.append_persisted_message(chat_id, ChatMessage::system(system_prompt))?;
|
||||
session_guard.append_persisted_message(
|
||||
chat_id,
|
||||
ChatMessage::system_with_context(
|
||||
system_prompt,
|
||||
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
|
||||
),
|
||||
)?;
|
||||
}
|
||||
|
||||
let user_message = session_guard.create_user_message(prompt, Vec::new());
|
||||
@ -1361,7 +1496,6 @@ impl SessionManager {
|
||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
let compressor = session_guard.compressor().clone();
|
||||
|
||||
session_guard.record_skill_offer(chat_id)?;
|
||||
|
||||
@ -1372,14 +1506,12 @@ impl SessionManager {
|
||||
provider_config.clone(),
|
||||
)?;
|
||||
|
||||
(history, compressor, agent, user_message_id)
|
||||
(history, agent, user_message_id)
|
||||
};
|
||||
|
||||
let history = compressor
|
||||
.compress_if_needed(history, &provider_config)
|
||||
.await?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
let mut should_schedule_compaction = false;
|
||||
let response = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
@ -1393,6 +1525,7 @@ impl SessionManager {
|
||||
Vec::new()
|
||||
} else {
|
||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
should_schedule_compaction = true;
|
||||
|
||||
result
|
||||
.emitted_messages
|
||||
@ -1411,6 +1544,12 @@ impl SessionManager {
|
||||
}
|
||||
};
|
||||
|
||||
if should_schedule_compaction {
|
||||
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.to_string()).await {
|
||||
tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for scheduled task");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ use crate::agent::EmittedMessageHandler;
|
||||
use crate::bus::message::{format_tool_call_content, ToolMessageState};
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
||||
use super::{GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}};
|
||||
|
||||
struct WsToolCallEmitter {
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
@ -246,6 +246,7 @@ async fn handle_inbound(
|
||||
WsInbound::UserInput { content, chat_id, sender_id, .. } => {
|
||||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||||
let (history, agent, user_tx) = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
session_guard.ensure_persistent_session(&chat_id)?;
|
||||
@ -268,19 +269,7 @@ async fn handle_inbound(
|
||||
let user_message_id = user_message.id.clone();
|
||||
session_guard.append_persisted_message(&chat_id, user_message)?;
|
||||
|
||||
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
let history = match session_guard
|
||||
.compressor()
|
||||
.compress_if_needed(raw_history, session_guard.provider_config())
|
||||
.await
|
||||
{
|
||||
Ok(history) => history,
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history");
|
||||
session_guard.get_or_create_history(&chat_id).clone()
|
||||
}
|
||||
};
|
||||
|
||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
session_guard.record_skill_offer(&chat_id)?;
|
||||
|
||||
let live_emitter = Arc::new(WsToolCallEmitter {
|
||||
@ -290,8 +279,13 @@ async fn handle_inbound(
|
||||
let agent = session_guard
|
||||
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
|
||||
.with_emitted_message_handler(live_emitter);
|
||||
|
||||
(history, agent, session_guard.user_tx.clone())
|
||||
};
|
||||
|
||||
match agent.process(history).await {
|
||||
Ok(result) => {
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||
for outbound in result
|
||||
.emitted_messages
|
||||
@ -304,10 +298,16 @@ async fn handle_inbound(
|
||||
{
|
||||
let _ = session_guard.send(outbound).await;
|
||||
}
|
||||
|
||||
drop(session_guard);
|
||||
|
||||
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()).await {
|
||||
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
||||
let _ = session_guard
|
||||
let _ = user_tx
|
||||
.send(WsOutbound::Error {
|
||||
code: "LLM_ERROR".to_string(),
|
||||
message: error.to_string(),
|
||||
|
||||
@ -136,6 +136,7 @@ struct AnthropicResponse {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum AnthropicContent {
|
||||
Text { text: String },
|
||||
#[allow(dead_code)]
|
||||
Thinking { thinking: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
|
||||
@ -150,6 +150,7 @@ struct OpenAIMessage {
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
@ -161,6 +162,7 @@ struct OpenAIToolCall {
|
||||
id: String,
|
||||
#[serde(rename = "function")]
|
||||
function: OAIFunction,
|
||||
#[allow(dead_code)]
|
||||
#[serde(default)]
|
||||
index: Option<u32>,
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#[cfg(not(test))]
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -190,11 +191,18 @@ pub struct SchedulerJobUpsert {
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
#[cfg(test)]
|
||||
pub fn new() -> Result<Self, StorageError> {
|
||||
Self::from_connection(Connection::open_in_memory()?)
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub fn new() -> Result<Self, StorageError> {
|
||||
let db_path = default_session_db_path()?;
|
||||
Self::open_at_path(&db_path)
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
fn open_at_path(path: &Path) -> Result<Self, StorageError> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
@ -205,6 +213,7 @@ impl SessionStore {
|
||||
}
|
||||
|
||||
fn from_connection(conn: Connection) -> Result<Self, StorageError> {
|
||||
conn.busy_timeout(std::time::Duration::from_secs(5))?;
|
||||
conn.execute_batch(
|
||||
"
|
||||
PRAGMA journal_mode = WAL;
|
||||
@ -238,6 +247,7 @@ impl SessionStore {
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
system_context TEXT,
|
||||
reasoning_content TEXT,
|
||||
media_refs_json TEXT NOT NULL,
|
||||
tool_call_id TEXT,
|
||||
@ -555,8 +565,8 @@ impl SessionStore {
|
||||
"
|
||||
INSERT INTO messages (
|
||||
id, session_id, seq, role, content,
|
||||
reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
|
||||
system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)
|
||||
",
|
||||
params![
|
||||
message.id,
|
||||
@ -564,6 +574,7 @@ impl SessionStore {
|
||||
seq,
|
||||
message.role,
|
||||
message.content,
|
||||
message.system_context,
|
||||
message.reasoning_content,
|
||||
media_refs_json,
|
||||
message.tool_call_id,
|
||||
@ -592,6 +603,85 @@ impl SessionStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn compact_active_history(
|
||||
&self,
|
||||
session_id: &str,
|
||||
expected_reset_cutoff_seq: i64,
|
||||
snapshot_end_seq: i64,
|
||||
preserved_system_messages: &[ChatMessage],
|
||||
summary_message: &ChatMessage,
|
||||
preserved_messages: &[ChatMessage],
|
||||
) -> Result<bool, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
let current_cutoff = active_reset_cutoff(&tx, session_id)?;
|
||||
if current_cutoff != expected_reset_cutoff_seq {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let current_max_seq: i64 = tx.query_row(
|
||||
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
if snapshot_end_seq <= current_cutoff || snapshot_end_seq > current_max_seq {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let delta_messages = load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
|
||||
let mut next_seq = current_max_seq + 1;
|
||||
let now = current_timestamp();
|
||||
let mut inserted_count = 0_i64;
|
||||
let mut active_user_turn_count = 0_i64;
|
||||
|
||||
for message in preserved_system_messages {
|
||||
let copied = clone_message_for_compaction(message, message.timestamp);
|
||||
insert_message_with_seq(&tx, session_id, next_seq, &copied)?;
|
||||
next_seq += 1;
|
||||
inserted_count += 1;
|
||||
}
|
||||
|
||||
let summary_copy = clone_message_for_compaction(summary_message, now);
|
||||
insert_message_with_seq(&tx, session_id, next_seq, &summary_copy)?;
|
||||
next_seq += 1;
|
||||
inserted_count += 1;
|
||||
|
||||
for message in preserved_messages.iter().chain(delta_messages.iter()) {
|
||||
let copied = clone_message_for_compaction(message, message.timestamp);
|
||||
if copied.role == "user" {
|
||||
active_user_turn_count += 1;
|
||||
}
|
||||
insert_message_with_seq(&tx, session_id, next_seq, &copied)?;
|
||||
next_seq += 1;
|
||||
inserted_count += 1;
|
||||
}
|
||||
|
||||
tx.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
SET reset_cutoff_seq = ?2,
|
||||
message_count = message_count + ?3,
|
||||
user_turn_count = ?4,
|
||||
updated_at = ?5,
|
||||
last_active_at = ?5,
|
||||
archived_at = NULL
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
params![
|
||||
session_id,
|
||||
current_max_seq,
|
||||
inserted_count,
|
||||
active_user_turn_count,
|
||||
now,
|
||||
],
|
||||
)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
@ -1206,6 +1296,7 @@ pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
Ok(home.join(".picobot").join("storage").join("sessions.db"))
|
||||
@ -1329,23 +1420,23 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<Schedul
|
||||
|
||||
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
if !has_column(conn, "sessions", "reset_cutoff_seq")? {
|
||||
conn.execute(
|
||||
add_column_if_missing(
|
||||
conn,
|
||||
"ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "sessions", "user_turn_count")? {
|
||||
conn.execute(
|
||||
add_column_if_missing(
|
||||
conn,
|
||||
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
|
||||
conn.execute(
|
||||
add_column_if_missing(
|
||||
conn,
|
||||
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
@ -1353,11 +1444,12 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
}
|
||||
|
||||
fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
if !has_column(conn, "messages", "system_context")? {
|
||||
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN system_context TEXT")?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "messages", "reasoning_content")? {
|
||||
conn.execute(
|
||||
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
|
||||
[],
|
||||
)?;
|
||||
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN reasoning_content TEXT")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -1438,6 +1530,15 @@ fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageError> {
|
||||
match conn.execute(sql, []) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
|
||||
if message.contains("duplicate column name") => Ok(()),
|
||||
Err(error) => Err(StorageError::Database(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> {
|
||||
let cutoff = conn
|
||||
.query_row(
|
||||
@ -1450,22 +1551,72 @@ fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, Stora
|
||||
Ok(cutoff.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn load_messages_after(
|
||||
conn: &Connection,
|
||||
fn insert_message_with_seq(
|
||||
conn: &rusqlite::Transaction<'_>,
|
||||
session_id: &str,
|
||||
cutoff_seq: i64,
|
||||
seq: i64,
|
||||
message: &ChatMessage,
|
||||
) -> Result<(), StorageError> {
|
||||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
||||
conn.execute(
|
||||
"
|
||||
INSERT INTO messages (
|
||||
id, session_id, seq, role, content,
|
||||
system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)
|
||||
",
|
||||
params![
|
||||
message.id,
|
||||
session_id,
|
||||
seq,
|
||||
message.role,
|
||||
message.content,
|
||||
message.system_context,
|
||||
message.reasoning_content,
|
||||
media_refs_json,
|
||||
message.tool_call_id,
|
||||
message.tool_name,
|
||||
tool_calls_json,
|
||||
message.timestamp,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clone_message_for_compaction(message: &ChatMessage, timestamp: i64) -> ChatMessage {
|
||||
ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: message.role.clone(),
|
||||
content: message.content.clone(),
|
||||
media_refs: message.media_refs.clone(),
|
||||
timestamp,
|
||||
system_context: message.system_context.clone(),
|
||||
reasoning_content: message.reasoning_content.clone(),
|
||||
tool_call_id: message.tool_call_id.clone(),
|
||||
tool_name: message.tool_name.clone(),
|
||||
tool_state: message.tool_state.clone(),
|
||||
tool_calls: message.tool_calls.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_messages_between(
|
||||
conn: &rusqlite::Transaction<'_>,
|
||||
session_id: &str,
|
||||
start_seq_exclusive: i64,
|
||||
end_seq_inclusive: i64,
|
||||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, role, content, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||
FROM messages
|
||||
WHERE session_id = ?1 AND seq > ?2
|
||||
WHERE session_id = ?1 AND seq > ?2 AND seq <= ?3
|
||||
ORDER BY seq ASC
|
||||
",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
|
||||
let media_refs_json: String = row.get(4)?;
|
||||
let rows = stmt.query_map(params![session_id, start_seq_exclusive, end_seq_inclusive], |row| {
|
||||
let media_refs_json: String = row.get(5)?;
|
||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
media_refs_json.len(),
|
||||
@ -1474,14 +1625,14 @@ fn load_messages_after(
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls_json: Option<String> = row.get(8)?;
|
||||
let tool_calls_json: Option<String> = row.get(9)?;
|
||||
let tool_calls = tool_calls_json
|
||||
.as_deref()
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
8,
|
||||
9,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
@ -1491,11 +1642,71 @@ fn load_messages_after(
|
||||
id: row.get(0)?,
|
||||
role: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
reasoning_content: row.get(3)?,
|
||||
system_context: row.get(3)?,
|
||||
reasoning_content: row.get(4)?,
|
||||
media_refs,
|
||||
timestamp: row.get(5)?,
|
||||
tool_call_id: row.get(6)?,
|
||||
tool_name: row.get(7)?,
|
||||
timestamp: row.get(6)?,
|
||||
tool_call_id: row.get(7)?,
|
||||
tool_name: row.get(8)?,
|
||||
tool_state: None,
|
||||
tool_calls,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for row in rows {
|
||||
messages.push(row?);
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn load_messages_after(
|
||||
conn: &Connection,
|
||||
session_id: &str,
|
||||
cutoff_seq: i64,
|
||||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||
FROM messages
|
||||
WHERE session_id = ?1 AND seq > ?2
|
||||
ORDER BY seq ASC
|
||||
",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
|
||||
let media_refs_json: String = row.get(5)?;
|
||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
media_refs_json.len(),
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls_json: Option<String> = row.get(9)?;
|
||||
let tool_calls = tool_calls_json
|
||||
.as_deref()
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
9,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(ChatMessage {
|
||||
id: row.get(0)?,
|
||||
role: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
system_context: row.get(3)?,
|
||||
reasoning_content: row.get(4)?,
|
||||
media_refs,
|
||||
timestamp: row.get(6)?,
|
||||
tool_call_id: row.get(7)?,
|
||||
tool_name: row.get(8)?,
|
||||
tool_state: None,
|
||||
tool_calls,
|
||||
})
|
||||
@ -1532,6 +1743,7 @@ fn quote_fts_or_query(queries: &[String]) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
|
||||
use crate::providers::ToolCall;
|
||||
|
||||
#[test]
|
||||
@ -1792,6 +2004,72 @@ mod tests {
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compact_active_history_rebuilds_active_segment_with_delta_messages() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("compact-history")).unwrap();
|
||||
|
||||
let agent_prompt = ChatMessage::system_with_context(
|
||||
"agent",
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
);
|
||||
let seed_messages = vec![
|
||||
agent_prompt.clone(),
|
||||
ChatMessage::user("u1"),
|
||||
ChatMessage::assistant("a1"),
|
||||
ChatMessage::user("u2"),
|
||||
ChatMessage::assistant("a2"),
|
||||
ChatMessage::user("u3"),
|
||||
ChatMessage::assistant("a3"),
|
||||
ChatMessage::user("u4"),
|
||||
ChatMessage::assistant("a4"),
|
||||
];
|
||||
|
||||
for message in &seed_messages {
|
||||
store.append_message(&session.id, message).unwrap();
|
||||
}
|
||||
|
||||
let snapshot_end_seq = store.get_session(&session.id).unwrap().unwrap().message_count;
|
||||
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
|
||||
let preserved_system_messages = vec![agent_prompt];
|
||||
|
||||
store.append_message(&session.id, &ChatMessage::user("u5")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::assistant("a5")).unwrap();
|
||||
|
||||
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
|
||||
let compacted = store
|
||||
.compact_active_history(
|
||||
&session.id,
|
||||
0,
|
||||
snapshot_end_seq,
|
||||
&preserved_system_messages,
|
||||
&summary_message,
|
||||
&preserved_messages,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(compacted);
|
||||
|
||||
let active_messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(active_messages.len(), 10);
|
||||
assert_eq!(active_messages[0].role, "system");
|
||||
assert_eq!(active_messages[0].content, "agent");
|
||||
assert_eq!(active_messages[0].system_context.as_deref(), Some(SYSTEM_CONTEXT_AGENT_PROMPT));
|
||||
assert_eq!(active_messages[1].role, "system");
|
||||
assert_eq!(active_messages[1].content, "[Compressed History]\n\nsummary");
|
||||
assert_eq!(active_messages[2].content, "u2");
|
||||
assert_eq!(active_messages[3].content, "a2");
|
||||
assert_eq!(active_messages[8].content, "u5");
|
||||
assert_eq!(active_messages[9].content, "a5");
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.reset_cutoff_seq, 11);
|
||||
assert_eq!(stored.user_turn_count, 4);
|
||||
|
||||
let all_messages = store.load_all_messages(&session.id).unwrap();
|
||||
assert_eq!(all_messages.len(), 21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mark_agent_prompt_reinjected_increments_counter() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
const MAX_CHARS: usize = 128_000;
|
||||
|
||||
@ -205,37 +205,6 @@ fn strip_all_tags(s: &str) -> String {
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
|
||||
let s_lower = s.to_lowercase();
|
||||
|
||||
let entities = [
|
||||
(" ", ' '),
|
||||
("<", '<'),
|
||||
(">", '>'),
|
||||
("&", '&'),
|
||||
(""", '"'),
|
||||
("'", '\''),
|
||||
("—", '—'),
|
||||
("–", '–'),
|
||||
("©", '©'),
|
||||
("®", '®'),
|
||||
("™", '™'),
|
||||
];
|
||||
|
||||
for (entity, replacement) in entities {
|
||||
if s_lower.starts_with(&entity.to_lowercase()) {
|
||||
return Some((replacement, entity.len()));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle numeric entities
|
||||
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
|
||||
// Skip for now
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_host(url: &str) -> Result<String, String> {
|
||||
let rest = url
|
||||
.strip_prefix("http://")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user