1729 lines
60 KiB
Rust
1729 lines
60 KiB
Rust
use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
|
||
#[cfg(test)]
|
||
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::protocol::WsOutbound;
|
||
use crate::scheduler::ScheduledAgentTaskOptions;
|
||
use crate::skills::SkillRuntime;
|
||
use crate::storage::{ConversationRepository, SessionRecord, SessionStore, SkillEventRepository};
|
||
use crate::tools::ToolRegistry;
|
||
use async_trait::async_trait;
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use tokio::sync::{Mutex, mpsc};
|
||
use uuid::Uuid;
|
||
|
||
use super::agent_factory::{AgentBuildRequest, AgentFactory};
|
||
use super::cli_session::CliSessionService;
|
||
use super::execution::should_display_message_to_user;
|
||
#[cfg(test)]
|
||
use super::memory_maintenance::{
|
||
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
|
||
combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error,
|
||
strip_json_code_fence,
|
||
};
|
||
use super::memory_maintenance::{MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult};
|
||
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
|
||
use super::prompt_injector::PromptInjector;
|
||
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
|
||
use super::session_history::SessionHistory;
|
||
use super::session_lifecycle::SessionLifecycleService;
|
||
use super::session_message_service::SessionMessageService;
|
||
|
||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||
pub struct Session {
|
||
pub id: Uuid,
|
||
pub channel_name: String,
|
||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||
provider_config: LLMProviderConfig,
|
||
skills: Arc<SkillRuntime>,
|
||
agent_factory: AgentFactory,
|
||
compressor: ContextCompressor,
|
||
history: SessionHistory,
|
||
}
|
||
|
||
pub struct BusToolCallEmitter {
|
||
bus: Arc<MessageBus>,
|
||
channel_name: String,
|
||
chat_id: String,
|
||
metadata: HashMap<String, String>,
|
||
show_tool_results: bool,
|
||
}
|
||
|
||
impl BusToolCallEmitter {
|
||
pub fn new(
|
||
bus: Arc<MessageBus>,
|
||
channel_name: impl Into<String>,
|
||
chat_id: impl Into<String>,
|
||
metadata: HashMap<String, String>,
|
||
show_tool_results: bool,
|
||
) -> Self {
|
||
Self {
|
||
bus,
|
||
channel_name: channel_name.into(),
|
||
chat_id: chat_id.into(),
|
||
metadata,
|
||
show_tool_results,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl EmittedMessageHandler for BusToolCallEmitter {
|
||
async fn handle(&self, message: ChatMessage) {
|
||
if !should_display_message_to_user(self.show_tool_results, &message) {
|
||
return;
|
||
}
|
||
|
||
for outbound in OutboundMessage::from_chat_message(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None,
|
||
&self.metadata,
|
||
&message,
|
||
) {
|
||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Session {
|
||
pub async fn new(
|
||
channel_name: String,
|
||
provider_config: LLMProviderConfig,
|
||
user_tx: mpsc::Sender<WsOutbound>,
|
||
tools: Arc<ToolRegistry>,
|
||
skills: Arc<SkillRuntime>,
|
||
store: Arc<SessionStore>,
|
||
agent_prompt_reinject_every: u64,
|
||
) -> Result<Self, AgentError> {
|
||
let agent_factory = AgentFactory::new(tools, skills.clone());
|
||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||
let prompt_injector = PromptInjector::new(store.clone(), agent_prompt_reinject_every);
|
||
Self::with_factories(
|
||
channel_name,
|
||
provider_config,
|
||
user_tx,
|
||
skills,
|
||
agent_factory,
|
||
prompt_injector,
|
||
conversations,
|
||
skill_events,
|
||
)
|
||
.await
|
||
}
|
||
|
||
pub(crate) async fn with_factories(
|
||
channel_name: String,
|
||
provider_config: LLMProviderConfig,
|
||
user_tx: mpsc::Sender<WsOutbound>,
|
||
skills: Arc<SkillRuntime>,
|
||
agent_factory: AgentFactory,
|
||
prompt_injector: PromptInjector,
|
||
conversations: Arc<dyn ConversationRepository>,
|
||
skill_events: Arc<dyn SkillEventRepository>,
|
||
) -> Result<Self, AgentError> {
|
||
Ok(Self {
|
||
id: Uuid::new_v4(),
|
||
channel_name: channel_name.clone(),
|
||
user_tx,
|
||
provider_config: provider_config.clone(),
|
||
skills,
|
||
agent_factory,
|
||
compressor: ContextCompressor::from_provider_config(&provider_config),
|
||
history: SessionHistory::new(
|
||
channel_name,
|
||
prompt_injector,
|
||
conversations,
|
||
skill_events,
|
||
),
|
||
})
|
||
}
|
||
|
||
pub fn persistent_session_id(&self, chat_id: &str) -> String {
|
||
self.history.persistent_session_id(chat_id)
|
||
}
|
||
|
||
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||
self.history.ensure_persistent_session(chat_id)
|
||
}
|
||
|
||
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
self.history.ensure_chat_loaded(chat_id)
|
||
}
|
||
|
||
pub fn ensure_agent_prompt_before_user_message(
|
||
&mut self,
|
||
chat_id: &str,
|
||
) -> Result<(), AgentError> {
|
||
self.history
|
||
.ensure_agent_prompt_before_user_message(chat_id)
|
||
}
|
||
|
||
/// 获取或创建指定 chat_id 的会话历史
|
||
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||
self.history.get_or_create_history(chat_id)
|
||
}
|
||
|
||
/// 获取指定 chat_id 的会话历史(不创建)
|
||
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
|
||
self.history.get_history(chat_id)
|
||
}
|
||
|
||
/// 使用完整消息追加到历史
|
||
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||
self.history.add_message(chat_id, message);
|
||
}
|
||
|
||
pub fn remove_history(&mut self, chat_id: &str) {
|
||
self.history.remove_history(chat_id);
|
||
}
|
||
|
||
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
self.history.clear_chat_history(chat_id)
|
||
}
|
||
|
||
pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
self.history.reset_chat_context(chat_id)
|
||
}
|
||
|
||
/// 将消息写入内存与持久化层
|
||
pub fn append_persisted_message(
|
||
&mut self,
|
||
chat_id: &str,
|
||
message: ChatMessage,
|
||
) -> Result<(), AgentError> {
|
||
self.history.append_persisted_message(chat_id, message)
|
||
}
|
||
|
||
pub fn append_persisted_messages<I>(
|
||
&mut self,
|
||
chat_id: &str,
|
||
messages: I,
|
||
) -> Result<(), AgentError>
|
||
where
|
||
I: IntoIterator<Item = ChatMessage>,
|
||
{
|
||
self.history.append_persisted_messages(chat_id, messages)
|
||
}
|
||
|
||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||
if media_refs.is_empty() {
|
||
ChatMessage::user(content)
|
||
} else {
|
||
ChatMessage::user_with_media(content, media_refs)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
||
self.latest_user_message(chat_id)
|
||
.map(|message| message.id.as_str())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
|
||
self.history.latest_user_message(chat_id)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
||
self.latest_user_message_id(chat_id)
|
||
.map(|current_id| current_id == message_id)
|
||
.unwrap_or(false)
|
||
}
|
||
|
||
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||
self.history.matches_current_user_turn(chat_id, message)
|
||
}
|
||
|
||
pub(crate) fn stale_result_diagnostics(
|
||
&self,
|
||
chat_id: &str,
|
||
) -> (Option<&str>, Option<String>, bool, usize) {
|
||
self.history.stale_result_diagnostics(chat_id)
|
||
}
|
||
|
||
/// 清除所有历史
|
||
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||
self.history.clear_all_history()
|
||
}
|
||
|
||
pub async fn send(&self, msg: WsOutbound) {
|
||
let _ = self.user_tx.send(msg).await;
|
||
}
|
||
|
||
/// 获取 provider_config 引用
|
||
pub fn provider_config(&self) -> &LLMProviderConfig {
|
||
&self.provider_config
|
||
}
|
||
|
||
/// 获取 compressor 引用
|
||
pub fn compressor(&self) -> &ContextCompressor {
|
||
&self.compressor
|
||
}
|
||
|
||
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||
self.history.try_start_background_compaction(chat_id)
|
||
}
|
||
|
||
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
|
||
self.history.finish_background_compaction(chat_id);
|
||
}
|
||
|
||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
self.history.reload_chat_history(chat_id)
|
||
}
|
||
|
||
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
|
||
self.history.conversations()
|
||
}
|
||
|
||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||
if self.skills.is_empty() {
|
||
return Ok(());
|
||
}
|
||
|
||
self.history.append_skill_event(
|
||
chat_id,
|
||
"offered",
|
||
None,
|
||
&self.skills.offered_event_payload(),
|
||
)
|
||
}
|
||
|
||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||
pub fn create_agent(
|
||
&self,
|
||
chat_id: &str,
|
||
sender_id: Option<&str>,
|
||
message_id: Option<&str>,
|
||
) -> Result<AgentLoop, AgentError> {
|
||
self.create_agent_with_provider_config(
|
||
chat_id,
|
||
sender_id,
|
||
message_id,
|
||
self.provider_config.clone(),
|
||
)
|
||
}
|
||
|
||
pub fn create_agent_with_provider_config(
|
||
&self,
|
||
chat_id: &str,
|
||
sender_id: Option<&str>,
|
||
message_id: Option<&str>,
|
||
provider_config: LLMProviderConfig,
|
||
) -> Result<AgentLoop, AgentError> {
|
||
self.agent_factory.create(AgentBuildRequest {
|
||
channel_name: &self.channel_name,
|
||
chat_id,
|
||
sender_id,
|
||
message_id,
|
||
provider_config,
|
||
})
|
||
}
|
||
}
|
||
|
||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||
#[derive(Clone)]
|
||
pub struct SessionManager {
|
||
tools: Arc<ToolRegistry>,
|
||
skills: Arc<SkillRuntime>,
|
||
store: Arc<SessionStore>,
|
||
show_tool_results: bool,
|
||
lifecycle: SessionLifecycleService,
|
||
cli_sessions: CliSessionService,
|
||
messages: SessionMessageService,
|
||
scheduled_tasks: ScheduledAgentTaskService,
|
||
memory_maintenance: MemoryMaintenanceCoordinator,
|
||
}
|
||
|
||
pub(crate) struct SessionManagerServices {
|
||
pub(crate) tools: Arc<ToolRegistry>,
|
||
pub(crate) skills: Arc<SkillRuntime>,
|
||
pub(crate) store: Arc<SessionStore>,
|
||
pub(crate) show_tool_results: bool,
|
||
pub(crate) lifecycle: SessionLifecycleService,
|
||
pub(crate) cli_sessions: CliSessionService,
|
||
pub(crate) messages: SessionMessageService,
|
||
pub(crate) scheduled_tasks: ScheduledAgentTaskService,
|
||
pub(crate) memory_maintenance: MemoryMaintenanceCoordinator,
|
||
}
|
||
|
||
impl SessionManager {
|
||
pub(crate) fn from_services(services: SessionManagerServices) -> Self {
|
||
Self {
|
||
tools: services.tools,
|
||
skills: services.skills,
|
||
store: services.store,
|
||
show_tool_results: services.show_tool_results,
|
||
lifecycle: services.lifecycle,
|
||
cli_sessions: services.cli_sessions,
|
||
messages: services.messages,
|
||
scheduled_tasks: services.scheduled_tasks,
|
||
memory_maintenance: services.memory_maintenance,
|
||
}
|
||
}
|
||
|
||
pub fn new(
|
||
session_ttl_hours: u64,
|
||
agent_prompt_reinject_every: u64,
|
||
show_tool_results: bool,
|
||
default_timezone: String,
|
||
provider_config: LLMProviderConfig,
|
||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||
skills: Arc<SkillRuntime>,
|
||
) -> Result<Self, AgentError> {
|
||
super::runtime::build_session_manager(
|
||
session_ttl_hours,
|
||
agent_prompt_reinject_every,
|
||
show_tool_results,
|
||
default_timezone,
|
||
provider_config,
|
||
provider_configs,
|
||
skills,
|
||
)
|
||
}
|
||
|
||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
||
self.tools.clone()
|
||
}
|
||
|
||
pub fn store(&self) -> Arc<SessionStore> {
|
||
self.store.clone()
|
||
}
|
||
|
||
pub fn show_tool_results(&self) -> bool {
|
||
self.show_tool_results
|
||
}
|
||
|
||
pub fn skills(&self) -> Arc<SkillRuntime> {
|
||
self.skills.clone()
|
||
}
|
||
|
||
pub(crate) fn cli_sessions(&self) -> CliSessionService {
|
||
self.cli_sessions.clone()
|
||
}
|
||
|
||
#[cfg_attr(not(test), allow(dead_code))]
|
||
pub(crate) async fn summarize_memory_maintenance_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||
self.memory_maintenance.summarize_for_scope(scope_key).await
|
||
}
|
||
|
||
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||
&self,
|
||
updated_since: Option<i64>,
|
||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||
self.memory_maintenance
|
||
.run_for_all_scopes(updated_since)
|
||
.await
|
||
}
|
||
|
||
/// 确保 session 存在且未超时,超时则重建
|
||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
self.lifecycle.ensure_session(channel_name).await
|
||
}
|
||
|
||
/// 获取 session(不检查超时)
|
||
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
||
self.lifecycle.get(channel_name).await
|
||
}
|
||
|
||
/// 更新最后活跃时间
|
||
pub async fn touch(&self, channel_name: &str) {
|
||
self.lifecycle.touch(channel_name).await;
|
||
}
|
||
|
||
pub async fn cleanup_expired_sessions(&self) -> usize {
|
||
self.lifecycle.cleanup_expired_sessions().await
|
||
}
|
||
|
||
/// 处理消息:路由到对应 session 的 agent
|
||
pub async fn handle_message(
|
||
&self,
|
||
channel_name: &str,
|
||
sender_id: &str,
|
||
chat_id: &str,
|
||
content: &str,
|
||
media: Vec<crate::bus::MediaItem>,
|
||
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
|
||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||
self.messages
|
||
.handle_message(
|
||
channel_name,
|
||
sender_id,
|
||
chat_id,
|
||
content,
|
||
media,
|
||
live_emitter,
|
||
)
|
||
.await
|
||
}
|
||
|
||
pub async fn run_scheduled_agent_task(
|
||
&self,
|
||
channel_name: &str,
|
||
chat_id: &str,
|
||
prompt: &str,
|
||
options: ScheduledAgentTaskOptions,
|
||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||
self.scheduled_tasks
|
||
.run(channel_name, chat_id, prompt, options)
|
||
.await
|
||
}
|
||
|
||
/// 清除指定 session 的所有历史
|
||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
if let Some(session) = self.get(channel_name).await {
|
||
let mut session_guard = session.lock().await;
|
||
session_guard.clear_all_history()?;
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::bus::MessageBus;
|
||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||
use crate::storage::MemoryRecord;
|
||
use crate::tools::NoopSessionMessageSender;
|
||
use axum::http::StatusCode;
|
||
use axum::{Json, Router, routing::post};
|
||
use serde_json::{Value, json};
|
||
use std::collections::{HashMap, HashSet};
|
||
use std::sync::{
|
||
Arc as StdArc,
|
||
atomic::{AtomicUsize, Ordering},
|
||
};
|
||
use tokio::net::TcpListener;
|
||
use tokio::sync::mpsc;
|
||
|
||
fn test_provider_config() -> LLMProviderConfig {
|
||
LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "test".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
llm_timeout_secs: 120,
|
||
model_id: "test-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"feishu".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store,
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
let first = session.create_user_message("first", Vec::new());
|
||
let first_id = first.id.clone();
|
||
session.append_persisted_message("chat-1", first).unwrap();
|
||
assert!(session.is_latest_user_message("chat-1", &first_id));
|
||
|
||
let second = session.create_user_message("second", Vec::new());
|
||
let second_id = second.id.clone();
|
||
session.append_persisted_message("chat-1", second).unwrap();
|
||
|
||
assert!(!session.is_latest_user_message("chat-1", &first_id));
|
||
assert!(session.is_latest_user_message("chat-1", &second_id));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_current_user_turn_match_survives_history_compaction_reload() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"feishu".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
let first = session.create_user_message("first", Vec::new());
|
||
let first_id = first.id.clone();
|
||
session.append_persisted_message("chat-1", first).unwrap();
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::assistant("answer-1"))
|
||
.unwrap();
|
||
|
||
let second = session.create_user_message("second", Vec::new());
|
||
session
|
||
.append_persisted_message("chat-1", second.clone())
|
||
.unwrap();
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::assistant("answer-2"))
|
||
.unwrap();
|
||
|
||
let session_id = session.persistent_session_id("chat-1");
|
||
let snapshot_end_seq = store
|
||
.get_session(&session_id)
|
||
.unwrap()
|
||
.unwrap()
|
||
.message_count;
|
||
let preserved_messages = session.get_history("chat-1").unwrap().clone();
|
||
|
||
store
|
||
.compact_active_history(
|
||
&session_id,
|
||
0,
|
||
snapshot_end_seq,
|
||
&[],
|
||
&ChatMessage::system("[Compressed History]\n\nsummary"),
|
||
&preserved_messages,
|
||
)
|
||
.unwrap();
|
||
|
||
session.reload_chat_history("chat-1").unwrap();
|
||
|
||
assert!(!session.is_latest_user_message("chat-1", &first_id));
|
||
assert!(!session.is_latest_user_message("chat-1", &second.id));
|
||
assert!(session.matches_current_user_turn("chat-1", &second));
|
||
}
|
||
|
||
async fn start_mock_openai_server() -> String {
|
||
start_mock_openai_server_with_content(None).await
|
||
}
|
||
|
||
async fn start_mock_openai_server_with_content(
|
||
mock_response_content: Option<String>,
|
||
) -> String {
|
||
async fn handle(
|
||
axum::extract::State(mock_response_content): axum::extract::State<Option<String>>,
|
||
Json(body): Json<Value>,
|
||
) -> Json<Value> {
|
||
let model = body
|
||
.get("model")
|
||
.and_then(|value| value.as_str())
|
||
.unwrap_or("unknown-model");
|
||
let content = mock_response_content.unwrap_or_else(|| format!("reply from {}", model));
|
||
|
||
Json(json!({
|
||
"id": "mock-response",
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"message": {
|
||
"content": content,
|
||
"tool_calls": []
|
||
}
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": 1,
|
||
"completion_tokens": 1,
|
||
"total_tokens": 2
|
||
}
|
||
}))
|
||
}
|
||
|
||
let app = Router::new()
|
||
.route("/chat/completions", post(handle))
|
||
.with_state(mock_response_content);
|
||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
let address = listener.local_addr().unwrap();
|
||
tokio::spawn(async move {
|
||
axum::serve(listener, app).await.unwrap();
|
||
});
|
||
format!("http://{}", address)
|
||
}
|
||
|
||
async fn start_mock_openai_504_server() -> String {
|
||
async fn handle() -> (StatusCode, &'static str) {
|
||
(StatusCode::GATEWAY_TIMEOUT, "stream timeout")
|
||
}
|
||
|
||
let app = Router::new().route("/chat/completions", post(handle));
|
||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
let address = listener.local_addr().unwrap();
|
||
tokio::spawn(async move {
|
||
axum::serve(listener, app).await.unwrap();
|
||
});
|
||
format!("http://{}", address)
|
||
}
|
||
|
||
async fn start_mock_openai_flaky_server(mock_response_content: String) -> String {
|
||
let attempts = StdArc::new(AtomicUsize::new(0));
|
||
let state = (attempts, mock_response_content);
|
||
|
||
async fn handle(
|
||
axum::extract::State((attempts, mock_response_content)): axum::extract::State<(
|
||
StdArc<AtomicUsize>,
|
||
String,
|
||
)>,
|
||
Json(body): Json<Value>,
|
||
) -> (StatusCode, Json<Value>) {
|
||
let attempt = attempts.fetch_add(1, Ordering::SeqCst);
|
||
if attempt == 0 {
|
||
return (
|
||
StatusCode::GATEWAY_TIMEOUT,
|
||
Json(json!({"error": "stream timeout"})),
|
||
);
|
||
}
|
||
|
||
let model = body
|
||
.get("model")
|
||
.and_then(|value| value.as_str())
|
||
.unwrap_or("unknown-model");
|
||
(
|
||
StatusCode::OK,
|
||
Json(json!({
|
||
"id": "mock-response",
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"message": {
|
||
"content": mock_response_content,
|
||
"tool_calls": []
|
||
}
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": 1,
|
||
"completion_tokens": 1,
|
||
"total_tokens": 2
|
||
}
|
||
})),
|
||
)
|
||
}
|
||
|
||
let app = Router::new()
|
||
.route("/chat/completions", post(handle))
|
||
.with_state(state);
|
||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
let address = listener.local_addr().unwrap();
|
||
tokio::spawn(async move {
|
||
axum::serve(listener, app).await.unwrap();
|
||
});
|
||
format!("http://{}", address)
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_handle_message_returns_recoverable_reply_on_llm_504() {
|
||
let base_url = start_mock_openai_504_server().await;
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "timeout-provider".to_string(),
|
||
base_url: base_url.clone(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "timeout-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
let outbound = session_manager
|
||
.handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert_eq!(outbound.len(), 1);
|
||
assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() {
|
||
let base_url = start_mock_openai_server().await;
|
||
let default_provider = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "default-provider".to_string(),
|
||
base_url: base_url.clone(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "default-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
let planner_provider = LLMProviderConfig {
|
||
model_id: "planner-model".to_string(),
|
||
name: "planner-provider".to_string(),
|
||
..default_provider.clone()
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
default_provider.clone(),
|
||
HashMap::from([
|
||
("default".to_string(), default_provider),
|
||
("planner".to_string(), planner_provider),
|
||
]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
let planner_outbound = session_manager
|
||
.run_scheduled_agent_task(
|
||
"feishu",
|
||
"chat-planner",
|
||
"请规划今天工作",
|
||
ScheduledAgentTaskOptions {
|
||
agent: Some("planner".to_string()),
|
||
fresh_session: true,
|
||
..Default::default()
|
||
},
|
||
)
|
||
.await
|
||
.unwrap();
|
||
assert_eq!(planner_outbound.len(), 1);
|
||
assert!(planner_outbound[0].content.contains("planner-model"));
|
||
|
||
let default_outbound = session_manager
|
||
.run_scheduled_agent_task(
|
||
"feishu",
|
||
"chat-default",
|
||
"请规划今天工作",
|
||
ScheduledAgentTaskOptions {
|
||
agent: Some("default".to_string()),
|
||
fresh_session: true,
|
||
..Default::default()
|
||
},
|
||
)
|
||
.await
|
||
.unwrap();
|
||
assert_eq!(default_outbound.len(), 1);
|
||
assert!(default_outbound[0].content.contains("default-model"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_scheduled_agent_task_persists_execution_guard_prompt() {
|
||
let base_url = start_mock_openai_server().await;
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "default-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "default-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.run_scheduled_agent_task(
|
||
"feishu",
|
||
"chat-guard",
|
||
"每小时执行以下流程:检查邮箱并同步待办",
|
||
ScheduledAgentTaskOptions {
|
||
fresh_session: true,
|
||
system_prompt: Some("你是邮箱待办同步助手。".to_string()),
|
||
..Default::default()
|
||
},
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let session = session_manager.get("feishu").await.unwrap();
|
||
let session_guard = session.lock().await;
|
||
let persisted_messages = session_guard
|
||
.store()
|
||
.load_messages(&session_guard.persistent_session_id("chat-guard"))
|
||
.unwrap();
|
||
|
||
let scheduled_prompt = persisted_messages
|
||
.iter()
|
||
.find(|message| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
|
||
.expect("missing scheduled system prompt");
|
||
|
||
assert!(scheduled_prompt.content.contains("已经触发的定时任务执行"));
|
||
assert!(
|
||
scheduled_prompt
|
||
.content
|
||
.contains("不要调用任何定时任务管理工具")
|
||
);
|
||
assert!(scheduled_prompt.content.contains("你是邮箱待办同步助手。"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_for_scope_uses_model_output() {
|
||
let mock_response_content = serde_json::to_string(&json!({
|
||
"user_facts": ["用户在做AI产品"],
|
||
"preferences": ["偏好简洁表达"],
|
||
"behavior_patterns": ["习惯先问方案再要代码"],
|
||
"merges": [],
|
||
"conflicts": [],
|
||
"low_value_ids": [],
|
||
"managed_markdown": "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达\n\n### 行为模式\n- 习惯先问方案再要代码"
|
||
}))
|
||
.unwrap();
|
||
let base_url =
|
||
start_mock_openai_server_with_content(Some(mock_response_content.clone())).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let output = session_manager
|
||
.summarize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]);
|
||
assert_eq!(output.preferences, vec!["偏好简洁表达".to_string()]);
|
||
assert_eq!(
|
||
output.behavior_patterns,
|
||
vec!["习惯先问方案再要代码".to_string()]
|
||
);
|
||
assert!(output.managed_markdown.contains("### 用户事实"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_is_recoverable_maintenance_llm_error_detects_transport_failures() {
|
||
assert!(is_recoverable_maintenance_llm_error(
|
||
"error sending request for url (https://example.invalid/v1/chat/completions)"
|
||
));
|
||
assert!(is_recoverable_maintenance_llm_error(
|
||
"API error 504 Gateway Timeout: stream timeout"
|
||
));
|
||
assert!(!is_recoverable_maintenance_llm_error(
|
||
"API error 401 Unauthorized"
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn test_extract_json_object_skips_wrapping_text() {
|
||
let wrapped = "下面是结果:\n```json\n{\n \"user_facts\": [],\n \"preferences\": []\n}\n```\n请查收";
|
||
let stripped = strip_json_code_fence(wrapped);
|
||
let extracted = extract_json_object(stripped).unwrap();
|
||
assert_eq!(
|
||
extracted,
|
||
"{\n \"user_facts\": [],\n \"preferences\": []\n}"
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_transport_error_includes_provider_context() {
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url: "https://example.invalid/v1".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 1,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let error = session_manager
|
||
.summarize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap_err()
|
||
.to_string();
|
||
|
||
assert!(error.contains("memory maintenance model error: transport error:"));
|
||
assert!(error.contains("provider=maintenance-provider"));
|
||
assert!(error.contains("model=maintenance-model"));
|
||
assert!(error.contains("url=https://example.invalid/v1/chat/completions"));
|
||
assert!(error.contains("timeout_secs=1"));
|
||
assert!(error.contains("error sending request for url"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_retries_recoverable_provider_errors() {
|
||
let mock_response_content = serde_json::to_string(&json!({
|
||
"user_facts": ["用户在做AI产品"],
|
||
"preferences": [],
|
||
"behavior_patterns": [],
|
||
"merges": [],
|
||
"conflicts": [],
|
||
"low_value_ids": [],
|
||
"managed_markdown": "### 用户事实\n- 用户在做AI产品"
|
||
}))
|
||
.unwrap();
|
||
let base_url = start_mock_openai_flaky_server(mock_response_content.clone()).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let output = session_manager
|
||
.summarize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_for_scope_extracts_wrapped_json_object() {
|
||
let mock_response_content = "结果如下:\n```json\n{\n \"user_facts\": [\"用户在做AI产品\"],\n \"preferences\": [],\n \"behavior_patterns\": [],\n \"merges\": [],\n \"conflicts\": [],\n \"low_value_ids\": [],\n \"managed_markdown\": \"### 用户事实\\n- 用户在做AI产品\"\n}\n```\n";
|
||
let base_url =
|
||
start_mock_openai_server_with_content(Some(mock_response_content.to_string())).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let output = session_manager
|
||
.summarize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]);
|
||
assert!(output.managed_markdown.contains("### 用户事实"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_memory_maintenance_for_all_scopes_returns_empty_when_no_recent_updates() {
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
let memory = session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let results = session_manager
|
||
.run_memory_maintenance_for_all_scopes(Some(memory.updated_at + 1))
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(results.is_empty());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_memory_maintenance_for_all_scopes_skips_recoverable_transport_failures() {
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url: "https://example.invalid/v1".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 1,
|
||
tool_result_max_chars: 20_000,
|
||
context_tool_result_trim_chars: 20_000,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
4,
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
)
|
||
.unwrap();
|
||
|
||
for scope_key in ["feishu:user-1", "feishu:user-2"] {
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: format!("{} 在做AI产品", scope_key),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
}
|
||
|
||
let results = session_manager
|
||
.run_memory_maintenance_for_all_scopes(None)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(results.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_apply_memory_maintenance_output_merges_and_deletes_low_value_records() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let scope_key = "feishu:user-1";
|
||
|
||
let work = store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work_short".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
let role = store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work_detail".to_string(),
|
||
content: "用户主要在做AI产品设计和实现".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
let noise = store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "notes".to_string(),
|
||
memory_key: "temporary".to_string(),
|
||
content: "今天临时提到过一个无后续的小细节".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let plan = build_memory_maintenance_plan(
|
||
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
||
);
|
||
let output = MemoryMaintenanceModelOutput {
|
||
user_facts: vec!["用户在做AI产品".to_string()],
|
||
preferences: Vec::new(),
|
||
behavior_patterns: Vec::new(),
|
||
merges: vec![MemoryMaintenanceMerge {
|
||
source_ids: vec![work.id.clone(), role.id.clone()],
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户主要在做AI产品设计与实现".to_string(),
|
||
}],
|
||
conflicts: Vec::new(),
|
||
low_value_ids: vec![noise.id.clone()],
|
||
managed_markdown: "### 用户事实\n- 用户在做AI产品".to_string(),
|
||
};
|
||
|
||
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
|
||
|
||
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
||
assert_eq!(all_memories.len(), 1);
|
||
assert_eq!(all_memories[0].namespace, "profile");
|
||
assert_eq!(all_memories[0].memory_key, "work");
|
||
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
|
||
}
|
||
|
||
#[test]
|
||
fn test_combine_managed_memory_markdown_prefers_richer_summary_over_subset() {
|
||
let combined = combine_managed_memory_markdown(&[
|
||
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达".to_string(),
|
||
"- 用户在做AI产品".to_string(),
|
||
"### 用户事实\n- 用户名为区德成,昵称DC。".to_string(),
|
||
]);
|
||
|
||
assert!(
|
||
combined.contains("### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达")
|
||
);
|
||
assert!(combined.contains("### 用户事实\n- 用户名为区德成,昵称DC。"));
|
||
assert_eq!(combined.matches("- 用户在做AI产品").count(), 1);
|
||
}
|
||
|
||
#[test]
|
||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||
let pending = ChatMessage::tool_with_state(
|
||
"call-2",
|
||
"bash",
|
||
"waiting",
|
||
crate::bus::message::ToolMessageState::PendingUserAction,
|
||
);
|
||
|
||
assert!(!should_display_message_to_user(false, &completed));
|
||
assert!(should_display_message_to_user(false, &pending));
|
||
assert!(should_display_message_to_user(true, &completed));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() {
|
||
let bus = MessageBus::new(4);
|
||
let emitter =
|
||
BusToolCallEmitter::new(bus.clone(), "feishu", "chat-1", HashMap::new(), false);
|
||
|
||
emitter
|
||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
||
.await;
|
||
|
||
assert!(
|
||
tokio::time::timeout(std::time::Duration::from_millis(50), bus.consume_outbound())
|
||
.await
|
||
.is_err()
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"feishu".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
let history = session.get_history("chat-1").unwrap();
|
||
assert_eq!(history.len(), 1);
|
||
assert_eq!(history[0].role, "system");
|
||
assert!(history[0].content.contains("PicoBot 代理配置"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"feishu".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
for turn in 0..100 {
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
|
||
.unwrap();
|
||
}
|
||
|
||
session
|
||
.ensure_agent_prompt_before_user_message("chat-1")
|
||
.unwrap();
|
||
|
||
let history = session.get_history("chat-1").unwrap();
|
||
let system_messages = history
|
||
.iter()
|
||
.filter(|message| message.role == "system")
|
||
.count();
|
||
assert_eq!(system_messages, 2);
|
||
|
||
let stored = store
|
||
.get_session(&session.persistent_session_id("chat-1"))
|
||
.unwrap()
|
||
.unwrap();
|
||
assert_eq!(stored.agent_prompt_reinjection_count, 1);
|
||
|
||
session
|
||
.ensure_agent_prompt_before_user_message("chat-1")
|
||
.unwrap();
|
||
let history = session.get_history("chat-1").unwrap();
|
||
let system_messages = history
|
||
.iter()
|
||
.filter(|message| message.role == "system")
|
||
.count();
|
||
assert_eq!(system_messages, 2);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"feishu".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
0,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
for turn in 0..100 {
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
|
||
.unwrap();
|
||
}
|
||
|
||
session
|
||
.ensure_agent_prompt_before_user_message("chat-1")
|
||
.unwrap();
|
||
|
||
let history = session.get_history("chat-1").unwrap();
|
||
let system_messages = history
|
||
.iter()
|
||
.filter(|message| message.role == "system")
|
||
.count();
|
||
assert_eq!(system_messages, 1);
|
||
}
|
||
|
||
#[test]
|
||
fn test_default_tools_registers_get_time() {
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let tools = ToolRegistryFactory::new(
|
||
skills,
|
||
store.clone(),
|
||
store.clone(),
|
||
store,
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
)
|
||
.build();
|
||
|
||
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_memory_maintenance_plan_deduplicates_and_categorizes() {
|
||
let memories = vec![
|
||
MemoryRecord {
|
||
id: "1".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 1,
|
||
updated_at: 1,
|
||
},
|
||
MemoryRecord {
|
||
id: "2".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "profile".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 2,
|
||
updated_at: 2,
|
||
},
|
||
MemoryRecord {
|
||
id: "3".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "preferences".to_string(),
|
||
memory_key: "style".to_string(),
|
||
content: "偏好简洁表达".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 3,
|
||
updated_at: 3,
|
||
},
|
||
MemoryRecord {
|
||
id: "4".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "patterns".to_string(),
|
||
memory_key: "workflow".to_string(),
|
||
content: "习惯先问方案再要代码".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 4,
|
||
updated_at: 4,
|
||
},
|
||
];
|
||
|
||
let plan = build_memory_maintenance_plan(&memories);
|
||
assert_eq!(plan.user_facts.len(), 1);
|
||
assert_eq!(plan.preferences.len(), 1);
|
||
assert_eq!(plan.behavior_patterns.len(), 1);
|
||
assert!(plan.others.is_empty());
|
||
assert_eq!(plan.user_facts[0].content, "用户在做AI产品");
|
||
assert_eq!(plan.preferences[0].content, "偏好简洁表达");
|
||
assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码");
|
||
}
|
||
}
|