use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler}; #[cfg(test)] use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT; use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::scheduler::ScheduledAgentTaskOptions; use crate::skills::SkillRuntime; use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository}; use crate::tools::ToolRegistry; use crate::tools::task::repository::TaskRepository; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; use super::agent_factory::{AgentBuildRequest, AgentFactory}; use super::cli_session::CliSessionService; #[cfg(test)] use super::execution::should_display_message_to_user; #[cfg(test)] use super::memory_maintenance::{ MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan, extract_json_object, is_recoverable_maintenance_llm_error, strip_json_code_fence, }; use super::memory_maintenance::{MemoryMaintenanceScopeResult, MemoryOrganizationOutput}; use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator; use super::scheduled_agent_task_service::ScheduledAgentTaskService; use super::session_history::SessionHistory; use super::session_lifecycle::SessionLifecycleService; use super::session_message_service::SessionMessageService; /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 /// Topic 按 chat_id 隔离,存储在 SessionHistory 中 pub struct Session { pub id: Uuid, pub channel_name: String, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, skills: Arc, agent_factory: AgentFactory, compressor: ContextCompressor, history: SessionHistory, store: Arc, } pub struct BusToolCallEmitter { bus: Arc, channel_name: String, chat_id: String, metadata: HashMap, } impl BusToolCallEmitter { pub fn new( bus: Arc, channel_name: impl Into, chat_id: impl Into, metadata: HashMap, ) -> Self { Self { bus, channel_name: channel_name.into(), chat_id: chat_id.into(), metadata, } } } #[async_trait] impl EmittedMessageHandler for BusToolCallEmitter { async fn handle(&self, message: ChatMessage) { for outbound in OutboundMessage::from_chat_message( &self.channel_name, &self.chat_id, None, // session_id None, &self.metadata, &message, ) { if let Err(error) = self.bus.publish_outbound(outbound).await { tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call"); } } } async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option) { let mut metadata = self.metadata.clone(); if let Some(ms) = duration_ms { metadata.insert("tool_duration_ms".to_string(), ms.to_string()); } for outbound in OutboundMessage::from_chat_message( &self.channel_name, &self.chat_id, None, // session_id None, &metadata, &message, ) { if let Err(error) = self.bus.publish_outbound(outbound).await { tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call"); } } } } impl Session { pub async fn new( channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, tools: Arc, skills: Arc, store: Arc, agent_prompt_reinject_every: u64, ) -> Result { let conversations: Arc = store.clone(); let skill_events: Arc = store.clone(); let prompt_repository: Arc = store.clone(); let agent_factory = AgentFactory::new( tools, skills.clone(), agent_prompt_reinject_every as usize, prompt_repository.clone(), ); Self::with_factories( channel_name, provider_config, user_tx, skills, agent_factory, conversations, skill_events, store, ) .await } pub(crate) async fn with_factories( channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, skills: Arc, agent_factory: AgentFactory, conversations: Arc, skill_events: Arc, store: Arc, ) -> Result { Ok(Self { id: Uuid::new_v4(), channel_name: channel_name.clone(), user_tx, provider_config: provider_config.clone(), skills, agent_factory, compressor: ContextCompressor::from_provider_config(&provider_config), history: SessionHistory::new( channel_name, conversations, skill_events, ), store, }) } pub fn persistent_session_id(&self, chat_id: &str) -> String { self.history.persistent_session_id(chat_id) } /// 设置当前话题 ID(指定 chat) pub fn set_current_topic(&mut self, chat_id: &str, topic_id: Option) { if let Some(topic_id) = topic_id { self.history.set_chat_topic(chat_id, topic_id); } else { self.history.clear_chat_topic(chat_id); } } /// 获取当前话题 ID(指定 chat) pub fn current_topic(&self, chat_id: &str) -> Option<&str> { self.history.chat_topic(chat_id) } /// 获取历史所对应的话题 ID(指定 chat) pub fn history_topic(&self, chat_id: &str) -> Option<&str> { self.history.history_topic(chat_id) } /// 切换话题 - 清除当前历史并加载新话题的历史 pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> { // 清除当前历史 self.history.remove_history(chat_id); // 先设置当前话题(set_history 需要这个) self.history.set_chat_topic(chat_id, topic_id.to_string()); // 加载新话题的历史 let messages = self .store .load_messages_for_topic(topic_id) .map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?; self.history.set_history(chat_id, messages); tracing::info!( topic_id = %topic_id, chat_id = %chat_id, "Switched to topic" ); Ok(()) } pub fn ensure_persistent_session(&self, chat_id: &str) -> Result { self.history.ensure_persistent_session(chat_id) } pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { // 检查历史是否存在且对应正确的话题 // 先获取 topic 信息并转换为 owned String,避免借用冲突 let current_topic: Option = self.history.chat_topic(chat_id).map(|s| s.to_string()); let stored_topic = self.history.history_topic(chat_id); if self.chat_history_exists(chat_id) { // 如果历史已存在,但话题不匹配,需要重新加载 if current_topic.as_deref() != stored_topic { tracing::info!( chat_id = %chat_id, current_topic = ?current_topic, stored_topic = ?stored_topic, "Topic changed, reloading history" ); self.reload_chat_history(chat_id)?; } return Ok(()); } // 历史不存在,按 topic 加载(如果设置了 topic) self.history.ensure_chat_loaded(chat_id, current_topic.as_deref()) } fn chat_history_exists(&self, chat_id: &str) -> bool { self.history.get_history(chat_id).is_some() } pub fn ensure_agent_prompt_before_user_message( &mut self, chat_id: &str, ) -> Result<(), AgentError> { self.history .ensure_agent_prompt_before_user_message(chat_id) } /// 获取或创建指定 chat_id 的会话历史 pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec { self.history.get_or_create_history(chat_id) } /// 获取指定 chat_id 的会话历史(不创建) pub fn get_history(&self, chat_id: &str) -> Option<&Vec> { self.history.get_history(chat_id) } /// 使用完整消息追加到历史 pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) { self.history.add_message(chat_id, message); } pub fn remove_history(&mut self, chat_id: &str) { self.history.remove_history(chat_id); } pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { self.history.clear_chat_history(chat_id) } /// 将消息写入内存与持久化层(使用当前 topic) pub fn append_persisted_message( &mut self, chat_id: &str, message: ChatMessage, ) -> Result<(), AgentError> { let session_id = self.persistent_session_id(chat_id); let topic_id = self.history.chat_topic(chat_id).map(|s| s.to_string()); self.store .append_message_with_topic(&session_id, topic_id.as_deref(), &message) .map_err(|err| { AgentError::Other(format!("append message persistence error: {}", err)) })?; self.add_message(chat_id, message); // 更新 topic 的最后活跃时间 if let Some(ref topic_id) = topic_id { if let Err(e) = self.store.touch_topic(topic_id) { tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic"); } } Ok(()) } pub fn append_persisted_messages( &mut self, chat_id: &str, messages: I, ) -> Result<(), AgentError> where I: IntoIterator, { self.history.append_persisted_messages(chat_id, messages) } /// 将消息保存到指定话题(直接写入数据库,不更新内存历史) pub fn append_messages_to_topic( &self, chat_id: &str, topic_id: &str, messages: &[ChatMessage], ) -> Result<(), AgentError> { self.history.append_to_topic(chat_id, topic_id, messages) } pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { if media_refs.is_empty() { ChatMessage::user(content) } else { ChatMessage::user_with_media(content, media_refs) } } #[cfg(test)] fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> { self.latest_user_message(chat_id) .map(|message| message.id.as_str()) } #[cfg(test)] fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> { self.history.latest_user_message(chat_id) } #[cfg(test)] fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool { self.latest_user_message_id(chat_id) .map(|current_id| current_id == message_id) .unwrap_or(false) } pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool { self.history.matches_current_user_turn(chat_id, message) } pub(crate) fn stale_result_diagnostics( &self, chat_id: &str, ) -> (Option<&str>, Option, bool, usize) { self.history.stale_result_diagnostics(chat_id) } /// 清除所有历史 pub fn clear_all_history(&mut self) -> Result<(), AgentError> { self.history.clear_all_history() } pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } /// 获取 provider_config 引用 pub fn provider_config(&self) -> &LLMProviderConfig { &self.provider_config } /// 获取 compressor 引用 pub fn compressor(&self) -> &ContextCompressor { &self.compressor } pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool { self.history.try_start_background_compaction(chat_id) } pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) { self.history.finish_background_compaction(chat_id); } pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { // 如果当前有 topic,加载该 topic 的消息 if let Some(topic_id) = self.history.chat_topic(chat_id) { let messages = self .store .load_messages_for_topic(topic_id) .map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?; self.history.set_history(chat_id, messages); } else { // 否则加载 session 的所有消息 self.history.reload_chat_history(chat_id)?; } Ok(()) } pub(crate) fn store(&self) -> Arc { self.history.conversations() } pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> { if self.skills.is_empty() { return Ok(()); } self.history.append_skill_event( chat_id, "offered", None, &self.skills.offered_event_payload(), ) } /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent( &self, chat_id: &str, sender_id: Option<&str>, message_id: Option<&str>, ) -> Result { self.create_agent_with_provider_config( chat_id, None, // notification_chat_id = None,使用 session_chat_id sender_id, message_id, self.provider_config.clone(), ) } pub fn create_agent_with_provider_config( &self, session_chat_id: &str, notification_chat_id: Option<&str>, sender_id: Option<&str>, message_id: Option<&str>, provider_config: LLMProviderConfig, ) -> Result { self.agent_factory.create(AgentBuildRequest { channel_name: &self.channel_name, session_chat_id, notification_chat_id, sender_id, message_id, provider_config, }) } } /// SessionManager 管理所有 Session,按 channel_name 路由 #[derive(Clone)] pub struct SessionManager { tools: Arc, skills: Arc, store: Arc, show_tool_results: bool, lifecycle: SessionLifecycleService, cli_sessions: CliSessionService, messages: SessionMessageService, scheduled_tasks: ScheduledAgentTaskService, memory_maintenance: MemoryMaintenanceCoordinator, task_repository: Arc, } pub(crate) struct SessionManagerServices { pub(crate) tools: Arc, pub(crate) skills: Arc, pub(crate) store: Arc, pub(crate) show_tool_results: bool, pub(crate) lifecycle: SessionLifecycleService, pub(crate) cli_sessions: CliSessionService, pub(crate) messages: SessionMessageService, pub(crate) scheduled_tasks: ScheduledAgentTaskService, pub(crate) memory_maintenance: MemoryMaintenanceCoordinator, pub(crate) task_repository: Arc, } impl SessionManager { pub(crate) fn from_services(services: SessionManagerServices) -> Self { Self { tools: services.tools, skills: services.skills, store: services.store, show_tool_results: services.show_tool_results, lifecycle: services.lifecycle, cli_sessions: services.cli_sessions, messages: services.messages, scheduled_tasks: services.scheduled_tasks, memory_maintenance: services.memory_maintenance, task_repository: services.task_repository, } } pub fn new( agent_prompt_reinject_every: u64, show_tool_results: bool, default_timezone: String, provider_config: LLMProviderConfig, provider_configs: HashMap, skills: Arc, disabled_tools: std::collections::HashSet, task_config: crate::config::TaskConfig, subagents_config: crate::config::SubagentsConfig, maintenance_config: crate::config::MemoryMaintenanceConfig, session_ttl_hours: Option, mcp_config: crate::mcp::McpConfig, ) -> Result { super::runtime::build_session_manager( agent_prompt_reinject_every, show_tool_results, default_timezone, provider_config, provider_configs, skills, disabled_tools, task_config, subagents_config, maintenance_config, session_ttl_hours, mcp_config, None, ) .map(|(session_manager, _)| session_manager) } pub fn tools(&self) -> Arc { self.tools.clone() } pub fn store(&self) -> Arc { self.store.clone() } pub fn show_tool_results(&self) -> bool { self.show_tool_results } pub fn task_repository(&self) -> Arc { self.task_repository.clone() } pub fn skills(&self) -> Arc { self.skills.clone() } pub(crate) fn cli_sessions(&self) -> CliSessionService { self.cli_sessions.clone() } #[cfg_attr(not(test), allow(dead_code))] pub(crate) async fn organize_memory_maintenance_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { self.memory_maintenance.organize_for_scope(scope_key).await } pub(crate) async fn run_memory_maintenance_for_all_scopes( &self, ) -> Result, AgentError> { self.memory_maintenance.run_for_all_scopes().await } /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { self.lifecycle.ensure_session(channel_name).await } /// 获取 session(不检查超时) pub async fn get(&self, channel_name: &str) -> Option>> { self.lifecycle.get(channel_name).await } /// 获取指定 chat 的当前话题(确保 session 存在,自动从数据库恢复) pub async fn get_current_topic(&self, channel_name: &str, chat_id: &str) -> Result, AgentError> { self.ensure_session(channel_name).await?; if let Some(session) = self.get(channel_name).await { let mut guard = session.lock().await; // 如果内存中没有当前话题,从数据库恢复最近活跃的话题 if guard.current_topic(chat_id).is_none() { let session_id = guard.persistent_session_id(chat_id); let topics = self.store.list_topics(&session_id) .map_err(|e| AgentError::Other(format!("Failed to list topics: {}", e)))?; if let Some(latest_topic) = topics.first() { // 设置最近活跃的话题为当前话题 guard.set_current_topic(chat_id, Some(latest_topic.id.clone())); tracing::info!( chat_id = %chat_id, topic_id = %latest_topic.id, topic_title = %latest_topic.title, "Restored current topic from database" ); } } Ok(guard.current_topic(chat_id).map(|s| s.to_string())) } else { Ok(None) } } /// 更新最后活跃时间 pub async fn touch(&self, channel_name: &str) { self.lifecycle.touch(channel_name).await; } pub async fn cleanup_expired_sessions(&self) -> usize { self.lifecycle.cleanup_expired_sessions().await } /// 处理消息:路由到对应 session 的 agent pub async fn handle_message( &self, channel_name: &str, sender_id: &str, chat_id: &str, content: &str, media: Vec, live_emitter: Option>, ) -> Result, AgentError> { self.messages .handle_message( channel_name, sender_id, chat_id, content, media, live_emitter, ) .await } pub async fn run_scheduled_agent_task( &self, channel_name: &str, chat_id: &str, prompt: &str, options: ScheduledAgentTaskOptions, ) -> Result, AgentError> { self.scheduled_tasks .run(channel_name, chat_id, None, prompt, options) .await } /// 执行 SilentAgentTask,支持 notification_chat_id 分离 pub async fn run_silent_agent_task( &self, channel_name: &str, session_chat_id: &str, notification_chat_id: Option<&str>, prompt: &str, options: ScheduledAgentTaskOptions, ) -> Result, AgentError> { self.scheduled_tasks .run(channel_name, session_chat_id, notification_chat_id, prompt, options) .await } /// 清除指定 session 的所有历史 pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { if let Some(session) = self.get(channel_name).await { let mut session_guard = session.lock().await; session_guard.clear_all_history()?; } Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::bus::MessageBus; use crate::bus::message::OutboundEventKind; use crate::gateway::tool_registry_factory::ToolRegistryFactory; use crate::storage::MemoryRecord; use crate::tools::NoopSessionMessageSender; use axum::http::StatusCode; use axum::{Json, Router, routing::post}; use serde_json::{Value, json}; use std::collections::{HashMap, HashSet}; use std::sync::{ Arc as StdArc, atomic::{AtomicUsize, Ordering}, }; use tokio::net::TcpListener; use tokio::sync::mpsc; fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), llm_timeout_secs: 120, memory_maintenance_timeout_secs: 600, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, } } #[tokio::test] async fn test_latest_user_message_guard_tracks_current_turn() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), HashSet::new(), Default::default(), ) .build(), ); let mut session = Session::new( "test-channel".to_string(), test_provider_config(), user_tx, tools, skills, store, 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); let first = session.create_user_message("first", Vec::new()); let first_id = first.id.clone(); session.append_persisted_message("chat-1", first).unwrap(); assert!(session.is_latest_user_message("chat-1", &first_id)); let second = session.create_user_message("second", Vec::new()); let second_id = second.id.clone(); session.append_persisted_message("chat-1", second).unwrap(); assert!(!session.is_latest_user_message("chat-1", &first_id)); assert!(session.is_latest_user_message("chat-1", &second_id)); } #[tokio::test] async fn test_current_user_turn_match_survives_history_compaction_reload() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), HashSet::new(), Default::default(), ) .build(), ); let mut session = Session::new( "test-channel".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); let first = session.create_user_message("first", Vec::new()); let first_id = first.id.clone(); session.append_persisted_message("chat-1", first).unwrap(); session .append_persisted_message("chat-1", ChatMessage::assistant("answer-1")) .unwrap(); let second = session.create_user_message("second", Vec::new()); session .append_persisted_message("chat-1", second.clone()) .unwrap(); session .append_persisted_message("chat-1", ChatMessage::assistant("answer-2")) .unwrap(); let session_id = session.persistent_session_id("chat-1"); let snapshot_end_seq = store .get_session(&session_id) .unwrap() .unwrap() .message_count; let preserved_messages = session.get_history("chat-1").unwrap().clone(); store .compact_active_history( &session_id, snapshot_end_seq, &[], &ChatMessage::system("[Compressed History]\n\nsummary"), &preserved_messages, ) .unwrap(); session.reload_chat_history("chat-1").unwrap(); assert!(!session.is_latest_user_message("chat-1", &first_id)); assert!(!session.is_latest_user_message("chat-1", &second.id)); assert!(session.matches_current_user_turn("chat-1", &second)); } async fn start_mock_openai_server() -> String { start_mock_openai_server_with_content(None).await } async fn start_mock_openai_server_with_content( mock_response_content: Option, ) -> String { async fn handle( axum::extract::State(mock_response_content): axum::extract::State>, Json(body): Json, ) -> Json { let model = body .get("model") .and_then(|value| value.as_str()) .unwrap_or("unknown-model"); let content = mock_response_content.unwrap_or_else(|| format!("reply from {}", model)); Json(json!({ "id": "mock-response", "model": model, "choices": [ { "message": { "content": content, "tool_calls": [] } } ], "usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 } })) } let app = Router::new() .route("/chat/completions", post(handle)) .with_state(mock_response_content); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); format!("http://{}", address) } async fn start_mock_openai_504_server() -> String { async fn handle() -> (StatusCode, &'static str) { (StatusCode::GATEWAY_TIMEOUT, "stream timeout") } let app = Router::new().route("/chat/completions", post(handle)); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); format!("http://{}", address) } async fn start_mock_openai_flaky_server(mock_response_content: String) -> String { let attempts = StdArc::new(AtomicUsize::new(0)); let state = (attempts, mock_response_content); async fn handle( axum::extract::State((attempts, mock_response_content)): axum::extract::State<( StdArc, String, )>, Json(body): Json, ) -> (StatusCode, Json) { let attempt = attempts.fetch_add(1, Ordering::SeqCst); if attempt == 0 { return ( StatusCode::GATEWAY_TIMEOUT, Json(json!({"error": "stream timeout"})), ); } let model = body .get("model") .and_then(|value| value.as_str()) .unwrap_or("unknown-model"); ( StatusCode::OK, Json(json!({ "id": "mock-response", "model": model, "choices": [ { "message": { "content": mock_response_content, "tool_calls": [] } } ], "usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 } })), ) } let app = Router::new() .route("/chat/completions", post(handle)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); format!("http://{}", address) } #[tokio::test] async fn test_handle_message_returns_recoverable_reply_on_llm_504() { let base_url = start_mock_openai_504_server().await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "timeout-provider".to_string(), base_url: base_url.clone(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "timeout-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); let outbound = session_manager .handle_message("test-channel", "user-1", "chat-1", "hello", Vec::new(), None) .await .unwrap(); assert_eq!(outbound.len(), 1); assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时")); } #[tokio::test] async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() { let base_url = start_mock_openai_server().await; let default_provider = LLMProviderConfig { provider_type: "openai".to_string(), name: "default-provider".to_string(), base_url: base_url.clone(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "default-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let planner_provider = LLMProviderConfig { model_id: "planner-model".to_string(), name: "planner-provider".to_string(), ..default_provider.clone() }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), default_provider.clone(), HashMap::from([ ("default".to_string(), default_provider), ("planner".to_string(), planner_provider), ]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); let planner_outbound = session_manager .run_scheduled_agent_task( "test-channel", "chat-planner", "请规划今天工作", ScheduledAgentTaskOptions { agent: Some("planner".to_string()), ..Default::default() }, ) .await .unwrap(); assert_eq!(planner_outbound.len(), 1); assert!(planner_outbound[0].content.contains("planner-model")); let default_outbound = session_manager .run_scheduled_agent_task( "test-channel", "chat-default", "请规划今天工作", ScheduledAgentTaskOptions { agent: Some("default".to_string()), ..Default::default() }, ) .await .unwrap(); assert_eq!(default_outbound.len(), 1); assert!(default_outbound[0].content.contains("default-model")); } #[tokio::test] async fn test_run_scheduled_agent_task_persists_execution_guard_prompt() { let base_url = start_mock_openai_server().await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "default-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "default-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); session_manager .run_scheduled_agent_task( "test-channel", "chat-guard", "每小时执行以下流程:检查邮箱并同步待办", ScheduledAgentTaskOptions { system_prompt: Some("你是邮箱待办同步助手。".to_string()), ..Default::default() }, ) .await .unwrap(); let session = session_manager.get("test-channel").await.unwrap(); let session_guard = session.lock().await; let persisted_messages = session_guard .store() .load_messages(&session_guard.persistent_session_id("chat-guard")) .unwrap(); let scheduled_prompt = persisted_messages .iter() .find(|message| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT)) .expect("missing scheduled system prompt"); assert!(scheduled_prompt.content.contains("已经触发的定时任务执行")); assert!( scheduled_prompt .content .contains("不要调用任何定时任务管理工具") ); assert!(scheduled_prompt.content.contains("你是邮箱待办同步助手。")); } #[tokio::test] async fn test_summarize_memory_maintenance_for_scope_uses_model_output() { let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": ["偏好简洁表达"], "behavior_patterns": ["习惯先问方案再要代码"], "merges": [], "conflicts": [], "low_value_ids": [], "managed_markdown": "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达\n\n### 行为模式\n- 习惯先问方案再要代码" })) .unwrap(); let base_url = start_mock_openai_server_with_content(Some(mock_response_content.clone())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let output = session_manager .organize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap() .unwrap(); assert!(output.merges.is_empty()); assert!(output.conflicts.is_empty()); assert!(output.low_value_ids.is_empty()); } #[test] fn test_is_recoverable_maintenance_llm_error_detects_transport_failures() { assert!(is_recoverable_maintenance_llm_error( "error sending request for url (https://example.invalid/v1/chat/completions)" )); assert!(is_recoverable_maintenance_llm_error( "API error 504 Gateway Timeout: stream timeout" )); assert!(!is_recoverable_maintenance_llm_error( "API error 401 Unauthorized" )); } #[test] fn test_extract_json_object_skips_wrapping_text() { let wrapped = "下面是结果:\n```json\n{\n \"user_facts\": [],\n \"preferences\": []\n}\n```\n请查收"; let stripped = strip_json_code_fence(wrapped); let extracted = extract_json_object(stripped).unwrap(); assert_eq!( extracted, "{\n \"user_facts\": [],\n \"preferences\": []\n}" ); } #[tokio::test] async fn test_summarize_memory_maintenance_transport_error_includes_provider_context() { let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url: "https://example.invalid/v1".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 1, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let error = session_manager .organize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap_err() .to_string(); assert!( error.contains("memory organization model error: transport error:") || error.contains("memory summary model error: transport error:"), "Error did not contain expected message: {}", error ); assert!(error.contains("provider=maintenance-provider")); assert!(error.contains("model=maintenance-model")); assert!(error.contains("url=https://example.invalid/v1/chat/completions")); assert!(error.contains("timeout_secs=600")); assert!(error.contains("error sending request for url")); } #[tokio::test] async fn test_summarize_memory_maintenance_retries_recoverable_provider_errors() { let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": [], "behavior_patterns": [], "merges": [], "conflicts": [], "low_value_ids": [], "managed_markdown": "### 用户事实\n- 用户在做AI产品" })) .unwrap(); let base_url = start_mock_openai_flaky_server(mock_response_content.clone()).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let output = session_manager .organize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap() .unwrap(); assert!(output.merges.is_empty()); } #[tokio::test] async fn test_summarize_memory_maintenance_for_scope_extracts_wrapped_json_object() { let mock_response_content = "结果如下:\n```json\n{\n \"user_facts\": [\"用户在做AI产品\"],\n \"preferences\": [],\n \"behavior_patterns\": [],\n \"merges\": [],\n \"conflicts\": [],\n \"low_value_ids\": [],\n \"managed_markdown\": \"### 用户事实\\n- 用户在做AI产品\"\n}\n```\n"; let base_url = start_mock_openai_server_with_content(Some(mock_response_content.to_string())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let output = session_manager .organize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap() .unwrap(); assert!(output.merges.is_empty()); } #[tokio::test] async fn test_run_memory_maintenance_for_all_scopes_scans_all_scopes_even_without_recent_updates() { let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": [], "behavior_patterns": [], "merges": [], "conflicts": [], "low_value_ids": [], "managed_markdown": "### 用户事实\n- 用户在做AI产品" })) .unwrap(); let base_url = start_mock_openai_server_with_content(Some(mock_response_content.clone())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let result = session_manager .run_memory_maintenance_for_all_scopes() .await .unwrap(); assert!(result.is_some()); let result = result.unwrap(); assert_eq!(result.scope_key, "all"); // 由于步骤2需要新的提示词和输入格式,这里只验证基本功能 assert!(!result.managed_markdown.is_empty()); } #[tokio::test] async fn test_run_memory_maintenance_for_all_scopes_skips_recoverable_transport_failures() { let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url: "https://example.invalid/v1".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 1, memory_maintenance_timeout_secs: 600, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, max_images_in_context: 1, max_image_age_rounds: 10, }; let session_manager = SessionManager::new( 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), crate::config::TaskConfig::default(), crate::config::SubagentsConfig::default(), crate::config::MemoryMaintenanceConfig::default(), Some(24), crate::mcp::McpConfig::default(), ) .unwrap(); for scope_key in ["feishu:user-1", "feishu:user-2"] { session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: format!("{} 在做AI产品", scope_key), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); } let result = session_manager .run_memory_maintenance_for_all_scopes() .await .unwrap(); // 当遇到可恢复错误时,没有整理任何记忆,返回 None assert!(result.is_none()); } #[test] fn test_apply_memory_maintenance_output_merges_and_deletes_low_value_records() { let store = SessionStore::in_memory().unwrap(); let scope_key = "feishu:user-1"; // 创建足够的记忆(7条),让合并操作满足保护限制 // 合并后需要保留至少 5 条(min_memories_to_keep) let work = store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "user".to_string(), memory_key: "work_short".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let role = store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "user".to_string(), memory_key: "work_detail".to_string(), content: "用户主要在做AI产品设计和实现".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let noise = store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "other".to_string(), memory_key: "temporary".to_string(), content: "今天临时提到过一个无后续的小细节".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); // 添加额外的记忆以满足 min_memories_to_keep = 5 的要求 for i in 0..4 { store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "user".to_string(), memory_key: format!("extra_{}", i), content: format!("额外记忆 {}", i), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); } let plan = build_memory_maintenance_plan( &store.list_memories_for_scope("user", scope_key).unwrap(), ); assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆 let output = MemoryOrganizationOutput { merges: vec![MemoryMaintenanceMerge { source_ids: vec![work.id.clone(), role.id.clone()], namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户主要在做AI产品设计与实现".to_string(), }], conflicts: Vec::new(), low_value_ids: vec![noise.id.clone()], }; // 使用默认配置进行验证 apply_memory_maintenance_output( &store, scope_key, &plan, &output, crate::config::MemoryMaintenanceConfig::default().max_merge_ratio, crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep, crate::config::MemoryMaintenanceConfig::default().max_merge_per_group, ) .unwrap(); let all_memories = store.list_memories_for_scope("user", scope_key).unwrap(); // 过滤掉 _meta 记录 let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect(); // 合并 2 条为 1 条,删除 1 条,7 - 2 + 1 = 5 条 assert_eq!(user_memories.len(), 5); // 验证合并后的记忆存在 assert!(user_memories.iter().any(|m| m.namespace == "user" && m.memory_key == "work")); } #[test] fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { let completed = ChatMessage::tool("call-1", "calculator", "2"); let pending = ChatMessage::tool_with_state( "call-2", "bash", "waiting", crate::bus::message::ToolMessageState::PendingUserAction, ); assert!(!should_display_message_to_user(false, &completed)); assert!(should_display_message_to_user(false, &pending)); assert!(should_display_message_to_user(true, &completed)); } #[tokio::test] async fn test_bus_tool_call_emitter_emits_completed_tool_results() { let bus = MessageBus::new(4); let emitter = BusToolCallEmitter::new( bus.clone(), "test-channel", "chat-1", HashMap::new(), ); emitter .handle(ChatMessage::tool("call-1", "calculator", "2")) .await; let msg = tokio::time::timeout(std::time::Duration::from_millis(500), bus.consume_outbound()) .await .expect("timeout waiting for outbound message") .expect("bus outbound closed"); assert_eq!(msg.event_kind, OutboundEventKind::ToolResult); } #[tokio::test] async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), HashSet::new(), Default::default(), ) .build(), ); let mut session = Session::new( "test-channel".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); let history = session.get_history("chat-1").unwrap(); // 新设计:系统提示词不再持久化到历史记录,而是每次请求时动态注入 assert_eq!(history.len(), 0); } #[tokio::test] async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), HashSet::new(), Default::default(), ) .build(), ); let mut session = Session::new( "test-channel".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); for turn in 0..100 { session .append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}"))) .unwrap(); } session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); // 新设计:系统提示词不再持久化到历史记录 let history = session.get_history("chat-1").unwrap(); let user_messages = history .iter() .filter(|message| message.role == "user") .count(); assert_eq!(user_messages, 100); // 注入计数在实际处理请求时由 AgentPromptProvider 更新 // 此处仅为模拟调用,不会触发实际注入 let stored = store .get_session(&session.persistent_session_id("chat-1")) .unwrap() .unwrap(); // 初始值为 0,只有在实际 process 调用时才会更新 assert_eq!(stored.agent_prompt_reinjection_count, 0); session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); let history = session.get_history("chat-1").unwrap(); let user_messages = history .iter() .filter(|message| message.role == "user") .count(); assert_eq!(user_messages, 100); } #[tokio::test] async fn test_agent_prompt_reinjection_can_be_disabled_by_config() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), HashSet::new(), Default::default(), ) .build(), ); let mut session = Session::new( "test-channel".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 0, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); for turn in 0..100 { session .append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}"))) .unwrap(); } session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); // 新设计:系统提示词不再持久化到历史记录 let history = session.get_history("chat-1").unwrap(); let user_messages = history .iter() .filter(|message| message.role == "user") .count(); assert_eq!(user_messages, 100); } #[test] fn test_default_tools_registers_get_time() { let skills = Arc::new(SkillRuntime::default()); let store = Arc::new(SessionStore::in_memory().unwrap()); let tools = ToolRegistryFactory::new( skills, store.clone(), store.clone(), store, Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), HashSet::new(), Default::default(), ) .build(); assert!(tools.tool_names().iter().any(|name| name == "get_time")); } #[test] fn test_build_memory_maintenance_plan_deduplicates_and_categorizes() { let memories = vec![ MemoryRecord { id: "1".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 1, updated_at: 1, }, MemoryRecord { id: "2".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 2, updated_at: 2, }, MemoryRecord { id: "3".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "user".to_string(), memory_key: "style".to_string(), content: "偏好简洁表达".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 3, updated_at: 3, }, MemoryRecord { id: "4".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "patterns".to_string(), memory_key: "workflow".to_string(), content: "习惯先问方案再要代码".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 4, updated_at: 4, }, ]; let plan = build_memory_maintenance_plan(&memories); // 去重后应该有3条(第1、2条重复) assert_eq!(plan.candidates.len(), 3); // 验证内容包含所有唯一的记忆 let contents: Vec = plan.candidates.iter().map(|c| c.content.clone()).collect(); assert!(contents.contains(&"用户在做AI产品".to_string())); assert!(contents.contains(&"偏好简洁表达".to_string())); assert!(contents.contains(&"习惯先问方案再要代码".to_string())); } } #[async_trait] impl crate::scheduler::MaintenanceExecutor for SessionManager { async fn cleanup_expired_sessions(&self) -> usize { self.cleanup_expired_sessions().await } async fn run_memory_maintenance_for_all_scopes( &self, ) -> anyhow::Result> { match self.run_memory_maintenance_for_all_scopes().await { Ok(Some(result)) => Ok(vec![crate::scheduler::MaintenanceRunSummary { scope_key: result.scope_key, merges: result.output.merges.len(), conflicts: result.output.conflicts.len(), low_value: result.output.low_value_ids.len(), }]), Ok(None) => Ok(vec![]), Err(error) => Err(anyhow::anyhow!(error.to_string())), } } }