From 008aba91acb8ab4f6dab9372a53a2f875725ddbc Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 28 Apr 2026 12:55:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E5=99=A8=E4=BB=A5=E4=BD=BF=E7=94=A8=20AgentTaskExecutor=20?= =?UTF-8?q?=E5=92=8C=20SchedulerMaintenanceService?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新调度器,将 SessionManager 替换为 AgentTaskExecutor 和 SchedulerMaintenanceService。 - 修改作业执行逻辑,使用新服务处理代理任务和内部事件。 - 添加新的 CliChannel 以处理 CLI 连接,并包括适当的注册和注销逻辑。 - 引入 AgentTaskExecutor 和 SchedulerMaintenanceService,用于管理代理任务和会话维护。 - 实现聊天命令处理,用于重置会话上下文。 - 添加后台历史压缩功能,以优化会话存储。 - 创建实用函数,用于准备通过 WebSocket 通信的出站消息。 - 为新功能添加测试,并确保现有测试通过。 Co-authored-by: Copilot --- README.md | 12 +- src/channels/cli.rs | 155 +++++++++ src/channels/manager.rs | 13 +- src/channels/mod.rs | 2 + src/gateway/agent_task_executor.rs | 52 +++ src/gateway/command.rs | 159 ++++++++++ src/gateway/compaction.rs | 105 +++++++ src/gateway/execution.rs | 8 +- src/gateway/message_prepare.rs | 39 +++ src/gateway/mod.rs | 9 +- src/gateway/processor.rs | 17 +- src/gateway/session.rs | 274 +--------------- src/gateway/ws.rs | 487 ++++++----------------------- src/gateway/ws_adapter.rs | 229 ++++++++++++++ src/scheduler/mod.rs | 52 +-- 15 files changed, 927 insertions(+), 686 deletions(-) create mode 100644 src/channels/cli.rs create mode 100644 src/gateway/agent_task_executor.rs create mode 100644 src/gateway/command.rs create mode 100644 src/gateway/compaction.rs create mode 100644 src/gateway/message_prepare.rs create mode 100644 src/gateway/ws_adapter.rs diff --git a/README.md b/README.md index a2114e6..2b7f579 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ PicoBot 是一个用 Rust 构建的多通道 Agent 网关。它把消息接入 PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个可持续运行的 Agent 基础设施: - 消息从不同渠道进入统一总线 -- SessionManager 负责会话路由、上下文恢复、工具执行和回复生成 +- SessionManager 负责会话路由和运行时服务编排,AgentExecutionService 负责上下文准备、AgentLoop 执行、结果持久化和回复生成 - SQLite 作为事实来源保存跨重启状态 - Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务 @@ -37,13 +37,13 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个 主要模块如下: -- src/gateway:网关入口、HTTP 健康检查、WebSocket 服务、Session 池、CLI 会话服务与 Agent 执行编排 +- src/gateway:网关入口、HTTP 健康检查、WebSocket 控制面、Session 池、CLI 会话服务与 Agent 执行编排 - src/bus:消息总线与消息结构定义 - src/agent:AgentLoop 与上下文压缩器 - src/providers:不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic - src/tools:内置工具集合 - src/storage:SQLite 持久化实现 -- src/channels:渠道适配层,当前已有飞书通道 +- src/channels:渠道适配层,当前已有 CLI 与飞书通道 - src/scheduler:数据库驱动的计划任务调度器 - src/skills:技能发现、加载与运行时管理 - src/client / src/cli:本地 CLI 客户端和交互命令 @@ -632,11 +632,11 @@ PicoBot/ ├── src/ │ ├── agent/ # AgentLoop、上下文压缩 │ ├── bus/ # 消息总线与消息结构 -│ ├── channels/ # 渠道适配 +│ ├── channels/ # CLI / 飞书等渠道适配 │ ├── cli/ # CLI 输入命令 │ ├── client/ # WebSocket CLI 客户端 │ ├── config/ # 配置解析 -│ ├── gateway/ # Gateway、SessionManager、WS/HTTP +│ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面 │ ├── providers/ # OpenAI / Anthropic Provider │ ├── scheduler/ # 定时任务系统 │ ├── skills/ # 技能运行时 @@ -656,7 +656,7 @@ PicoBot/ 建议维护时重点关注: - docs/PERSISTENCE.md:持久化结构是否与代码一致 -- src/gateway/session.rs:消息流、工具注册、记忆维护、会话恢复主逻辑 +- src/gateway/session.rs:会话状态、会话路由和运行时服务编排 - src/storage/mod.rs:SQLite schema 变更 - src/config/mod.rs:配置项变更是否同步到 README diff --git a/src/channels/cli.rs b/src/channels/cli.rs new file mode 100644 index 0000000..4571ced --- /dev/null +++ b/src/channels/cli.rs @@ -0,0 +1,155 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{RwLock, mpsc}; + +use crate::bus::{MessageBus, OutboundMessage}; +use crate::gateway::ws_adapter::ws_outbound_from_outbound_message; +use crate::protocol::WsOutbound; + +use super::base::{Channel, ChannelError}; + +#[derive(Clone)] +struct CliConnection { + connection_id: String, + sender: mpsc::Sender, +} + +#[derive(Clone)] +pub struct CliChannel { + connections: Arc>>, +} + +impl CliChannel { + pub fn new() -> Self { + Self { + connections: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn register_connection( + &self, + session_id: impl Into, + connection_id: impl Into, + sender: mpsc::Sender, + ) { + let session_id = session_id.into(); + let connection_id = connection_id.into(); + let previous = self.connections.write().await.insert( + session_id.clone(), + CliConnection { + connection_id: connection_id.clone(), + sender, + }, + ); + + if previous.is_some() { + tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced"); + } + } + + pub async fn unregister_connection(&self, connection_id: &str) { + self.connections + .write() + .await + .retain(|_, connection| connection.connection_id != connection_id); + } +} + +impl Default for CliChannel { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Channel for CliChannel { + fn name(&self) -> &str { + "cli" + } + + fn is_running(&self) -> bool { + true + } + + async fn start(&self, _bus: Arc) -> Result<(), ChannelError> { + Ok(()) + } + + async fn stop(&self) -> Result<(), ChannelError> { + self.connections.write().await.clear(); + Ok(()) + } + + async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { + let connection = self.connections.read().await.get(&msg.chat_id).cloned(); + let Some(connection) = connection else { + return Err(ChannelError::SendError(format!( + "No active CLI connection for session {}", + msg.chat_id + ))); + }; + + for outbound in ws_outbound_from_outbound_message(&msg) { + connection + .sender + .send(outbound) + .await + .map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bus::OutboundMessage; + + #[tokio::test] + async fn test_cli_channel_sends_to_registered_session() { + let channel = CliChannel::new(); + let (sender, mut receiver) = mpsc::channel(4); + channel + .register_connection("session-1", "conn-1", sender) + .await; + + channel + .send(OutboundMessage::assistant( + "cli", + "session-1", + "hello", + None, + HashMap::new(), + )) + .await + .unwrap(); + + let outbound = receiver.recv().await.unwrap(); + assert!(matches!(outbound, WsOutbound::AssistantResponse { .. })); + } + + #[tokio::test] + async fn test_cli_channel_unregisters_connection_sessions() { + let channel = CliChannel::new(); + let (sender, _receiver) = mpsc::channel(4); + channel + .register_connection("session-1", "conn-1", sender) + .await; + channel.unregister_connection("conn-1").await; + + let error = channel + .send(OutboundMessage::assistant( + "cli", + "session-1", + "hello", + None, + HashMap::new(), + )) + .await + .unwrap_err(); + + assert!(error.to_string().contains("No active CLI connection")); + } +} diff --git a/src/channels/manager.rs b/src/channels/manager.rs index ae5a932..4a6ea36 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -4,6 +4,7 @@ use tokio::sync::RwLock; use crate::bus::MessageBus; use crate::channels::base::{Channel, ChannelError}; +use crate::channels::cli::CliChannel; use crate::channels::feishu::FeishuChannel; use crate::config::Config; @@ -12,13 +13,19 @@ use crate::config::Config; pub struct ChannelManager { channels: Arc>>>, bus: Arc, + cli_channel: Arc, } impl ChannelManager { pub fn new() -> Self { + let cli_channel = Arc::new(CliChannel::new()); + let mut channels: HashMap> = HashMap::new(); + channels.insert("cli".to_string(), cli_channel.clone()); + Self { - channels: Arc::new(RwLock::new(HashMap::new())), + channels: Arc::new(RwLock::new(channels)), bus: MessageBus::new(100), + cli_channel, } } @@ -27,6 +34,10 @@ impl ChannelManager { self.bus.clone() } + pub fn cli_channel(&self) -> Arc { + self.cli_channel.clone() + } + /// Initialize all Channel instances from config pub async fn init( &self, diff --git a/src/channels/mod.rs b/src/channels/mod.rs index e6a94de..53ce40d 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1,7 +1,9 @@ pub mod base; +pub mod cli; pub mod feishu; pub mod manager; pub use base::{Channel, ChannelError}; +pub use cli::CliChannel; pub use feishu::FeishuChannel; pub use manager::ChannelManager; diff --git a/src/gateway/agent_task_executor.rs b/src/gateway/agent_task_executor.rs new file mode 100644 index 0000000..a77376b --- /dev/null +++ b/src/gateway/agent_task_executor.rs @@ -0,0 +1,52 @@ +use crate::agent::AgentError; +use crate::bus::OutboundMessage; + +use super::memory_maintenance::MemoryMaintenanceScopeResult; +use super::session::{ScheduledAgentTaskOptions, SessionManager}; + +#[derive(Clone)] +pub struct AgentTaskExecutor { + session_manager: SessionManager, +} + +impl AgentTaskExecutor { + pub fn new(session_manager: SessionManager) -> Self { + Self { session_manager } + } + + pub(crate) async fn execute( + &self, + channel_name: &str, + chat_id: &str, + prompt: &str, + options: ScheduledAgentTaskOptions, + ) -> Result, AgentError> { + self.session_manager + .run_scheduled_agent_task(channel_name, chat_id, prompt, options) + .await + } +} + +#[derive(Clone)] +pub struct SchedulerMaintenanceService { + session_manager: SessionManager, +} + +impl SchedulerMaintenanceService { + pub fn new(session_manager: SessionManager) -> Self { + Self { session_manager } + } + + pub(crate) async fn cleanup_expired_sessions(&self) -> usize { + self.session_manager.cleanup_expired_sessions().await + } + + pub(crate) async fn run_memory_maintenance_for_all_scopes( + &self, + updated_since: Option, + ) -> Result, AgentError> { + self.session_manager + .run_memory_maintenance_for_all_scopes(updated_since) + .await + } +} diff --git a/src/gateway/command.rs b/src/gateway/command.rs new file mode 100644 index 0000000..3763c58 --- /dev/null +++ b/src/gateway/command.rs @@ -0,0 +1,159 @@ +use crate::agent::AgentError; + +use super::session::Session; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InChatCommand { + FreshConversation, +} + +fn parse_in_chat_command(content: &str) -> Option { + match content.trim() { + "/new" | "/reset" => Some(InChatCommand::FreshConversation), + _ => None, + } +} + +pub(crate) fn handle_in_chat_command( + session: &mut Session, + chat_id: &str, + content: &str, +) -> Result, AgentError> { + match parse_in_chat_command(content) { + Some(InChatCommand::FreshConversation) => { + session.reset_chat_context(chat_id)?; + Ok(Some("Started a fresh conversation.".to_string())) + } + None => Ok(None), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bus::ChatMessage; + use crate::config::LLMProviderConfig; + use crate::skills::SkillRuntime; + use crate::storage::SessionStore; + use crate::tools::ToolRegistry; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::sync::mpsc; + + fn test_provider_config() -> LLMProviderConfig { + LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test".to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 120, + model_id: "test-model".to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + context_window_tokens: None, + model_extra: HashMap::new(), + max_tool_iterations: 1, + tool_result_max_chars: 20_000, + context_tool_result_trim_chars: 20_000, + } + } + + #[test] + fn test_parse_in_chat_command_aliases() { + assert_eq!( + parse_in_chat_command("/new"), + Some(InChatCommand::FreshConversation) + ); + assert_eq!( + parse_in_chat_command(" /reset \n"), + Some(InChatCommand::FreshConversation) + ); + assert_eq!(parse_in_chat_command("/new planning"), None); + assert_eq!(parse_in_chat_command("please /reset"), None); + } + + #[tokio::test] + async fn test_handle_in_chat_command_resets_active_history_only() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(ToolRegistry::new()); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store.clone(), + 100, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + session + .append_persisted_message("chat-1", ChatMessage::user("hello")) + .unwrap(); + + let response = handle_in_chat_command(&mut session, "chat-1", "/reset") + .unwrap() + .unwrap(); + + assert_eq!(response, "Started a fresh conversation."); + assert!(session.get_history("chat-1").unwrap().is_empty()); + assert!( + store + .load_messages(&session.persistent_session_id("chat-1")) + .unwrap() + .is_empty() + ); + assert_eq!( + store + .load_all_messages(&session.persistent_session_id("chat-1")) + .unwrap() + .len(), + 2, + ); + + session.ensure_chat_loaded("chat-1").unwrap(); + let history = session.get_history("chat-1").unwrap(); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + } + + #[tokio::test] + async fn test_reset_reinjects_agent_prompt_before_next_user_message() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(ToolRegistry::new()); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store, + 100, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + session + .append_persisted_message("chat-1", ChatMessage::user("hello")) + .unwrap(); + + handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap(); + session + .ensure_agent_prompt_before_user_message("chat-1") + .unwrap(); + + let history = session.get_history("chat-1").unwrap(); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + } +} diff --git a/src/gateway/compaction.rs b/src/gateway/compaction.rs new file mode 100644 index 0000000..ffe8e56 --- /dev/null +++ b/src/gateway/compaction.rs @@ -0,0 +1,105 @@ +use std::sync::Arc; + +use tokio::sync::Mutex; + +use crate::agent::AgentError; + +use super::session::Session; + +pub(crate) async fn schedule_background_history_compaction( + session: Arc>, + chat_id: impl Into, +) -> Result<(), AgentError> { + let chat_id = chat_id.into(); + + let snapshot = { + let mut session_guard = session.lock().await; + let session_record = session_guard.ensure_persistent_session(&chat_id)?; + session_guard.ensure_chat_loaded(&chat_id)?; + + let history = session_guard.get_or_create_history(&chat_id).clone(); + let compressor = session_guard.compressor().clone(); + if !compressor.should_compress(&history) { + return Ok(()); + } + + if !session_guard.try_start_background_compaction(&chat_id) { + return Ok(()); + } + + ( + session_guard.store(), + session_guard.persistent_session_id(&chat_id), + session_record.reset_cutoff_seq, + session_record.message_count, + history, + compressor, + session_guard.provider_config().clone(), + ) + }; + + let ( + store, + session_id, + expected_reset_cutoff_seq, + snapshot_end_seq, + history, + compressor, + provider_config, + ) = snapshot; + let session_for_task = session.clone(); + let chat_id_for_task = chat_id.clone(); + + tokio::spawn(async move { + tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction"); + + let compaction_result = compressor + .build_compaction_plan(&history, &provider_config) + .await; + let mut committed = false; + + match compaction_result { + Ok(Some(plan)) => match store.compact_active_history( + &session_id, + expected_reset_cutoff_seq, + snapshot_end_seq, + &plan.preserved_system_messages, + &plan.summary_message, + &plan.preserved_messages, + ) { + Ok(true) => { + committed = true; + tracing::info!( + chat_id = %chat_id_for_task, + snapshot_end_seq, + compressed_turns = plan.compressed_turns, + preserved_turns = plan.preserved_turns, + "Background history compaction committed" + ); + } + Ok(false) => { + tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot"); + } + Err(error) => { + tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed"); + } + }, + Ok(None) => { + tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis"); + } + Err(error) => { + tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed"); + } + } + + let mut session_guard = session_for_task.lock().await; + if committed { + if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) { + tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction"); + } + } + session_guard.finish_background_compaction(&chat_id_for_task); + }); + + Ok(()) +} diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index 4953aa8..1988bbc 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -7,10 +7,10 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL use crate::config::LLMProviderConfig; use tokio::sync::Mutex; -use super::session::{ - Session, enrich_user_content_with_media_refs, handle_in_chat_command, - schedule_background_history_compaction, -}; +use super::command::handle_in_chat_command; +use super::compaction::schedule_background_history_compaction; +use super::message_prepare::enrich_user_content_with_media_refs; +use super::session::Session; const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。"; diff --git a/src/gateway/message_prepare.rs b/src/gateway/message_prepare.rs new file mode 100644 index 0000000..fdc7ff5 --- /dev/null +++ b/src/gateway/message_prepare.rs @@ -0,0 +1,39 @@ +use crate::agent::AgentError; + +pub(crate) fn enrich_user_content_with_media_refs( + content: &str, + media_refs: &[String], +) -> Result { + if media_refs.is_empty() { + return Ok(content.to_string()); + } + + let media_refs_json = serde_json::to_string(media_refs) + .map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?; + + Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_enrich_user_content_with_media_refs_appends_tagged_json() { + let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()]; + + let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap(); + + assert_eq!( + enriched, + "hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]" + ); + } + + #[test] + fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() { + let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap(); + + assert_eq!(enriched, "hello"); + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 00c1991..e578c14 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,13 +1,18 @@ +pub mod agent_task_executor; pub mod cli_session; +pub mod command; +pub mod compaction; pub mod execution; pub mod http; pub mod memory_maintenance; +pub mod message_prepare; pub mod processor; pub mod prompt; pub mod session; pub mod session_factory; pub mod session_pool; pub mod ws; +pub mod ws_adapter; use axum::{Router, routing}; use std::collections::HashMap; @@ -21,6 +26,7 @@ use crate::config::LLMProviderConfig; use crate::logging; use crate::scheduler::Scheduler; use crate::skills::SkillRuntime; +use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; use processor::InboundProcessor; use session::SessionManager; @@ -122,7 +128,8 @@ pub async fn run( state.config.scheduler.clone(), timezone, state.session_manager.store(), - state.session_manager.clone(), + AgentTaskExecutor::new(state.session_manager.clone()), + SchedulerMaintenanceService::new(state.session_manager.clone()), ); tokio::spawn(async move { diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 466602d..eda2812 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::bus::MessageBus; +use crate::bus::{MessageBus, OutboundMessage}; use super::session::{BusToolCallEmitter, SessionManager}; @@ -70,6 +70,21 @@ impl InboundProcessor { } Err(error) => { tracing::error!(error = %error, "Failed to handle message"); + let mut metadata = inbound.forwarded_metadata.clone(); + metadata.insert("error_kind".to_string(), "agent_execution".to_string()); + if let Err(publish_error) = self + .bus + .publish_outbound(OutboundMessage::error_notification( + inbound.channel, + inbound.chat_id, + error.to_string(), + None, + metadata, + )) + .await + { + tracing::error!(error = %publish_error, "Failed to publish execution error outbound"); + } } } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index eccab60..2e9a54f 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -43,20 +43,6 @@ fn preview_text(content: &str, max_chars: usize) -> String { preview.replace('\n', "\\n") } -pub(crate) fn enrich_user_content_with_media_refs( - content: &str, - media_refs: &[String], -) -> Result { - if media_refs.is_empty() { - return Ok(content.to_string()); - } - - let media_refs_json = serde_json::to_string(media_refs) - .map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?; - - Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}")) -} - /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 pub struct Session { @@ -393,15 +379,15 @@ impl Session { &self.compressor } - fn try_start_background_compaction(&mut self, chat_id: &str) -> bool { + pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool { self.compression_in_flight.insert(chat_id.to_string()) } - fn finish_background_compaction(&mut self, chat_id: &str) { + pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) { self.compression_in_flight.remove(chat_id); } - fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { + pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { let history = self .store .load_messages(&self.persistent_session_id(chat_id)) @@ -410,6 +396,10 @@ impl Session { Ok(()) } + pub(crate) fn store(&self) -> Arc { + self.store.clone() + } + pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> { if self.skills.is_empty() { return Ok(()); @@ -528,129 +518,6 @@ fn default_tools( registry } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum InChatCommand { - FreshConversation, -} - -fn parse_in_chat_command(content: &str) -> Option { - match content.trim() { - "/new" | "/reset" => Some(InChatCommand::FreshConversation), - _ => None, - } -} - -pub(crate) fn handle_in_chat_command( - session: &mut Session, - chat_id: &str, - content: &str, -) -> Result, AgentError> { - match parse_in_chat_command(content) { - Some(InChatCommand::FreshConversation) => { - session.reset_chat_context(chat_id)?; - Ok(Some("Started a fresh conversation.".to_string())) - } - None => Ok(None), - } -} - -pub(crate) async fn schedule_background_history_compaction( - session: Arc>, - chat_id: impl Into, -) -> Result<(), AgentError> { - let chat_id = chat_id.into(); - - let snapshot = { - let mut session_guard = session.lock().await; - let session_record = session_guard.ensure_persistent_session(&chat_id)?; - session_guard.ensure_chat_loaded(&chat_id)?; - - let history = session_guard.get_or_create_history(&chat_id).clone(); - let compressor = session_guard.compressor().clone(); - if !compressor.should_compress(&history) { - return Ok(()); - } - - if !session_guard.try_start_background_compaction(&chat_id) { - return Ok(()); - } - - ( - session_guard.store.clone(), - session_guard.persistent_session_id(&chat_id), - session_record.reset_cutoff_seq, - session_record.message_count, - history, - compressor, - session_guard.provider_config().clone(), - ) - }; - - let ( - store, - session_id, - expected_reset_cutoff_seq, - snapshot_end_seq, - history, - compressor, - provider_config, - ) = snapshot; - let session_for_task = session.clone(); - let chat_id_for_task = chat_id.clone(); - - tokio::spawn(async move { - tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction"); - - let compaction_result = compressor - .build_compaction_plan(&history, &provider_config) - .await; - let mut committed = false; - - match compaction_result { - Ok(Some(plan)) => match store.compact_active_history( - &session_id, - expected_reset_cutoff_seq, - snapshot_end_seq, - &plan.preserved_system_messages, - &plan.summary_message, - &plan.preserved_messages, - ) { - Ok(true) => { - committed = true; - tracing::info!( - chat_id = %chat_id_for_task, - snapshot_end_seq, - compressed_turns = plan.compressed_turns, - preserved_turns = plan.preserved_turns, - "Background history compaction committed" - ); - } - Ok(false) => { - tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot"); - } - Err(error) => { - tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed"); - } - }, - Ok(None) => { - tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis"); - } - Err(error) => { - tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed"); - } - } - - let mut session_guard = session_for_task.lock().await; - if committed { - if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) { - tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction"); - } - } - session_guard.finish_background_compaction(&chat_id_for_task); - }); - - Ok(()) -} impl SessionManager { pub fn new( session_ttl_hours: u64, @@ -913,25 +780,6 @@ mod tests { } } - #[test] - fn test_enrich_user_content_with_media_refs_appends_tagged_json() { - let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()]; - - let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap(); - - assert_eq!( - enriched, - "hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]" - ); - } - - #[test] - fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() { - let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap(); - - assert_eq!(enriched, "hello"); - } - #[tokio::test] async fn test_latest_user_message_guard_tracks_current_turn() { let store = Arc::new(SessionStore::in_memory().unwrap()); @@ -1750,75 +1598,6 @@ mod tests { ); } - #[test] - fn test_parse_in_chat_command_aliases() { - assert_eq!( - parse_in_chat_command("/new"), - Some(InChatCommand::FreshConversation) - ); - assert_eq!( - parse_in_chat_command(" /reset \n"), - Some(InChatCommand::FreshConversation) - ); - assert_eq!(parse_in_chat_command("/new planning"), None); - assert_eq!(parse_in_chat_command("please /reset"), None); - } - - #[tokio::test] - async fn test_handle_in_chat_command_resets_active_history_only() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); - let mut session = Session::new( - "feishu".to_string(), - test_provider_config(), - user_tx, - tools, - skills, - store.clone(), - 100, - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - let response = handle_in_chat_command(&mut session, "chat-1", "/reset") - .unwrap() - .unwrap(); - - assert_eq!(response, "Started a fresh conversation."); - assert!(session.get_history("chat-1").unwrap().is_empty()); - assert!( - store - .load_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .is_empty() - ); - assert_eq!( - store - .load_all_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .len(), - 2, - ); - - session.ensure_chat_loaded("chat-1").unwrap(); - let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 1); - assert_eq!(history[0].role, "system"); - } - #[tokio::test] async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() { let store = Arc::new(SessionStore::in_memory().unwrap()); @@ -1955,45 +1734,6 @@ mod tests { assert_eq!(system_messages, 1); } - #[tokio::test] - async fn test_reset_reinjects_agent_prompt_before_next_user_message() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools( - skills.clone(), - store.clone(), - HashSet::new(), - "Asia/Shanghai".to_string(), - )); - let mut session = Session::new( - "feishu".to_string(), - test_provider_config(), - user_tx, - tools, - skills, - store.clone(), - 100, - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap(); - session - .ensure_agent_prompt_before_user_message("chat-1") - .unwrap(); - - let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 1); - assert_eq!(history[0].role, "system"); - } - #[test] fn test_default_tools_registers_get_time() { let skills = Arc::new(SkillRuntime::default()); diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 87ce66f..ca0be31 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,40 +1,17 @@ -use super::{ - GatewayState, - execution::{AgentExecutionService, MessageExecutionRequest, should_display_message_to_user}, - session::Session, -}; -use crate::agent::{AgentError, EmittedMessageHandler}; -use crate::bus::message::{OutboundEventKind, ToolMessageState, format_tool_call_content}; -use crate::bus::{ChatMessage, OutboundMessage}; +use super::GatewayState; +use crate::agent::AgentError; +use crate::bus::InboundMessage; use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; -use async_trait::async_trait; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; +use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::{Mutex, mpsc}; +use tokio::sync::mpsc; const CLI_CHANNEL_NAME: &str = "cli"; -struct WsToolCallEmitter { - sender: mpsc::Sender, - show_tool_results: bool, -} - -#[async_trait] -impl EmittedMessageHandler for WsToolCallEmitter { - async fn handle(&self, message: ChatMessage) { - if !should_display_message_to_user(self.show_tool_results, &message) { - return; - } - - for outbound in ws_outbound_from_chat_message(&message) { - let _ = self.sender.send(outbound).await; - } - } -} - pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; @@ -44,14 +21,6 @@ pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State) { let (sender, receiver) = mpsc::channel::(100); - let provider_config = match state.config.get_provider_config("default") { - Ok(cfg) => cfg, - Err(e) => { - tracing::error!(error = %e, "Failed to get provider config"); - return; - } - }; - let cli_sessions = state.session_manager.cli_sessions(); let initial_record = match cli_sessions.create(None) { Ok(record) => record, @@ -61,39 +30,20 @@ async fn handle_socket(ws: WebSocket, state: Arc) { } }; - let channel_name = CLI_CHANNEL_NAME.to_string(); - - // 创建 CLI session - let session = match Session::new( - channel_name.clone(), - provider_config, - sender, - state.session_manager.tools(), - state.session_manager.skills(), - state.session_manager.store(), - state.config.gateway.agent_prompt_reinject_every, - ) - .await - { - Ok(s) => Arc::new(Mutex::new(s)), - Err(e) => { - tracing::error!(error = %e, "Failed to create session"); - return; - } - }; - - if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) { - tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history"); - return; - } - - let runtime_session_id = session.lock().await.id.to_string(); + let runtime_session_id = uuid::Uuid::new_v4().to_string(); let mut current_session_id = initial_record.id.clone(); + state + .channel_manager + .cli_channel() + .register_connection( + current_session_id.clone(), + runtime_session_id.clone(), + sender.clone(), + ) + .await; tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established"); - let _ = session - .lock() - .await + let _ = sender .send(WsOutbound::SessionEstablished { session_id: current_session_id.clone(), }) @@ -123,7 +73,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { Ok(inbound) => { if let Err(e) = handle_inbound( &state, - &session, + &sender, &runtime_session_id, &mut current_session_id, inbound, @@ -131,9 +81,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { .await { tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); - let _ = session - .lock() - .await + let _ = sender .send(WsOutbound::Error { code: "SESSION_ERROR".to_string(), message: e.to_string(), @@ -143,9 +91,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { } Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); - let _ = session - .lock() - .await + let _ = sender .send(WsOutbound::Error { code: "PARSE_ERROR".to_string(), message: e.to_string(), @@ -163,6 +109,11 @@ async fn handle_socket(ws: WebSocket, state: Arc) { } } + state + .channel_manager + .cli_channel() + .unregister_connection(&runtime_session_id) + .await; tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); } @@ -178,115 +129,9 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { } } -fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { - match message.role.as_str() { - "assistant" => { - if let Some(tool_calls) = &message.tool_calls { - let mut outbound = Vec::new(); - if !message.content.trim().is_empty() { - outbound.push(WsOutbound::AssistantResponse { - id: message.id.clone(), - content: message.content.clone(), - role: message.role.clone(), - }); - } - - outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall { - id: message.id.clone(), - tool_call_id: tool_call.id.clone(), - tool_name: tool_call.name.clone(), - arguments: tool_call.arguments.clone(), - content: format_tool_call_content(&tool_call.name, &tool_call.arguments), - role: message.role.clone(), - })); - outbound - } else { - vec![WsOutbound::AssistantResponse { - id: message.id.clone(), - content: message.content.clone(), - role: message.role.clone(), - }] - } - } - "tool" => match message - .tool_state - .as_ref() - .unwrap_or(&ToolMessageState::Completed) - { - ToolMessageState::Completed => vec![WsOutbound::ToolResult { - id: message.id.clone(), - tool_call_id: message.tool_call_id.clone().unwrap_or_default(), - tool_name: message.tool_name.clone().unwrap_or_default(), - content: message.content.clone(), - role: message.role.clone(), - }], - ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending { - id: message.id.clone(), - tool_call_id: message.tool_call_id.clone().unwrap_or_default(), - tool_name: message.tool_name.clone().unwrap_or_default(), - content: message.content.clone(), - role: message.role.clone(), - resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(), - }], - }, - _ => Vec::new(), - } -} - -fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec { - match message.event_kind { - OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => { - vec![WsOutbound::AssistantResponse { - id: uuid::Uuid::new_v4().to_string(), - content: message.content.clone(), - role: message.role.clone(), - }] - } - OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall { - id: message - .tool_call_id - .clone() - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - tool_call_id: message.tool_call_id.clone().unwrap_or_default(), - tool_name: message.tool_name.clone().unwrap_or_default(), - arguments: message - .tool_arguments - .clone() - .unwrap_or(serde_json::Value::Null), - content: message.content.clone(), - role: message.role.clone(), - }], - OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult { - id: message - .tool_call_id - .clone() - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - tool_call_id: message.tool_call_id.clone().unwrap_or_default(), - tool_name: message.tool_name.clone().unwrap_or_default(), - content: message.content.clone(), - role: message.role.clone(), - }], - OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending { - id: message - .tool_call_id - .clone() - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - tool_call_id: message.tool_call_id.clone().unwrap_or_default(), - tool_name: message.tool_name.clone().unwrap_or_default(), - content: message.content.clone(), - role: message.role.clone(), - resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(), - }], - OutboundEventKind::ErrorNotification => vec![WsOutbound::Error { - code: "AGENT_ERROR".to_string(), - message: message.content.clone(), - }], - } -} - async fn handle_inbound( state: &Arc, - session: &Arc>, + sender: &mpsc::Sender, runtime_session_id: &str, current_session_id: &mut String, inbound: WsInbound, @@ -300,43 +145,31 @@ async fn handle_inbound( } => { let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); - let user_tx = session.lock().await.user_tx.clone(); - let live_emitter = Arc::new(WsToolCallEmitter { - sender: user_tx.clone(), - show_tool_results: state.config.gateway.show_tool_results, - }); - match AgentExecutionService::new(state.config.gateway.show_tool_results) - .prepare_and_execute_message(MessageExecutionRequest { - session: session.clone(), - channel_name: CLI_CHANNEL_NAME, - sender_id: &sender_id, - chat_id: &chat_id, - content: &content, + state + .channel_manager + .cli_channel() + .register_connection( + chat_id.clone(), + runtime_session_id.to_string(), + sender.clone(), + ) + .await; + + state + .bus + .publish_inbound(InboundMessage { + channel: CLI_CHANNEL_NAME.to_string(), + sender_id, + chat_id, + content, + timestamp: current_timestamp(), media: Vec::new(), - live_emitter: Some(live_emitter), + metadata: HashMap::new(), + forwarded_metadata: HashMap::new(), }) .await - { - Ok(outbound_messages) => { - for outbound in outbound_messages - .iter() - .flat_map(ws_outbound_from_outbound_message) - { - let _ = user_tx.send(outbound).await; - } - } - Err(AgentError::LlmError(error)) => { - tracing::error!(chat_id = %chat_id, error = %error, "Agent process error"); - let _ = user_tx - .send(WsOutbound::Error { - code: "LLM_ERROR".to_string(), - message: error, - }) - .await; - } - Err(error) => return Err(error), - } + .map_err(|error| AgentError::Other(error.to_string()))?; Ok(()) } @@ -352,9 +185,11 @@ async fn handle_inbound( .cli_sessions() .clear_messages(&target)?; - let mut session_guard = session.lock().await; - session_guard.remove_history(&target); - let _ = session_guard + if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { + session.lock().await.remove_history(&target); + } + + let _ = sender .send(WsOutbound::HistoryCleared { session_id: target }) .await; Ok(()) @@ -366,9 +201,16 @@ async fn handle_inbound( .create(title.as_deref())?; *current_session_id = record.id.clone(); - let mut session_guard = session.lock().await; - session_guard.ensure_chat_loaded(&record.id)?; - let _ = session_guard + state + .channel_manager + .cli_channel() + .register_connection( + record.id.clone(), + runtime_session_id.to_string(), + sender.clone(), + ) + .await; + let _ = sender .send(WsOutbound::SessionCreated { session_id: record.id, title: record.title, @@ -383,8 +225,7 @@ async fn handle_inbound( .list(include_archived)?; let summaries = records.into_iter().map(to_session_summary).collect(); - let session_guard = session.lock().await; - let _ = session_guard + let _ = sender .send(WsOutbound::SessionList { sessions: summaries, current_session_id: Some(current_session_id.clone()), @@ -394,8 +235,7 @@ async fn handle_inbound( } WsInbound::LoadSession { session_id } => { let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else { - let session_guard = session.lock().await; - let _ = session_guard + let _ = sender .send(WsOutbound::Error { code: "SESSION_NOT_FOUND".to_string(), message: format!("Session not found: {}", session_id), @@ -405,9 +245,16 @@ async fn handle_inbound( }; *current_session_id = record.id.clone(); - let mut session_guard = session.lock().await; - session_guard.ensure_chat_loaded(&record.id)?; - let _ = session_guard + state + .channel_manager + .cli_channel() + .register_connection( + record.id.clone(), + runtime_session_id.to_string(), + sender.clone(), + ) + .await; + let _ = sender .send(WsOutbound::SessionLoaded { session_id: record.id, title: record.title, @@ -422,8 +269,7 @@ async fn handle_inbound( .session_manager .cli_sessions() .rename(&target, &title)?; - let session_guard = session.lock().await; - let _ = session_guard + let _ = sender .send(WsOutbound::SessionRenamed { session_id: target, title, @@ -434,8 +280,7 @@ async fn handle_inbound( WsInbound::ArchiveSession { session_id } => { let target = session_id.unwrap_or_else(|| current_session_id.clone()); state.session_manager.cli_sessions().archive(&target)?; - let session_guard = session.lock().await; - let _ = session_guard + let _ = sender .send(WsOutbound::SessionArchived { session_id: target }) .await; Ok(()) @@ -450,9 +295,11 @@ async fn handle_inbound( None }; - let mut session_guard = session.lock().await; - session_guard.remove_history(&target); - let _ = session_guard + if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await { + session.lock().await.remove_history(&target); + } + + let _ = sender .send(WsOutbound::SessionDeleted { session_id: target.clone(), }) @@ -460,8 +307,16 @@ async fn handle_inbound( if let Some(record) = replacement { *current_session_id = record.id.clone(); - session_guard.ensure_chat_loaded(&record.id)?; - let _ = session_guard + state + .channel_manager + .cli_channel() + .register_connection( + record.id.clone(), + runtime_session_id.to_string(), + sender.clone(), + ) + .await; + let _ = sender .send(WsOutbound::SessionCreated { session_id: record.id, title: record.title, @@ -472,13 +327,19 @@ async fn handle_inbound( Ok(()) } WsInbound::Ping => { - let session_guard = session.lock().await; - let _ = session_guard.send(WsOutbound::Pong).await; + let _ = sender.send(WsOutbound::Pong).await; Ok(()) } } } +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as i64 +} + fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String { sender_id .map(str::trim) @@ -489,138 +350,7 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St #[cfg(test)] mod tests { - use super::{ - WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user, - ws_outbound_from_chat_message, - }; - use crate::agent::EmittedMessageHandler; - use crate::bus::message::ToolMessageState; - use crate::bus::{ChatMessage, OutboundMessage}; - use crate::protocol::WsOutbound; - use crate::providers::ToolCall; - use serde_json::json; - use tokio::sync::mpsc; - - #[test] - fn test_ws_outbound_from_chat_message_expands_tool_calls() { - let message = ChatMessage::assistant_with_tool_calls( - "", - vec![ToolCall { - id: "call-1".to_string(), - name: "calculator".to_string(), - arguments: json!({"expression": "1 + 1"}), - }], - ); - - let outbound = ws_outbound_from_chat_message(&message); - - assert_eq!(outbound.len(), 1); - match &outbound[0] { - WsOutbound::ToolCall { - tool_call_id, - tool_name, - arguments, - content, - .. - } => { - assert_eq!(tool_call_id, "call-1"); - assert_eq!(tool_name, "calculator"); - assert_eq!(arguments["expression"], "1 + 1"); - assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}"); - } - other => panic!("unexpected outbound variant: {:?}", other), - } - } - - #[test] - fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() { - let message = ChatMessage::assistant_with_tool_calls( - "日报已整理完成。", - vec![ToolCall { - id: "call-1".to_string(), - name: "memory_manage".to_string(), - arguments: json!({"action": "put"}), - }], - ); - - let outbound = ws_outbound_from_chat_message(&message); - - assert_eq!(outbound.len(), 2); - assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. })); - assert!(matches!(outbound[1], WsOutbound::ToolCall { .. })); - } - - #[test] - fn test_ws_outbound_from_chat_message_includes_tool_results() { - let message = ChatMessage::tool("call-1", "calculator", "2"); - - let outbound = ws_outbound_from_chat_message(&message); - - assert_eq!(outbound.len(), 1); - assert!(matches!(outbound[0], WsOutbound::ToolResult { .. })); - } - - #[test] - fn test_ws_outbound_from_chat_message_includes_tool_pending() { - let message = ChatMessage::tool_with_state( - "call-1", - "bash", - "等待你完成授权后再继续。", - ToolMessageState::PendingUserAction, - ); - - let outbound = ws_outbound_from_chat_message(&message); - - assert_eq!(outbound.len(), 1); - assert!(matches!(outbound[0], WsOutbound::ToolPending { .. })); - } - - #[test] - fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { - let completed = ChatMessage::tool("call-1", "calculator", "2"); - let pending = ChatMessage::tool_with_state( - "call-2", - "bash", - "waiting", - ToolMessageState::PendingUserAction, - ); - - assert!(!should_display_message_to_user(false, &completed)); - assert!(should_display_message_to_user(false, &pending)); - assert!(should_display_message_to_user(true, &completed)); - } - - #[test] - fn test_ws_outbound_from_outbound_message_maps_tool_call() { - let message = OutboundMessage::tool_call( - "cli", - "session-1", - "call-1", - "calculator", - json!({"expression": "1 + 1"}), - None, - Default::default(), - ); - - let outbound = super::ws_outbound_from_outbound_message(&message); - - assert_eq!(outbound.len(), 1); - match &outbound[0] { - WsOutbound::ToolCall { - tool_call_id, - tool_name, - arguments, - content, - .. - } => { - assert_eq!(tool_call_id, "call-1"); - assert_eq!(tool_name, "calculator"); - assert_eq!(arguments["expression"], "1 + 1"); - assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}"); - } - other => panic!("unexpected outbound variant: {:?}", other), - } - } + use super::resolve_ws_sender_id; #[test] fn test_resolve_ws_sender_id_prefers_inbound_sender() { @@ -639,23 +369,4 @@ mod tests { assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); } - - #[tokio::test] - async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() { - let (sender, mut receiver) = mpsc::channel(4); - let emitter = WsToolCallEmitter { - sender, - show_tool_results: false, - }; - - emitter - .handle(ChatMessage::tool("call-1", "calculator", "2")) - .await; - - assert!( - tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv()) - .await - .is_err() - ); - } } diff --git a/src/gateway/ws_adapter.rs b/src/gateway/ws_adapter.rs new file mode 100644 index 0000000..c013a7b --- /dev/null +++ b/src/gateway/ws_adapter.rs @@ -0,0 +1,229 @@ +#[cfg(test)] +use crate::bus::ChatMessage; +use crate::bus::OutboundMessage; +use crate::bus::message::OutboundEventKind; +#[cfg(test)] +use crate::bus::message::{ToolMessageState, format_tool_call_content}; +use crate::protocol::WsOutbound; + +const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。"; + +#[cfg(test)] +pub(crate) fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + if let Some(tool_calls) = &message.tool_calls { + let mut outbound = Vec::new(); + if !message.content.trim().is_empty() { + outbound.push(WsOutbound::AssistantResponse { + id: message.id.clone(), + content: message.content.clone(), + role: message.role.clone(), + }); + } + + outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall { + id: message.id.clone(), + tool_call_id: tool_call.id.clone(), + tool_name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + content: format_tool_call_content(&tool_call.name, &tool_call.arguments), + role: message.role.clone(), + })); + outbound + } else { + vec![WsOutbound::AssistantResponse { + id: message.id.clone(), + content: message.content.clone(), + role: message.role.clone(), + }] + } + } + "tool" => match message + .tool_state + .as_ref() + .unwrap_or(&ToolMessageState::Completed) + { + ToolMessageState::Completed => vec![WsOutbound::ToolResult { + id: message.id.clone(), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + content: message.content.clone(), + role: message.role.clone(), + }], + ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending { + id: message.id.clone(), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + content: message.content.clone(), + role: message.role.clone(), + resume_hint: TOOL_PENDING_RESUME_HINT.to_string(), + }], + }, + _ => Vec::new(), + } +} + +pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec { + match message.event_kind { + OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => { + vec![WsOutbound::AssistantResponse { + id: uuid::Uuid::new_v4().to_string(), + content: message.content.clone(), + role: message.role.clone(), + }] + } + OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall { + id: message + .tool_call_id + .clone() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + arguments: message + .tool_arguments + .clone() + .unwrap_or(serde_json::Value::Null), + content: message.content.clone(), + role: message.role.clone(), + }], + OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult { + id: message + .tool_call_id + .clone() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + content: message.content.clone(), + role: message.role.clone(), + }], + OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending { + id: message + .tool_call_id + .clone() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + content: message.content.clone(), + role: message.role.clone(), + resume_hint: TOOL_PENDING_RESUME_HINT.to_string(), + }], + OutboundEventKind::ErrorNotification => vec![WsOutbound::Error { + code: "AGENT_ERROR".to_string(), + message: message.content.clone(), + }], + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::ToolCall; + use serde_json::json; + + #[test] + fn test_ws_outbound_from_chat_message_expands_tool_calls() { + let message = ChatMessage::assistant_with_tool_calls( + "", + vec![ToolCall { + id: "call-1".to_string(), + name: "calculator".to_string(), + arguments: json!({"expression": "1 + 1"}), + }], + ); + + let outbound = ws_outbound_from_chat_message(&message); + + assert_eq!(outbound.len(), 1); + match &outbound[0] { + WsOutbound::ToolCall { + tool_call_id, + tool_name, + arguments, + content, + .. + } => { + assert_eq!(tool_call_id, "call-1"); + assert_eq!(tool_name, "calculator"); + assert_eq!(arguments["expression"], "1 + 1"); + assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}"); + } + other => panic!("unexpected outbound variant: {:?}", other), + } + } + + #[test] + fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() { + let message = ChatMessage::assistant_with_tool_calls( + "日报已整理完成。", + vec![ToolCall { + id: "call-1".to_string(), + name: "memory_manage".to_string(), + arguments: json!({"action": "put"}), + }], + ); + + let outbound = ws_outbound_from_chat_message(&message); + + assert_eq!(outbound.len(), 2); + assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. })); + assert!(matches!(outbound[1], WsOutbound::ToolCall { .. })); + } + + #[test] + fn test_ws_outbound_from_chat_message_includes_tool_results() { + let message = ChatMessage::tool("call-1", "calculator", "2"); + + let outbound = ws_outbound_from_chat_message(&message); + + assert_eq!(outbound.len(), 1); + assert!(matches!(outbound[0], WsOutbound::ToolResult { .. })); + } + + #[test] + fn test_ws_outbound_from_chat_message_includes_tool_pending() { + let message = ChatMessage::tool_with_state( + "call-1", + "bash", + "等待你完成授权后再继续。", + ToolMessageState::PendingUserAction, + ); + + let outbound = ws_outbound_from_chat_message(&message); + + assert_eq!(outbound.len(), 1); + assert!(matches!(outbound[0], WsOutbound::ToolPending { .. })); + } + + #[test] + fn test_ws_outbound_from_outbound_message_maps_tool_call() { + let message = OutboundMessage::tool_call( + "cli", + "session-1", + "call-1", + "calculator", + json!({"expression": "1 + 1"}), + None, + Default::default(), + ); + + let outbound = ws_outbound_from_outbound_message(&message); + + assert_eq!(outbound.len(), 1); + match &outbound[0] { + WsOutbound::ToolCall { + tool_call_id, + tool_name, + arguments, + content, + .. + } => { + assert_eq!(tool_call_id, "call-1"); + assert_eq!(tool_name, "calculator"); + assert_eq!(arguments["expression"], "1 + 1"); + assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}"); + } + other => panic!("unexpected outbound variant: {:?}", other), + } + } +} diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index 1472704..2144bba 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -11,8 +11,8 @@ use crate::config::{ SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerMisfirePolicy, SchedulerSchedule, }; +use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; use crate::gateway::session::ScheduledAgentTaskOptions; -use crate::gateway::session::SessionManager; use crate::storage::{ SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore, }; @@ -22,7 +22,8 @@ pub struct Scheduler { config: SchedulerConfig, timezone: Tz, store: Arc, - session_manager: SessionManager, + agent_task_executor: AgentTaskExecutor, + maintenance_service: SchedulerMaintenanceService, } impl Scheduler { @@ -31,14 +32,16 @@ impl Scheduler { config: SchedulerConfig, timezone: Tz, store: Arc, - session_manager: SessionManager, + agent_task_executor: AgentTaskExecutor, + maintenance_service: SchedulerMaintenanceService, ) -> Self { Self { bus, config, timezone, store, - session_manager, + agent_task_executor, + maintenance_service, } } @@ -168,11 +171,11 @@ impl Scheduler { self.bus.publish_outbound(message).await?; } SchedulerJobKind::InternalEvent => { - execute_internal_event(&self.session_manager, job).await?; + execute_internal_event(&self.maintenance_service, job).await?; } SchedulerJobKind::AgentTask => { let outbound_messages = execute_agent_task( - &self.session_manager, + &self.agent_task_executor, job, required_notification_chat_id(job, "agent_task")?, ) @@ -184,7 +187,7 @@ impl Scheduler { SchedulerJobKind::SilentAgentTask => { let execution_chat_id = resolve_execution_chat_id(job)?; if let Err(error) = - execute_agent_task(&self.session_manager, job, &execution_chat_id).await + execute_agent_task(&self.agent_task_executor, job, &execution_chat_id).await { if let Err(notify_error) = self.notify_silent_agent_task_failure(job, &error).await @@ -587,7 +590,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result { } async fn execute_internal_event( - session_manager: &SessionManager, + maintenance_service: &SchedulerMaintenanceService, job: &RuntimeJob, ) -> anyhow::Result<()> { let event = job @@ -598,12 +601,12 @@ async fn execute_internal_event( match event { "session_cleanup" => { - let removed = session_manager.cleanup_expired_sessions().await; + let removed = maintenance_service.cleanup_expired_sessions().await; tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed"); Ok(()) } "memory_maintenance" => { - let results = session_manager + let results = maintenance_service .run_memory_maintenance_for_all_scopes(job.last_fired_at) .await?; for result in &results { @@ -627,7 +630,7 @@ async fn execute_internal_event( } async fn execute_agent_task( - session_manager: &SessionManager, + agent_task_executor: &AgentTaskExecutor, job: &RuntimeJob, execution_chat_id: &str, ) -> anyhow::Result> { @@ -643,8 +646,8 @@ async fn execute_agent_task( .ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?; let options = parse_scheduled_agent_task_options(job)?; - session_manager - .run_scheduled_agent_task(channel_name, execution_chat_id, prompt, options) + agent_task_executor + .execute(channel_name, execution_chat_id, prompt, options) .await .map_err(|error| anyhow::anyhow!(error.to_string())) } @@ -964,6 +967,7 @@ mod tests { use super::*; use crate::bus::MessageBus; use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig}; + use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; use crate::gateway::session::SessionManager; use crate::skills::SkillRuntime; use crate::storage::{SchedulerJobUpsert, SessionStore}; @@ -1002,6 +1006,14 @@ mod tests { .unwrap() } + fn test_scheduler_services() -> (AgentTaskExecutor, SchedulerMaintenanceService) { + let session_manager = test_session_manager(); + ( + AgentTaskExecutor::new(session_manager.clone()), + SchedulerMaintenanceService::new(session_manager), + ) + } + #[test] fn runtime_job_skip_policy_advances_from_now() { let now = Utc @@ -1129,7 +1141,7 @@ mod tests { }) .unwrap(); - let session_manager = test_session_manager(); + let (agent_task_executor, maintenance_service) = test_scheduler_services(); let scheduler = Scheduler::new( MessageBus::new(8), SchedulerConfig { @@ -1141,7 +1153,8 @@ mod tests { }, chrono_tz::Asia::Shanghai, store.clone(), - session_manager, + agent_task_executor, + maintenance_service, ); scheduler.process_tick().await.unwrap(); @@ -1159,13 +1172,14 @@ mod tests { fn sync_config_jobs_persists_builtin_memory_maintenance_job() { let store = Arc::new(SessionStore::in_memory().unwrap()); - let session_manager = test_session_manager(); + let (agent_task_executor, maintenance_service) = test_scheduler_services(); let scheduler = Scheduler::new( MessageBus::new(8), SchedulerConfig::default(), chrono_tz::Asia::Shanghai, store.clone(), - session_manager, + agent_task_executor, + maintenance_service, ); scheduler.sync_config_jobs().unwrap(); @@ -1204,6 +1218,7 @@ mod tests { async fn silent_agent_task_failure_notifies_primary_chat() { let store = Arc::new(SessionStore::in_memory().unwrap()); let bus = MessageBus::new(8); + let (agent_task_executor, maintenance_service) = test_scheduler_services(); let scheduler = Scheduler::new( bus.clone(), SchedulerConfig { @@ -1215,7 +1230,8 @@ mod tests { }, chrono_tz::Asia::Shanghai, store, - test_session_manager(), + agent_task_executor, + maintenance_service, ); let job = RuntimeJob {