PicoBot/src/gateway/session.rs

2065 lines
73 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
#[cfg(test)]
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::scheduler::ScheduledAgentTaskOptions;
use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::tools::ToolRegistry;
use crate::tools::task::repository::TaskRepository;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use super::agent_factory::{AgentBuildRequest, AgentFactory};
use super::cli_session::CliSessionService;
#[cfg(test)]
use super::execution::should_display_message_to_user;
#[cfg(test)]
use super::memory_maintenance::{
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
extract_json_object, is_recoverable_maintenance_llm_error,
strip_json_code_fence,
};
use super::memory_maintenance::{MemoryMaintenanceScopeResult, MemoryOrganizationOutput};
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
use super::session_history::SessionHistory;
use super::session_lifecycle::SessionLifecycleService;
use super::session_message_service::SessionMessageService;
/// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理
/// Topic 按 chat_id 隔离,存储在 SessionHistory 中
pub struct Session {
pub id: Uuid,
pub channel_name: String,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
compressor: ContextCompressor,
history: SessionHistory,
store: Arc<SessionStore>,
}
pub struct BusToolCallEmitter {
bus: Arc<MessageBus>,
channel_name: String,
chat_id: String,
metadata: HashMap<String, String>,
}
impl BusToolCallEmitter {
pub fn new(
bus: Arc<MessageBus>,
channel_name: impl Into<String>,
chat_id: impl Into<String>,
metadata: HashMap<String, String>,
) -> Self {
Self {
bus,
channel_name: channel_name.into(),
chat_id: chat_id.into(),
metadata,
}
}
}
#[async_trait]
impl EmittedMessageHandler for BusToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None, // session_id
None,
&self.metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
}
}
}
async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
let mut metadata = self.metadata.clone();
if let Some(ms) = duration_ms {
metadata.insert("tool_duration_ms".to_string(), ms.to_string());
}
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None, // session_id
None,
&metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
}
}
}
}
impl Session {
pub async fn new(
channel_name: String,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> {
let conversations: Arc<dyn ConversationRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
let agent_factory = AgentFactory::new(
tools,
skills.clone(),
agent_prompt_reinject_every as usize,
prompt_repository.clone(),
);
Self::with_factories(
channel_name,
provider_config,
user_tx,
skills,
agent_factory,
conversations,
skill_events,
store,
)
.await
}
pub(crate) async fn with_factories(
channel_name: String,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
store: Arc<SessionStore>,
) -> Result<Self, AgentError> {
Ok(Self {
id: Uuid::new_v4(),
channel_name: channel_name.clone(),
user_tx,
provider_config: provider_config.clone(),
skills,
agent_factory,
compressor: ContextCompressor::from_provider_config(&provider_config),
history: SessionHistory::new(
channel_name,
conversations,
skill_events,
),
store,
})
}
pub fn persistent_session_id(&self, chat_id: &str) -> String {
self.history.persistent_session_id(chat_id)
}
/// 设置当前话题 ID指定 chat
pub fn set_current_topic(&mut self, chat_id: &str, topic_id: Option<String>) {
if let Some(topic_id) = topic_id {
self.history.set_chat_topic(chat_id, topic_id);
} else {
self.history.clear_chat_topic(chat_id);
}
}
/// 获取当前话题 ID指定 chat
pub fn current_topic(&self, chat_id: &str) -> Option<&str> {
self.history.chat_topic(chat_id)
}
/// 获取历史所对应的话题 ID指定 chat
pub fn history_topic(&self, chat_id: &str) -> Option<&str> {
self.history.history_topic(chat_id)
}
/// 切换话题 - 清除当前历史并加载新话题的历史
pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> {
// 清除当前历史
self.history.remove_history(chat_id);
// 先设置当前话题set_history 需要这个)
self.history.set_chat_topic(chat_id, topic_id.to_string());
// 加载新话题的历史
let messages = self
.store
.load_messages_for_topic(topic_id)
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
self.history.set_history(chat_id, messages);
tracing::info!(
topic_id = %topic_id,
chat_id = %chat_id,
"Switched to topic"
);
Ok(())
}
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
self.history.ensure_persistent_session(chat_id)
}
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
// 检查历史是否存在且对应正确的话题
// 先获取 topic 信息并转换为 owned String避免借用冲突
let current_topic: Option<String> = self.history.chat_topic(chat_id).map(|s| s.to_string());
let stored_topic = self.history.history_topic(chat_id);
if self.chat_history_exists(chat_id) {
// 如果历史已存在,但话题不匹配,需要重新加载
if current_topic.as_deref() != stored_topic {
tracing::info!(
chat_id = %chat_id,
current_topic = ?current_topic,
stored_topic = ?stored_topic,
"Topic changed, reloading history"
);
self.reload_chat_history(chat_id)?;
}
return Ok(());
}
// 历史不存在,按 topic 加载(如果设置了 topic
self.history.ensure_chat_loaded(chat_id, current_topic.as_deref())
}
fn chat_history_exists(&self, chat_id: &str) -> bool {
self.history.get_history(chat_id).is_some()
}
pub fn ensure_agent_prompt_before_user_message(
&mut self,
chat_id: &str,
) -> Result<(), AgentError> {
self.history
.ensure_agent_prompt_before_user_message(chat_id)
}
/// 获取或创建指定 chat_id 的会话历史
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
self.history.get_or_create_history(chat_id)
}
/// 获取指定 chat_id 的会话历史(不创建)
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
self.history.get_history(chat_id)
}
/// 使用完整消息追加到历史
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
self.history.add_message(chat_id, message);
}
pub fn remove_history(&mut self, chat_id: &str) {
self.history.remove_history(chat_id);
}
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
self.history.clear_chat_history(chat_id)
}
/// 将消息写入内存与持久化层(使用当前 topic
pub fn append_persisted_message(
&mut self,
chat_id: &str,
message: ChatMessage,
) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id);
let topic_id = self.history.chat_topic(chat_id).map(|s| s.to_string());
self.store
.append_message_with_topic(&session_id, topic_id.as_deref(), &message)
.map_err(|err| {
AgentError::Other(format!("append message persistence error: {}", err))
})?;
self.add_message(chat_id, message);
// 更新 topic 的最后活跃时间
if let Some(ref topic_id) = topic_id {
if let Err(e) = self.store.touch_topic(topic_id) {
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic");
}
}
Ok(())
}
pub fn append_persisted_messages<I>(
&mut self,
chat_id: &str,
messages: I,
) -> Result<(), AgentError>
where
I: IntoIterator<Item = ChatMessage>,
{
self.history.append_persisted_messages(chat_id, messages)
}
/// 将消息保存到指定话题(直接写入数据库,不更新内存历史)
pub fn append_messages_to_topic(
&self,
chat_id: &str,
topic_id: &str,
messages: &[ChatMessage],
) -> Result<(), AgentError> {
self.history.append_to_topic(chat_id, topic_id, messages)
}
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
if media_refs.is_empty() {
ChatMessage::user(content)
} else {
ChatMessage::user_with_media(content, media_refs)
}
}
#[cfg(test)]
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
self.latest_user_message(chat_id)
.map(|message| message.id.as_str())
}
#[cfg(test)]
fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
self.history.latest_user_message(chat_id)
}
#[cfg(test)]
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
self.latest_user_message_id(chat_id)
.map(|current_id| current_id == message_id)
.unwrap_or(false)
}
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
self.history.matches_current_user_turn(chat_id, message)
}
pub(crate) fn stale_result_diagnostics(
&self,
chat_id: &str,
) -> (Option<&str>, Option<String>, bool, usize) {
self.history.stale_result_diagnostics(chat_id)
}
/// 清除所有历史
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
self.history.clear_all_history()
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
/// 获取 provider_config 引用
pub fn provider_config(&self) -> &LLMProviderConfig {
&self.provider_config
}
/// 获取 compressor 引用
pub fn compressor(&self) -> &ContextCompressor {
&self.compressor
}
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
self.history.try_start_background_compaction(chat_id)
}
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
self.history.finish_background_compaction(chat_id);
}
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
// 如果当前有 topic加载该 topic 的消息
if let Some(topic_id) = self.history.chat_topic(chat_id) {
let messages = self
.store
.load_messages_for_topic(topic_id)
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
self.history.set_history(chat_id, messages);
} else {
// 否则加载 session 的所有消息
self.history.reload_chat_history(chat_id)?;
}
Ok(())
}
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
self.history.conversations()
}
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
if self.skills.is_empty() {
return Ok(());
}
self.history.append_skill_event(
chat_id,
"offered",
None,
&self.skills.offered_event_payload(),
)
}
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(
&self,
chat_id: &str,
sender_id: Option<&str>,
message_id: Option<&str>,
) -> Result<AgentLoop, AgentError> {
self.create_agent_with_provider_config(
chat_id,
None, // notification_chat_id = None使用 session_chat_id
sender_id,
message_id,
self.provider_config.clone(),
)
}
pub fn create_agent_with_provider_config(
&self,
session_chat_id: &str,
notification_chat_id: Option<&str>,
sender_id: Option<&str>,
message_id: Option<&str>,
provider_config: LLMProviderConfig,
) -> Result<AgentLoop, AgentError> {
self.agent_factory.create(AgentBuildRequest {
channel_name: &self.channel_name,
session_chat_id,
notification_chat_id,
sender_id,
message_id,
provider_config,
})
}
}
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
show_tool_results: bool,
lifecycle: SessionLifecycleService,
cli_sessions: CliSessionService,
messages: SessionMessageService,
scheduled_tasks: ScheduledAgentTaskService,
memory_maintenance: MemoryMaintenanceCoordinator,
task_repository: Arc<dyn TaskRepository>,
}
pub(crate) struct SessionManagerServices {
pub(crate) tools: Arc<ToolRegistry>,
pub(crate) skills: Arc<SkillRuntime>,
pub(crate) store: Arc<SessionStore>,
pub(crate) show_tool_results: bool,
pub(crate) lifecycle: SessionLifecycleService,
pub(crate) cli_sessions: CliSessionService,
pub(crate) messages: SessionMessageService,
pub(crate) scheduled_tasks: ScheduledAgentTaskService,
pub(crate) memory_maintenance: MemoryMaintenanceCoordinator,
pub(crate) task_repository: Arc<dyn TaskRepository>,
}
impl SessionManager {
pub(crate) fn from_services(services: SessionManagerServices) -> Self {
Self {
tools: services.tools,
skills: services.skills,
store: services.store,
show_tool_results: services.show_tool_results,
lifecycle: services.lifecycle,
cli_sessions: services.cli_sessions,
messages: services.messages,
scheduled_tasks: services.scheduled_tasks,
memory_maintenance: services.memory_maintenance,
task_repository: services.task_repository,
}
}
pub fn new(
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
disabled_tools: std::collections::HashSet<String>,
task_config: crate::config::TaskConfig,
subagents_config: crate::config::SubagentsConfig,
maintenance_config: crate::config::MemoryMaintenanceConfig,
session_ttl_hours: Option<u64>,
mcp_config: crate::mcp::McpConfig,
) -> Result<Self, AgentError> {
super::runtime::build_session_manager(
agent_prompt_reinject_every,
show_tool_results,
default_timezone,
provider_config,
provider_configs,
skills,
disabled_tools,
task_config,
subagents_config,
maintenance_config,
session_ttl_hours,
mcp_config,
None,
)
.map(|(session_manager, _)| session_manager)
}
pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone()
}
pub fn store(&self) -> Arc<SessionStore> {
self.store.clone()
}
pub fn show_tool_results(&self) -> bool {
self.show_tool_results
}
pub fn task_repository(&self) -> Arc<dyn TaskRepository> {
self.task_repository.clone()
}
pub fn skills(&self) -> Arc<SkillRuntime> {
self.skills.clone()
}
pub(crate) fn cli_sessions(&self) -> CliSessionService {
self.cli_sessions.clone()
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn organize_memory_maintenance_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
self.memory_maintenance.organize_for_scope(scope_key).await
}
pub(crate) async fn run_memory_maintenance_for_all_scopes(
&self,
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
self.memory_maintenance.run_for_all_scopes().await
}
/// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
self.lifecycle.ensure_session(channel_name).await
}
/// 获取 session不检查超时
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.lifecycle.get(channel_name).await
}
/// 获取指定 chat 的当前话题(确保 session 存在,自动从数据库恢复)
pub async fn get_current_topic(&self, channel_name: &str, chat_id: &str) -> Result<Option<String>, AgentError> {
self.ensure_session(channel_name).await?;
if let Some(session) = self.get(channel_name).await {
let mut guard = session.lock().await;
// 如果内存中没有当前话题,从数据库恢复最近活跃的话题
if guard.current_topic(chat_id).is_none() {
let session_id = guard.persistent_session_id(chat_id);
let topics = self.store.list_topics(&session_id)
.map_err(|e| AgentError::Other(format!("Failed to list topics: {}", e)))?;
if let Some(latest_topic) = topics.first() {
// 设置最近活跃的话题为当前话题
guard.set_current_topic(chat_id, Some(latest_topic.id.clone()));
tracing::info!(
chat_id = %chat_id,
topic_id = %latest_topic.id,
topic_title = %latest_topic.title,
"Restored current topic from database"
);
}
}
Ok(guard.current_topic(chat_id).map(|s| s.to_string()))
} else {
Ok(None)
}
}
/// 更新最后活跃时间
pub async fn touch(&self, channel_name: &str) {
self.lifecycle.touch(channel_name).await;
}
pub async fn cleanup_expired_sessions(&self) -> usize {
self.lifecycle.cleanup_expired_sessions().await
}
/// 处理消息:路由到对应 session 的 agent
pub async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
self.messages
.handle_message(
channel_name,
sender_id,
chat_id,
content,
media,
live_emitter,
)
.await
}
pub async fn run_scheduled_agent_task(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
self.scheduled_tasks
.run(channel_name, chat_id, None, prompt, options)
.await
}
/// 执行 SilentAgentTask支持 notification_chat_id 分离
pub async fn run_silent_agent_task(
&self,
channel_name: &str,
session_chat_id: &str,
notification_chat_id: Option<&str>,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
self.scheduled_tasks
.run(channel_name, session_chat_id, notification_chat_id, prompt, options)
.await
}
/// 清除指定 session 的所有历史
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
if let Some(session) = self.get(channel_name).await {
let mut session_guard = session.lock().await;
session_guard.clear_all_history()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::MessageBus;
use crate::bus::message::OutboundEventKind;
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::storage::MemoryRecord;
use crate::tools::NoopSessionMessageSender;
use axum::http::StatusCode;
use axum::{Json, Router, routing::post};
use serde_json::{Value, json};
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc as StdArc,
atomic::{AtomicUsize, Ordering},
};
use tokio::net::TcpListener;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
#[tokio::test]
async fn test_latest_user_message_guard_tracks_current_turn() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender),
HashSet::new(),
"Asia/Shanghai".to_string(),
HashSet::new(),
Default::default(),
)
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store,
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let first = session.create_user_message("first", Vec::new());
let first_id = first.id.clone();
session.append_persisted_message("chat-1", first).unwrap();
assert!(session.is_latest_user_message("chat-1", &first_id));
let second = session.create_user_message("second", Vec::new());
let second_id = second.id.clone();
session.append_persisted_message("chat-1", second).unwrap();
assert!(!session.is_latest_user_message("chat-1", &first_id));
assert!(session.is_latest_user_message("chat-1", &second_id));
}
#[tokio::test]
async fn test_current_user_turn_match_survives_history_compaction_reload() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender),
HashSet::new(),
"Asia/Shanghai".to_string(),
HashSet::new(),
Default::default(),
)
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let first = session.create_user_message("first", Vec::new());
let first_id = first.id.clone();
session.append_persisted_message("chat-1", first).unwrap();
session
.append_persisted_message("chat-1", ChatMessage::assistant("answer-1"))
.unwrap();
let second = session.create_user_message("second", Vec::new());
session
.append_persisted_message("chat-1", second.clone())
.unwrap();
session
.append_persisted_message("chat-1", ChatMessage::assistant("answer-2"))
.unwrap();
let session_id = session.persistent_session_id("chat-1");
let snapshot_end_seq = store
.get_session(&session_id)
.unwrap()
.unwrap()
.message_count;
let preserved_messages = session.get_history("chat-1").unwrap().clone();
store
.compact_active_history(
&session_id,
snapshot_end_seq,
&[],
&ChatMessage::system("[Compressed History]\n\nsummary"),
&preserved_messages,
)
.unwrap();
session.reload_chat_history("chat-1").unwrap();
assert!(!session.is_latest_user_message("chat-1", &first_id));
assert!(!session.is_latest_user_message("chat-1", &second.id));
assert!(session.matches_current_user_turn("chat-1", &second));
}
async fn start_mock_openai_server() -> String {
start_mock_openai_server_with_content(None).await
}
async fn start_mock_openai_server_with_content(
mock_response_content: Option<String>,
) -> String {
async fn handle(
axum::extract::State(mock_response_content): axum::extract::State<Option<String>>,
Json(body): Json<Value>,
) -> Json<Value> {
let model = body
.get("model")
.and_then(|value| value.as_str())
.unwrap_or("unknown-model");
let content = mock_response_content.unwrap_or_else(|| format!("reply from {}", model));
Json(json!({
"id": "mock-response",
"model": model,
"choices": [
{
"message": {
"content": content,
"tool_calls": []
}
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2
}
}))
}
let app = Router::new()
.route("/chat/completions", post(handle))
.with_state(mock_response_content);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
format!("http://{}", address)
}
async fn start_mock_openai_504_server() -> String {
async fn handle() -> (StatusCode, &'static str) {
(StatusCode::GATEWAY_TIMEOUT, "stream timeout")
}
let app = Router::new().route("/chat/completions", post(handle));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
format!("http://{}", address)
}
async fn start_mock_openai_flaky_server(mock_response_content: String) -> String {
let attempts = StdArc::new(AtomicUsize::new(0));
let state = (attempts, mock_response_content);
async fn handle(
axum::extract::State((attempts, mock_response_content)): axum::extract::State<(
StdArc<AtomicUsize>,
String,
)>,
Json(body): Json<Value>,
) -> (StatusCode, Json<Value>) {
let attempt = attempts.fetch_add(1, Ordering::SeqCst);
if attempt == 0 {
return (
StatusCode::GATEWAY_TIMEOUT,
Json(json!({"error": "stream timeout"})),
);
}
let model = body
.get("model")
.and_then(|value| value.as_str())
.unwrap_or("unknown-model");
(
StatusCode::OK,
Json(json!({
"id": "mock-response",
"model": model,
"choices": [
{
"message": {
"content": mock_response_content,
"tool_calls": []
}
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2
}
})),
)
}
let app = Router::new()
.route("/chat/completions", post(handle))
.with_state(state);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
format!("http://{}", address)
}
#[tokio::test]
async fn test_handle_message_returns_recoverable_reply_on_llm_504() {
let base_url = start_mock_openai_504_server().await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "timeout-provider".to_string(),
base_url: base_url.clone(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "timeout-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
let outbound = session_manager
.handle_message("test-channel", "user-1", "chat-1", "hello", Vec::new(), None)
.await
.unwrap();
assert_eq!(outbound.len(), 1);
assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时"));
}
#[tokio::test]
async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() {
let base_url = start_mock_openai_server().await;
let default_provider = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "default-provider".to_string(),
base_url: base_url.clone(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "default-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let planner_provider = LLMProviderConfig {
model_id: "planner-model".to_string(),
name: "planner-provider".to_string(),
..default_provider.clone()
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
default_provider.clone(),
HashMap::from([
("default".to_string(), default_provider),
("planner".to_string(), planner_provider),
]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
let planner_outbound = session_manager
.run_scheduled_agent_task(
"test-channel",
"chat-planner",
"请规划今天工作",
ScheduledAgentTaskOptions {
agent: Some("planner".to_string()),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(planner_outbound.len(), 1);
assert!(planner_outbound[0].content.contains("planner-model"));
let default_outbound = session_manager
.run_scheduled_agent_task(
"test-channel",
"chat-default",
"请规划今天工作",
ScheduledAgentTaskOptions {
agent: Some("default".to_string()),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(default_outbound.len(), 1);
assert!(default_outbound[0].content.contains("default-model"));
}
#[tokio::test]
async fn test_run_scheduled_agent_task_persists_execution_guard_prompt() {
let base_url = start_mock_openai_server().await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "default-provider".to_string(),
base_url,
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "default-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
session_manager
.run_scheduled_agent_task(
"test-channel",
"chat-guard",
"每小时执行以下流程:检查邮箱并同步待办",
ScheduledAgentTaskOptions {
system_prompt: Some("你是邮箱待办同步助手。".to_string()),
..Default::default()
},
)
.await
.unwrap();
let session = session_manager.get("test-channel").await.unwrap();
let session_guard = session.lock().await;
let persisted_messages = session_guard
.store()
.load_messages(&session_guard.persistent_session_id("chat-guard"))
.unwrap();
let scheduled_prompt = persisted_messages
.iter()
.find(|message| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
.expect("missing scheduled system prompt");
assert!(scheduled_prompt.content.contains("已经触发的定时任务执行"));
assert!(
scheduled_prompt
.content
.contains("不要调用任何定时任务管理工具")
);
assert!(scheduled_prompt.content.contains("你是邮箱待办同步助手。"));
}
#[tokio::test]
async fn test_summarize_memory_maintenance_for_scope_uses_model_output() {
let mock_response_content = serde_json::to_string(&json!({
"user_facts": ["用户在做AI产品"],
"preferences": ["偏好简洁表达"],
"behavior_patterns": ["习惯先问方案再要代码"],
"merges": [],
"conflicts": [],
"low_value_ids": [],
"managed_markdown": "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达\n\n### 行为模式\n- 习惯先问方案再要代码"
}))
.unwrap();
let base_url =
start_mock_openai_server_with_content(Some(mock_response_content.clone())).await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "maintenance-provider".to_string(),
base_url,
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "maintenance-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([(
"mock_response_content".to_string(),
json!(mock_response_content),
)]),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
session_manager
.store()
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let output = session_manager
.organize_memory_maintenance_for_scope("feishu:user-1")
.await
.unwrap()
.unwrap();
assert!(output.merges.is_empty());
assert!(output.conflicts.is_empty());
assert!(output.low_value_ids.is_empty());
}
#[test]
fn test_is_recoverable_maintenance_llm_error_detects_transport_failures() {
assert!(is_recoverable_maintenance_llm_error(
"error sending request for url (https://example.invalid/v1/chat/completions)"
));
assert!(is_recoverable_maintenance_llm_error(
"API error 504 Gateway Timeout: stream timeout"
));
assert!(!is_recoverable_maintenance_llm_error(
"API error 401 Unauthorized"
));
}
#[test]
fn test_extract_json_object_skips_wrapping_text() {
let wrapped = "下面是结果:\n```json\n{\n \"user_facts\": [],\n \"preferences\": []\n}\n```\n请查收";
let stripped = strip_json_code_fence(wrapped);
let extracted = extract_json_object(stripped).unwrap();
assert_eq!(
extracted,
"{\n \"user_facts\": [],\n \"preferences\": []\n}"
);
}
#[tokio::test]
async fn test_summarize_memory_maintenance_transport_error_includes_provider_context() {
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "maintenance-provider".to_string(),
base_url: "https://example.invalid/v1".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "maintenance-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
llm_timeout_secs: 1,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
session_manager
.store()
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let error = session_manager
.organize_memory_maintenance_for_scope("feishu:user-1")
.await
.unwrap_err()
.to_string();
assert!(
error.contains("memory organization model error: transport error:")
|| error.contains("memory summary model error: transport error:"),
"Error did not contain expected message: {}",
error
);
assert!(error.contains("provider=maintenance-provider"));
assert!(error.contains("model=maintenance-model"));
assert!(error.contains("url=https://example.invalid/v1/chat/completions"));
assert!(error.contains("timeout_secs=600"));
assert!(error.contains("error sending request for url"));
}
#[tokio::test]
async fn test_summarize_memory_maintenance_retries_recoverable_provider_errors() {
let mock_response_content = serde_json::to_string(&json!({
"user_facts": ["用户在做AI产品"],
"preferences": [],
"behavior_patterns": [],
"merges": [],
"conflicts": [],
"low_value_ids": [],
"managed_markdown": "### 用户事实\n- 用户在做AI产品"
}))
.unwrap();
let base_url = start_mock_openai_flaky_server(mock_response_content.clone()).await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "maintenance-provider".to_string(),
base_url,
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "maintenance-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([(
"mock_response_content".to_string(),
json!(mock_response_content),
)]),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
session_manager
.store()
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let output = session_manager
.organize_memory_maintenance_for_scope("feishu:user-1")
.await
.unwrap()
.unwrap();
assert!(output.merges.is_empty());
}
#[tokio::test]
async fn test_summarize_memory_maintenance_for_scope_extracts_wrapped_json_object() {
let mock_response_content = "结果如下:\n```json\n{\n \"user_facts\": [\"用户在做AI产品\"],\n \"preferences\": [],\n \"behavior_patterns\": [],\n \"merges\": [],\n \"conflicts\": [],\n \"low_value_ids\": [],\n \"managed_markdown\": \"### 用户事实\\n- 用户在做AI产品\"\n}\n```\n";
let base_url =
start_mock_openai_server_with_content(Some(mock_response_content.to_string())).await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "maintenance-provider".to_string(),
base_url,
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "maintenance-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([(
"mock_response_content".to_string(),
json!(mock_response_content),
)]),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
session_manager
.store()
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let output = session_manager
.organize_memory_maintenance_for_scope("feishu:user-1")
.await
.unwrap()
.unwrap();
assert!(output.merges.is_empty());
}
#[tokio::test]
async fn test_run_memory_maintenance_for_all_scopes_scans_all_scopes_even_without_recent_updates() {
let mock_response_content = serde_json::to_string(&json!({
"user_facts": ["用户在做AI产品"],
"preferences": [],
"behavior_patterns": [],
"merges": [],
"conflicts": [],
"low_value_ids": [],
"managed_markdown": "### 用户事实\n- 用户在做AI产品"
}))
.unwrap();
let base_url =
start_mock_openai_server_with_content(Some(mock_response_content.clone())).await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "maintenance-provider".to_string(),
base_url,
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "maintenance-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::from([(
"mock_response_content".to_string(),
json!(mock_response_content),
)]),
max_tool_iterations: 1,
llm_timeout_secs: 30,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
session_manager
.store()
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let result = session_manager
.run_memory_maintenance_for_all_scopes()
.await
.unwrap();
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.scope_key, "all");
// 由于步骤2需要新的提示词和输入格式这里只验证基本功能
assert!(!result.managed_markdown.is_empty());
}
#[tokio::test]
async fn test_run_memory_maintenance_for_all_scopes_skips_recoverable_transport_failures() {
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "maintenance-provider".to_string(),
base_url: "https://example.invalid/v1".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "maintenance-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(256),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
llm_timeout_secs: 1,
memory_maintenance_timeout_secs: 600,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
};
let session_manager = SessionManager::new(
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::SubagentsConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(24),
crate::mcp::McpConfig::default(),
)
.unwrap();
for scope_key in ["feishu:user-1", "feishu:user-2"] {
session_manager
.store()
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: format!("{} 在做AI产品", scope_key),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
}
let result = session_manager
.run_memory_maintenance_for_all_scopes()
.await
.unwrap();
// 当遇到可恢复错误时,没有整理任何记忆,返回 None
assert!(result.is_none());
}
#[test]
fn test_apply_memory_maintenance_output_merges_and_deletes_low_value_records() {
let store = SessionStore::in_memory().unwrap();
let scope_key = "feishu:user-1";
// 创建足够的记忆7条让合并操作满足保护限制
// 合并后需要保留至少 5 条min_memories_to_keep
let work = store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "user".to_string(),
memory_key: "work_short".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let role = store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "user".to_string(),
memory_key: "work_detail".to_string(),
content: "用户主要在做AI产品设计和实现".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
let noise = store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "other".to_string(),
memory_key: "temporary".to_string(),
content: "今天临时提到过一个无后续的小细节".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
for i in 0..4 {
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "user".to_string(),
memory_key: format!("extra_{}", i),
content: format!("额外记忆 {}", i),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
}
let plan = build_memory_maintenance_plan(
&store.list_memories_for_scope("user", scope_key).unwrap(),
);
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
let output = MemoryOrganizationOutput {
merges: vec![MemoryMaintenanceMerge {
source_ids: vec![work.id.clone(), role.id.clone()],
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户主要在做AI产品设计与实现".to_string(),
}],
conflicts: Vec::new(),
low_value_ids: vec![noise.id.clone()],
};
// 使用默认配置进行验证
apply_memory_maintenance_output(
&store,
scope_key,
&plan,
&output,
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
)
.unwrap();
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
// 过滤掉 _meta 记录
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
// 合并 2 条为 1 条,删除 1 条7 - 2 + 1 = 5 条
assert_eq!(user_memories.len(), 5);
// 验证合并后的记忆存在
assert!(user_memories.iter().any(|m| m.namespace == "user" && m.memory_key == "work"));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
let completed = ChatMessage::tool("call-1", "calculator", "2");
let pending = ChatMessage::tool_with_state(
"call-2",
"bash",
"waiting",
crate::bus::message::ToolMessageState::PendingUserAction,
);
assert!(!should_display_message_to_user(false, &completed));
assert!(should_display_message_to_user(false, &pending));
assert!(should_display_message_to_user(true, &completed));
}
#[tokio::test]
async fn test_bus_tool_call_emitter_emits_completed_tool_results() {
let bus = MessageBus::new(4);
let emitter =
BusToolCallEmitter::new(
bus.clone(),
"test-channel",
"chat-1",
HashMap::new(),
);
emitter
.handle(ChatMessage::tool("call-1", "calculator", "2"))
.await;
let msg = tokio::time::timeout(std::time::Duration::from_millis(500), bus.consume_outbound())
.await
.expect("timeout waiting for outbound message")
.expect("bus outbound closed");
assert_eq!(msg.event_kind, OutboundEventKind::ToolResult);
}
#[tokio::test]
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender),
HashSet::new(),
"Asia/Shanghai".to_string(),
HashSet::new(),
Default::default(),
)
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
// 新设计:系统提示词不再持久化到历史记录,而是每次请求时动态注入
assert_eq!(history.len(), 0);
}
#[tokio::test]
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender),
HashSet::new(),
"Asia/Shanghai".to_string(),
HashSet::new(),
Default::default(),
)
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
// 新设计:系统提示词不再持久化到历史记录
let history = session.get_history("chat-1").unwrap();
let user_messages = history
.iter()
.filter(|message| message.role == "user")
.count();
assert_eq!(user_messages, 100);
// 注入计数在实际处理请求时由 AgentPromptProvider 更新
// 此处仅为模拟调用,不会触发实际注入
let stored = store
.get_session(&session.persistent_session_id("chat-1"))
.unwrap()
.unwrap();
// 初始值为 0只有在实际 process 调用时才会更新
assert_eq!(stored.agent_prompt_reinjection_count, 0);
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
let history = session.get_history("chat-1").unwrap();
let user_messages = history
.iter()
.filter(|message| message.role == "user")
.count();
assert_eq!(user_messages, 100);
}
#[tokio::test]
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
store.clone(),
store.clone(),
store.clone(),
Arc::new(NoopSessionMessageSender),
HashSet::new(),
"Asia/Shanghai".to_string(),
HashSet::new(),
Default::default(),
)
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
0,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
// 新设计:系统提示词不再持久化到历史记录
let history = session.get_history("chat-1").unwrap();
let user_messages = history
.iter()
.filter(|message| message.role == "user")
.count();
assert_eq!(user_messages, 100);
}
#[test]
fn test_default_tools_registers_get_time() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
let tools = ToolRegistryFactory::new(
skills,
store.clone(),
store.clone(),
store,
Arc::new(NoopSessionMessageSender),
HashSet::new(),
"Asia/Shanghai".to_string(),
HashSet::new(),
Default::default(),
)
.build();
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
}
#[test]
fn test_build_memory_maintenance_plan_deduplicates_and_categorizes() {
let memories = vec![
MemoryRecord {
id: "1".to_string(),
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
created_at: 1,
updated_at: 1,
},
MemoryRecord {
id: "2".to_string(),
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
created_at: 2,
updated_at: 2,
},
MemoryRecord {
id: "3".to_string(),
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "user".to_string(),
memory_key: "style".to_string(),
content: "偏好简洁表达".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
created_at: 3,
updated_at: 3,
},
MemoryRecord {
id: "4".to_string(),
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "patterns".to_string(),
memory_key: "workflow".to_string(),
content: "习惯先问方案再要代码".to_string(),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
created_at: 4,
updated_at: 4,
},
];
let plan = build_memory_maintenance_plan(&memories);
// 去重后应该有3条第1、2条重复
assert_eq!(plan.candidates.len(), 3);
// 验证内容包含所有唯一的记忆
let contents: Vec<String> = plan.candidates.iter().map(|c| c.content.clone()).collect();
assert!(contents.contains(&"用户在做AI产品".to_string()));
assert!(contents.contains(&"偏好简洁表达".to_string()));
assert!(contents.contains(&"习惯先问方案再要代码".to_string()));
}
}
#[async_trait]
impl crate::scheduler::MaintenanceExecutor for SessionManager {
async fn cleanup_expired_sessions(&self) -> usize {
self.cleanup_expired_sessions().await
}
async fn run_memory_maintenance_for_all_scopes(
&self,
) -> anyhow::Result<Vec<crate::scheduler::MaintenanceRunSummary>> {
match self.run_memory_maintenance_for_all_scopes().await {
Ok(Some(result)) => Ok(vec![crate::scheduler::MaintenanceRunSummary {
scope_key: result.scope_key,
merges: result.output.merges.len(),
conflicts: result.output.conflicts.len(),
low_value: result.output.low_value_ids.len(),
}]),
Ok(None) => Ok(vec![]),
Err(error) => Err(anyhow::anyhow!(error.to_string())),
}
}
}