feat: Enhance ChatMessage with system context and background compaction

- Added `system_context` field to `ChatMessage` for better message context handling.
- Introduced constants for system context prompts in `message.rs`.
- Updated `Session` to manage background history compaction, including methods to start and finish compaction.
- Implemented logic to schedule background compaction after message processing in `SessionManager`.
- Enhanced database schema to support new `system_context` field in messages.
- Added functionality to compact active history, preserving system messages and summaries.
- Updated tests to validate new compaction logic and ensure message integrity.
- Removed unused functions and cleaned up code in various modules for better maintainability.

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-26 09:31:13 +08:00
parent 3792472b83
commit 3045a6b596
10 changed files with 690 additions and 293 deletions

View File

@ -1,4 +1,9 @@
use crate::bus::ChatMessage; use crate::bus::{
ChatMessage,
SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_HISTORY_COMPACTION,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, ChatCompletionRequest, Message}; use crate::providers::{create_provider, ChatCompletionRequest, Message};
use crate::text::{char_count, take_prefix_chars}; use crate::text::{char_count, take_prefix_chars};
@ -17,26 +22,32 @@ pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
/// Configuration for context compression. /// Configuration for context compression.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ContextCompressionConfig { pub struct ContextCompressionConfig {
/// Protect first N messages (system prompt, etc.) /// Preserve the latest N complete user turns in full.
pub protect_first_n: usize, pub retain_last_user_turns: usize,
/// Protect last N messages (recent context)
pub protect_last_n: usize,
/// Maximum compression passes
pub max_passes: u32,
/// Maximum characters in summary /// Maximum characters in summary
pub summary_max_chars: usize, pub summary_max_chars: usize,
/// Characters to keep when trimming tool results }
pub tool_result_trim_chars: usize,
#[derive(Debug, Clone, PartialEq, Eq)]
struct UserTurnRange {
start: usize,
end_exclusive: usize,
}
#[derive(Debug, Clone)]
pub struct HistoryCompactionPlan {
pub preserved_system_messages: Vec<ChatMessage>,
pub summary_message: ChatMessage,
pub preserved_messages: Vec<ChatMessage>,
pub compressed_turns: usize,
pub preserved_turns: usize,
} }
impl Default for ContextCompressionConfig { impl Default for ContextCompressionConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
protect_first_n: 1, retain_last_user_turns: 3,
protect_last_n: 4,
max_passes: 3,
summary_max_chars: 20_000, summary_max_chars: 20_000,
tool_result_trim_chars: 2_000,
} }
} }
} }
@ -65,7 +76,6 @@ impl ContextCompressor {
provider_config.token_limit, provider_config.token_limit,
ContextCompressionConfig { ContextCompressionConfig {
summary_max_chars: provider_config.context_summary_max_chars, summary_max_chars: provider_config.context_summary_max_chars,
tool_result_trim_chars: provider_config.context_tool_result_trim_chars,
..ContextCompressionConfig::default() ..ContextCompressionConfig::default()
}, },
) )
@ -85,26 +95,88 @@ impl ContextCompressor {
(self.context_window as f64 * self.threshold_ratio) as usize (self.context_window as f64 * self.threshold_ratio) as usize
} }
/// Fast-path: trim oversized tool results without LLM call. pub fn should_compress(&self, history: &[ChatMessage]) -> bool {
/// Returns the number of messages modified. estimate_tokens(history) > self.threshold()
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize { }
let limit = self.config.tool_result_trim_chars;
let mut modified = 0;
for msg in messages.iter_mut() { fn user_turn_ranges(&self, history: &[ChatMessage]) -> Vec<UserTurnRange> {
let content_chars = char_count(&msg.content); let user_indices: Vec<usize> = history
if msg.role == "tool" && content_chars > limit { .iter()
let removed = content_chars - limit; .enumerate()
msg.content = format!( .filter(|(_, message)| message.role == "user")
"{}...\n\n[Output truncated - {} characters removed]", .map(|(index, _)| index)
take_prefix_chars(&msg.content, limit), .collect();
removed
); user_indices
modified += 1; .iter()
} .enumerate()
.map(|(turn_index, start)| UserTurnRange {
start: *start,
end_exclusive: user_indices
.get(turn_index + 1)
.copied()
.unwrap_or(history.len()),
})
.collect()
}
fn should_preserve_system_message(&self, message: &ChatMessage) -> bool {
message.role == "system"
&& (message.has_system_context(SYSTEM_CONTEXT_AGENT_PROMPT)
|| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
}
fn split_prefix_messages(&self, history: &[ChatMessage]) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
let preserved_system_messages = history
.iter()
.filter(|message| self.should_preserve_system_message(message))
.cloned()
.collect();
let summary_source = history
.iter()
.filter(|message| !self.should_preserve_system_message(message))
.cloned()
.collect();
(preserved_system_messages, summary_source)
}
pub async fn build_compaction_plan(
&self,
history: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<Option<HistoryCompactionPlan>, AgentError> {
if !self.should_compress(history) {
return Ok(None);
} }
modified let turn_ranges = self.user_turn_ranges(history);
if turn_ranges.len() <= self.config.retain_last_user_turns {
return Ok(None);
}
let preserved_turn_start = turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
if preserved_turn_start == 0 {
return Ok(None);
}
let (preserved_system_messages, summary_source) =
self.split_prefix_messages(&history[..preserved_turn_start]);
let summary = self
.summarize_segment(&summary_source, provider_config)
.await?;
Ok(Some(HistoryCompactionPlan {
preserved_system_messages,
summary_message: ChatMessage::system_with_context(format!(
"[Compressed History]\n\n{}",
summary
), Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string())),
preserved_messages: history[preserved_turn_start..].to_vec(),
compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns,
preserved_turns: self.config.retain_last_user_turns,
}))
} }
/// Main entry point - compresses history if over threshold. /// Main entry point - compresses history if over threshold.
@ -113,8 +185,6 @@ impl ContextCompressor {
history: Vec<ChatMessage>, history: Vec<ChatMessage>,
provider_config: &LLMProviderConfig, provider_config: &LLMProviderConfig,
) -> Result<Vec<ChatMessage>, AgentError> { ) -> Result<Vec<ChatMessage>, AgentError> {
let mut history = history;
// Check if compression is needed
let tokens = estimate_tokens(&history); let tokens = estimate_tokens(&history);
if tokens <= self.threshold() { if tokens <= self.threshold() {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -134,50 +204,18 @@ impl ContextCompressor {
"Starting context compression" "Starting context compression"
); );
// Fast trim pass first let current_history = match self.build_compaction_plan(&history, provider_config).await? {
let trimmed = self.fast_trim_tool_results(&mut history); Some(plan) => {
if trimmed > 0 { let mut compressed = Vec::with_capacity(
let tokens_after = estimate_tokens(&history); plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1,
#[cfg(debug_assertions)] );
tracing::debug!( compressed.extend(plan.preserved_system_messages);
trimmed_messages = trimmed, compressed.push(plan.summary_message);
tokens_after = tokens_after, compressed.extend(plan.preserved_messages);
"Fast trim completed" compressed
);
if tokens_after <= self.threshold() {
return Ok(history);
} }
} None => history,
};
// LLM summarization pass
let mut current_history = history;
for pass in 0..self.config.max_passes {
let tokens = estimate_tokens(&current_history);
if tokens <= self.threshold() {
break;
}
#[cfg(debug_assertions)]
tracing::debug!(
pass = pass + 1,
tokens = tokens,
"Compression pass"
);
match self.compress_once(&current_history, provider_config).await {
Ok(Some(compressed)) => {
current_history = compressed;
}
Ok(None) => {
// No more compressible content
break;
}
Err(e) => {
tracing::warn!(error = %e, "Compression pass failed, using current history");
break;
}
}
}
tracing::info!( tracing::info!(
final_tokens = estimate_tokens(&current_history), final_tokens = estimate_tokens(&current_history),
@ -188,74 +226,6 @@ impl ContextCompressor {
Ok(current_history) Ok(current_history)
} }
/// Single compression pass - summarize middle messages between user turns.
/// Returns Some(compressed) if compression happened, None if nothing to compress.
async fn compress_once(
&self,
history: &[ChatMessage],
provider_config: &LLMProviderConfig,
) -> Result<Option<Vec<ChatMessage>>, AgentError> {
if history.len() <= self.config.protect_first_n + self.config.protect_last_n {
return Ok(None);
}
// Find user message indices (excluding protected first messages)
let user_indices: Vec<usize> = history
.iter()
.enumerate()
.skip(self.config.protect_first_n)
.filter(|(_, m)| m.role == "user")
.map(|(i, _)| i)
.collect();
// Need at least one user message and content between users to compress
if user_indices.len() < 2 {
return Ok(None);
}
// Build segments: user -> (assistant turns) -> next user
// We'll summarize the assistant turns between consecutive user messages
let mut new_messages = history[..=user_indices[0]].to_vec();
for i in 0..user_indices.len() - 1 {
let user_idx = user_indices[i];
let next_user_idx = user_indices[i + 1];
new_messages.push(history[user_idx].clone());
// Check if there's assistant content between these two user messages
let between_start = user_idx + 1;
let between_end = next_user_idx;
if between_start < between_end {
let between = &history[between_start..between_end];
let summary = self.summarize_segment(between, provider_config).await?;
// Add summary as a special user message
new_messages.push(ChatMessage::user(format!(
"[Context Summary]\n\n{}",
summary
)));
}
}
// Add last user and everything after (protected)
let last_user_idx = user_indices[user_indices.len() - 1];
if last_user_idx < history.len() - 1 {
// Add everything from last user onwards (protected)
for i in last_user_idx..history.len() {
new_messages.push(history[i].clone());
}
}
// If nothing changed, return None
if new_messages.len() == history.len() {
return Ok(None);
}
Ok(Some(new_messages))
}
/// Summarize a segment of messages using LLM. /// Summarize a segment of messages using LLM.
async fn summarize_segment( async fn summarize_segment(
&self, &self,
@ -299,6 +269,7 @@ impl ContextCompressor {
r#"You are a conversation compaction engine. Summarize the following conversation segment. r#"You are a conversation compaction engine. Summarize the following conversation segment.
PRESERVE: PRESERVE:
- Each user question or request in full or as a near-verbatim restatement
- All identifiers (UUIDs, hashes, file paths, URLs) - All identifiers (UUIDs, hashes, file paths, URLs)
- Actions taken (tool calls, file operations, commands) - Actions taken (tool calls, file operations, commands)
- Key information obtained (results, data, errors) - Key information obtained (results, data, errors)
@ -306,9 +277,11 @@ PRESERVE:
- Current task status - Current task status
OMIT: OMIT:
- Verbose tool output (keep key results only) - Reproducing full tool output verbatim unless it is essential
- Repeated greetings or filler - Repeated greetings or filler
Do not assume tool content was pre-trimmed. You may receive long tool outputs; keep the important results, errors, and artifacts.
Be concise, aim for {} characters or less. Be concise, aim for {} characters or less.
--- ---
@ -362,43 +335,48 @@ mod tests {
} }
#[test] #[test]
fn test_fast_trim() { fn test_should_compress() {
let config = ContextCompressionConfig { let compressor = ContextCompressor::new(20);
tool_result_trim_chars: 50, let messages = vec![ChatMessage::user(&"x".repeat(200))];
..Default::default() assert!(compressor.should_compress(&messages));
}; }
let compressor = ContextCompressor::with_config(100_000, config);
let mut messages = vec![ #[test]
ChatMessage::user("Hello"), fn test_user_turn_ranges_follow_user_boundaries() {
ChatMessage::tool("call1", "bash", &"x".repeat(200)), let compressor = ContextCompressor::new(100_000);
let history = vec![
ChatMessage::system("system"),
ChatMessage::user("u1"),
ChatMessage::assistant("a1"),
ChatMessage::tool("call-1", "bash", "t1"),
ChatMessage::user("u2"),
ChatMessage::assistant("a2"),
ChatMessage::user("u3"),
]; ];
let modified = compressor.fast_trim_tool_results(&mut messages); let turns = compressor.user_turn_ranges(&history);
assert_eq!(modified, 1); assert_eq!(turns, vec![
assert!(messages[1].content.len() < 100); UserTurnRange { start: 1, end_exclusive: 4 },
UserTurnRange { start: 4, end_exclusive: 6 },
UserTurnRange { start: 6, end_exclusive: 7 },
]);
} }
#[test] #[test]
fn test_fast_trim_handles_utf8_char_boundaries() { fn test_split_prefix_messages_preserves_key_system_messages() {
let config = ContextCompressionConfig { let compressor = ContextCompressor::new(50);
tool_result_trim_chars: 5, let prefix = vec![
..Default::default() ChatMessage::system_with_context("agent prompt", Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string())),
}; ChatMessage::user("u1"),
let compressor = ContextCompressor::with_config(100_000, config); ChatMessage::assistant("a1"),
ChatMessage::system_with_context("scheduled prompt", Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string())),
];
let mut messages = vec![ChatMessage::tool("call1", "bash", &"".repeat(20))]; let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix);
assert_eq!(preserved_system_messages.len(), 2);
let modified = compressor.fast_trim_tool_results(&mut messages); assert_eq!(summary_source.len(), 2);
assert_eq!(modified, 1); assert!(preserved_system_messages[0].has_system_context(SYSTEM_CONTEXT_AGENT_PROMPT));
assert!(messages[0].content.contains("Output truncated")); assert!(preserved_system_messages[1].has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT));
assert!(messages[0].content.is_char_boundary(messages[0].content.len()));
}
#[test]
fn test_default_tool_result_trim_chars() {
let config = ContextCompressionConfig::default();
assert_eq!(config.tool_result_trim_chars, 2_000);
} }
#[test] #[test]

View File

@ -3,6 +3,10 @@ use serde::{Deserialize, Serialize};
use crate::providers::ToolCall; use crate::providers::ToolCall;
pub const SYSTEM_CONTEXT_AGENT_PROMPT: &str = "agent_prompt";
pub const SYSTEM_CONTEXT_SCHEDULED_PROMPT: &str = "scheduled_system_prompt";
pub const SYSTEM_CONTEXT_HISTORY_COMPACTION: &str = "history_compaction";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ToolMessageState { pub enum ToolMessageState {
@ -75,6 +79,8 @@ pub struct ChatMessage {
pub media_refs: Vec<String>, // Paths to media files for context pub media_refs: Vec<String>, // Paths to media files for context
pub timestamp: i64, pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub system_context: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
@ -94,6 +100,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
system_context: None,
reasoning_content: None, reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
@ -109,6 +116,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs, media_refs,
timestamp: current_timestamp(), timestamp: current_timestamp(),
system_context: None,
reasoning_content: None, reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
@ -124,6 +132,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
system_context: None,
reasoning_content: None, reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
@ -148,6 +157,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
system_context: None,
reasoning_content: None, reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
@ -167,12 +177,20 @@ impl ChatMessage {
} }
pub fn system(content: impl Into<String>) -> Self { pub fn system(content: impl Into<String>) -> Self {
Self::system_with_context(content, None::<String>)
}
pub fn system_with_context(
content: impl Into<String>,
system_context: impl Into<Option<String>>,
) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "system".to_string(), role: "system".to_string(),
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
system_context: system_context.into(),
reasoning_content: None, reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
@ -197,6 +215,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
system_context: None,
reasoning_content: None, reasoning_content: None,
tool_call_id: Some(tool_call_id.into()), tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()), tool_name: Some(tool_name.into()),
@ -205,6 +224,10 @@ impl ChatMessage {
} }
} }
pub fn has_system_context(&self, expected: &str) -> bool {
self.system_context.as_deref() == Some(expected)
}
pub fn is_assistant_tool_call_message(&self) -> bool { pub fn is_assistant_tool_call_message(&self) -> bool {
self.role == "assistant" self.role == "assistant"
&& self && self

View File

@ -2,7 +2,16 @@ pub mod dispatcher;
pub mod message; pub mod message;
pub use dispatcher::OutboundDispatcher; pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage}; pub use message::{
ChatMessage,
ContentBlock,
InboundMessage,
MediaItem,
OutboundMessage,
SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_HISTORY_COMPACTION,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};

View File

@ -7,7 +7,13 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, mpsc}; use tokio::sync::{Mutex, mpsc};
use uuid::Uuid; use uuid::Uuid;
use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::bus::{
ChatMessage,
MessageBus,
OutboundMessage,
SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler}; use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
use crate::providers::{create_provider, ChatCompletionRequest, Message}; use crate::providers::{create_provider, ChatCompletionRequest, Message};
@ -399,6 +405,7 @@ pub struct Session {
pub channel_name: String, pub channel_name: String,
/// 按 chat_id 路由到不同会话历史,支持多用户多会话 /// 按 chat_id 路由到不同会话历史,支持多用户多会话
chat_histories: HashMap<String, Vec<ChatMessage>>, chat_histories: HashMap<String, Vec<ChatMessage>>,
compression_in_flight: HashSet<String>,
pub user_tx: mpsc::Sender<WsOutbound>, pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
@ -478,6 +485,7 @@ impl Session {
id: Uuid::new_v4(), id: Uuid::new_v4(),
channel_name, channel_name,
chat_histories: HashMap::new(), chat_histories: HashMap::new(),
compression_in_flight: HashSet::new(),
user_tx, user_tx,
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
tools, tools,
@ -532,7 +540,13 @@ impl Session {
> session_record.agent_prompt_reinjection_count > session_record.agent_prompt_reinjection_count
{ {
if let Some(agent_prompt) = load_agent_prompt()? { if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?; self.append_persisted_message(
chat_id,
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
),
)?;
self.store self.store
.mark_agent_prompt_reinjected(&session_id) .mark_agent_prompt_reinjected(&session_id)
.map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?; .map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?;
@ -562,6 +576,7 @@ impl Session {
pub fn remove_history(&mut self, chat_id: &str) { pub fn remove_history(&mut self, chat_id: &str) {
self.chat_histories.remove(chat_id); self.chat_histories.remove(chat_id);
self.compression_in_flight.remove(chat_id);
} }
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
@ -640,6 +655,7 @@ impl Session {
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect(); let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
let total: usize = self.chat_histories.values().map(|h| h.len()).sum(); let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear(); self.chat_histories.clear();
self.compression_in_flight.clear();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(previous_total = total, "All chat histories cleared"); tracing::debug!(previous_total = total, "All chat histories cleared");
@ -666,6 +682,23 @@ impl Session {
&self.compressor &self.compressor
} }
fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
self.compression_in_flight.insert(chat_id.to_string())
}
fn finish_background_compaction(&mut self, chat_id: &str) {
self.compression_in_flight.remove(chat_id);
}
fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history = self
.store
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
Ok(())
}
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> { pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
if self.skills.is_empty() { if self.skills.is_empty() {
return Ok(()); return Ok(());
@ -729,7 +762,13 @@ impl Session {
} }
if let Some(agent_prompt) = load_agent_prompt()? { if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?; self.append_persisted_message(
chat_id,
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
),
)?;
} }
Ok(()) Ok(())
@ -830,6 +869,94 @@ pub(crate) fn handle_in_chat_command(
} }
} }
pub(crate) async fn schedule_background_history_compaction(
session: Arc<Mutex<Session>>,
chat_id: impl Into<String>,
) -> Result<(), AgentError> {
let chat_id = chat_id.into();
let snapshot = {
let mut session_guard = session.lock().await;
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
let compressor = session_guard.compressor().clone();
if !compressor.should_compress(&history) {
return Ok(());
}
if !session_guard.try_start_background_compaction(&chat_id) {
return Ok(());
}
(
session_guard.store.clone(),
session_guard.persistent_session_id(&chat_id),
session_record.reset_cutoff_seq,
session_record.message_count,
history,
compressor,
session_guard.provider_config().clone(),
)
};
let (store, session_id, expected_reset_cutoff_seq, snapshot_end_seq, history, compressor, provider_config) = snapshot;
let session_for_task = session.clone();
let chat_id_for_task = chat_id.clone();
tokio::spawn(async move {
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
let compaction_result = compressor.build_compaction_plan(&history, &provider_config).await;
let mut committed = false;
match compaction_result {
Ok(Some(plan)) => match store.compact_active_history(
&session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
&plan.preserved_system_messages,
&plan.summary_message,
&plan.preserved_messages,
) {
Ok(true) => {
committed = true;
tracing::info!(
chat_id = %chat_id_for_task,
snapshot_end_seq,
compressed_turns = plan.compressed_turns,
preserved_turns = plan.preserved_turns,
"Background history compaction committed"
);
}
Ok(false) => {
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
}
Err(error) => {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
}
},
Ok(None) => {
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
}
Err(error) => {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
}
}
let mut session_guard = session_for_task.lock().await;
if committed {
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
}
}
session_guard.finish_background_compaction(&chat_id_for_task);
});
Ok(())
}
impl SessionManager { impl SessionManager {
pub fn new( pub fn new(
session_ttl_hours: u64, session_ttl_hours: u64,
@ -1224,7 +1351,7 @@ impl SessionManager {
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?; .ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
// 处理消息 // 处理消息
let (history, compressor, provider_config, agent, user_message_id) = { let (history, agent, user_message_id) = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(chat_id)?; session_guard.ensure_persistent_session(chat_id)?;
@ -1254,9 +1381,6 @@ impl SessionManager {
session_guard.append_persisted_message(chat_id, user_message)?; session_guard.append_persisted_message(chat_id, user_message)?;
let history = session_guard.get_or_create_history(chat_id).clone(); let history = session_guard.get_or_create_history(chat_id).clone();
let compressor = session_guard.compressor().clone();
let provider_config = session_guard.provider_config().clone();
session_guard.record_skill_offer(chat_id)?; session_guard.record_skill_offer(chat_id)?;
// 创建 agent 并处理 // 创建 agent 并处理
@ -1265,14 +1389,12 @@ impl SessionManager {
agent = agent.with_emitted_message_handler(handler); agent = agent.with_emitted_message_handler(handler);
} }
(history, compressor, provider_config, agent, user_message_id) (history, agent, user_message_id)
}; };
let history = compressor
.compress_if_needed(history, &provider_config)
.await?;
let result = agent.process(history).await?; let result = agent.process(history).await?;
let mut should_schedule_compaction = false;
let response = { let response = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
@ -1287,6 +1409,7 @@ impl SessionManager {
} else { } else {
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
should_schedule_compaction = true;
result result
.emitted_messages .emitted_messages
@ -1308,6 +1431,12 @@ impl SessionManager {
} }
}; };
if should_schedule_compaction {
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.to_string()).await {
tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction");
}
}
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
channel = %channel_name, channel = %channel_name,
@ -1340,7 +1469,7 @@ impl SessionManager {
.unwrap_or_else(|| "scheduler".to_string()); .unwrap_or_else(|| "scheduler".to_string());
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?; let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
let (history, compressor, agent, user_message_id) = { let (history, agent, user_message_id) = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(chat_id)?; session_guard.ensure_persistent_session(chat_id)?;
@ -1353,7 +1482,13 @@ impl SessionManager {
session_guard.ensure_agent_prompt_before_user_message(chat_id)?; session_guard.ensure_agent_prompt_before_user_message(chat_id)?;
if let Some(system_prompt) = options.system_prompt.as_deref() { if let Some(system_prompt) = options.system_prompt.as_deref() {
session_guard.append_persisted_message(chat_id, ChatMessage::system(system_prompt))?; session_guard.append_persisted_message(
chat_id,
ChatMessage::system_with_context(
system_prompt,
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
),
)?;
} }
let user_message = session_guard.create_user_message(prompt, Vec::new()); let user_message = session_guard.create_user_message(prompt, Vec::new());
@ -1361,7 +1496,6 @@ impl SessionManager {
session_guard.append_persisted_message(chat_id, user_message)?; session_guard.append_persisted_message(chat_id, user_message)?;
let history = session_guard.get_or_create_history(chat_id).clone(); let history = session_guard.get_or_create_history(chat_id).clone();
let compressor = session_guard.compressor().clone();
session_guard.record_skill_offer(chat_id)?; session_guard.record_skill_offer(chat_id)?;
@ -1372,14 +1506,12 @@ impl SessionManager {
provider_config.clone(), provider_config.clone(),
)?; )?;
(history, compressor, agent, user_message_id) (history, agent, user_message_id)
}; };
let history = compressor
.compress_if_needed(history, &provider_config)
.await?;
let result = agent.process(history).await?; let result = agent.process(history).await?;
let mut should_schedule_compaction = false;
let response = { let response = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
@ -1393,6 +1525,7 @@ impl SessionManager {
Vec::new() Vec::new()
} else { } else {
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
should_schedule_compaction = true;
result result
.emitted_messages .emitted_messages
@ -1411,6 +1544,12 @@ impl SessionManager {
} }
}; };
if should_schedule_compaction {
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.to_string()).await {
tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for scheduled task");
}
}
Ok(response) Ok(response)
} }

