2132 lines
76 KiB
Rust
2132 lines
76 KiB
Rust
use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler};
|
||
#[cfg(test)]
|
||
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::protocol::WsOutbound;
|
||
use crate::scheduler::ScheduledAgentTaskOptions;
|
||
use crate::skills::SkillRuntime;
|
||
use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository};
|
||
use crate::tools::ToolRegistry;
|
||
use crate::tools::task::repository::TaskRepository;
|
||
use async_trait::async_trait;
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use tokio::sync::{Mutex, mpsc};
|
||
use uuid::Uuid;
|
||
|
||
use super::agent_factory::{AgentBuildRequest, AgentFactory};
|
||
use super::cli_session::CliSessionService;
|
||
#[cfg(test)]
|
||
use super::execution::should_display_message_to_user;
|
||
#[cfg(test)]
|
||
use super::memory_maintenance::{
|
||
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
|
||
extract_json_object, is_recoverable_maintenance_llm_error,
|
||
strip_json_code_fence,
|
||
};
|
||
use super::memory_maintenance::{MemoryMaintenanceScopeResult, MemoryOrganizationOutput};
|
||
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
|
||
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
|
||
use super::session_history::SessionHistory;
|
||
use super::session_lifecycle::SessionLifecycleService;
|
||
use super::session_message_service::SessionMessageService;
|
||
|
||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||
/// Topic 按 chat_id 隔离,存储在 SessionHistory 中
|
||
pub struct Session {
|
||
pub id: Uuid,
|
||
pub channel_name: String,
|
||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||
provider_config: LLMProviderConfig,
|
||
skills: Arc<SkillRuntime>,
|
||
agent_factory: AgentFactory,
|
||
compressor: ContextCompressor,
|
||
history: SessionHistory,
|
||
store: Arc<SessionStore>,
|
||
/// 等待中的取消信号接收端(按 chat_id 索引)。
|
||
/// 在 Agent 执行前由外部注入,Agent 构建时消费。
|
||
pending_cancel_tokens: HashMap<String, tokio::sync::watch::Receiver<()>>,
|
||
}
|
||
|
||
pub struct BusToolCallEmitter {
|
||
bus: Arc<MessageBus>,
|
||
channel_name: String,
|
||
chat_id: String,
|
||
metadata: HashMap<String, String>,
|
||
}
|
||
|
||
impl BusToolCallEmitter {
|
||
pub fn new(
|
||
bus: Arc<MessageBus>,
|
||
channel_name: impl Into<String>,
|
||
chat_id: impl Into<String>,
|
||
metadata: HashMap<String, String>,
|
||
) -> Self {
|
||
Self {
|
||
bus,
|
||
channel_name: channel_name.into(),
|
||
chat_id: chat_id.into(),
|
||
metadata,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl EmittedMessageHandler for BusToolCallEmitter {
|
||
async fn handle(&self, message: ChatMessage) {
|
||
for outbound in OutboundMessage::from_chat_message(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None, // session_id
|
||
None,
|
||
&self.metadata,
|
||
&message,
|
||
) {
|
||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn handle_tool_result(&self, message: ChatMessage, duration_ms: Option<u64>) {
|
||
let mut metadata = self.metadata.clone();
|
||
if let Some(ms) = duration_ms {
|
||
metadata.insert("tool_duration_ms".to_string(), ms.to_string());
|
||
}
|
||
for outbound in OutboundMessage::from_chat_message(
|
||
&self.channel_name,
|
||
&self.chat_id,
|
||
None, // session_id
|
||
None,
|
||
&metadata,
|
||
&message,
|
||
) {
|
||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Session {
|
||
pub async fn new(
|
||
channel_name: String,
|
||
provider_config: LLMProviderConfig,
|
||
user_tx: mpsc::Sender<WsOutbound>,
|
||
tools: Arc<ToolRegistry>,
|
||
skills: Arc<SkillRuntime>,
|
||
store: Arc<SessionStore>,
|
||
agent_prompt_reinject_every: u64,
|
||
) -> Result<Self, AgentError> {
|
||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
||
let agent_factory = AgentFactory::new(
|
||
tools,
|
||
skills.clone(),
|
||
agent_prompt_reinject_every as usize,
|
||
prompt_repository.clone(),
|
||
);
|
||
Self::with_factories(
|
||
channel_name,
|
||
provider_config,
|
||
user_tx,
|
||
skills,
|
||
agent_factory,
|
||
conversations,
|
||
skill_events,
|
||
store,
|
||
)
|
||
.await
|
||
}
|
||
|
||
pub(crate) async fn with_factories(
|
||
channel_name: String,
|
||
provider_config: LLMProviderConfig,
|
||
user_tx: mpsc::Sender<WsOutbound>,
|
||
skills: Arc<SkillRuntime>,
|
||
agent_factory: AgentFactory,
|
||
conversations: Arc<dyn ConversationRepository>,
|
||
skill_events: Arc<dyn SkillEventRepository>,
|
||
store: Arc<SessionStore>,
|
||
) -> Result<Self, AgentError> {
|
||
Ok(Self {
|
||
id: Uuid::new_v4(),
|
||
channel_name: channel_name.clone(),
|
||
user_tx,
|
||
provider_config: provider_config.clone(),
|
||
skills,
|
||
agent_factory,
|
||
compressor: ContextCompressor::from_provider_config(&provider_config),
|
||
history: SessionHistory::new(
|
||
channel_name,
|
||
conversations,
|
||
skill_events,
|
||
),
|
||
store,
|
||
pending_cancel_tokens: HashMap::new(),
|
||
})
|
||
}
|
||
|
||
pub fn persistent_session_id(&self, chat_id: &str) -> String {
|
||
self.history.persistent_session_id(chat_id)
|
||
}
|
||
|
||
/// 设置当前话题 ID(指定 chat)
|
||
pub fn set_current_topic(&mut self, chat_id: &str, topic_id: Option<String>) {
|
||
if let Some(topic_id) = topic_id {
|
||
self.history.set_chat_topic(chat_id, topic_id);
|
||
} else {
|
||
self.history.clear_chat_topic(chat_id);
|
||
}
|
||
}
|
||
|
||
/// 存入待使用的取消信号接收端。
|
||
///
|
||
/// 在 Agent 执行前由处理器调用,Agent 构建时(create_agent)自动消费。
|
||
/// 每个 chat_id 同时只允许一个 pending token;新 token 会替换旧 token。
|
||
pub fn set_cancel_receiver(
|
||
&mut self,
|
||
chat_id: &str,
|
||
receiver: tokio::sync::watch::Receiver<()>,
|
||
) {
|
||
self.pending_cancel_tokens
|
||
.insert(chat_id.to_string(), receiver);
|
||
}
|
||
|
||
/// 获取当前话题 ID(指定 chat)
|
||
pub fn current_topic(&self, chat_id: &str) -> Option<&str> {
|
||
self.history.chat_topic(chat_id)
|
||
}
|
||
|
||
/// 获取历史所对应的话题 ID(指定 chat)
|
||
pub fn history_topic(&self, chat_id: &str) -> Option<&str> {
|
||
self.history.history_topic(chat_id)
|
||
}
|
||
|
||
/// 切换话题 - 清除当前历史并加载新话题的历史
|
||
pub fn switch_topic(&mut self, chat_id: &str, topic_id: &str) -> Result<(), AgentError> {
|
||
// 清除当前历史
|
||
self.history.remove_history(chat_id);
|
||
|
||
// 先设置当前话题(set_history 需要这个)
|
||
self.history.set_chat_topic(chat_id, topic_id.to_string());
|
||
|
||
// 加载新话题的历史
|
||
let messages = self
|
||
.store
|
||
.load_messages_for_topic(topic_id)
|
||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||
|
||
self.history.set_history(chat_id, messages);
|
||
|
||
tracing::info!(
|
||
topic_id = %topic_id,
|
||
chat_id = %chat_id,
|
||
"Switched to topic"
|
||
);
|
||
Ok(())
|
||
}
|
||
|
||
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||
self.history.ensure_persistent_session(chat_id)
|
||
}
|
||
|
||
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
// 检查历史是否存在且对应正确的话题
|
||
// 先获取 topic 信息并转换为 owned String,避免借用冲突
|
||
let current_topic: Option<String> = self.history.chat_topic(chat_id).map(|s| s.to_string());
|
||
let stored_topic = self.history.history_topic(chat_id);
|
||
|
||
if self.chat_history_exists(chat_id) {
|
||
// 如果历史已存在,但话题不匹配,需要重新加载
|
||
if current_topic.as_deref() != stored_topic {
|
||
tracing::info!(
|
||
chat_id = %chat_id,
|
||
current_topic = ?current_topic,
|
||
stored_topic = ?stored_topic,
|
||
"Topic changed, reloading history"
|
||
);
|
||
self.reload_chat_history(chat_id)?;
|
||
}
|
||
return Ok(());
|
||
}
|
||
|
||
// 历史不存在,按 topic 加载(如果设置了 topic)
|
||
self.history.ensure_chat_loaded(chat_id, current_topic.as_deref())
|
||
}
|
||
|
||
fn chat_history_exists(&self, chat_id: &str) -> bool {
|
||
self.history.get_history(chat_id).is_some()
|
||
}
|
||
|
||
pub fn ensure_agent_prompt_before_user_message(
|
||
&mut self,
|
||
chat_id: &str,
|
||
) -> Result<(), AgentError> {
|
||
self.history
|
||
.ensure_agent_prompt_before_user_message(chat_id)
|
||
}
|
||
|
||
/// 获取或创建指定 chat_id 的会话历史
|
||
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||
self.history.get_or_create_history(chat_id)
|
||
}
|
||
|
||
/// 获取指定 chat_id 的会话历史(不创建)
|
||
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
|
||
self.history.get_history(chat_id)
|
||
}
|
||
|
||
/// 使用完整消息追加到历史
|
||
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||
self.history.add_message(chat_id, message);
|
||
}
|
||
|
||
pub fn remove_history(&mut self, chat_id: &str) {
|
||
self.history.remove_history(chat_id);
|
||
}
|
||
|
||
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
self.history.clear_chat_history(chat_id)
|
||
}
|
||
|
||
/// 将消息写入内存与持久化层(使用当前 topic)
|
||
pub fn append_persisted_message(
|
||
&mut self,
|
||
chat_id: &str,
|
||
message: ChatMessage,
|
||
) -> Result<(), AgentError> {
|
||
let session_id = self.persistent_session_id(chat_id);
|
||
let topic_id = self.history.chat_topic(chat_id).map(|s| s.to_string());
|
||
self.store
|
||
.append_message_with_topic(&session_id, topic_id.as_deref(), &message)
|
||
.map_err(|err| {
|
||
AgentError::Other(format!("append message persistence error: {}", err))
|
||
})?;
|
||
self.add_message(chat_id, message);
|
||
|
||
// 更新 topic 的最后活跃时间
|
||
if let Some(ref topic_id) = topic_id {
|
||
if let Err(e) = self.store.touch_topic(topic_id) {
|
||
tracing::warn!(error = %e, topic_id = %topic_id, "Failed to touch topic");
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub fn append_persisted_messages<I>(
|
||
&mut self,
|
||
chat_id: &str,
|
||
messages: I,
|
||
) -> Result<(), AgentError>
|
||
where
|
||
I: IntoIterator<Item = ChatMessage>,
|
||
{
|
||
self.history.append_persisted_messages(chat_id, messages)
|
||
}
|
||
|
||
/// 将消息保存到指定话题(直接写入数据库,不更新内存历史)
|
||
pub fn append_messages_to_topic(
|
||
&self,
|
||
chat_id: &str,
|
||
topic_id: &str,
|
||
messages: &[ChatMessage],
|
||
) -> Result<(), AgentError> {
|
||
self.history.append_to_topic(chat_id, topic_id, messages)
|
||
}
|
||
|
||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||
if media_refs.is_empty() {
|
||
ChatMessage::user(content)
|
||
} else {
|
||
ChatMessage::user_with_media(content, media_refs)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
||
self.latest_user_message(chat_id)
|
||
.map(|message| message.id.as_str())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
|
||
self.history.latest_user_message(chat_id)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
||
self.latest_user_message_id(chat_id)
|
||
.map(|current_id| current_id == message_id)
|
||
.unwrap_or(false)
|
||
}
|
||
|
||
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||
self.history.matches_current_user_turn(chat_id, message)
|
||
}
|
||
|
||
pub(crate) fn stale_result_diagnostics(
|
||
&self,
|
||
chat_id: &str,
|
||
) -> (Option<&str>, Option<String>, bool, usize) {
|
||
self.history.stale_result_diagnostics(chat_id)
|
||
}
|
||
|
||
/// 清除所有历史
|
||
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||
self.history.clear_all_history()
|
||
}
|
||
|
||
pub async fn send(&self, msg: WsOutbound) {
|
||
let _ = self.user_tx.send(msg).await;
|
||
}
|
||
|
||
/// 获取 provider_config 引用
|
||
pub fn provider_config(&self) -> &LLMProviderConfig {
|
||
&self.provider_config
|
||
}
|
||
|
||
/// 获取 compressor 引用
|
||
pub fn compressor(&self) -> &ContextCompressor {
|
||
&self.compressor
|
||
}
|
||
|
||
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||
self.history.try_start_background_compaction(chat_id)
|
||
}
|
||
|
||
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
|
||
self.history.finish_background_compaction(chat_id);
|
||
}
|
||
|
||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
// 如果当前有 topic,加载该 topic 的消息
|
||
if let Some(topic_id) = self.history.chat_topic(chat_id) {
|
||
let messages = self
|
||
.store
|
||
.load_messages_for_topic(topic_id)
|
||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||
self.history.set_history(chat_id, messages);
|
||
} else {
|
||
// 否则加载 session 的所有消息
|
||
self.history.reload_chat_history(chat_id)?;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
pub(crate) fn store(&self) -> Arc<dyn ConversationRepository> {
|
||
self.history.conversations()
|
||
}
|
||
|
||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||
if self.skills.is_empty() {
|
||
return Ok(());
|
||
}
|
||
|
||
self.history.append_skill_event(
|
||
chat_id,
|
||
"offered",
|
||
None,
|
||
&self.skills.offered_event_payload(),
|
||
)
|
||
}
|
||
|
||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||
pub fn create_agent(
|
||
&mut self,
|
||
chat_id: &str,
|
||
sender_id: Option<&str>,
|
||
message_id: Option<&str>,
|
||
) -> Result<AgentLoop, AgentError> {
|
||
self.create_agent_with_provider_config(
|
||
chat_id,
|
||
None, // notification_chat_id = None,使用 session_chat_id
|
||
sender_id,
|
||
message_id,
|
||
self.provider_config.clone(),
|
||
)
|
||
}
|
||
|
||
pub fn create_agent_with_provider_config(
|
||
&mut self,
|
||
session_chat_id: &str,
|
||
notification_chat_id: Option<&str>,
|
||
sender_id: Option<&str>,
|
||
message_id: Option<&str>,
|
||
provider_config: LLMProviderConfig,
|
||
) -> Result<AgentLoop, AgentError> {
|
||
// 消费 pending 的取消信号接收端(如果存在)
|
||
let cancel_token = self.pending_cancel_tokens.remove(session_chat_id);
|
||
self.agent_factory.create(AgentBuildRequest {
|
||
channel_name: &self.channel_name,
|
||
session_chat_id,
|
||
notification_chat_id,
|
||
sender_id,
|
||
message_id,
|
||
provider_config,
|
||
cancel_token,
|
||
})
|
||
}
|
||
}
|
||
|
||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||
#[derive(Clone)]
|
||
pub struct SessionManager {
|
||
tools: Arc<ToolRegistry>,
|
||
skills: Arc<SkillRuntime>,
|
||
store: Arc<SessionStore>,
|
||
show_tool_results: bool,
|
||
lifecycle: SessionLifecycleService,
|
||
cli_sessions: CliSessionService,
|
||
messages: SessionMessageService,
|
||
scheduled_tasks: ScheduledAgentTaskService,
|
||
memory_maintenance: MemoryMaintenanceCoordinator,
|
||
task_repository: Arc<dyn TaskRepository>,
|
||
}
|
||
|
||
pub(crate) struct SessionManagerServices {
|
||
pub(crate) tools: Arc<ToolRegistry>,
|
||
pub(crate) skills: Arc<SkillRuntime>,
|
||
pub(crate) store: Arc<SessionStore>,
|
||
pub(crate) show_tool_results: bool,
|
||
pub(crate) lifecycle: SessionLifecycleService,
|
||
pub(crate) cli_sessions: CliSessionService,
|
||
pub(crate) messages: SessionMessageService,
|
||
pub(crate) scheduled_tasks: ScheduledAgentTaskService,
|
||
pub(crate) memory_maintenance: MemoryMaintenanceCoordinator,
|
||
pub(crate) task_repository: Arc<dyn TaskRepository>,
|
||
}
|
||
|
||
impl SessionManager {
|
||
pub(crate) fn from_services(services: SessionManagerServices) -> Self {
|
||
Self {
|
||
tools: services.tools,
|
||
skills: services.skills,
|
||
store: services.store,
|
||
show_tool_results: services.show_tool_results,
|
||
lifecycle: services.lifecycle,
|
||
cli_sessions: services.cli_sessions,
|
||
messages: services.messages,
|
||
scheduled_tasks: services.scheduled_tasks,
|
||
memory_maintenance: services.memory_maintenance,
|
||
task_repository: services.task_repository,
|
||
}
|
||
}
|
||
|
||
pub fn new(
|
||
agent_prompt_reinject_every: u64,
|
||
show_tool_results: bool,
|
||
default_timezone: String,
|
||
provider_config: LLMProviderConfig,
|
||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||
skills: Arc<SkillRuntime>,
|
||
disabled_tools: std::collections::HashSet<String>,
|
||
task_config: crate::config::TaskConfig,
|
||
subagents_config: crate::config::SubagentsConfig,
|
||
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
||
session_ttl_hours: Option<u64>,
|
||
mcp_config: crate::mcp::McpConfig,
|
||
) -> Result<Self, AgentError> {
|
||
super::runtime::build_session_manager(
|
||
agent_prompt_reinject_every,
|
||
show_tool_results,
|
||
default_timezone,
|
||
provider_config,
|
||
provider_configs,
|
||
skills,
|
||
disabled_tools,
|
||
task_config,
|
||
subagents_config,
|
||
maintenance_config,
|
||
session_ttl_hours,
|
||
mcp_config,
|
||
None,
|
||
)
|
||
.map(|(session_manager, _)| session_manager)
|
||
}
|
||
|
||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
||
self.tools.clone()
|
||
}
|
||
|
||
pub fn store(&self) -> Arc<SessionStore> {
|
||
self.store.clone()
|
||
}
|
||
|
||
pub fn show_tool_results(&self) -> bool {
|
||
self.show_tool_results
|
||
}
|
||
|
||
pub fn task_repository(&self) -> Arc<dyn TaskRepository> {
|
||
self.task_repository.clone()
|
||
}
|
||
|
||
pub fn skills(&self) -> Arc<SkillRuntime> {
|
||
self.skills.clone()
|
||
}
|
||
|
||
pub(crate) fn cli_sessions(&self) -> CliSessionService {
|
||
self.cli_sessions.clone()
|
||
}
|
||
|
||
#[cfg_attr(not(test), allow(dead_code))]
|
||
pub(crate) async fn organize_memory_maintenance_for_scope(
|
||
&self,
|
||
scope_key: &str,
|
||
) -> Result<Option<MemoryOrganizationOutput>, AgentError> {
|
||
self.memory_maintenance.organize_for_scope(scope_key).await
|
||
}
|
||
|
||
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||
&self,
|
||
) -> Result<Option<MemoryMaintenanceScopeResult>, AgentError> {
|
||
self.memory_maintenance.run_for_all_scopes().await
|
||
}
|
||
|
||
/// 确保 session 存在且未超时,超时则重建
|
||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
self.lifecycle.ensure_session(channel_name).await
|
||
}
|
||
|
||
/// 获取 session(不检查超时)
|
||
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
||
self.lifecycle.get(channel_name).await
|
||
}
|
||
|
||
/// 获取指定 chat 的当前话题(确保 session 存在,自动从数据库恢复)
|
||
pub async fn get_current_topic(&self, channel_name: &str, chat_id: &str) -> Result<Option<String>, AgentError> {
|
||
self.ensure_session(channel_name).await?;
|
||
if let Some(session) = self.get(channel_name).await {
|
||
let mut guard = session.lock().await;
|
||
|
||
// 如果内存中没有当前话题,从数据库恢复最近活跃的话题
|
||
if guard.current_topic(chat_id).is_none() {
|
||
let session_id = guard.persistent_session_id(chat_id);
|
||
let topics = self.store.list_topics(&session_id)
|
||
.map_err(|e| AgentError::Other(format!("Failed to list topics: {}", e)))?;
|
||
|
||
if let Some(latest_topic) = topics.first() {
|
||
// 设置最近活跃的话题为当前话题
|
||
guard.set_current_topic(chat_id, Some(latest_topic.id.clone()));
|
||
tracing::info!(
|
||
chat_id = %chat_id,
|
||
topic_id = %latest_topic.id,
|
||
topic_title = %latest_topic.title,
|
||
"Restored current topic from database"
|
||
);
|
||
} else {
|
||
// 数据库中也没有话题,自动创建默认话题
|
||
let title = format!(
|
||
"话题 {}",
|
||
chrono::Local::now().format("%m/%d %H:%M")
|
||
);
|
||
match self.store.create_topic(&session_id, &title, None) {
|
||
Ok(topic) => {
|
||
guard.set_current_topic(chat_id, Some(topic.id.clone()));
|
||
tracing::info!(
|
||
chat_id = %chat_id,
|
||
topic_id = %topic.id,
|
||
topic_title = %topic.title,
|
||
session_id = %session_id,
|
||
"Auto-created default topic for new chat"
|
||
);
|
||
}
|
||
Err(e) => {
|
||
tracing::error!(
|
||
error = %e,
|
||
session_id = %session_id,
|
||
"Failed to auto-create default topic"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(guard.current_topic(chat_id).map(|s| s.to_string()))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 存入 Agent 取消信号接收端,供 Agent 构建时消费。
|
||
///
|
||
/// 在 Agent 执行前由处理器调用。Agent 在 create_agent() 时自动取出。
|
||
pub async fn set_agent_cancel_token(
|
||
&self,
|
||
channel_name: &str,
|
||
chat_id: &str,
|
||
token: tokio::sync::watch::Receiver<()>,
|
||
) {
|
||
if let Some(session) = self.get(channel_name).await {
|
||
session.lock().await.set_cancel_receiver(chat_id, token);
|
||
}
|
||
}
|
||
|
||
/// 更新最后活跃时间
|
||
pub async fn touch(&self, channel_name: &str) {
|
||
self.lifecycle.touch(channel_name).await;
|
||
}
|
||
|
||
pub async fn cleanup_expired_sessions(&self) -> usize {
|
||
self.lifecycle.cleanup_expired_sessions().await
|
||
}
|
||
|
||
/// 处理消息:路由到对应 session 的 agent
|
||
pub async fn handle_message(
|
||
&self,
|
||
channel_name: &str,
|
||
sender_id: &str,
|
||
chat_id: &str,
|
||
content: &str,
|
||
media: Vec<crate::bus::MediaItem>,
|
||
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
|
||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||
self.messages
|
||
.handle_message(
|
||
channel_name,
|
||
sender_id,
|
||
chat_id,
|
||
content,
|
||
media,
|
||
live_emitter,
|
||
)
|
||
.await
|
||
}
|
||
|
||
pub async fn run_scheduled_agent_task(
|
||
&self,
|
||
channel_name: &str,
|
||
chat_id: &str,
|
||
prompt: &str,
|
||
options: ScheduledAgentTaskOptions,
|
||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||
self.scheduled_tasks
|
||
.run(channel_name, chat_id, None, prompt, options)
|
||
.await
|
||
}
|
||
|
||
/// 执行 SilentAgentTask,支持 notification_chat_id 分离
|
||
pub async fn run_silent_agent_task(
|
||
&self,
|
||
channel_name: &str,
|
||
session_chat_id: &str,
|
||
notification_chat_id: Option<&str>,
|
||
prompt: &str,
|
||
options: ScheduledAgentTaskOptions,
|
||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||
self.scheduled_tasks
|
||
.run(channel_name, session_chat_id, notification_chat_id, prompt, options)
|
||
.await
|
||
}
|
||
|
||
/// 清除指定 session 的所有历史
|
||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
if let Some(session) = self.get(channel_name).await {
|
||
let mut session_guard = session.lock().await;
|
||
session_guard.clear_all_history()?;
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::bus::MessageBus;
|
||
use crate::bus::message::OutboundEventKind;
|
||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||
use crate::storage::MemoryRecord;
|
||
use crate::tools::NoopSessionMessageSender;
|
||
use axum::http::StatusCode;
|
||
use axum::{Json, Router, routing::post};
|
||
use serde_json::{Value, json};
|
||
use std::collections::{HashMap, HashSet};
|
||
use std::sync::{
|
||
Arc as StdArc,
|
||
atomic::{AtomicUsize, Ordering},
|
||
};
|
||
use tokio::net::TcpListener;
|
||
use tokio::sync::mpsc;
|
||
|
||
fn test_provider_config() -> LLMProviderConfig {
|
||
LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "test".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
llm_timeout_secs: 120,
|
||
memory_maintenance_timeout_secs: 600,
|
||
model_id: "test-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
HashSet::new(),
|
||
Default::default(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"test-channel".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store,
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
let first = session.create_user_message("first", Vec::new());
|
||
let first_id = first.id.clone();
|
||
session.append_persisted_message("chat-1", first).unwrap();
|
||
assert!(session.is_latest_user_message("chat-1", &first_id));
|
||
|
||
let second = session.create_user_message("second", Vec::new());
|
||
let second_id = second.id.clone();
|
||
session.append_persisted_message("chat-1", second).unwrap();
|
||
|
||
assert!(!session.is_latest_user_message("chat-1", &first_id));
|
||
assert!(session.is_latest_user_message("chat-1", &second_id));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_current_user_turn_match_survives_history_compaction_reload() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
HashSet::new(),
|
||
Default::default(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"test-channel".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
let first = session.create_user_message("first", Vec::new());
|
||
let first_id = first.id.clone();
|
||
session.append_persisted_message("chat-1", first).unwrap();
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::assistant("answer-1"))
|
||
.unwrap();
|
||
|
||
let second = session.create_user_message("second", Vec::new());
|
||
session
|
||
.append_persisted_message("chat-1", second.clone())
|
||
.unwrap();
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::assistant("answer-2"))
|
||
.unwrap();
|
||
|
||
let session_id = session.persistent_session_id("chat-1");
|
||
let snapshot_end_seq = store
|
||
.get_session(&session_id)
|
||
.unwrap()
|
||
.unwrap()
|
||
.message_count;
|
||
let preserved_messages = session.get_history("chat-1").unwrap().clone();
|
||
|
||
store
|
||
.compact_active_history(
|
||
&session_id,
|
||
snapshot_end_seq,
|
||
&[],
|
||
&ChatMessage::system("[Compressed History]\n\nsummary"),
|
||
&preserved_messages,
|
||
)
|
||
.unwrap();
|
||
|
||
session.reload_chat_history("chat-1").unwrap();
|
||
|
||
assert!(!session.is_latest_user_message("chat-1", &first_id));
|
||
assert!(!session.is_latest_user_message("chat-1", &second.id));
|
||
assert!(session.matches_current_user_turn("chat-1", &second));
|
||
}
|
||
|
||
async fn start_mock_openai_server() -> String {
|
||
start_mock_openai_server_with_content(None).await
|
||
}
|
||
|
||
async fn start_mock_openai_server_with_content(
|
||
mock_response_content: Option<String>,
|
||
) -> String {
|
||
async fn handle(
|
||
axum::extract::State(mock_response_content): axum::extract::State<Option<String>>,
|
||
Json(body): Json<Value>,
|
||
) -> Json<Value> {
|
||
let model = body
|
||
.get("model")
|
||
.and_then(|value| value.as_str())
|
||
.unwrap_or("unknown-model");
|
||
let content = mock_response_content.unwrap_or_else(|| format!("reply from {}", model));
|
||
|
||
Json(json!({
|
||
"id": "mock-response",
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"message": {
|
||
"content": content,
|
||
"tool_calls": []
|
||
}
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": 1,
|
||
"completion_tokens": 1,
|
||
"total_tokens": 2
|
||
}
|
||
}))
|
||
}
|
||
|
||
let app = Router::new()
|
||
.route("/chat/completions", post(handle))
|
||
.with_state(mock_response_content);
|
||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
let address = listener.local_addr().unwrap();
|
||
tokio::spawn(async move {
|
||
axum::serve(listener, app).await.unwrap();
|
||
});
|
||
format!("http://{}", address)
|
||
}
|
||
|
||
async fn start_mock_openai_504_server() -> String {
|
||
async fn handle() -> (StatusCode, &'static str) {
|
||
(StatusCode::GATEWAY_TIMEOUT, "stream timeout")
|
||
}
|
||
|
||
let app = Router::new().route("/chat/completions", post(handle));
|
||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
let address = listener.local_addr().unwrap();
|
||
tokio::spawn(async move {
|
||
axum::serve(listener, app).await.unwrap();
|
||
});
|
||
format!("http://{}", address)
|
||
}
|
||
|
||
async fn start_mock_openai_flaky_server(mock_response_content: String) -> String {
|
||
let attempts = StdArc::new(AtomicUsize::new(0));
|
||
let state = (attempts, mock_response_content);
|
||
|
||
async fn handle(
|
||
axum::extract::State((attempts, mock_response_content)): axum::extract::State<(
|
||
StdArc<AtomicUsize>,
|
||
String,
|
||
)>,
|
||
Json(body): Json<Value>,
|
||
) -> (StatusCode, Json<Value>) {
|
||
let attempt = attempts.fetch_add(1, Ordering::SeqCst);
|
||
if attempt == 0 {
|
||
return (
|
||
StatusCode::GATEWAY_TIMEOUT,
|
||
Json(json!({"error": "stream timeout"})),
|
||
);
|
||
}
|
||
|
||
let model = body
|
||
.get("model")
|
||
.and_then(|value| value.as_str())
|
||
.unwrap_or("unknown-model");
|
||
(
|
||
StatusCode::OK,
|
||
Json(json!({
|
||
"id": "mock-response",
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"message": {
|
||
"content": mock_response_content,
|
||
"tool_calls": []
|
||
}
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": 1,
|
||
"completion_tokens": 1,
|
||
"total_tokens": 2
|
||
}
|
||
})),
|
||
)
|
||
}
|
||
|
||
let app = Router::new()
|
||
.route("/chat/completions", post(handle))
|
||
.with_state(state);
|
||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
let address = listener.local_addr().unwrap();
|
||
tokio::spawn(async move {
|
||
axum::serve(listener, app).await.unwrap();
|
||
});
|
||
format!("http://{}", address)
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_handle_message_returns_recoverable_reply_on_llm_504() {
|
||
let base_url = start_mock_openai_504_server().await;
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "timeout-provider".to_string(),
|
||
base_url: base_url.clone(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "timeout-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
let outbound = session_manager
|
||
.handle_message("test-channel", "user-1", "chat-1", "hello", Vec::new(), None)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert_eq!(outbound.len(), 1);
|
||
assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() {
|
||
let base_url = start_mock_openai_server().await;
|
||
let default_provider = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "default-provider".to_string(),
|
||
base_url: base_url.clone(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "default-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
let planner_provider = LLMProviderConfig {
|
||
model_id: "planner-model".to_string(),
|
||
name: "planner-provider".to_string(),
|
||
..default_provider.clone()
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
default_provider.clone(),
|
||
HashMap::from([
|
||
("default".to_string(), default_provider),
|
||
("planner".to_string(), planner_provider),
|
||
]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
let planner_outbound = session_manager
|
||
.run_scheduled_agent_task(
|
||
"test-channel",
|
||
"chat-planner",
|
||
"请规划今天工作",
|
||
ScheduledAgentTaskOptions {
|
||
agent: Some("planner".to_string()),
|
||
..Default::default()
|
||
},
|
||
)
|
||
.await
|
||
.unwrap();
|
||
assert_eq!(planner_outbound.len(), 1);
|
||
assert!(planner_outbound[0].content.contains("planner-model"));
|
||
|
||
let default_outbound = session_manager
|
||
.run_scheduled_agent_task(
|
||
"test-channel",
|
||
"chat-default",
|
||
"请规划今天工作",
|
||
ScheduledAgentTaskOptions {
|
||
agent: Some("default".to_string()),
|
||
..Default::default()
|
||
},
|
||
)
|
||
.await
|
||
.unwrap();
|
||
assert_eq!(default_outbound.len(), 1);
|
||
assert!(default_outbound[0].content.contains("default-model"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_scheduled_agent_task_persists_execution_guard_prompt() {
|
||
let base_url = start_mock_openai_server().await;
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "default-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "default-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.run_scheduled_agent_task(
|
||
"test-channel",
|
||
"chat-guard",
|
||
"每小时执行以下流程:检查邮箱并同步待办",
|
||
ScheduledAgentTaskOptions {
|
||
system_prompt: Some("你是邮箱待办同步助手。".to_string()),
|
||
..Default::default()
|
||
},
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
let session = session_manager.get("test-channel").await.unwrap();
|
||
let session_guard = session.lock().await;
|
||
let persisted_messages = session_guard
|
||
.store()
|
||
.load_messages(&session_guard.persistent_session_id("chat-guard"))
|
||
.unwrap();
|
||
|
||
let scheduled_prompt = persisted_messages
|
||
.iter()
|
||
.find(|message| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
|
||
.expect("missing scheduled system prompt");
|
||
|
||
assert!(scheduled_prompt.content.contains("已经触发的定时任务执行"));
|
||
assert!(
|
||
scheduled_prompt
|
||
.content
|
||
.contains("不要调用任何定时任务管理工具")
|
||
);
|
||
assert!(scheduled_prompt.content.contains("你是邮箱待办同步助手。"));
|
||
}
|
||
|
||
/// 测试专用的 MemoryMaintenanceConfig,降低 min_memories_to_keep 以便于单条记忆测试
|
||
fn test_maintenance_config() -> crate::config::MemoryMaintenanceConfig {
|
||
crate::config::MemoryMaintenanceConfig {
|
||
min_memories_to_keep: 1,
|
||
..Default::default()
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_for_scope_uses_model_output() {
|
||
let mock_response_content = serde_json::to_string(&json!({
|
||
"user_facts": ["用户在做AI产品"],
|
||
"preferences": ["偏好简洁表达"],
|
||
"behavior_patterns": ["习惯先问方案再要代码"],
|
||
"merges": [],
|
||
"conflicts": [],
|
||
"low_value_ids": [],
|
||
"managed_markdown": "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达\n\n### 行为模式\n- 习惯先问方案再要代码"
|
||
}))
|
||
.unwrap();
|
||
let base_url =
|
||
start_mock_openai_server_with_content(Some(mock_response_content.clone())).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let output = session_manager
|
||
.organize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert!(output.merges.is_empty());
|
||
assert!(output.conflicts.is_empty());
|
||
assert!(output.low_value_ids.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_is_recoverable_maintenance_llm_error_detects_transport_failures() {
|
||
assert!(is_recoverable_maintenance_llm_error(
|
||
"error sending request for url (https://example.invalid/v1/chat/completions)"
|
||
));
|
||
assert!(is_recoverable_maintenance_llm_error(
|
||
"API error 504 Gateway Timeout: stream timeout"
|
||
));
|
||
assert!(!is_recoverable_maintenance_llm_error(
|
||
"API error 401 Unauthorized"
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn test_extract_json_object_skips_wrapping_text() {
|
||
let wrapped = "下面是结果:\n```json\n{\n \"user_facts\": [],\n \"preferences\": []\n}\n```\n请查收";
|
||
let stripped = strip_json_code_fence(wrapped);
|
||
let extracted = extract_json_object(stripped).unwrap();
|
||
assert_eq!(
|
||
extracted,
|
||
"{\n \"user_facts\": [],\n \"preferences\": []\n}"
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_transport_error_includes_provider_context() {
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url: "https://example.invalid/v1".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 1,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let error = session_manager
|
||
.organize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap_err()
|
||
.to_string();
|
||
|
||
assert!(
|
||
error.contains("memory organization model error: transport error:")
|
||
|| error.contains("memory summary model error: transport error:"),
|
||
"Error did not contain expected message: {}",
|
||
error
|
||
);
|
||
assert!(error.contains("provider=maintenance-provider"));
|
||
assert!(error.contains("model=maintenance-model"));
|
||
assert!(error.contains("url=https://example.invalid/v1/chat/completions"));
|
||
assert!(error.contains("timeout_secs=600"));
|
||
assert!(error.contains("error sending request for url"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_retries_recoverable_provider_errors() {
|
||
let mock_response_content = serde_json::to_string(&json!({
|
||
"user_facts": ["用户在做AI产品"],
|
||
"preferences": [],
|
||
"behavior_patterns": [],
|
||
"merges": [],
|
||
"conflicts": [],
|
||
"low_value_ids": [],
|
||
"managed_markdown": "### 用户事实\n- 用户在做AI产品"
|
||
}))
|
||
.unwrap();
|
||
let base_url = start_mock_openai_flaky_server(mock_response_content.clone()).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let output = session_manager
|
||
.organize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert!(output.merges.is_empty());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_summarize_memory_maintenance_for_scope_extracts_wrapped_json_object() {
|
||
let mock_response_content = "结果如下:\n```json\n{\n \"user_facts\": [\"用户在做AI产品\"],\n \"preferences\": [],\n \"behavior_patterns\": [],\n \"merges\": [],\n \"conflicts\": [],\n \"low_value_ids\": [],\n \"managed_markdown\": \"### 用户事实\\n- 用户在做AI产品\"\n}\n```\n";
|
||
let base_url =
|
||
start_mock_openai_server_with_content(Some(mock_response_content.to_string())).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let output = session_manager
|
||
.organize_memory_maintenance_for_scope("feishu:user-1")
|
||
.await
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert!(output.merges.is_empty());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_memory_maintenance_for_all_scopes_scans_all_scopes_even_without_recent_updates() {
|
||
let mock_response_content = serde_json::to_string(&json!({
|
||
"user_facts": ["用户在做AI产品"],
|
||
"preferences": [],
|
||
"behavior_patterns": [],
|
||
"merges": [],
|
||
"conflicts": [],
|
||
"low_value_ids": [],
|
||
"managed_markdown": "### 用户事实\n- 用户在做AI产品"
|
||
}))
|
||
.unwrap();
|
||
let base_url =
|
||
start_mock_openai_server_with_content(Some(mock_response_content.clone())).await;
|
||
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url,
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::from([(
|
||
"mock_response_content".to_string(),
|
||
json!(mock_response_content),
|
||
)]),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 30,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
let result = session_manager
|
||
.run_memory_maintenance_for_all_scopes()
|
||
.await
|
||
.unwrap();
|
||
|
||
assert!(result.is_some());
|
||
let result = result.unwrap();
|
||
assert_eq!(result.scope_key, "all");
|
||
// 由于步骤2需要新的提示词和输入格式,这里只验证基本功能
|
||
assert!(!result.managed_markdown.is_empty());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_run_memory_maintenance_for_all_scopes_skips_recoverable_transport_failures() {
|
||
let provider_config = LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "maintenance-provider".to_string(),
|
||
base_url: "https://example.invalid/v1".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "maintenance-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(256),
|
||
context_window_tokens: None,
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
llm_timeout_secs: 1,
|
||
memory_maintenance_timeout_secs: 600,
|
||
tool_result_max_chars: 100_000,
|
||
context_tool_result_trim_chars: 100_000,
|
||
max_images_in_context: 1,
|
||
max_image_age_rounds: 10,
|
||
};
|
||
|
||
let session_manager = SessionManager::new(
|
||
100,
|
||
false,
|
||
"Asia/Shanghai".to_string(),
|
||
provider_config.clone(),
|
||
HashMap::from([("default".to_string(), provider_config)]),
|
||
Arc::new(SkillRuntime::default()),
|
||
HashSet::new(),
|
||
crate::config::TaskConfig::default(),
|
||
crate::config::SubagentsConfig::default(),
|
||
test_maintenance_config(),
|
||
Some(24),
|
||
crate::mcp::McpConfig::default(),
|
||
)
|
||
.unwrap();
|
||
|
||
for scope_key in ["feishu:user-1", "feishu:user-2"] {
|
||
session_manager
|
||
.store()
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: format!("{} 在做AI产品", scope_key),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
}
|
||
|
||
let result = session_manager
|
||
.run_memory_maintenance_for_all_scopes()
|
||
.await
|
||
.unwrap();
|
||
|
||
// 当遇到可恢复错误时,没有整理任何记忆,返回 None
|
||
assert!(result.is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_apply_memory_maintenance_output_merges_and_deletes_low_value_records() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let scope_key = "feishu:user-1";
|
||
|
||
// 创建足够的记忆(7条),让合并操作满足保护限制
|
||
// 合并后需要保留至少 5 条(min_memories_to_keep)
|
||
let work = store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work_short".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
let role = store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work_detail".to_string(),
|
||
content: "用户主要在做AI产品设计和实现".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
let noise = store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "other".to_string(),
|
||
memory_key: "temporary".to_string(),
|
||
content: "今天临时提到过一个无后续的小细节".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
|
||
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
|
||
for i in 0..4 {
|
||
store
|
||
.put_memory(&crate::storage::MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: scope_key.to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: format!("extra_{}", i),
|
||
content: format!("额外记忆 {}", i),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
})
|
||
.unwrap();
|
||
}
|
||
|
||
let plan = build_memory_maintenance_plan(
|
||
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
||
);
|
||
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
|
||
|
||
let output = MemoryOrganizationOutput {
|
||
merges: vec![MemoryMaintenanceMerge {
|
||
source_ids: vec![work.id.clone(), role.id.clone()],
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户主要在做AI产品设计与实现".to_string(),
|
||
}],
|
||
conflicts: Vec::new(),
|
||
low_value_ids: vec![noise.id.clone()],
|
||
};
|
||
|
||
// 使用默认配置进行验证
|
||
apply_memory_maintenance_output(
|
||
&store,
|
||
scope_key,
|
||
&plan,
|
||
&output,
|
||
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
|
||
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
|
||
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
|
||
)
|
||
.unwrap();
|
||
|
||
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
||
// 过滤掉 _meta 记录
|
||
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
|
||
// 合并 2 条为 1 条,删除 1 条,7 - 2 + 1 = 5 条
|
||
assert_eq!(user_memories.len(), 5);
|
||
// 验证合并后的记忆存在
|
||
assert!(user_memories.iter().any(|m| m.namespace == "user" && m.memory_key == "work"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||
let pending = ChatMessage::tool_with_state(
|
||
"call-2",
|
||
"bash",
|
||
"waiting",
|
||
crate::bus::message::ToolMessageState::PendingUserAction,
|
||
);
|
||
|
||
assert!(!should_display_message_to_user(false, &completed));
|
||
assert!(should_display_message_to_user(false, &pending));
|
||
assert!(should_display_message_to_user(true, &completed));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_bus_tool_call_emitter_emits_completed_tool_results() {
|
||
let bus = MessageBus::new(4);
|
||
let emitter =
|
||
BusToolCallEmitter::new(
|
||
bus.clone(),
|
||
"test-channel",
|
||
"chat-1",
|
||
HashMap::new(),
|
||
);
|
||
|
||
emitter
|
||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
||
.await;
|
||
|
||
let msg = tokio::time::timeout(std::time::Duration::from_millis(500), bus.consume_outbound())
|
||
.await
|
||
.expect("timeout waiting for outbound message")
|
||
.expect("bus outbound closed");
|
||
assert_eq!(msg.event_kind, OutboundEventKind::ToolResult);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
HashSet::new(),
|
||
Default::default(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"test-channel".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
let history = session.get_history("chat-1").unwrap();
|
||
// 新设计:系统提示词不再持久化到历史记录,而是每次请求时动态注入
|
||
assert_eq!(history.len(), 0);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
HashSet::new(),
|
||
Default::default(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"test-channel".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
100,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
for turn in 0..100 {
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
|
||
.unwrap();
|
||
}
|
||
|
||
session
|
||
.ensure_agent_prompt_before_user_message("chat-1")
|
||
.unwrap();
|
||
|
||
// 新设计:系统提示词不再持久化到历史记录
|
||
let history = session.get_history("chat-1").unwrap();
|
||
let user_messages = history
|
||
.iter()
|
||
.filter(|message| message.role == "user")
|
||
.count();
|
||
assert_eq!(user_messages, 100);
|
||
|
||
// 注入计数在实际处理请求时由 AgentPromptProvider 更新
|
||
// 此处仅为模拟调用,不会触发实际注入
|
||
let stored = store
|
||
.get_session(&session.persistent_session_id("chat-1"))
|
||
.unwrap()
|
||
.unwrap();
|
||
// 初始值为 0,只有在实际 process 调用时才会更新
|
||
assert_eq!(stored.agent_prompt_reinjection_count, 0);
|
||
|
||
session
|
||
.ensure_agent_prompt_before_user_message("chat-1")
|
||
.unwrap();
|
||
let history = session.get_history("chat-1").unwrap();
|
||
let user_messages = history
|
||
.iter()
|
||
.filter(|message| message.role == "user")
|
||
.count();
|
||
assert_eq!(user_messages, 100);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let tools = Arc::new(
|
||
ToolRegistryFactory::new(
|
||
skills.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
store.clone(),
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
HashSet::new(),
|
||
Default::default(),
|
||
)
|
||
.build(),
|
||
);
|
||
let mut session = Session::new(
|
||
"test-channel".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
skills,
|
||
store.clone(),
|
||
0,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
|
||
for turn in 0..100 {
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
|
||
.unwrap();
|
||
}
|
||
|
||
session
|
||
.ensure_agent_prompt_before_user_message("chat-1")
|
||
.unwrap();
|
||
|
||
// 新设计:系统提示词不再持久化到历史记录
|
||
let history = session.get_history("chat-1").unwrap();
|
||
let user_messages = history
|
||
.iter()
|
||
.filter(|message| message.role == "user")
|
||
.count();
|
||
assert_eq!(user_messages, 100);
|
||
}
|
||
|
||
#[test]
|
||
fn test_default_tools_registers_get_time() {
|
||
let skills = Arc::new(SkillRuntime::default());
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let tools = ToolRegistryFactory::new(
|
||
skills,
|
||
store.clone(),
|
||
store.clone(),
|
||
store,
|
||
Arc::new(NoopSessionMessageSender),
|
||
HashSet::new(),
|
||
"Asia/Shanghai".to_string(),
|
||
HashSet::new(),
|
||
Default::default(),
|
||
)
|
||
.build();
|
||
|
||
assert!(tools.tool_names().iter().any(|name| name == "get_time"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_memory_maintenance_plan_deduplicates_and_categorizes() {
|
||
let memories = vec![
|
||
MemoryRecord {
|
||
id: "1".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 1,
|
||
updated_at: 1,
|
||
},
|
||
MemoryRecord {
|
||
id: "2".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 2,
|
||
updated_at: 2,
|
||
},
|
||
MemoryRecord {
|
||
id: "3".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "user".to_string(),
|
||
memory_key: "style".to_string(),
|
||
content: "偏好简洁表达".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 3,
|
||
updated_at: 3,
|
||
},
|
||
MemoryRecord {
|
||
id: "4".to_string(),
|
||
scope_kind: "user".to_string(),
|
||
scope_key: "feishu:user-1".to_string(),
|
||
namespace: "patterns".to_string(),
|
||
memory_key: "workflow".to_string(),
|
||
content: "习惯先问方案再要代码".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: None,
|
||
source_message_id: None,
|
||
source_message_seq: None,
|
||
source_channel_name: None,
|
||
source_chat_id: None,
|
||
created_at: 4,
|
||
updated_at: 4,
|
||
},
|
||
];
|
||
|
||
let plan = build_memory_maintenance_plan(&memories);
|
||
// 去重后应该有3条(第1、2条重复)
|
||
assert_eq!(plan.candidates.len(), 3);
|
||
// 验证内容包含所有唯一的记忆
|
||
let contents: Vec<String> = plan.candidates.iter().map(|c| c.content.clone()).collect();
|
||
assert!(contents.contains(&"用户在做AI产品".to_string()));
|
||
assert!(contents.contains(&"偏好简洁表达".to_string()));
|
||
assert!(contents.contains(&"习惯先问方案再要代码".to_string()));
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl crate::scheduler::MaintenanceExecutor for SessionManager {
|
||
async fn cleanup_expired_sessions(&self) -> usize {
|
||
self.cleanup_expired_sessions().await
|
||
}
|
||
|
||
async fn run_memory_maintenance_for_all_scopes(
|
||
&self,
|
||
) -> anyhow::Result<Vec<crate::scheduler::MaintenanceRunSummary>> {
|
||
match self.run_memory_maintenance_for_all_scopes().await {
|
||
Ok(Some(result)) => Ok(vec![crate::scheduler::MaintenanceRunSummary {
|
||
scope_key: result.scope_key,
|
||
merges: result.output.merges.len(),
|
||
conflicts: result.output.conflicts.len(),
|
||
low_value: result.output.low_value_ids.len(),
|
||
}]),
|
||
Ok(None) => Ok(vec![]),
|
||
Err(error) => Err(anyhow::anyhow!(error.to_string())),
|
||
}
|
||
}
|
||
}
|