use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler}; #[cfg(test)] use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT; use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; use crate::scheduler::ScheduledAgentTaskOptions; use crate::skills::SkillRuntime; use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository}; use crate::tools::ToolRegistry; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; use super::agent_factory::{AgentBuildRequest, AgentFactory}; use super::cli_session::CliSessionService; use super::execution::should_display_message_to_user; #[cfg(test)] use super::memory_maintenance::{ MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan, combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error, strip_json_code_fence, }; use super::memory_maintenance::{MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult}; use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator; use super::prompt_injector::PromptInjector; use super::scheduled_agent_task_service::ScheduledAgentTaskService; use super::session_history::SessionHistory; use super::session_lifecycle::SessionLifecycleService; use super::session_message_service::SessionMessageService; /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 pub struct Session { pub id: Uuid, pub channel_name: String, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, skills: Arc, agent_factory: AgentFactory, compressor: ContextCompressor, history: SessionHistory, } pub struct BusToolCallEmitter { bus: Arc, channel_name: String, chat_id: String, metadata: HashMap, show_tool_results: bool, } impl BusToolCallEmitter { pub fn new( bus: Arc, channel_name: impl Into, chat_id: impl Into, metadata: HashMap, show_tool_results: bool, ) -> Self { Self { bus, channel_name: channel_name.into(), chat_id: chat_id.into(), metadata, show_tool_results, } } } #[async_trait] impl EmittedMessageHandler for BusToolCallEmitter { async fn handle(&self, message: ChatMessage) { if !should_display_message_to_user(self.show_tool_results, &message) { return; } for outbound in OutboundMessage::from_chat_message( &self.channel_name, &self.chat_id, None, &self.metadata, &message, ) { if let Err(error) = self.bus.publish_outbound(outbound).await { tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call"); } } } } impl Session { pub async fn new( channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, tools: Arc, skills: Arc, store: Arc, agent_prompt_reinject_every: u64, ) -> Result { let agent_factory = AgentFactory::new(tools, skills.clone()); let conversations: Arc = store.clone(); let skill_events: Arc = store.clone(); let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every); Self::with_factories( channel_name, provider_config, user_tx, skills, agent_factory, prompt_injector, conversations, skill_events, ) .await } pub(crate) async fn with_factories( channel_name: String, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, skills: Arc, agent_factory: AgentFactory, prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, ) -> Result { Ok(Self { id: Uuid::new_v4(), channel_name: channel_name.clone(), user_tx, provider_config: provider_config.clone(), skills, agent_factory, compressor: ContextCompressor::from_provider_config(&provider_config), history: SessionHistory::new( channel_name, prompt_injector, conversations, skill_events, ), }) } pub fn persistent_session_id(&self, chat_id: &str) -> String { self.history.persistent_session_id(chat_id) } pub fn ensure_persistent_session(&self, chat_id: &str) -> Result { self.history.ensure_persistent_session(chat_id) } pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { self.history.ensure_chat_loaded(chat_id) } pub fn ensure_agent_prompt_before_user_message( &mut self, chat_id: &str, ) -> Result<(), AgentError> { self.history .ensure_agent_prompt_before_user_message(chat_id) } /// 获取或创建指定 chat_id 的会话历史 pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec { self.history.get_or_create_history(chat_id) } /// 获取指定 chat_id 的会话历史(不创建) pub fn get_history(&self, chat_id: &str) -> Option<&Vec> { self.history.get_history(chat_id) } /// 使用完整消息追加到历史 pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) { self.history.add_message(chat_id, message); } pub fn remove_history(&mut self, chat_id: &str) { self.history.remove_history(chat_id); } pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { self.history.clear_chat_history(chat_id) } pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> { self.history.reset_chat_context(chat_id) } /// 将消息写入内存与持久化层 pub fn append_persisted_message( &mut self, chat_id: &str, message: ChatMessage, ) -> Result<(), AgentError> { self.history.append_persisted_message(chat_id, message) } pub fn append_persisted_messages( &mut self, chat_id: &str, messages: I, ) -> Result<(), AgentError> where I: IntoIterator, { self.history.append_persisted_messages(chat_id, messages) } pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { if media_refs.is_empty() { ChatMessage::user(content) } else { ChatMessage::user_with_media(content, media_refs) } } #[cfg(test)] fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> { self.latest_user_message(chat_id) .map(|message| message.id.as_str()) } #[cfg(test)] fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> { self.history.latest_user_message(chat_id) } #[cfg(test)] fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool { self.latest_user_message_id(chat_id) .map(|current_id| current_id == message_id) .unwrap_or(false) } pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool { self.history.matches_current_user_turn(chat_id, message) } pub(crate) fn stale_result_diagnostics( &self, chat_id: &str, ) -> (Option<&str>, Option, bool, usize) { self.history.stale_result_diagnostics(chat_id) } /// 清除所有历史 pub fn clear_all_history(&mut self) -> Result<(), AgentError> { self.history.clear_all_history() } pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } /// 获取 provider_config 引用 pub fn provider_config(&self) -> &LLMProviderConfig { &self.provider_config } /// 获取 compressor 引用 pub fn compressor(&self) -> &ContextCompressor { &self.compressor } pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool { self.history.try_start_background_compaction(chat_id) } pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) { self.history.finish_background_compaction(chat_id); } pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { self.history.reload_chat_history(chat_id) } pub(crate) fn store(&self) -> Arc { self.history.conversations() } pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> { if self.skills.is_empty() { return Ok(()); } self.history.append_skill_event( chat_id, "offered", None, &self.skills.offered_event_payload(), ) } /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent( &self, chat_id: &str, sender_id: Option<&str>, message_id: Option<&str>, ) -> Result { self.create_agent_with_provider_config( chat_id, sender_id, message_id, self.provider_config.clone(), ) } pub fn create_agent_with_provider_config( &self, chat_id: &str, sender_id: Option<&str>, message_id: Option<&str>, provider_config: LLMProviderConfig, ) -> Result { self.agent_factory.create(AgentBuildRequest { channel_name: &self.channel_name, chat_id, sender_id, message_id, provider_config, }) } } /// SessionManager 管理所有 Session,按 channel_name 路由 #[derive(Clone)] pub struct SessionManager { tools: Arc, skills: Arc, store: Arc, show_tool_results: bool, lifecycle: SessionLifecycleService, cli_sessions: CliSessionService, messages: SessionMessageService, scheduled_tasks: ScheduledAgentTaskService, memory_maintenance: MemoryMaintenanceCoordinator, } pub(crate) struct SessionManagerServices { pub(crate) tools: Arc, pub(crate) skills: Arc, pub(crate) store: Arc, pub(crate) show_tool_results: bool, pub(crate) lifecycle: SessionLifecycleService, pub(crate) cli_sessions: CliSessionService, pub(crate) messages: SessionMessageService, pub(crate) scheduled_tasks: ScheduledAgentTaskService, pub(crate) memory_maintenance: MemoryMaintenanceCoordinator, } impl SessionManager { pub(crate) fn from_services(services: SessionManagerServices) -> Self { Self { tools: services.tools, skills: services.skills, store: services.store, show_tool_results: services.show_tool_results, lifecycle: services.lifecycle, cli_sessions: services.cli_sessions, messages: services.messages, scheduled_tasks: services.scheduled_tasks, memory_maintenance: services.memory_maintenance, } } pub fn new( session_ttl_hours: u64, agent_prompt_reinject_every: u64, show_tool_results: bool, default_timezone: String, provider_config: LLMProviderConfig, provider_configs: HashMap, skills: Arc, ) -> Result { super::runtime::build_session_manager( session_ttl_hours, agent_prompt_reinject_every, show_tool_results, default_timezone, provider_config, provider_configs, skills, ) } pub fn tools(&self) -> Arc { self.tools.clone() } pub fn store(&self) -> Arc { self.store.clone() } pub fn show_tool_results(&self) -> bool { self.show_tool_results } pub fn skills(&self) -> Arc { self.skills.clone() } pub(crate) fn cli_sessions(&self) -> CliSessionService { self.cli_sessions.clone() } #[cfg_attr(not(test), allow(dead_code))] pub(crate) async fn summarize_memory_maintenance_for_scope( &self, scope_key: &str, ) -> Result, AgentError> { self.memory_maintenance.summarize_for_scope(scope_key).await } pub(crate) async fn run_memory_maintenance_for_all_scopes( &self, updated_since: Option, ) -> Result, AgentError> { self.memory_maintenance .run_for_all_scopes(updated_since) .await } /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { self.lifecycle.ensure_session(channel_name).await } /// 获取 session(不检查超时) pub async fn get(&self, channel_name: &str) -> Option>> { self.lifecycle.get(channel_name).await } /// 更新最后活跃时间 pub async fn touch(&self, channel_name: &str) { self.lifecycle.touch(channel_name).await; } pub async fn cleanup_expired_sessions(&self) -> usize { self.lifecycle.cleanup_expired_sessions().await } /// 处理消息:路由到对应 session 的 agent pub async fn handle_message( &self, channel_name: &str, sender_id: &str, chat_id: &str, content: &str, media: Vec, live_emitter: Option>, ) -> Result, AgentError> { self.messages .handle_message( channel_name, sender_id, chat_id, content, media, live_emitter, ) .await } pub async fn run_scheduled_agent_task( &self, channel_name: &str, chat_id: &str, prompt: &str, options: ScheduledAgentTaskOptions, ) -> Result, AgentError> { self.scheduled_tasks .run(channel_name, chat_id, prompt, options) .await } /// 清除指定 session 的所有历史 pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { if let Some(session) = self.get(channel_name).await { let mut session_guard = session.lock().await; session_guard.clear_all_history()?; } Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::bus::MessageBus; use crate::gateway::tool_registry_factory::ToolRegistryFactory; use crate::storage::MemoryRecord; use crate::tools::NoopSessionMessageSender; use axum::http::StatusCode; use axum::{Json, Router, routing::post}; use serde_json::{Value, json}; use std::collections::{HashMap, HashSet}; use std::sync::{ Arc as StdArc, atomic::{AtomicUsize, Ordering}, }; use tokio::net::TcpListener; use tokio::sync::mpsc; fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), name: "test".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), llm_timeout_secs: 120, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, } } #[tokio::test] async fn test_latest_user_message_guard_tracks_current_turn() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), ) .build(), ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), user_tx, tools, skills, store, 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); let first = session.create_user_message("first", Vec::new()); let first_id = first.id.clone(); session.append_persisted_message("chat-1", first).unwrap(); assert!(session.is_latest_user_message("chat-1", &first_id)); let second = session.create_user_message("second", Vec::new()); let second_id = second.id.clone(); session.append_persisted_message("chat-1", second).unwrap(); assert!(!session.is_latest_user_message("chat-1", &first_id)); assert!(session.is_latest_user_message("chat-1", &second_id)); } #[tokio::test] async fn test_current_user_turn_match_survives_history_compaction_reload() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), ) .build(), ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); let first = session.create_user_message("first", Vec::new()); let first_id = first.id.clone(); session.append_persisted_message("chat-1", first).unwrap(); session .append_persisted_message("chat-1", ChatMessage::assistant("answer-1")) .unwrap(); let second = session.create_user_message("second", Vec::new()); session .append_persisted_message("chat-1", second.clone()) .unwrap(); session .append_persisted_message("chat-1", ChatMessage::assistant("answer-2")) .unwrap(); let session_id = session.persistent_session_id("chat-1"); let snapshot_end_seq = store .get_session(&session_id) .unwrap() .unwrap() .message_count; let preserved_messages = session.get_history("chat-1").unwrap().clone(); store .compact_active_history( &session_id, 0, snapshot_end_seq, &[], &ChatMessage::system("[Compressed History]\n\nsummary"), &preserved_messages, ) .unwrap(); session.reload_chat_history("chat-1").unwrap(); assert!(!session.is_latest_user_message("chat-1", &first_id)); assert!(!session.is_latest_user_message("chat-1", &second.id)); assert!(session.matches_current_user_turn("chat-1", &second)); } async fn start_mock_openai_server() -> String { start_mock_openai_server_with_content(None).await } async fn start_mock_openai_server_with_content( mock_response_content: Option, ) -> String { async fn handle( axum::extract::State(mock_response_content): axum::extract::State>, Json(body): Json, ) -> Json { let model = body .get("model") .and_then(|value| value.as_str()) .unwrap_or("unknown-model"); let content = mock_response_content.unwrap_or_else(|| format!("reply from {}", model)); Json(json!({ "id": "mock-response", "model": model, "choices": [ { "message": { "content": content, "tool_calls": [] } } ], "usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 } })) } let app = Router::new() .route("/chat/completions", post(handle)) .with_state(mock_response_content); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); format!("http://{}", address) } async fn start_mock_openai_504_server() -> String { async fn handle() -> (StatusCode, &'static str) { (StatusCode::GATEWAY_TIMEOUT, "stream timeout") } let app = Router::new().route("/chat/completions", post(handle)); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); format!("http://{}", address) } async fn start_mock_openai_flaky_server(mock_response_content: String) -> String { let attempts = StdArc::new(AtomicUsize::new(0)); let state = (attempts, mock_response_content); async fn handle( axum::extract::State((attempts, mock_response_content)): axum::extract::State<( StdArc, String, )>, Json(body): Json, ) -> (StatusCode, Json) { let attempt = attempts.fetch_add(1, Ordering::SeqCst); if attempt == 0 { return ( StatusCode::GATEWAY_TIMEOUT, Json(json!({"error": "stream timeout"})), ); } let model = body .get("model") .and_then(|value| value.as_str()) .unwrap_or("unknown-model"); ( StatusCode::OK, Json(json!({ "id": "mock-response", "model": model, "choices": [ { "message": { "content": mock_response_content, "tool_calls": [] } } ], "usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 } })), ) } let app = Router::new() .route("/chat/completions", post(handle)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); format!("http://{}", address) } #[tokio::test] async fn test_handle_message_returns_recoverable_reply_on_llm_504() { let base_url = start_mock_openai_504_server().await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "timeout-provider".to_string(), base_url: base_url.clone(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "timeout-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); let outbound = session_manager .handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None) .await .unwrap(); assert_eq!(outbound.len(), 1); assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时")); } #[tokio::test] async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() { let base_url = start_mock_openai_server().await; let default_provider = LLMProviderConfig { provider_type: "openai".to_string(), name: "default-provider".to_string(), base_url: base_url.clone(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "default-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let planner_provider = LLMProviderConfig { model_id: "planner-model".to_string(), name: "planner-provider".to_string(), ..default_provider.clone() }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), default_provider.clone(), HashMap::from([ ("default".to_string(), default_provider), ("planner".to_string(), planner_provider), ]), Arc::new(SkillRuntime::default()), ) .unwrap(); let planner_outbound = session_manager .run_scheduled_agent_task( "feishu", "chat-planner", "请规划今天工作", ScheduledAgentTaskOptions { agent: Some("planner".to_string()), fresh_session: true, ..Default::default() }, ) .await .unwrap(); assert_eq!(planner_outbound.len(), 1); assert!(planner_outbound[0].content.contains("planner-model")); let default_outbound = session_manager .run_scheduled_agent_task( "feishu", "chat-default", "请规划今天工作", ScheduledAgentTaskOptions { agent: Some("default".to_string()), fresh_session: true, ..Default::default() }, ) .await .unwrap(); assert_eq!(default_outbound.len(), 1); assert!(default_outbound[0].content.contains("default-model")); } #[tokio::test] async fn test_run_scheduled_agent_task_persists_execution_guard_prompt() { let base_url = start_mock_openai_server().await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "default-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "default-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); session_manager .run_scheduled_agent_task( "feishu", "chat-guard", "每小时执行以下流程:检查邮箱并同步待办", ScheduledAgentTaskOptions { fresh_session: true, system_prompt: Some("你是邮箱待办同步助手。".to_string()), ..Default::default() }, ) .await .unwrap(); let session = session_manager.get("feishu").await.unwrap(); let session_guard = session.lock().await; let persisted_messages = session_guard .store() .load_messages(&session_guard.persistent_session_id("chat-guard")) .unwrap(); let scheduled_prompt = persisted_messages .iter() .find(|message| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT)) .expect("missing scheduled system prompt"); assert!(scheduled_prompt.content.contains("已经触发的定时任务执行")); assert!( scheduled_prompt .content .contains("不要调用任何定时任务管理工具") ); assert!(scheduled_prompt.content.contains("你是邮箱待办同步助手。")); } #[tokio::test] async fn test_summarize_memory_maintenance_for_scope_uses_model_output() { let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": ["偏好简洁表达"], "behavior_patterns": ["习惯先问方案再要代码"], "merges": [], "conflicts": [], "low_value_ids": [], "managed_markdown": "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达\n\n### 行为模式\n- 习惯先问方案再要代码" })) .unwrap(); let base_url = start_mock_openai_server_with_content(Some(mock_response_content.clone())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let output = session_manager .summarize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap() .unwrap(); assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]); assert_eq!(output.preferences, vec!["偏好简洁表达".to_string()]); assert_eq!( output.behavior_patterns, vec!["习惯先问方案再要代码".to_string()] ); assert!(output.managed_markdown.contains("### 用户事实")); } #[test] fn test_is_recoverable_maintenance_llm_error_detects_transport_failures() { assert!(is_recoverable_maintenance_llm_error( "error sending request for url (https://example.invalid/v1/chat/completions)" )); assert!(is_recoverable_maintenance_llm_error( "API error 504 Gateway Timeout: stream timeout" )); assert!(!is_recoverable_maintenance_llm_error( "API error 401 Unauthorized" )); } #[test] fn test_extract_json_object_skips_wrapping_text() { let wrapped = "下面是结果:\n```json\n{\n \"user_facts\": [],\n \"preferences\": []\n}\n```\n请查收"; let stripped = strip_json_code_fence(wrapped); let extracted = extract_json_object(stripped).unwrap(); assert_eq!( extracted, "{\n \"user_facts\": [],\n \"preferences\": []\n}" ); } #[tokio::test] async fn test_summarize_memory_maintenance_transport_error_includes_provider_context() { let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url: "https://example.invalid/v1".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let error = session_manager .summarize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap_err() .to_string(); assert!(error.contains("memory maintenance model error: transport error:")); assert!(error.contains("provider=maintenance-provider")); assert!(error.contains("model=maintenance-model")); assert!(error.contains("url=https://example.invalid/v1/chat/completions")); assert!(error.contains("timeout_secs=1")); assert!(error.contains("error sending request for url")); } #[tokio::test] async fn test_summarize_memory_maintenance_retries_recoverable_provider_errors() { let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": [], "behavior_patterns": [], "merges": [], "conflicts": [], "low_value_ids": [], "managed_markdown": "### 用户事实\n- 用户在做AI产品" })) .unwrap(); let base_url = start_mock_openai_flaky_server(mock_response_content.clone()).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let output = session_manager .summarize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap() .unwrap(); assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]); } #[tokio::test] async fn test_summarize_memory_maintenance_for_scope_extracts_wrapped_json_object() { let mock_response_content = "结果如下:\n```json\n{\n \"user_facts\": [\"用户在做AI产品\"],\n \"preferences\": [],\n \"behavior_patterns\": [],\n \"merges\": [],\n \"conflicts\": [],\n \"low_value_ids\": [],\n \"managed_markdown\": \"### 用户事实\\n- 用户在做AI产品\"\n}\n```\n"; let base_url = start_mock_openai_server_with_content(Some(mock_response_content.to_string())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url, api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::from([( "mock_response_content".to_string(), json!(mock_response_content), )]), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let output = session_manager .summarize_memory_maintenance_for_scope("feishu:user-1") .await .unwrap() .unwrap(); assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]); assert!(output.managed_markdown.contains("### 用户事实")); } #[tokio::test] async fn test_run_memory_maintenance_for_all_scopes_returns_empty_when_no_recent_updates() { let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 30, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); let memory = session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let results = session_manager .run_memory_maintenance_for_all_scopes(Some(memory.updated_at + 1)) .await .unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn test_run_memory_maintenance_for_all_scopes_skips_recoverable_transport_failures() { let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), name: "maintenance-provider".to_string(), base_url: "https://example.invalid/v1".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), model_id: "maintenance-model".to_string(), temperature: Some(0.0), max_tokens: Some(256), context_window_tokens: None, model_extra: HashMap::new(), max_tool_iterations: 1, llm_timeout_secs: 1, tool_result_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }; let session_manager = SessionManager::new( 4, 100, false, "Asia/Shanghai".to_string(), provider_config.clone(), HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), ) .unwrap(); for scope_key in ["feishu:user-1", "feishu:user-2"] { session_manager .store() .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: format!("{} 在做AI产品", scope_key), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); } let results = session_manager .run_memory_maintenance_for_all_scopes(None) .await .unwrap(); assert!(results.is_empty()); } #[test] fn test_apply_memory_maintenance_output_merges_and_deletes_low_value_records() { let store = SessionStore::in_memory().unwrap(); let scope_key = "feishu:user-1"; let work = store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "profile".to_string(), memory_key: "work_short".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let role = store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "profile".to_string(), memory_key: "work_detail".to_string(), content: "用户主要在做AI产品设计和实现".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let noise = store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: scope_key.to_string(), namespace: "notes".to_string(), memory_key: "temporary".to_string(), content: "今天临时提到过一个无后续的小细节".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, }) .unwrap(); let plan = build_memory_maintenance_plan( &store.list_memories_for_scope("user", scope_key).unwrap(), ); let output = MemoryMaintenanceModelOutput { user_facts: vec!["用户在做AI产品".to_string()], preferences: Vec::new(), behavior_patterns: Vec::new(), merges: vec![MemoryMaintenanceMerge { source_ids: vec![work.id.clone(), role.id.clone()], namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户主要在做AI产品设计与实现".to_string(), }], conflicts: Vec::new(), low_value_ids: vec![noise.id.clone()], managed_markdown: "### 用户事实\n- 用户在做AI产品".to_string(), }; apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap(); let all_memories = store.list_memories_for_scope("user", scope_key).unwrap(); assert_eq!(all_memories.len(), 1); assert_eq!(all_memories[0].namespace, "profile"); assert_eq!(all_memories[0].memory_key, "work"); assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现"); } #[test] fn test_combine_managed_memory_markdown_prefers_richer_summary_over_subset() { let combined = combine_managed_memory_markdown(&[ "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达".to_string(), "- 用户在做AI产品".to_string(), "### 用户事实\n- 用户名为区德成,昵称DC。".to_string(), ]); assert!( combined.contains("### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达") ); assert!(combined.contains("### 用户事实\n- 用户名为区德成,昵称DC。")); assert_eq!(combined.matches("- 用户在做AI产品").count(), 1); } #[test] fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { let completed = ChatMessage::tool("call-1", "calculator", "2"); let pending = ChatMessage::tool_with_state( "call-2", "bash", "waiting", crate::bus::message::ToolMessageState::PendingUserAction, ); assert!(!should_display_message_to_user(false, &completed)); assert!(should_display_message_to_user(false, &pending)); assert!(should_display_message_to_user(true, &completed)); } #[tokio::test] async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() { let bus = MessageBus::new(4); let emitter = BusToolCallEmitter::new(bus.clone(), "feishu", "chat-1", HashMap::new(), false); emitter .handle(ChatMessage::tool("call-1", "calculator", "2")) .await; assert!( tokio::time::timeout(std::time::Duration::from_millis(50), bus.consume_outbound()) .await .is_err() ); } #[tokio::test] async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), ) .build(), ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); let history = session.get_history("chat-1").unwrap(); assert_eq!(history.len(), 1); assert_eq!(history[0].role, "system"); assert!(history[0].content.contains("PicoBot 代理配置")); } #[tokio::test] async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), ) .build(), ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 100, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); for turn in 0..100 { session .append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}"))) .unwrap(); } session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); let history = session.get_history("chat-1").unwrap(); let system_messages = history .iter() .filter(|message| message.role == "system") .count(); assert_eq!(system_messages, 2); let stored = store .get_session(&session.persistent_session_id("chat-1")) .unwrap() .unwrap(); assert_eq!(stored.agent_prompt_reinjection_count, 1); session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); let history = session.get_history("chat-1").unwrap(); let system_messages = history .iter() .filter(|message| message.role == "system") .count(); assert_eq!(system_messages, 2); } #[tokio::test] async fn test_agent_prompt_reinjection_can_be_disabled_by_config() { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new( ToolRegistryFactory::new( skills.clone(), store.clone(), store.clone(), store.clone(), Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), ) .build(), ); let mut session = Session::new( "feishu".to_string(), test_provider_config(), user_tx, tools, skills, store.clone(), 0, ) .await .unwrap(); session.ensure_persistent_session("chat-1").unwrap(); session.ensure_chat_loaded("chat-1").unwrap(); for turn in 0..100 { session .append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}"))) .unwrap(); } session .ensure_agent_prompt_before_user_message("chat-1") .unwrap(); let history = session.get_history("chat-1").unwrap(); let system_messages = history .iter() .filter(|message| message.role == "system") .count(); assert_eq!(system_messages, 1); } #[test] fn test_default_tools_registers_get_time() { let skills = Arc::new(SkillRuntime::default()); let store = Arc::new(SessionStore::in_memory().unwrap()); let tools = ToolRegistryFactory::new( skills, store.clone(), store.clone(), store, Arc::new(NoopSessionMessageSender), HashSet::new(), "Asia/Shanghai".to_string(), ) .build(); assert!(tools.tool_names().iter().any(|name| name == "get_time")); } #[test] fn test_build_memory_maintenance_plan_deduplicates_and_categorizes() { let memories = vec![ MemoryRecord { id: "1".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 1, updated_at: 1, }, MemoryRecord { id: "2".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 2, updated_at: 2, }, MemoryRecord { id: "3".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "preferences".to_string(), memory_key: "style".to_string(), content: "偏好简洁表达".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 3, updated_at: 3, }, MemoryRecord { id: "4".to_string(), scope_kind: "user".to_string(), scope_key: "feishu:user-1".to_string(), namespace: "patterns".to_string(), memory_key: "workflow".to_string(), content: "习惯先问方案再要代码".to_string(), source_type: "message".to_string(), source_session_id: None, source_message_id: None, source_message_seq: None, source_channel_name: None, source_chat_id: None, created_at: 4, updated_at: 4, }, ]; let plan = build_memory_maintenance_plan(&memories); assert_eq!(plan.user_facts.len(), 1); assert_eq!(plan.preferences.len(), 1); assert_eq!(plan.behavior_patterns.len(), 1); assert!(plan.others.is_empty()); assert_eq!(plan.user_facts[0].content, "用户在做AI产品"); assert_eq!(plan.preferences[0].content, "偏好简洁表达"); assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码"); } }