feat: 重构调度器以使用 AgentTaskExecutor 和 SchedulerMaintenanceService
- 更新调度器,将 SessionManager 替换为 AgentTaskExecutor 和 SchedulerMaintenanceService。 - 修改作业执行逻辑,使用新服务处理代理任务和内部事件。 - 添加新的 CliChannel 以处理 CLI 连接,并包括适当的注册和注销逻辑。 - 引入 AgentTaskExecutor 和 SchedulerMaintenanceService,用于管理代理任务和会话维护。 - 实现聊天命令处理,用于重置会话上下文。 - 添加后台历史压缩功能,以优化会话存储。 - 创建实用函数,用于准备通过 WebSocket 通信的出站消息。 - 为新功能添加测试,并确保现有测试通过。 Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
62b38eac73
commit
008aba91ac
12
README.md
12
README.md
@ -18,7 +18,7 @@ PicoBot 是一个用 Rust 构建的多通道 Agent 网关。它把消息接入
|
||||
PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个可持续运行的 Agent 基础设施:
|
||||
|
||||
- 消息从不同渠道进入统一总线
|
||||
- SessionManager 负责会话路由、上下文恢复、工具执行和回复生成
|
||||
- SessionManager 负责会话路由和运行时服务编排,AgentExecutionService 负责上下文准备、AgentLoop 执行、结果持久化和回复生成
|
||||
- SQLite 作为事实来源保存跨重启状态
|
||||
- Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务
|
||||
|
||||
@ -37,13 +37,13 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个
|
||||
|
||||
主要模块如下:
|
||||
|
||||
- src/gateway:网关入口、HTTP 健康检查、WebSocket 服务、Session 池、CLI 会话服务与 Agent 执行编排
|
||||
- src/gateway:网关入口、HTTP 健康检查、WebSocket 控制面、Session 池、CLI 会话服务与 Agent 执行编排
|
||||
- src/bus:消息总线与消息结构定义
|
||||
- src/agent:AgentLoop 与上下文压缩器
|
||||
- src/providers:不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic
|
||||
- src/tools:内置工具集合
|
||||
- src/storage:SQLite 持久化实现
|
||||
- src/channels:渠道适配层,当前已有飞书通道
|
||||
- src/channels:渠道适配层,当前已有 CLI 与飞书通道
|
||||
- src/scheduler:数据库驱动的计划任务调度器
|
||||
- src/skills:技能发现、加载与运行时管理
|
||||
- src/client / src/cli:本地 CLI 客户端和交互命令
|
||||
@ -632,11 +632,11 @@ PicoBot/
|
||||
├── src/
|
||||
│ ├── agent/ # AgentLoop、上下文压缩
|
||||
│ ├── bus/ # 消息总线与消息结构
|
||||
│ ├── channels/ # 渠道适配
|
||||
│ ├── channels/ # CLI / 飞书等渠道适配
|
||||
│ ├── cli/ # CLI 输入命令
|
||||
│ ├── client/ # WebSocket CLI 客户端
|
||||
│ ├── config/ # 配置解析
|
||||
│ ├── gateway/ # Gateway、SessionManager、WS/HTTP
|
||||
│ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面
|
||||
│ ├── providers/ # OpenAI / Anthropic Provider
|
||||
│ ├── scheduler/ # 定时任务系统
|
||||
│ ├── skills/ # 技能运行时
|
||||
@ -656,7 +656,7 @@ PicoBot/
|
||||
建议维护时重点关注:
|
||||
|
||||
- docs/PERSISTENCE.md:持久化结构是否与代码一致
|
||||
- src/gateway/session.rs:消息流、工具注册、记忆维护、会话恢复主逻辑
|
||||
- src/gateway/session.rs:会话状态、会话路由和运行时服务编排
|
||||
- src/storage/mod.rs:SQLite schema 变更
|
||||
- src/config/mod.rs:配置项变更是否同步到 README
|
||||
|
||||
|
||||
155
src/channels/cli.rs
Normal file
155
src/channels/cli.rs
Normal file
@ -0,0 +1,155 @@
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{RwLock, mpsc};
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::gateway::ws_adapter::ws_outbound_from_outbound_message;
|
||||
use crate::protocol::WsOutbound;
|
||||
|
||||
use super::base::{Channel, ChannelError};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CliConnection {
|
||||
connection_id: String,
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CliChannel {
|
||||
connections: Arc<RwLock<HashMap<String, CliConnection>>>,
|
||||
}
|
||||
|
||||
impl CliChannel {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn register_connection(
|
||||
&self,
|
||||
session_id: impl Into<String>,
|
||||
connection_id: impl Into<String>,
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
) {
|
||||
let session_id = session_id.into();
|
||||
let connection_id = connection_id.into();
|
||||
let previous = self.connections.write().await.insert(
|
||||
session_id.clone(),
|
||||
CliConnection {
|
||||
connection_id: connection_id.clone(),
|
||||
sender,
|
||||
},
|
||||
);
|
||||
|
||||
if previous.is_some() {
|
||||
tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unregister_connection(&self, connection_id: &str) {
|
||||
self.connections
|
||||
.write()
|
||||
.await
|
||||
.retain(|_, connection| connection.connection_id != connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CliChannel {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for CliChannel {
|
||||
fn name(&self) -> &str {
|
||||
"cli"
|
||||
}
|
||||
|
||||
fn is_running(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn start(&self, _bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), ChannelError> {
|
||||
self.connections.write().await.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let connection = self.connections.read().await.get(&msg.chat_id).cloned();
|
||||
let Some(connection) = connection else {
|
||||
return Err(ChannelError::SendError(format!(
|
||||
"No active CLI connection for session {}",
|
||||
msg.chat_id
|
||||
)));
|
||||
};
|
||||
|
||||
for outbound in ws_outbound_from_outbound_message(&msg) {
|
||||
connection
|
||||
.sender
|
||||
.send(outbound)
|
||||
.await
|
||||
.map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::OutboundMessage;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cli_channel_sends_to_registered_session() {
|
||||
let channel = CliChannel::new();
|
||||
let (sender, mut receiver) = mpsc::channel(4);
|
||||
channel
|
||||
.register_connection("session-1", "conn-1", sender)
|
||||
.await;
|
||||
|
||||
channel
|
||||
.send(OutboundMessage::assistant(
|
||||
"cli",
|
||||
"session-1",
|
||||
"hello",
|
||||
None,
|
||||
HashMap::new(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let outbound = receiver.recv().await.unwrap();
|
||||
assert!(matches!(outbound, WsOutbound::AssistantResponse { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cli_channel_unregisters_connection_sessions() {
|
||||
let channel = CliChannel::new();
|
||||
let (sender, _receiver) = mpsc::channel(4);
|
||||
channel
|
||||
.register_connection("session-1", "conn-1", sender)
|
||||
.await;
|
||||
channel.unregister_connection("conn-1").await;
|
||||
|
||||
let error = channel
|
||||
.send(OutboundMessage::assistant(
|
||||
"cli",
|
||||
"session-1",
|
||||
"hello",
|
||||
None,
|
||||
HashMap::new(),
|
||||
))
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(error.to_string().contains("No active CLI connection"));
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ use tokio::sync::RwLock;
|
||||
|
||||
use crate::bus::MessageBus;
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::cli::CliChannel;
|
||||
use crate::channels::feishu::FeishuChannel;
|
||||
use crate::config::Config;
|
||||
|
||||
@ -12,13 +13,19 @@ use crate::config::Config;
|
||||
pub struct ChannelManager {
|
||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
||||
bus: Arc<MessageBus>,
|
||||
cli_channel: Arc<CliChannel>,
|
||||
}
|
||||
|
||||
impl ChannelManager {
|
||||
pub fn new() -> Self {
|
||||
let cli_channel = Arc::new(CliChannel::new());
|
||||
let mut channels: HashMap<String, Arc<dyn Channel + Send + Sync>> = HashMap::new();
|
||||
channels.insert("cli".to_string(), cli_channel.clone());
|
||||
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
channels: Arc::new(RwLock::new(channels)),
|
||||
bus: MessageBus::new(100),
|
||||
cli_channel,
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,6 +34,10 @@ impl ChannelManager {
|
||||
self.bus.clone()
|
||||
}
|
||||
|
||||
pub fn cli_channel(&self) -> Arc<CliChannel> {
|
||||
self.cli_channel.clone()
|
||||
}
|
||||
|
||||
/// Initialize all Channel instances from config
|
||||
pub async fn init(
|
||||
&self,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
pub mod base;
|
||||
pub mod cli;
|
||||
pub mod feishu;
|
||||
pub mod manager;
|
||||
|
||||
pub use base::{Channel, ChannelError};
|
||||
pub use cli::CliChannel;
|
||||
pub use feishu::FeishuChannel;
|
||||
pub use manager::ChannelManager;
|
||||
|
||||
52
src/gateway/agent_task_executor.rs
Normal file
52
src/gateway/agent_task_executor.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::OutboundMessage;
|
||||
|
||||
use super::memory_maintenance::MemoryMaintenanceScopeResult;
|
||||
use super::session::{ScheduledAgentTaskOptions, SessionManager};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AgentTaskExecutor {
|
||||
session_manager: SessionManager,
|
||||
}
|
||||
|
||||
impl AgentTaskExecutor {
|
||||
pub fn new(session_manager: SessionManager) -> Self {
|
||||
Self { session_manager }
|
||||
}
|
||||
|
||||
pub(crate) async fn execute(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
chat_id: &str,
|
||||
prompt: &str,
|
||||
options: ScheduledAgentTaskOptions,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
self.session_manager
|
||||
.run_scheduled_agent_task(channel_name, chat_id, prompt, options)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SchedulerMaintenanceService {
|
||||
session_manager: SessionManager,
|
||||
}
|
||||
|
||||
impl SchedulerMaintenanceService {
|
||||
pub fn new(session_manager: SessionManager) -> Self {
|
||||
Self { session_manager }
|
||||
}
|
||||
|
||||
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
|
||||
self.session_manager.cleanup_expired_sessions().await
|
||||
}
|
||||
|
||||
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||||
&self,
|
||||
updated_since: Option<i64>,
|
||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||
self.session_manager
|
||||
.run_memory_maintenance_for_all_scopes(updated_since)
|
||||
.await
|
||||
}
|
||||
}
|
||||
159
src/gateway/command.rs
Normal file
159
src/gateway/command.rs
Normal file
@ -0,0 +1,159 @@
|
||||
use crate::agent::AgentError;
|
||||
|
||||
use super::session::Session;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum InChatCommand {
|
||||
FreshConversation,
|
||||
}
|
||||
|
||||
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
||||
match content.trim() {
|
||||
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_in_chat_command(
|
||||
session: &mut Session,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<Option<String>, AgentError> {
|
||||
match parse_in_chat_command(content) {
|
||||
Some(InChatCommand::FreshConversation) => {
|
||||
session.reset_chat_context(chat_id)?;
|
||||
Ok(Some("Started a fresh conversation.".to_string()))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
fn test_provider_config() -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_in_chat_command_aliases() {
|
||||
assert_eq!(
|
||||
parse_in_chat_command("/new"),
|
||||
Some(InChatCommand::FreshConversation)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_in_chat_command(" /reset \n"),
|
||||
Some(InChatCommand::FreshConversation)
|
||||
);
|
||||
assert_eq!(parse_in_chat_command("/new planning"), None);
|
||||
assert_eq!(parse_in_chat_command("please /reset"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_in_chat_command_resets_active_history_only() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(ToolRegistry::new());
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response, "Started a fresh conversation.");
|
||||
assert!(session.get_history("chat-1").unwrap().is_empty());
|
||||
assert!(
|
||||
store
|
||||
.load_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.is_empty()
|
||||
);
|
||||
assert_eq!(
|
||||
store
|
||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.len(),
|
||||
2,
|
||||
);
|
||||
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(ToolRegistry::new());
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store,
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
||||
session
|
||||
.ensure_agent_prompt_before_user_message("chat-1")
|
||||
.unwrap();
|
||||
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
}
|
||||
}
|
||||
105
src/gateway/compaction.rs
Normal file
105
src/gateway/compaction.rs
Normal file
@ -0,0 +1,105 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
|
||||
use super::session::Session;
|
||||
|
||||
pub(crate) async fn schedule_background_history_compaction(
|
||||
session: Arc<Mutex<Session>>,
|
||||
chat_id: impl Into<String>,
|
||||
) -> Result<(), AgentError> {
|
||||
let chat_id = chat_id.into();
|
||||
|
||||
let snapshot = {
|
||||
let mut session_guard = session.lock().await;
|
||||
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
|
||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
let compressor = session_guard.compressor().clone();
|
||||
if !compressor.should_compress(&history) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !session_guard.try_start_background_compaction(&chat_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
(
|
||||
session_guard.store(),
|
||||
session_guard.persistent_session_id(&chat_id),
|
||||
session_record.reset_cutoff_seq,
|
||||
session_record.message_count,
|
||||
history,
|
||||
compressor,
|
||||
session_guard.provider_config().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let (
|
||||
store,
|
||||
session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
history,
|
||||
compressor,
|
||||
provider_config,
|
||||
) = snapshot;
|
||||
let session_for_task = session.clone();
|
||||
let chat_id_for_task = chat_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
|
||||
|
||||
let compaction_result = compressor
|
||||
.build_compaction_plan(&history, &provider_config)
|
||||
.await;
|
||||
let mut committed = false;
|
||||
|
||||
match compaction_result {
|
||||
Ok(Some(plan)) => match store.compact_active_history(
|
||||
&session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
&plan.preserved_system_messages,
|
||||
&plan.summary_message,
|
||||
&plan.preserved_messages,
|
||||
) {
|
||||
Ok(true) => {
|
||||
committed = true;
|
||||
tracing::info!(
|
||||
chat_id = %chat_id_for_task,
|
||||
snapshot_end_seq,
|
||||
compressed_turns = plan.compressed_turns,
|
||||
preserved_turns = plan.preserved_turns,
|
||||
"Background history compaction committed"
|
||||
);
|
||||
}
|
||||
Ok(false) => {
|
||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
|
||||
}
|
||||
}
|
||||
|
||||
let mut session_guard = session_for_task.lock().await;
|
||||
if committed {
|
||||
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
|
||||
}
|
||||
}
|
||||
session_guard.finish_background_compaction(&chat_id_for_task);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -7,10 +7,10 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL
|
||||
use crate::config::LLMProviderConfig;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::session::{
|
||||
Session, enrich_user_content_with_media_refs, handle_in_chat_command,
|
||||
schedule_background_history_compaction,
|
||||
};
|
||||
use super::command::handle_in_chat_command;
|
||||
use super::compaction::schedule_background_history_compaction;
|
||||
use super::message_prepare::enrich_user_content_with_media_refs;
|
||||
use super::session::Session;
|
||||
|
||||
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||
|
||||
|
||||
39
src/gateway/message_prepare.rs
Normal file
39
src/gateway/message_prepare.rs
Normal file
@ -0,0 +1,39 @@
|
||||
use crate::agent::AgentError;
|
||||
|
||||
pub(crate) fn enrich_user_content_with_media_refs(
|
||||
content: &str,
|
||||
media_refs: &[String],
|
||||
) -> Result<String, AgentError> {
|
||||
if media_refs.is_empty() {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
|
||||
let media_refs_json = serde_json::to_string(media_refs)
|
||||
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
|
||||
|
||||
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
enriched,
|
||||
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
|
||||
|
||||
assert_eq!(enriched, "hello");
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,18 @@
|
||||
pub mod agent_task_executor;
|
||||
pub mod cli_session;
|
||||
pub mod command;
|
||||
pub mod compaction;
|
||||
pub mod execution;
|
||||
pub mod http;
|
||||
pub mod memory_maintenance;
|
||||
pub mod message_prepare;
|
||||
pub mod processor;
|
||||
pub mod prompt;
|
||||
pub mod session;
|
||||
pub mod session_factory;
|
||||
pub mod session_pool;
|
||||
pub mod ws;
|
||||
pub mod ws_adapter;
|
||||
|
||||
use axum::{Router, routing};
|
||||
use std::collections::HashMap;
|
||||
@ -21,6 +26,7 @@ use crate::config::LLMProviderConfig;
|
||||
use crate::logging;
|
||||
use crate::scheduler::Scheduler;
|
||||
use crate::skills::SkillRuntime;
|
||||
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
||||
use processor::InboundProcessor;
|
||||
use session::SessionManager;
|
||||
|
||||
@ -122,7 +128,8 @@ pub async fn run(
|
||||
state.config.scheduler.clone(),
|
||||
timezone,
|
||||
state.session_manager.store(),
|
||||
state.session_manager.clone(),
|
||||
AgentTaskExecutor::new(state.session_manager.clone()),
|
||||
SchedulerMaintenanceService::new(state.session_manager.clone()),
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::bus::MessageBus;
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
|
||||
use super::session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
@ -70,6 +70,21 @@ impl InboundProcessor {
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!(error = %error, "Failed to handle message");
|
||||
let mut metadata = inbound.forwarded_metadata.clone();
|
||||
metadata.insert("error_kind".to_string(), "agent_execution".to_string());
|
||||
if let Err(publish_error) = self
|
||||
.bus
|
||||
.publish_outbound(OutboundMessage::error_notification(
|
||||
inbound.channel,
|
||||
inbound.chat_id,
|
||||
error.to_string(),
|
||||
None,
|
||||
metadata,
|
||||
))
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %publish_error, "Failed to publish execution error outbound");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,20 +43,6 @@ fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
|
||||
pub(crate) fn enrich_user_content_with_media_refs(
|
||||
content: &str,
|
||||
media_refs: &[String],
|
||||
) -> Result<String, AgentError> {
|
||||
if media_refs.is_empty() {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
|
||||
let media_refs_json = serde_json::to_string(media_refs)
|
||||
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
|
||||
|
||||
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
||||
}
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||
pub struct Session {
|
||||
@ -393,15 +379,15 @@ impl Session {
|
||||
&self.compressor
|
||||
}
|
||||
|
||||
fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||||
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||||
self.compression_in_flight.insert(chat_id.to_string())
|
||||
}
|
||||
|
||||
fn finish_background_compaction(&mut self, chat_id: &str) {
|
||||
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
}
|
||||
|
||||
fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history = self
|
||||
.store
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
@ -410,6 +396,10 @@ impl Session {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self) -> Arc<SessionStore> {
|
||||
self.store.clone()
|
||||
}
|
||||
|
||||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if self.skills.is_empty() {
|
||||
return Ok(());
|
||||
@ -528,129 +518,6 @@ fn default_tools(
|
||||
registry
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum InChatCommand {
|
||||
FreshConversation,
|
||||
}
|
||||
|
||||
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
||||
match content.trim() {
|
||||
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_in_chat_command(
|
||||
session: &mut Session,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<Option<String>, AgentError> {
|
||||
match parse_in_chat_command(content) {
|
||||
Some(InChatCommand::FreshConversation) => {
|
||||
session.reset_chat_context(chat_id)?;
|
||||
Ok(Some("Started a fresh conversation.".to_string()))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn schedule_background_history_compaction(
|
||||
session: Arc<Mutex<Session>>,
|
||||
chat_id: impl Into<String>,
|
||||
) -> Result<(), AgentError> {
|
||||
let chat_id = chat_id.into();
|
||||
|
||||
let snapshot = {
|
||||
let mut session_guard = session.lock().await;
|
||||
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
|
||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
let compressor = session_guard.compressor().clone();
|
||||
if !compressor.should_compress(&history) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !session_guard.try_start_background_compaction(&chat_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
(
|
||||
session_guard.store.clone(),
|
||||
session_guard.persistent_session_id(&chat_id),
|
||||
session_record.reset_cutoff_seq,
|
||||
session_record.message_count,
|
||||
history,
|
||||
compressor,
|
||||
session_guard.provider_config().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let (
|
||||
store,
|
||||
session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
history,
|
||||
compressor,
|
||||
provider_config,
|
||||
) = snapshot;
|
||||
let session_for_task = session.clone();
|
||||
let chat_id_for_task = chat_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
|
||||
|
||||
let compaction_result = compressor
|
||||
.build_compaction_plan(&history, &provider_config)
|
||||
.await;
|
||||
let mut committed = false;
|
||||
|
||||
match compaction_result {
|
||||
Ok(Some(plan)) => match store.compact_active_history(
|
||||
&session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
&plan.preserved_system_messages,
|
||||
&plan.summary_message,
|
||||
&plan.preserved_messages,
|
||||
) {
|
||||
Ok(true) => {
|
||||
committed = true;
|
||||
tracing::info!(
|
||||
chat_id = %chat_id_for_task,
|
||||
snapshot_end_seq,
|
||||
compressed_turns = plan.compressed_turns,
|
||||
preserved_turns = plan.preserved_turns,
|
||||
"Background history compaction committed"
|
||||
);
|
||||
}
|
||||
Ok(false) => {
|
||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
|
||||
}
|
||||
}
|
||||
|
||||
let mut session_guard = session_for_task.lock().await;
|
||||
if committed {
|
||||
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
|
||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
|
||||
}
|
||||
}
|
||||
session_guard.finish_background_compaction(&chat_id_for_task);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
impl SessionManager {
|
||||
pub fn new(
|
||||
session_ttl_hours: u64,
|
||||
@ -913,25 +780,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
enriched,
|
||||
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
|
||||
|
||||
assert_eq!(enriched, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
@ -1750,75 +1598,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_in_chat_command_aliases() {
|
||||
assert_eq!(
|
||||
parse_in_chat_command("/new"),
|
||||
Some(InChatCommand::FreshConversation)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_in_chat_command(" /reset \n"),
|
||||
Some(InChatCommand::FreshConversation)
|
||||
);
|
||||
assert_eq!(parse_in_chat_command("/new planning"), None);
|
||||
assert_eq!(parse_in_chat_command("please /reset"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_in_chat_command_resets_active_history_only() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response, "Started a fresh conversation.");
|
||||
assert!(session.get_history("chat-1").unwrap().is_empty());
|
||||
assert!(
|
||||
store
|
||||
.load_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.is_empty()
|
||||
);
|
||||
assert_eq!(
|
||||
store
|
||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.len(),
|
||||
2,
|
||||
);
|
||||
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
@ -1955,45 +1734,6 @@ mod tests {
|
||||
assert_eq!(system_messages, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(
|
||||
skills.clone(),
|
||||
store.clone(),
|
||||
HashSet::new(),
|
||||
"Asia/Shanghai".to_string(),
|
||||
));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
||||
session
|
||||
.ensure_agent_prompt_before_user_message("chat-1")
|
||||
.unwrap();
|
||||
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_tools_registers_get_time() {
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
|
||||
@ -1,40 +1,17 @@
|
||||
use super::{
|
||||
GatewayState,
|
||||
execution::{AgentExecutionService, MessageExecutionRequest, should_display_message_to_user},
|
||||
session::Session,
|
||||
};
|
||||
use crate::agent::{AgentError, EmittedMessageHandler};
|
||||
use crate::bus::message::{OutboundEventKind, ToolMessageState, format_tool_call_content};
|
||||
use crate::bus::{ChatMessage, OutboundMessage};
|
||||
use super::GatewayState;
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::InboundMessage;
|
||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||
use async_trait::async_trait;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const CLI_CHANNEL_NAME: &str = "cli";
|
||||
|
||||
struct WsToolCallEmitter {
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
show_tool_results: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmittedMessageHandler for WsToolCallEmitter {
|
||||
async fn handle(&self, message: ChatMessage) {
|
||||
if !should_display_message_to_user(self.show_tool_results, &message) {
|
||||
return;
|
||||
}
|
||||
|
||||
for outbound in ws_outbound_from_chat_message(&message) {
|
||||
let _ = self.sender.send(outbound).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||
ws.on_upgrade(|socket| async {
|
||||
handle_socket(socket, state).await;
|
||||
@ -44,14 +21,6 @@ pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewaySta
|
||||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
||||
|
||||
let provider_config = match state.config.get_provider_config("default") {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to get provider config");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let cli_sessions = state.session_manager.cli_sessions();
|
||||
let initial_record = match cli_sessions.create(None) {
|
||||
Ok(record) => record,
|
||||
@ -61,39 +30,20 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
}
|
||||
};
|
||||
|
||||
let channel_name = CLI_CHANNEL_NAME.to_string();
|
||||
|
||||
// 创建 CLI session
|
||||
let session = match Session::new(
|
||||
channel_name.clone(),
|
||||
provider_config,
|
||||
sender,
|
||||
state.session_manager.tools(),
|
||||
state.session_manager.skills(),
|
||||
state.session_manager.store(),
|
||||
state.config.gateway.agent_prompt_reinject_every,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(s) => Arc::new(Mutex::new(s)),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to create session");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
|
||||
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
|
||||
return;
|
||||
}
|
||||
|
||||
let runtime_session_id = session.lock().await.id.to_string();
|
||||
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut current_session_id = initial_record.id.clone();
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
current_session_id.clone(),
|
||||
runtime_session_id.clone(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
||||
|
||||
let _ = session
|
||||
.lock()
|
||||
.await
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionEstablished {
|
||||
session_id: current_session_id.clone(),
|
||||
})
|
||||
@ -123,7 +73,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
Ok(inbound) => {
|
||||
if let Err(e) = handle_inbound(
|
||||
&state,
|
||||
&session,
|
||||
&sender,
|
||||
&runtime_session_id,
|
||||
&mut current_session_id,
|
||||
inbound,
|
||||
@ -131,9 +81,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||||
let _ = session
|
||||
.lock()
|
||||
.await
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "SESSION_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
@ -143,9 +91,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||
let _ = session
|
||||
.lock()
|
||||
.await
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "PARSE_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
@ -163,6 +109,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
}
|
||||
}
|
||||
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.unregister_connection(&runtime_session_id)
|
||||
.await;
|
||||
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||||
}
|
||||
|
||||
@ -178,115 +129,9 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
match message.role.as_str() {
|
||||
"assistant" => {
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
let mut outbound = Vec::new();
|
||||
if !message.content.trim().is_empty() {
|
||||
outbound.push(WsOutbound::AssistantResponse {
|
||||
id: message.id.clone(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||
role: message.role.clone(),
|
||||
}));
|
||||
outbound
|
||||
} else {
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
id: message.id.clone(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}]
|
||||
}
|
||||
}
|
||||
"tool" => match message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed)
|
||||
{
|
||||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
||||
}],
|
||||
},
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
|
||||
match message.event_kind {
|
||||
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}]
|
||||
}
|
||||
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
|
||||
id: message
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
arguments: message
|
||||
.tool_arguments
|
||||
.clone()
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
|
||||
id: message
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
|
||||
id: message
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
||||
}],
|
||||
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
|
||||
code: "AGENT_ERROR".to_string(),
|
||||
message: message.content.clone(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_inbound(
|
||||
state: &Arc<GatewayState>,
|
||||
session: &Arc<Mutex<Session>>,
|
||||
sender: &mpsc::Sender<WsOutbound>,
|
||||
runtime_session_id: &str,
|
||||
current_session_id: &mut String,
|
||||
inbound: WsInbound,
|
||||
@ -300,43 +145,31 @@ async fn handle_inbound(
|
||||
} => {
|
||||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||||
let user_tx = session.lock().await.user_tx.clone();
|
||||
let live_emitter = Arc::new(WsToolCallEmitter {
|
||||
sender: user_tx.clone(),
|
||||
show_tool_results: state.config.gateway.show_tool_results,
|
||||
});
|
||||
|
||||
match AgentExecutionService::new(state.config.gateway.show_tool_results)
|
||||
.prepare_and_execute_message(MessageExecutionRequest {
|
||||
session: session.clone(),
|
||||
channel_name: CLI_CHANNEL_NAME,
|
||||
sender_id: &sender_id,
|
||||
chat_id: &chat_id,
|
||||
content: &content,
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
chat_id.clone(),
|
||||
runtime_session_id.to_string(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
state
|
||||
.bus
|
||||
.publish_inbound(InboundMessage {
|
||||
channel: CLI_CHANNEL_NAME.to_string(),
|
||||
sender_id,
|
||||
chat_id,
|
||||
content,
|
||||
timestamp: current_timestamp(),
|
||||
media: Vec::new(),
|
||||
live_emitter: Some(live_emitter),
|
||||
metadata: HashMap::new(),
|
||||
forwarded_metadata: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(outbound_messages) => {
|
||||
for outbound in outbound_messages
|
||||
.iter()
|
||||
.flat_map(ws_outbound_from_outbound_message)
|
||||
{
|
||||
let _ = user_tx.send(outbound).await;
|
||||
}
|
||||
}
|
||||
Err(AgentError::LlmError(error)) => {
|
||||
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
||||
let _ = user_tx
|
||||
.send(WsOutbound::Error {
|
||||
code: "LLM_ERROR".to_string(),
|
||||
message: error,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
.map_err(|error| AgentError::Other(error.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -352,9 +185,11 @@ async fn handle_inbound(
|
||||
.cli_sessions()
|
||||
.clear_messages(&target)?;
|
||||
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.remove_history(&target);
|
||||
let _ = session_guard
|
||||
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
||||
session.lock().await.remove_history(&target);
|
||||
}
|
||||
|
||||
let _ = sender
|
||||
.send(WsOutbound::HistoryCleared { session_id: target })
|
||||
.await;
|
||||
Ok(())
|
||||
@ -366,9 +201,16 @@ async fn handle_inbound(
|
||||
.create(title.as_deref())?;
|
||||
*current_session_id = record.id.clone();
|
||||
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.ensure_chat_loaded(&record.id)?;
|
||||
let _ = session_guard
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
record.id.clone(),
|
||||
runtime_session_id.to_string(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionCreated {
|
||||
session_id: record.id,
|
||||
title: record.title,
|
||||
@ -383,8 +225,7 @@ async fn handle_inbound(
|
||||
.list(include_archived)?;
|
||||
let summaries = records.into_iter().map(to_session_summary).collect();
|
||||
|
||||
let session_guard = session.lock().await;
|
||||
let _ = session_guard
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionList {
|
||||
sessions: summaries,
|
||||
current_session_id: Some(current_session_id.clone()),
|
||||
@ -394,8 +235,7 @@ async fn handle_inbound(
|
||||
}
|
||||
WsInbound::LoadSession { session_id } => {
|
||||
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
|
||||
let session_guard = session.lock().await;
|
||||
let _ = session_guard
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "SESSION_NOT_FOUND".to_string(),
|
||||
message: format!("Session not found: {}", session_id),
|
||||
@ -405,9 +245,16 @@ async fn handle_inbound(
|
||||
};
|
||||
|
||||
*current_session_id = record.id.clone();
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.ensure_chat_loaded(&record.id)?;
|
||||
let _ = session_guard
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
record.id.clone(),
|
||||
runtime_session_id.to_string(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionLoaded {
|
||||
session_id: record.id,
|
||||
title: record.title,
|
||||
@ -422,8 +269,7 @@ async fn handle_inbound(
|
||||
.session_manager
|
||||
.cli_sessions()
|
||||
.rename(&target, &title)?;
|
||||
let session_guard = session.lock().await;
|
||||
let _ = session_guard
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionRenamed {
|
||||
session_id: target,
|
||||
title,
|
||||
@ -434,8 +280,7 @@ async fn handle_inbound(
|
||||
WsInbound::ArchiveSession { session_id } => {
|
||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||
state.session_manager.cli_sessions().archive(&target)?;
|
||||
let session_guard = session.lock().await;
|
||||
let _ = session_guard
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionArchived { session_id: target })
|
||||
.await;
|
||||
Ok(())
|
||||
@ -450,9 +295,11 @@ async fn handle_inbound(
|
||||
None
|
||||
};
|
||||
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.remove_history(&target);
|
||||
let _ = session_guard
|
||||
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
||||
session.lock().await.remove_history(&target);
|
||||
}
|
||||
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionDeleted {
|
||||
session_id: target.clone(),
|
||||
})
|
||||
@ -460,8 +307,16 @@ async fn handle_inbound(
|
||||
|
||||
if let Some(record) = replacement {
|
||||
*current_session_id = record.id.clone();
|
||||
session_guard.ensure_chat_loaded(&record.id)?;
|
||||
let _ = session_guard
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
record.id.clone(),
|
||||
runtime_session_id.to_string(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionCreated {
|
||||
session_id: record.id,
|
||||
title: record.title,
|
||||
@ -472,13 +327,19 @@ async fn handle_inbound(
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::Ping => {
|
||||
let session_guard = session.lock().await;
|
||||
let _ = session_guard.send(WsOutbound::Pong).await;
|
||||
let _ = sender.send(WsOutbound::Pong).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as i64
|
||||
}
|
||||
|
||||
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
||||
sender_id
|
||||
.map(str::trim)
|
||||
@ -489,138 +350,7 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
|
||||
ws_outbound_from_chat_message,
|
||||
};
|
||||
use crate::agent::EmittedMessageHandler;
|
||||
use crate::bus::message::ToolMessageState;
|
||||
use crate::bus::{ChatMessage, OutboundMessage};
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::providers::ToolCall;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
|
||||
let message = ChatMessage::assistant_with_tool_calls(
|
||||
"",
|
||||
vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "calculator".to_string(),
|
||||
arguments: json!({"expression": "1 + 1"}),
|
||||
}],
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
match &outbound[0] {
|
||||
WsOutbound::ToolCall {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||
}
|
||||
other => panic!("unexpected outbound variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
|
||||
let message = ChatMessage::assistant_with_tool_calls(
|
||||
"日报已整理完成。",
|
||||
vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "memory_manage".to_string(),
|
||||
arguments: json!({"action": "put"}),
|
||||
}],
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 2);
|
||||
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
|
||||
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_chat_message_includes_tool_results() {
|
||||
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
|
||||
let message = ChatMessage::tool_with_state(
|
||||
"call-1",
|
||||
"bash",
|
||||
"等待你完成授权后再继续。",
|
||||
ToolMessageState::PendingUserAction,
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||||
let pending = ChatMessage::tool_with_state(
|
||||
"call-2",
|
||||
"bash",
|
||||
"waiting",
|
||||
ToolMessageState::PendingUserAction,
|
||||
);
|
||||
|
||||
assert!(!should_display_message_to_user(false, &completed));
|
||||
assert!(should_display_message_to_user(false, &pending));
|
||||
assert!(should_display_message_to_user(true, &completed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
|
||||
let message = OutboundMessage::tool_call(
|
||||
"cli",
|
||||
"session-1",
|
||||
"call-1",
|
||||
"calculator",
|
||||
json!({"expression": "1 + 1"}),
|
||||
None,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let outbound = super::ws_outbound_from_outbound_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
match &outbound[0] {
|
||||
WsOutbound::ToolCall {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||
}
|
||||
other => panic!("unexpected outbound variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
use super::resolve_ws_sender_id;
|
||||
|
||||
#[test]
|
||||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||||
@ -639,23 +369,4 @@ mod tests {
|
||||
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
||||
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() {
|
||||
let (sender, mut receiver) = mpsc::channel(4);
|
||||
let emitter = WsToolCallEmitter {
|
||||
sender,
|
||||
show_tool_results: false,
|
||||
};
|
||||
|
||||
emitter
|
||||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
229
src/gateway/ws_adapter.rs
Normal file
229
src/gateway/ws_adapter.rs
Normal file
@ -0,0 +1,229 @@
|
||||
#[cfg(test)]
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::OutboundMessage;
|
||||
use crate::bus::message::OutboundEventKind;
|
||||
#[cfg(test)]
|
||||
use crate::bus::message::{ToolMessageState, format_tool_call_content};
|
||||
use crate::protocol::WsOutbound;
|
||||
|
||||
const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。";
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
match message.role.as_str() {
|
||||
"assistant" => {
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
let mut outbound = Vec::new();
|
||||
if !message.content.trim().is_empty() {
|
||||
outbound.push(WsOutbound::AssistantResponse {
|
||||
id: message.id.clone(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||
role: message.role.clone(),
|
||||
}));
|
||||
outbound
|
||||
} else {
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
id: message.id.clone(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}]
|
||||
}
|
||||
}
|
||||
"tool" => match message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed)
|
||||
{
|
||||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
|
||||
}],
|
||||
},
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
|
||||
match message.event_kind {
|
||||
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}]
|
||||
}
|
||||
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
|
||||
id: message
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
arguments: message
|
||||
.tool_arguments
|
||||
.clone()
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
|
||||
id: message
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
|
||||
id: message
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
|
||||
}],
|
||||
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
|
||||
code: "AGENT_ERROR".to_string(),
|
||||
message: message.content.clone(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::providers::ToolCall;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
|
||||
let message = ChatMessage::assistant_with_tool_calls(
|
||||
"",
|
||||
vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "calculator".to_string(),
|
||||
arguments: json!({"expression": "1 + 1"}),
|
||||
}],
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
match &outbound[0] {
|
||||
WsOutbound::ToolCall {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||
}
|
||||
other => panic!("unexpected outbound variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
|
||||
let message = ChatMessage::assistant_with_tool_calls(
|
||||
"日报已整理完成。",
|
||||
vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "memory_manage".to_string(),
|
||||
arguments: json!({"action": "put"}),
|
||||
}],
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 2);
|
||||
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
|
||||
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_chat_message_includes_tool_results() {
|
||||
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
|
||||
let message = ChatMessage::tool_with_state(
|
||||
"call-1",
|
||||
"bash",
|
||||
"等待你完成授权后再继续。",
|
||||
ToolMessageState::PendingUserAction,
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_chat_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
|
||||
let message = OutboundMessage::tool_call(
|
||||
"cli",
|
||||
"session-1",
|
||||
"call-1",
|
||||
"calculator",
|
||||
json!({"expression": "1 + 1"}),
|
||||
None,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let outbound = ws_outbound_from_outbound_message(&message);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
match &outbound[0] {
|
||||
WsOutbound::ToolCall {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||
}
|
||||
other => panic!("unexpected outbound variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -11,8 +11,8 @@ use crate::config::{
|
||||
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
|
||||
SchedulerMisfirePolicy, SchedulerSchedule,
|
||||
};
|
||||
use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
||||
use crate::gateway::session::ScheduledAgentTaskOptions;
|
||||
use crate::gateway::session::SessionManager;
|
||||
use crate::storage::{
|
||||
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
|
||||
};
|
||||
@ -22,7 +22,8 @@ pub struct Scheduler {
|
||||
config: SchedulerConfig,
|
||||
timezone: Tz,
|
||||
store: Arc<SessionStore>,
|
||||
session_manager: SessionManager,
|
||||
agent_task_executor: AgentTaskExecutor,
|
||||
maintenance_service: SchedulerMaintenanceService,
|
||||
}
|
||||
|
||||
impl Scheduler {
|
||||
@ -31,14 +32,16 @@ impl Scheduler {
|
||||
config: SchedulerConfig,
|
||||
timezone: Tz,
|
||||
store: Arc<SessionStore>,
|
||||
session_manager: SessionManager,
|
||||
agent_task_executor: AgentTaskExecutor,
|
||||
maintenance_service: SchedulerMaintenanceService,
|
||||
) -> Self {
|
||||
Self {
|
||||
bus,
|
||||
config,
|
||||
timezone,
|
||||
store,
|
||||
session_manager,
|
||||
agent_task_executor,
|
||||
maintenance_service,
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,11 +171,11 @@ impl Scheduler {
|
||||
self.bus.publish_outbound(message).await?;
|
||||
}
|
||||
SchedulerJobKind::InternalEvent => {
|
||||
execute_internal_event(&self.session_manager, job).await?;
|
||||
execute_internal_event(&self.maintenance_service, job).await?;
|
||||
}
|
||||
SchedulerJobKind::AgentTask => {
|
||||
let outbound_messages = execute_agent_task(
|
||||
&self.session_manager,
|
||||
&self.agent_task_executor,
|
||||
job,
|
||||
required_notification_chat_id(job, "agent_task")?,
|
||||
)
|
||||
@ -184,7 +187,7 @@ impl Scheduler {
|
||||
SchedulerJobKind::SilentAgentTask => {
|
||||
let execution_chat_id = resolve_execution_chat_id(job)?;
|
||||
if let Err(error) =
|
||||
execute_agent_task(&self.session_manager, job, &execution_chat_id).await
|
||||
execute_agent_task(&self.agent_task_executor, job, &execution_chat_id).await
|
||||
{
|
||||
if let Err(notify_error) =
|
||||
self.notify_silent_agent_task_failure(job, &error).await
|
||||
@ -587,7 +590,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
|
||||
}
|
||||
|
||||
async fn execute_internal_event(
|
||||
session_manager: &SessionManager,
|
||||
maintenance_service: &SchedulerMaintenanceService,
|
||||
job: &RuntimeJob,
|
||||
) -> anyhow::Result<()> {
|
||||
let event = job
|
||||
@ -598,12 +601,12 @@ async fn execute_internal_event(
|
||||
|
||||
match event {
|
||||
"session_cleanup" => {
|
||||
let removed = session_manager.cleanup_expired_sessions().await;
|
||||
let removed = maintenance_service.cleanup_expired_sessions().await;
|
||||
tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed");
|
||||
Ok(())
|
||||
}
|
||||
"memory_maintenance" => {
|
||||
let results = session_manager
|
||||
let results = maintenance_service
|
||||
.run_memory_maintenance_for_all_scopes(job.last_fired_at)
|
||||
.await?;
|
||||
for result in &results {
|
||||
@ -627,7 +630,7 @@ async fn execute_internal_event(
|
||||
}
|
||||
|
||||
async fn execute_agent_task(
|
||||
session_manager: &SessionManager,
|
||||
agent_task_executor: &AgentTaskExecutor,
|
||||
job: &RuntimeJob,
|
||||
execution_chat_id: &str,
|
||||
) -> anyhow::Result<Vec<OutboundMessage>> {
|
||||
@ -643,8 +646,8 @@ async fn execute_agent_task(
|
||||
.ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?;
|
||||
let options = parse_scheduled_agent_task_options(job)?;
|
||||
|
||||
session_manager
|
||||
.run_scheduled_agent_task(channel_name, execution_chat_id, prompt, options)
|
||||
agent_task_executor
|
||||
.execute(channel_name, execution_chat_id, prompt, options)
|
||||
.await
|
||||
.map_err(|error| anyhow::anyhow!(error.to_string()))
|
||||
}
|
||||
@ -964,6 +967,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::bus::MessageBus;
|
||||
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
|
||||
use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
||||
use crate::gateway::session::SessionManager;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
||||
@ -1002,6 +1006,14 @@ mod tests {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn test_scheduler_services() -> (AgentTaskExecutor, SchedulerMaintenanceService) {
|
||||
let session_manager = test_session_manager();
|
||||
(
|
||||
AgentTaskExecutor::new(session_manager.clone()),
|
||||
SchedulerMaintenanceService::new(session_manager),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_job_skip_policy_advances_from_now() {
|
||||
let now = Utc
|
||||
@ -1129,7 +1141,7 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let session_manager = test_session_manager();
|
||||
let (agent_task_executor, maintenance_service) = test_scheduler_services();
|
||||
let scheduler = Scheduler::new(
|
||||
MessageBus::new(8),
|
||||
SchedulerConfig {
|
||||
@ -1141,7 +1153,8 @@ mod tests {
|
||||
},
|
||||
chrono_tz::Asia::Shanghai,
|
||||
store.clone(),
|
||||
session_manager,
|
||||
agent_task_executor,
|
||||
maintenance_service,
|
||||
);
|
||||
|
||||
scheduler.process_tick().await.unwrap();
|
||||
@ -1159,13 +1172,14 @@ mod tests {
|
||||
fn sync_config_jobs_persists_builtin_memory_maintenance_job() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
|
||||
let session_manager = test_session_manager();
|
||||
let (agent_task_executor, maintenance_service) = test_scheduler_services();
|
||||
let scheduler = Scheduler::new(
|
||||
MessageBus::new(8),
|
||||
SchedulerConfig::default(),
|
||||
chrono_tz::Asia::Shanghai,
|
||||
store.clone(),
|
||||
session_manager,
|
||||
agent_task_executor,
|
||||
maintenance_service,
|
||||
);
|
||||
|
||||
scheduler.sync_config_jobs().unwrap();
|
||||
@ -1204,6 +1218,7 @@ mod tests {
|
||||
async fn silent_agent_task_failure_notifies_primary_chat() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let bus = MessageBus::new(8);
|
||||
let (agent_task_executor, maintenance_service) = test_scheduler_services();
|
||||
let scheduler = Scheduler::new(
|
||||
bus.clone(),
|
||||
SchedulerConfig {
|
||||
@ -1215,7 +1230,8 @@ mod tests {
|
||||
},
|
||||
chrono_tz::Asia::Shanghai,
|
||||
store,
|
||||
test_session_manager(),
|
||||
agent_task_executor,
|
||||
maintenance_service,
|
||||
);
|
||||
|
||||
let job = RuntimeJob {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user