View File

@ -9,7 +9,7 @@ use crate::agent::EmittedMessageHandler;
use crate::bus::message::{format_tool_call_content, ToolMessageState}; use crate::bus::message::{format_tool_call_content, ToolMessageState};
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command}}; use super::{GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}};
struct WsToolCallEmitter { struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>, sender: mpsc::Sender<WsOutbound>,
@ -246,52 +246,46 @@ async fn handle_inbound(
WsInbound::UserInput { content, chat_id, sender_id, .. } => { WsInbound::UserInput { content, chat_id, sender_id, .. } => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
let mut session_guard = session.lock().await; let (history, agent, user_tx) = {
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(&chat_id)?; session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?; session_guard.ensure_chat_loaded(&chat_id)?;
if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? { if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? {
let _ = session_guard let _ = session_guard
.send(WsOutbound::AssistantResponse { .send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
content: command_response, content: command_response,
role: "assistant".to_string(), role: "assistant".to_string(),
}) })
.await; .await;
return Ok(()); return Ok(());
}
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
let history = match session_guard
.compressor()
.compress_if_needed(raw_history, session_guard.provider_config())
.await
{
Ok(history) => history,
Err(error) => {
tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history");
session_guard.get_or_create_history(&chat_id).clone()
} }
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
session_guard.record_skill_offer(&chat_id)?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
show_tool_results: state.config.gateway.show_tool_results,
});
let agent = session_guard
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
(history, agent, session_guard.user_tx.clone())
}; };
session_guard.record_skill_offer(&chat_id)?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
show_tool_results: state.config.gateway.show_tool_results,
});
let agent = session_guard
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
match agent.process(history).await { match agent.process(history).await {
Ok(result) => { Ok(result) => {
let mut session_guard = session.lock().await;
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?; session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result for outbound in result
.emitted_messages .emitted_messages
@ -304,10 +298,16 @@ async fn handle_inbound(
{ {
let _ = session_guard.send(outbound).await; let _ = session_guard.send(outbound).await;
} }
drop(session_guard);
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()).await {
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
}
} }
Err(error) => { Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error"); tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
let _ = session_guard let _ = user_tx
.send(WsOutbound::Error { .send(WsOutbound::Error {
code: "LLM_ERROR".to_string(), code: "LLM_ERROR".to_string(),
message: error.to_string(), message: error.to_string(),

View File

@ -136,6 +136,7 @@ struct AnthropicResponse {
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent { enum AnthropicContent {
Text { text: String }, Text { text: String },
#[allow(dead_code)]
Thinking { thinking: String }, Thinking { thinking: String },
#[serde(rename = "tool_use")] #[serde(rename = "tool_use")]
ToolUse { ToolUse {

View File

@ -150,6 +150,7 @@ struct OpenAIMessage {
content: Option<String>, content: Option<String>,
#[serde(default)] #[serde(default)]
reasoning_content: Option<String>, reasoning_content: Option<String>,
#[allow(dead_code)]
#[serde(default)] #[serde(default)]
name: Option<String>, name: Option<String>,
#[serde(default)] #[serde(default)]
@ -161,6 +162,7 @@ struct OpenAIToolCall {
id: String, id: String,
#[serde(rename = "function")] #[serde(rename = "function")]
function: OAIFunction, function: OAIFunction,
#[allow(dead_code)]
#[serde(default)] #[serde(default)]
index: Option<u32>, index: Option<u32>,
} }

View File

@ -1,3 +1,4 @@
#[cfg(not(test))]
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -190,11 +191,18 @@ pub struct SchedulerJobUpsert {
} }
impl SessionStore { impl SessionStore {
#[cfg(test)]
pub fn new() -> Result<Self, StorageError> {
Self::from_connection(Connection::open_in_memory()?)
}
#[cfg(not(test))]
pub fn new() -> Result<Self, StorageError> { pub fn new() -> Result<Self, StorageError> {
let db_path = default_session_db_path()?; let db_path = default_session_db_path()?;
Self::open_at_path(&db_path) Self::open_at_path(&db_path)
} }
#[cfg(not(test))]
fn open_at_path(path: &Path) -> Result<Self, StorageError> { fn open_at_path(path: &Path) -> Result<Self, StorageError> {
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?; std::fs::create_dir_all(parent)?;
@ -205,6 +213,7 @@ impl SessionStore {
} }
fn from_connection(conn: Connection) -> Result<Self, StorageError> { fn from_connection(conn: Connection) -> Result<Self, StorageError> {
conn.busy_timeout(std::time::Duration::from_secs(5))?;
conn.execute_batch( conn.execute_batch(
" "
PRAGMA journal_mode = WAL; PRAGMA journal_mode = WAL;
@ -238,6 +247,7 @@ impl SessionStore {
seq INTEGER NOT NULL, seq INTEGER NOT NULL,
role TEXT NOT NULL, role TEXT NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
system_context TEXT,
reasoning_content TEXT, reasoning_content TEXT,
media_refs_json TEXT NOT NULL, media_refs_json TEXT NOT NULL,
tool_call_id TEXT, tool_call_id TEXT,
@ -555,8 +565,8 @@ impl SessionStore {
" "
INSERT INTO messages ( INSERT INTO messages (
id, session_id, seq, role, content, id, session_id, seq, role, content,
reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)
", ",
params![ params![
message.id, message.id,
@ -564,6 +574,7 @@ impl SessionStore {
seq, seq,
message.role, message.role,
message.content, message.content,
message.system_context,
message.reasoning_content, message.reasoning_content,
media_refs_json, media_refs_json,
message.tool_call_id, message.tool_call_id,
@ -592,6 +603,85 @@ impl SessionStore {
Ok(()) Ok(())
} }
pub fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let tx = conn.unchecked_transaction()?;
let current_cutoff = active_reset_cutoff(&tx, session_id)?;
if current_cutoff != expected_reset_cutoff_seq {
return Ok(false);
}
let current_max_seq: i64 = tx.query_row(
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
params![session_id],
|row| row.get(0),
)?;
if snapshot_end_seq <= current_cutoff || snapshot_end_seq > current_max_seq {
return Ok(false);
}
let delta_messages = load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
let mut next_seq = current_max_seq + 1;
let now = current_timestamp();
let mut inserted_count = 0_i64;
let mut active_user_turn_count = 0_i64;
for message in preserved_system_messages {
let copied = clone_message_for_compaction(message, message.timestamp);
insert_message_with_seq(&tx, session_id, next_seq, &copied)?;
next_seq += 1;
inserted_count += 1;
}
let summary_copy = clone_message_for_compaction(summary_message, now);
insert_message_with_seq(&tx, session_id, next_seq, &summary_copy)?;
next_seq += 1;
inserted_count += 1;
for message in preserved_messages.iter().chain(delta_messages.iter()) {
let copied = clone_message_for_compaction(message, message.timestamp);
if copied.role == "user" {
active_user_turn_count += 1;
}
insert_message_with_seq(&tx, session_id, next_seq, &copied)?;
next_seq += 1;
inserted_count += 1;
}
tx.execute(
"
UPDATE sessions
SET reset_cutoff_seq = ?2,
message_count = message_count + ?3,
user_turn_count = ?4,
updated_at = ?5,
last_active_at = ?5,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![
session_id,
current_max_seq,
inserted_count,
active_user_turn_count,
now,
],
)?;
tx.commit()?;
Ok(true)
}
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> { pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp(); let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned"); let conn = self.conn.lock().expect("session db mutex poisoned");
@ -1206,6 +1296,7 @@ pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
} }
} }
#[cfg(not(test))]
fn default_session_db_path() -> Result<PathBuf, std::io::Error> { fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
Ok(home.join(".picobot").join("storage").join("sessions.db")) Ok(home.join(".picobot").join("storage").join("sessions.db"))
@ -1329,23 +1420,23 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<Schedul
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> { fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
if !has_column(conn, "sessions", "reset_cutoff_seq")? { if !has_column(conn, "sessions", "reset_cutoff_seq")? {
conn.execute( add_column_if_missing(
conn,
"ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0", "ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0",
[],
)?; )?;
} }
if !has_column(conn, "sessions", "user_turn_count")? { if !has_column(conn, "sessions", "user_turn_count")? {
conn.execute( add_column_if_missing(
conn,
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0", "ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
[],
)?; )?;
} }
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? { if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
conn.execute( add_column_if_missing(
conn,
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0", "ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
[],
)?; )?;
} }
@ -1353,11 +1444,12 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
} }
fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> { fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
if !has_column(conn, "messages", "system_context")? {
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN system_context TEXT")?;
}
if !has_column(conn, "messages", "reasoning_content")? { if !has_column(conn, "messages", "reasoning_content")? {
conn.execute( add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN reasoning_content TEXT")?;
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
[],
)?;
} }
Ok(()) Ok(())
@ -1438,6 +1530,15 @@ fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<
Ok(false) Ok(false)
} }
fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageError> {
match conn.execute(sql, []) {
Ok(_) => Ok(()),
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
if message.contains("duplicate column name") => Ok(()),
Err(error) => Err(StorageError::Database(error)),
}
}
fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> { fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> {
let cutoff = conn let cutoff = conn
.query_row( .query_row(
@ -1450,22 +1551,72 @@ fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, Stora
Ok(cutoff.unwrap_or(0)) Ok(cutoff.unwrap_or(0))
} }
fn load_messages_after( fn insert_message_with_seq(
conn: &Connection, conn: &rusqlite::Transaction<'_>,
session_id: &str, session_id: &str,
cutoff_seq: i64, seq: i64,
message: &ChatMessage,
) -> Result<(), StorageError> {
let media_refs_json = serde_json::to_string(&message.media_refs)?;
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
conn.execute(
"
INSERT INTO messages (
id, session_id, seq, role, content,
system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)
",
params![
message.id,
session_id,
seq,
message.role,
message.content,
message.system_context,
message.reasoning_content,
media_refs_json,
message.tool_call_id,
message.tool_name,
tool_calls_json,
message.timestamp,
],
)?;
Ok(())
}
fn clone_message_for_compaction(message: &ChatMessage, timestamp: i64) -> ChatMessage {
ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: message.role.clone(),
content: message.content.clone(),
media_refs: message.media_refs.clone(),
timestamp,
system_context: message.system_context.clone(),
reasoning_content: message.reasoning_content.clone(),
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_state: message.tool_state.clone(),
tool_calls: message.tool_calls.clone(),
}
}
fn load_messages_between(
conn: &rusqlite::Transaction<'_>,
session_id: &str,
start_seq_exclusive: i64,
end_seq_inclusive: i64,
) -> Result<Vec<ChatMessage>, StorageError> { ) -> Result<Vec<ChatMessage>, StorageError> {
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
" "
SELECT id, role, content, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
FROM messages FROM messages
WHERE session_id = ?1 AND seq > ?2 WHERE session_id = ?1 AND seq > ?2 AND seq <= ?3
ORDER BY seq ASC ORDER BY seq ASC
", ",
)?; )?;
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| { let rows = stmt.query_map(params![session_id, start_seq_exclusive, end_seq_inclusive], |row| {
let media_refs_json: String = row.get(4)?; let media_refs_json: String = row.get(5)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| { let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure( rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(), media_refs_json.len(),
@ -1474,14 +1625,14 @@ fn load_messages_after(
) )
})?; })?;
let tool_calls_json: Option<String> = row.get(8)?; let tool_calls_json: Option<String> = row.get(9)?;
let tool_calls = tool_calls_json let tool_calls = tool_calls_json
.as_deref() .as_deref()
.map(serde_json::from_str) .map(serde_json::from_str)
.transpose() .transpose()
.map_err(|err| { .map_err(|err| {
rusqlite::Error::FromSqlConversionFailure( rusqlite::Error::FromSqlConversionFailure(
8, 9,
rusqlite::types::Type::Text, rusqlite::types::Type::Text,
Box::new(err), Box::new(err),
) )
@ -1491,11 +1642,71 @@ fn load_messages_after(
id: row.get(0)?, id: row.get(0)?,
role: row.get(1)?, role: row.get(1)?,
content: row.get(2)?, content: row.get(2)?,
reasoning_content: row.get(3)?, system_context: row.get(3)?,
reasoning_content: row.get(4)?,
media_refs, media_refs,
timestamp: row.get(5)?, timestamp: row.get(6)?,
tool_call_id: row.get(6)?, tool_call_id: row.get(7)?,
tool_name: row.get(7)?, tool_name: row.get(8)?,
tool_state: None,
tool_calls,
})
})?;
let mut messages = Vec::new();
for row in rows {
messages.push(row?);
}
Ok(messages)
}
fn load_messages_after(
conn: &Connection,
session_id: &str,
cutoff_seq: i64,
) -> Result<Vec<ChatMessage>, StorageError> {
let mut stmt = conn.prepare(
"
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
FROM messages
WHERE session_id = ?1 AND seq > ?2
ORDER BY seq ASC
",
)?;
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
let media_refs_json: String = row.get(5)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let tool_calls_json: Option<String> = row.get(9)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
9,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
system_context: row.get(3)?,
reasoning_content: row.get(4)?,
media_refs,
timestamp: row.get(6)?,
tool_call_id: row.get(7)?,
tool_name: row.get(8)?,
tool_state: None, tool_state: None,
tool_calls, tool_calls,
}) })
@ -1532,6 +1743,7 @@ fn quote_fts_or_query(queries: &[String]) -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
use crate::providers::ToolCall; use crate::providers::ToolCall;
#[test] #[test]
@ -1792,6 +2004,72 @@ mod tests {
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2); assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
} }
#[test]
fn test_compact_active_history_rebuilds_active_segment_with_delta_messages() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("compact-history")).unwrap();
let agent_prompt = ChatMessage::system_with_context(
"agent",
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
);
let seed_messages = vec![
agent_prompt.clone(),
ChatMessage::user("u1"),
ChatMessage::assistant("a1"),
ChatMessage::user("u2"),
ChatMessage::assistant("a2"),
ChatMessage::user("u3"),
ChatMessage::assistant("a3"),
ChatMessage::user("u4"),
ChatMessage::assistant("a4"),
];
for message in &seed_messages {
store.append_message(&session.id, message).unwrap();
}
let snapshot_end_seq = store.get_session(&session.id).unwrap().unwrap().message_count;
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
let preserved_system_messages = vec![agent_prompt];
store.append_message(&session.id, &ChatMessage::user("u5")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("a5")).unwrap();
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
let compacted = store
.compact_active_history(
&session.id,
0,
snapshot_end_seq,
&preserved_system_messages,
&summary_message,
&preserved_messages,
)
.unwrap();
assert!(compacted);
let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 10);
assert_eq!(active_messages[0].role, "system");
assert_eq!(active_messages[0].content, "agent");
assert_eq!(active_messages[0].system_context.as_deref(), Some(SYSTEM_CONTEXT_AGENT_PROMPT));
assert_eq!(active_messages[1].role, "system");
assert_eq!(active_messages[1].content, "[Compressed History]\n\nsummary");
assert_eq!(active_messages[2].content, "u2");
assert_eq!(active_messages[3].content, "a2");
assert_eq!(active_messages[8].content, "u5");
assert_eq!(active_messages[9].content, "a5");
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.reset_cutoff_seq, 11);
assert_eq!(stored.user_turn_count, 4);
let all_messages = store.load_all_messages(&session.id).unwrap();
assert_eq!(all_messages.len(), 21);
}
#[test] #[test]
fn test_mark_agent_prompt_reinjected_increments_counter() { fn test_mark_agent_prompt_reinjected_increments_counter() {
let store = SessionStore::in_memory().unwrap(); let store = SessionStore::in_memory().unwrap();

View File

@ -1,10 +1,8 @@
use std::io::Read;
use std::path::Path; use std::path::Path;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use crate::bus::message::ContentBlock;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
const MAX_CHARS: usize = 128_000; const MAX_CHARS: usize = 128_000;

View File

@ -205,37 +205,6 @@ fn strip_all_tags(s: &str) -> String {
result result
} }
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
let s_lower = s.to_lowercase();
let entities = [
("&nbsp;", ' '),
("&lt;", '<'),
("&gt;", '>'),
("&amp;", '&'),
("&quot;", '"'),
("&apos;", '\''),
("&mdash;", '—'),
("&ndash;", ''),
("&copy;", '©'),
("&reg;", '®'),
("&trade;", '™'),
];
for (entity, replacement) in entities {
if s_lower.starts_with(&entity.to_lowercase()) {
return Some((replacement, entity.len()));
}
}
// Handle numeric entities
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
// Skip for now
}
None
}
fn extract_host(url: &str) -> Result<String, String> { fn extract_host(url: &str) -> Result<String, String> {
let rest = url let rest = url
.strip_prefix("http://") .strip_prefix("http://")