feat: 重构调度器以使用 AgentTaskExecutor 和 SchedulerMaintenanceService
- 更新调度器,将 SessionManager 替换为 AgentTaskExecutor 和 SchedulerMaintenanceService。 - 修改作业执行逻辑,使用新服务处理代理任务和内部事件。 - 添加新的 CliChannel 以处理 CLI 连接,并包括适当的注册和注销逻辑。 - 引入 AgentTaskExecutor 和 SchedulerMaintenanceService,用于管理代理任务和会话维护。 - 实现聊天命令处理,用于重置会话上下文。 - 添加后台历史压缩功能,以优化会话存储。 - 创建实用函数,用于准备通过 WebSocket 通信的出站消息。 - 为新功能添加测试,并确保现有测试通过。 Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
62b38eac73
commit
008aba91ac
12
README.md
12
README.md
@ -18,7 +18,7 @@ PicoBot 是一个用 Rust 构建的多通道 Agent 网关。它把消息接入
|
|||||||
PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个可持续运行的 Agent 基础设施:
|
PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个可持续运行的 Agent 基础设施:
|
||||||
|
|
||||||
- 消息从不同渠道进入统一总线
|
- 消息从不同渠道进入统一总线
|
||||||
- SessionManager 负责会话路由、上下文恢复、工具执行和回复生成
|
- SessionManager 负责会话路由和运行时服务编排,AgentExecutionService 负责上下文准备、AgentLoop 执行、结果持久化和回复生成
|
||||||
- SQLite 作为事实来源保存跨重启状态
|
- SQLite 作为事实来源保存跨重启状态
|
||||||
- Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务
|
- Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务
|
||||||
|
|
||||||
@ -37,13 +37,13 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个
|
|||||||
|
|
||||||
主要模块如下:
|
主要模块如下:
|
||||||
|
|
||||||
- src/gateway:网关入口、HTTP 健康检查、WebSocket 服务、Session 池、CLI 会话服务与 Agent 执行编排
|
- src/gateway:网关入口、HTTP 健康检查、WebSocket 控制面、Session 池、CLI 会话服务与 Agent 执行编排
|
||||||
- src/bus:消息总线与消息结构定义
|
- src/bus:消息总线与消息结构定义
|
||||||
- src/agent:AgentLoop 与上下文压缩器
|
- src/agent:AgentLoop 与上下文压缩器
|
||||||
- src/providers:不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic
|
- src/providers:不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic
|
||||||
- src/tools:内置工具集合
|
- src/tools:内置工具集合
|
||||||
- src/storage:SQLite 持久化实现
|
- src/storage:SQLite 持久化实现
|
||||||
- src/channels:渠道适配层,当前已有飞书通道
|
- src/channels:渠道适配层,当前已有 CLI 与飞书通道
|
||||||
- src/scheduler:数据库驱动的计划任务调度器
|
- src/scheduler:数据库驱动的计划任务调度器
|
||||||
- src/skills:技能发现、加载与运行时管理
|
- src/skills:技能发现、加载与运行时管理
|
||||||
- src/client / src/cli:本地 CLI 客户端和交互命令
|
- src/client / src/cli:本地 CLI 客户端和交互命令
|
||||||
@ -632,11 +632,11 @@ PicoBot/
|
|||||||
├── src/
|
├── src/
|
||||||
│ ├── agent/ # AgentLoop、上下文压缩
|
│ ├── agent/ # AgentLoop、上下文压缩
|
||||||
│ ├── bus/ # 消息总线与消息结构
|
│ ├── bus/ # 消息总线与消息结构
|
||||||
│ ├── channels/ # 渠道适配
|
│ ├── channels/ # CLI / 飞书等渠道适配
|
||||||
│ ├── cli/ # CLI 输入命令
|
│ ├── cli/ # CLI 输入命令
|
||||||
│ ├── client/ # WebSocket CLI 客户端
|
│ ├── client/ # WebSocket CLI 客户端
|
||||||
│ ├── config/ # 配置解析
|
│ ├── config/ # 配置解析
|
||||||
│ ├── gateway/ # Gateway、SessionManager、WS/HTTP
|
│ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面
|
||||||
│ ├── providers/ # OpenAI / Anthropic Provider
|
│ ├── providers/ # OpenAI / Anthropic Provider
|
||||||
│ ├── scheduler/ # 定时任务系统
|
│ ├── scheduler/ # 定时任务系统
|
||||||
│ ├── skills/ # 技能运行时
|
│ ├── skills/ # 技能运行时
|
||||||
@ -656,7 +656,7 @@ PicoBot/
|
|||||||
建议维护时重点关注:
|
建议维护时重点关注:
|
||||||
|
|
||||||
- docs/PERSISTENCE.md:持久化结构是否与代码一致
|
- docs/PERSISTENCE.md:持久化结构是否与代码一致
|
||||||
- src/gateway/session.rs:消息流、工具注册、记忆维护、会话恢复主逻辑
|
- src/gateway/session.rs:会话状态、会话路由和运行时服务编排
|
||||||
- src/storage/mod.rs:SQLite schema 变更
|
- src/storage/mod.rs:SQLite schema 变更
|
||||||
- src/config/mod.rs:配置项变更是否同步到 README
|
- src/config/mod.rs:配置项变更是否同步到 README
|
||||||
|
|
||||||
|
|||||||
155
src/channels/cli.rs
Normal file
155
src/channels/cli.rs
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{RwLock, mpsc};
|
||||||
|
|
||||||
|
use crate::bus::{MessageBus, OutboundMessage};
|
||||||
|
use crate::gateway::ws_adapter::ws_outbound_from_outbound_message;
|
||||||
|
use crate::protocol::WsOutbound;
|
||||||
|
|
||||||
|
use super::base::{Channel, ChannelError};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CliConnection {
|
||||||
|
connection_id: String,
|
||||||
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CliChannel {
|
||||||
|
connections: Arc<RwLock<HashMap<String, CliConnection>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CliChannel {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
connections: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register_connection(
|
||||||
|
&self,
|
||||||
|
session_id: impl Into<String>,
|
||||||
|
connection_id: impl Into<String>,
|
||||||
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
|
) {
|
||||||
|
let session_id = session_id.into();
|
||||||
|
let connection_id = connection_id.into();
|
||||||
|
let previous = self.connections.write().await.insert(
|
||||||
|
session_id.clone(),
|
||||||
|
CliConnection {
|
||||||
|
connection_id: connection_id.clone(),
|
||||||
|
sender,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if previous.is_some() {
|
||||||
|
tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn unregister_connection(&self, connection_id: &str) {
|
||||||
|
self.connections
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.retain(|_, connection| connection.connection_id != connection_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CliChannel {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Channel for CliChannel {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"cli"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_running(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start(&self, _bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop(&self) -> Result<(), ChannelError> {
|
||||||
|
self.connections.write().await.clear();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
|
let connection = self.connections.read().await.get(&msg.chat_id).cloned();
|
||||||
|
let Some(connection) = connection else {
|
||||||
|
return Err(ChannelError::SendError(format!(
|
||||||
|
"No active CLI connection for session {}",
|
||||||
|
msg.chat_id
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
|
||||||
|
for outbound in ws_outbound_from_outbound_message(&msg) {
|
||||||
|
connection
|
||||||
|
.sender
|
||||||
|
.send(outbound)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::bus::OutboundMessage;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cli_channel_sends_to_registered_session() {
|
||||||
|
let channel = CliChannel::new();
|
||||||
|
let (sender, mut receiver) = mpsc::channel(4);
|
||||||
|
channel
|
||||||
|
.register_connection("session-1", "conn-1", sender)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
channel
|
||||||
|
.send(OutboundMessage::assistant(
|
||||||
|
"cli",
|
||||||
|
"session-1",
|
||||||
|
"hello",
|
||||||
|
None,
|
||||||
|
HashMap::new(),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let outbound = receiver.recv().await.unwrap();
|
||||||
|
assert!(matches!(outbound, WsOutbound::AssistantResponse { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_cli_channel_unregisters_connection_sessions() {
|
||||||
|
let channel = CliChannel::new();
|
||||||
|
let (sender, _receiver) = mpsc::channel(4);
|
||||||
|
channel
|
||||||
|
.register_connection("session-1", "conn-1", sender)
|
||||||
|
.await;
|
||||||
|
channel.unregister_connection("conn-1").await;
|
||||||
|
|
||||||
|
let error = channel
|
||||||
|
.send(OutboundMessage::assistant(
|
||||||
|
"cli",
|
||||||
|
"session-1",
|
||||||
|
"hello",
|
||||||
|
None,
|
||||||
|
HashMap::new(),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert!(error.to_string().contains("No active CLI connection"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ use tokio::sync::RwLock;
|
|||||||
|
|
||||||
use crate::bus::MessageBus;
|
use crate::bus::MessageBus;
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
use crate::channels::cli::CliChannel;
|
||||||
use crate::channels::feishu::FeishuChannel;
|
use crate::channels::feishu::FeishuChannel;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|
||||||
@ -12,13 +13,19 @@ use crate::config::Config;
|
|||||||
pub struct ChannelManager {
|
pub struct ChannelManager {
|
||||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
||||||
bus: Arc<MessageBus>,
|
bus: Arc<MessageBus>,
|
||||||
|
cli_channel: Arc<CliChannel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChannelManager {
|
impl ChannelManager {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
let cli_channel = Arc::new(CliChannel::new());
|
||||||
|
let mut channels: HashMap<String, Arc<dyn Channel + Send + Sync>> = HashMap::new();
|
||||||
|
channels.insert("cli".to_string(), cli_channel.clone());
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
channels: Arc::new(RwLock::new(channels)),
|
||||||
bus: MessageBus::new(100),
|
bus: MessageBus::new(100),
|
||||||
|
cli_channel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,6 +34,10 @@ impl ChannelManager {
|
|||||||
self.bus.clone()
|
self.bus.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cli_channel(&self) -> Arc<CliChannel> {
|
||||||
|
self.cli_channel.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// Initialize all Channel instances from config
|
/// Initialize all Channel instances from config
|
||||||
pub async fn init(
|
pub async fn init(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
pub mod base;
|
pub mod base;
|
||||||
|
pub mod cli;
|
||||||
pub mod feishu;
|
pub mod feishu;
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
|
|
||||||
pub use base::{Channel, ChannelError};
|
pub use base::{Channel, ChannelError};
|
||||||
|
pub use cli::CliChannel;
|
||||||
pub use feishu::FeishuChannel;
|
pub use feishu::FeishuChannel;
|
||||||
pub use manager::ChannelManager;
|
pub use manager::ChannelManager;
|
||||||
|
|||||||
52
src/gateway/agent_task_executor.rs
Normal file
52
src/gateway/agent_task_executor.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
use crate::agent::AgentError;
|
||||||
|
use crate::bus::OutboundMessage;
|
||||||
|
|
||||||
|
use super::memory_maintenance::MemoryMaintenanceScopeResult;
|
||||||
|
use super::session::{ScheduledAgentTaskOptions, SessionManager};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AgentTaskExecutor {
|
||||||
|
session_manager: SessionManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentTaskExecutor {
|
||||||
|
pub fn new(session_manager: SessionManager) -> Self {
|
||||||
|
Self { session_manager }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn execute(
|
||||||
|
&self,
|
||||||
|
channel_name: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
prompt: &str,
|
||||||
|
options: ScheduledAgentTaskOptions,
|
||||||
|
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||||
|
self.session_manager
|
||||||
|
.run_scheduled_agent_task(channel_name, chat_id, prompt, options)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SchedulerMaintenanceService {
|
||||||
|
session_manager: SessionManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SchedulerMaintenanceService {
|
||||||
|
pub fn new(session_manager: SessionManager) -> Self {
|
||||||
|
Self { session_manager }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
|
||||||
|
self.session_manager.cleanup_expired_sessions().await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||||||
|
&self,
|
||||||
|
updated_since: Option<i64>,
|
||||||
|
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||||
|
self.session_manager
|
||||||
|
.run_memory_maintenance_for_all_scopes(updated_since)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
159
src/gateway/command.rs
Normal file
159
src/gateway/command.rs
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
use crate::agent::AgentError;
|
||||||
|
|
||||||
|
use super::session::Session;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum InChatCommand {
|
||||||
|
FreshConversation,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
||||||
|
match content.trim() {
|
||||||
|
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn handle_in_chat_command(
|
||||||
|
session: &mut Session,
|
||||||
|
chat_id: &str,
|
||||||
|
content: &str,
|
||||||
|
) -> Result<Option<String>, AgentError> {
|
||||||
|
match parse_in_chat_command(content) {
|
||||||
|
Some(InChatCommand::FreshConversation) => {
|
||||||
|
session.reset_chat_context(chat_id)?;
|
||||||
|
Ok(Some("Started a fresh conversation.".to_string()))
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use crate::skills::SkillRuntime;
|
||||||
|
use crate::storage::SessionStore;
|
||||||
|
use crate::tools::ToolRegistry;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
fn test_provider_config() -> LLMProviderConfig {
|
||||||
|
LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "test".to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
|
model_id: "test-model".to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
tool_result_max_chars: 20_000,
|
||||||
|
context_tool_result_trim_chars: 20_000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_in_chat_command_aliases() {
|
||||||
|
assert_eq!(
|
||||||
|
parse_in_chat_command("/new"),
|
||||||
|
Some(InChatCommand::FreshConversation)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
parse_in_chat_command(" /reset \n"),
|
||||||
|
Some(InChatCommand::FreshConversation)
|
||||||
|
);
|
||||||
|
assert_eq!(parse_in_chat_command("/new planning"), None);
|
||||||
|
assert_eq!(parse_in_chat_command("please /reset"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_in_chat_command_resets_active_history_only() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
|
let tools = Arc::new(ToolRegistry::new());
|
||||||
|
let mut session = Session::new(
|
||||||
|
"feishu".to_string(),
|
||||||
|
test_provider_config(),
|
||||||
|
user_tx,
|
||||||
|
tools,
|
||||||
|
skills,
|
||||||
|
store.clone(),
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
session.ensure_persistent_session("chat-1").unwrap();
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
session
|
||||||
|
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response, "Started a fresh conversation.");
|
||||||
|
assert!(session.get_history("chat-1").unwrap().is_empty());
|
||||||
|
assert!(
|
||||||
|
store
|
||||||
|
.load_messages(&session.persistent_session_id("chat-1"))
|
||||||
|
.unwrap()
|
||||||
|
.is_empty()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
store
|
||||||
|
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
2,
|
||||||
|
);
|
||||||
|
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
let history = session.get_history("chat-1").unwrap();
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
|
let tools = Arc::new(ToolRegistry::new());
|
||||||
|
let mut session = Session::new(
|
||||||
|
"feishu".to_string(),
|
||||||
|
test_provider_config(),
|
||||||
|
user_tx,
|
||||||
|
tools,
|
||||||
|
skills,
|
||||||
|
store,
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
session.ensure_persistent_session("chat-1").unwrap();
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
session
|
||||||
|
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
||||||
|
session
|
||||||
|
.ensure_agent_prompt_before_user_message("chat-1")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let history = session.get_history("chat-1").unwrap();
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
}
|
||||||
|
}
|
||||||
105
src/gateway/compaction.rs
Normal file
105
src/gateway/compaction.rs
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use crate::agent::AgentError;
|
||||||
|
|
||||||
|
use super::session::Session;
|
||||||
|
|
||||||
|
pub(crate) async fn schedule_background_history_compaction(
|
||||||
|
session: Arc<Mutex<Session>>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
) -> Result<(), AgentError> {
|
||||||
|
let chat_id = chat_id.into();
|
||||||
|
|
||||||
|
let snapshot = {
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
|
||||||
|
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||||
|
|
||||||
|
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||||
|
let compressor = session_guard.compressor().clone();
|
||||||
|
if !compressor.should_compress(&history) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session_guard.try_start_background_compaction(&chat_id) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
session_guard.store(),
|
||||||
|
session_guard.persistent_session_id(&chat_id),
|
||||||
|
session_record.reset_cutoff_seq,
|
||||||
|
session_record.message_count,
|
||||||
|
history,
|
||||||
|
compressor,
|
||||||
|
session_guard.provider_config().clone(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (
|
||||||
|
store,
|
||||||
|
session_id,
|
||||||
|
expected_reset_cutoff_seq,
|
||||||
|
snapshot_end_seq,
|
||||||
|
history,
|
||||||
|
compressor,
|
||||||
|
provider_config,
|
||||||
|
) = snapshot;
|
||||||
|
let session_for_task = session.clone();
|
||||||
|
let chat_id_for_task = chat_id.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
|
||||||
|
|
||||||
|
let compaction_result = compressor
|
||||||
|
.build_compaction_plan(&history, &provider_config)
|
||||||
|
.await;
|
||||||
|
let mut committed = false;
|
||||||
|
|
||||||
|
match compaction_result {
|
||||||
|
Ok(Some(plan)) => match store.compact_active_history(
|
||||||
|
&session_id,
|
||||||
|
expected_reset_cutoff_seq,
|
||||||
|
snapshot_end_seq,
|
||||||
|
&plan.preserved_system_messages,
|
||||||
|
&plan.summary_message,
|
||||||
|
&plan.preserved_messages,
|
||||||
|
) {
|
||||||
|
Ok(true) => {
|
||||||
|
committed = true;
|
||||||
|
tracing::info!(
|
||||||
|
chat_id = %chat_id_for_task,
|
||||||
|
snapshot_end_seq,
|
||||||
|
compressed_turns = plan.compressed_turns,
|
||||||
|
preserved_turns = plan.preserved_turns,
|
||||||
|
"Background history compaction committed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(false) => {
|
||||||
|
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Ok(None) => {
|
||||||
|
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut session_guard = session_for_task.lock().await;
|
||||||
|
if committed {
|
||||||
|
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
|
||||||
|
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
session_guard.finish_background_compaction(&chat_id_for_task);
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@ -7,10 +7,10 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL
|
|||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use super::session::{
|
use super::command::handle_in_chat_command;
|
||||||
Session, enrich_user_content_with_media_refs, handle_in_chat_command,
|
use super::compaction::schedule_background_history_compaction;
|
||||||
schedule_background_history_compaction,
|
use super::message_prepare::enrich_user_content_with_media_refs;
|
||||||
};
|
use super::session::Session;
|
||||||
|
|
||||||
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||||
|
|
||||||
|
|||||||
39
src/gateway/message_prepare.rs
Normal file
39
src/gateway/message_prepare.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
use crate::agent::AgentError;
|
||||||
|
|
||||||
|
pub(crate) fn enrich_user_content_with_media_refs(
|
||||||
|
content: &str,
|
||||||
|
media_refs: &[String],
|
||||||
|
) -> Result<String, AgentError> {
|
||||||
|
if media_refs.is_empty() {
|
||||||
|
return Ok(content.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let media_refs_json = serde_json::to_string(media_refs)
|
||||||
|
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
|
||||||
|
|
||||||
|
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||||
|
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||||
|
|
||||||
|
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
enriched,
|
||||||
|
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
|
||||||
|
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(enriched, "hello");
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,13 +1,18 @@
|
|||||||
|
pub mod agent_task_executor;
|
||||||
pub mod cli_session;
|
pub mod cli_session;
|
||||||
|
pub mod command;
|
||||||
|
pub mod compaction;
|
||||||
pub mod execution;
|
pub mod execution;
|
||||||
pub mod http;
|
pub mod http;
|
||||||
pub mod memory_maintenance;
|
pub mod memory_maintenance;
|
||||||
|
pub mod message_prepare;
|
||||||
pub mod processor;
|
pub mod processor;
|
||||||
pub mod prompt;
|
pub mod prompt;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod session_factory;
|
pub mod session_factory;
|
||||||
pub mod session_pool;
|
pub mod session_pool;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
pub mod ws_adapter;
|
||||||
|
|
||||||
use axum::{Router, routing};
|
use axum::{Router, routing};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -21,6 +26,7 @@ use crate::config::LLMProviderConfig;
|
|||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::scheduler::Scheduler;
|
use crate::scheduler::Scheduler;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
|
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
||||||
use processor::InboundProcessor;
|
use processor::InboundProcessor;
|
||||||
use session::SessionManager;
|
use session::SessionManager;
|
||||||
|
|
||||||
@ -122,7 +128,8 @@ pub async fn run(
|
|||||||
state.config.scheduler.clone(),
|
state.config.scheduler.clone(),
|
||||||
timezone,
|
timezone,
|
||||||
state.session_manager.store(),
|
state.session_manager.store(),
|
||||||
state.session_manager.clone(),
|
AgentTaskExecutor::new(state.session_manager.clone()),
|
||||||
|
SchedulerMaintenanceService::new(state.session_manager.clone()),
|
||||||
);
|
);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::bus::MessageBus;
|
use crate::bus::{MessageBus, OutboundMessage};
|
||||||
|
|
||||||
use super::session::{BusToolCallEmitter, SessionManager};
|
use super::session::{BusToolCallEmitter, SessionManager};
|
||||||
|
|
||||||
@ -70,6 +70,21 @@ impl InboundProcessor {
|
|||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
tracing::error!(error = %error, "Failed to handle message");
|
tracing::error!(error = %error, "Failed to handle message");
|
||||||
|
let mut metadata = inbound.forwarded_metadata.clone();
|
||||||
|
metadata.insert("error_kind".to_string(), "agent_execution".to_string());
|
||||||
|
if let Err(publish_error) = self
|
||||||
|
.bus
|
||||||
|
.publish_outbound(OutboundMessage::error_notification(
|
||||||
|
inbound.channel,
|
||||||
|
inbound.chat_id,
|
||||||
|
error.to_string(),
|
||||||
|
None,
|
||||||
|
metadata,
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
tracing::error!(error = %publish_error, "Failed to publish execution error outbound");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,20 +43,6 @@ fn preview_text(content: &str, max_chars: usize) -> String {
|
|||||||
preview.replace('\n', "\\n")
|
preview.replace('\n', "\\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn enrich_user_content_with_media_refs(
|
|
||||||
content: &str,
|
|
||||||
media_refs: &[String],
|
|
||||||
) -> Result<String, AgentError> {
|
|
||||||
if media_refs.is_empty() {
|
|
||||||
return Ok(content.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let media_refs_json = serde_json::to_string(media_refs)
|
|
||||||
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
|
|
||||||
|
|
||||||
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
@ -393,15 +379,15 @@ impl Session {
|
|||||||
&self.compressor
|
&self.compressor
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||||||
self.compression_in_flight.insert(chat_id.to_string())
|
self.compression_in_flight.insert(chat_id.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finish_background_compaction(&mut self, chat_id: &str) {
|
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
|
||||||
self.compression_in_flight.remove(chat_id);
|
self.compression_in_flight.remove(chat_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
let history = self
|
let history = self
|
||||||
.store
|
.store
|
||||||
.load_messages(&self.persistent_session_id(chat_id))
|
.load_messages(&self.persistent_session_id(chat_id))
|
||||||
@ -410,6 +396,10 @@ impl Session {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn store(&self) -> Arc<SessionStore> {
|
||||||
|
self.store.clone()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
if self.skills.is_empty() {
|
if self.skills.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
@ -528,129 +518,6 @@ fn default_tools(
|
|||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
enum InChatCommand {
|
|
||||||
FreshConversation,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
|
||||||
match content.trim() {
|
|
||||||
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn handle_in_chat_command(
|
|
||||||
session: &mut Session,
|
|
||||||
chat_id: &str,
|
|
||||||
content: &str,
|
|
||||||
) -> Result<Option<String>, AgentError> {
|
|
||||||
match parse_in_chat_command(content) {
|
|
||||||
Some(InChatCommand::FreshConversation) => {
|
|
||||||
session.reset_chat_context(chat_id)?;
|
|
||||||
Ok(Some("Started a fresh conversation.".to_string()))
|
|
||||||
}
|
|
||||||
None => Ok(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn schedule_background_history_compaction(
|
|
||||||
session: Arc<Mutex<Session>>,
|
|
||||||
chat_id: impl Into<String>,
|
|
||||||
) -> Result<(), AgentError> {
|
|
||||||
let chat_id = chat_id.into();
|
|
||||||
|
|
||||||
let snapshot = {
|
|
||||||
let mut session_guard = session.lock().await;
|
|
||||||
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
|
|
||||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
|
||||||
|
|
||||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
|
||||||
let compressor = session_guard.compressor().clone();
|
|
||||||
if !compressor.should_compress(&history) {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if !session_guard.try_start_background_compaction(&chat_id) {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
(
|
|
||||||
session_guard.store.clone(),
|
|
||||||
session_guard.persistent_session_id(&chat_id),
|
|
||||||
session_record.reset_cutoff_seq,
|
|
||||||
session_record.message_count,
|
|
||||||
history,
|
|
||||||
compressor,
|
|
||||||
session_guard.provider_config().clone(),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (
|
|
||||||
store,
|
|
||||||
session_id,
|
|
||||||
expected_reset_cutoff_seq,
|
|
||||||
snapshot_end_seq,
|
|
||||||
history,
|
|
||||||
compressor,
|
|
||||||
provider_config,
|
|
||||||
) = snapshot;
|
|
||||||
let session_for_task = session.clone();
|
|
||||||
let chat_id_for_task = chat_id.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
|
|
||||||
|
|
||||||
let compaction_result = compressor
|
|
||||||
.build_compaction_plan(&history, &provider_config)
|
|
||||||
.await;
|
|
||||||
let mut committed = false;
|
|
||||||
|
|
||||||
match compaction_result {
|
|
||||||
Ok(Some(plan)) => match store.compact_active_history(
|
|
||||||
&session_id,
|
|
||||||
expected_reset_cutoff_seq,
|
|
||||||
snapshot_end_seq,
|
|
||||||
&plan.preserved_system_messages,
|
|
||||||
&plan.summary_message,
|
|
||||||
&plan.preserved_messages,
|
|
||||||
) {
|
|
||||||
Ok(true) => {
|
|
||||||
committed = true;
|
|
||||||
tracing::info!(
|
|
||||||
chat_id = %chat_id_for_task,
|
|
||||||
snapshot_end_seq,
|
|
||||||
compressed_turns = plan.compressed_turns,
|
|
||||||
preserved_turns = plan.preserved_turns,
|
|
||||||
"Background history compaction committed"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(false) => {
|
|
||||||
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Ok(None) => {
|
|
||||||
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut session_guard = session_for_task.lock().await;
|
|
||||||
if committed {
|
|
||||||
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
|
|
||||||
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
session_guard.finish_background_compaction(&chat_id_for_task);
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
session_ttl_hours: u64,
|
session_ttl_hours: u64,
|
||||||
@ -913,25 +780,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
|
||||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
|
||||||
|
|
||||||
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
enriched,
|
|
||||||
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
|
|
||||||
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(enriched, "hello");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_latest_user_message_guard_tracks_current_turn() {
|
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
@ -1750,75 +1598,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_parse_in_chat_command_aliases() {
|
|
||||||
assert_eq!(
|
|
||||||
parse_in_chat_command("/new"),
|
|
||||||
Some(InChatCommand::FreshConversation)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
parse_in_chat_command(" /reset \n"),
|
|
||||||
Some(InChatCommand::FreshConversation)
|
|
||||||
);
|
|
||||||
assert_eq!(parse_in_chat_command("/new planning"), None);
|
|
||||||
assert_eq!(parse_in_chat_command("please /reset"), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_handle_in_chat_command_resets_active_history_only() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
|
||||||
let tools = Arc::new(default_tools(
|
|
||||||
skills.clone(),
|
|
||||||
store.clone(),
|
|
||||||
HashSet::new(),
|
|
||||||
"Asia/Shanghai".to_string(),
|
|
||||||
));
|
|
||||||
let mut session = Session::new(
|
|
||||||
"feishu".to_string(),
|
|
||||||
test_provider_config(),
|
|
||||||
user_tx,
|
|
||||||
tools,
|
|
||||||
skills,
|
|
||||||
store.clone(),
|
|
||||||
100,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
session.ensure_persistent_session("chat-1").unwrap();
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
|
||||||
session
|
|
||||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(response, "Started a fresh conversation.");
|
|
||||||
assert!(session.get_history("chat-1").unwrap().is_empty());
|
|
||||||
assert!(
|
|
||||||
store
|
|
||||||
.load_messages(&session.persistent_session_id("chat-1"))
|
|
||||||
.unwrap()
|
|
||||||
.is_empty()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
store
|
|
||||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
|
||||||
.unwrap()
|
|
||||||
.len(),
|
|
||||||
2,
|
|
||||||
);
|
|
||||||
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
|
||||||
let history = session.get_history("chat-1").unwrap();
|
|
||||||
assert_eq!(history.len(), 1);
|
|
||||||
assert_eq!(history[0].role, "system");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
|
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
@ -1955,45 +1734,6 @@ mod tests {
|
|||||||
assert_eq!(system_messages, 1);
|
assert_eq!(system_messages, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
|
||||||
let tools = Arc::new(default_tools(
|
|
||||||
skills.clone(),
|
|
||||||
store.clone(),
|
|
||||||
HashSet::new(),
|
|
||||||
"Asia/Shanghai".to_string(),
|
|
||||||
));
|
|
||||||
let mut session = Session::new(
|
|
||||||
"feishu".to_string(),
|
|
||||||
test_provider_config(),
|
|
||||||
user_tx,
|
|
||||||
tools,
|
|
||||||
skills,
|
|
||||||
store.clone(),
|
|
||||||
100,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
session.ensure_persistent_session("chat-1").unwrap();
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
|
||||||
session
|
|
||||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
|
||||||
session
|
|
||||||
.ensure_agent_prompt_before_user_message("chat-1")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let history = session.get_history("chat-1").unwrap();
|
|
||||||
assert_eq!(history.len(), 1);
|
|
||||||
assert_eq!(history[0].role, "system");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_default_tools_registers_get_time() {
|
fn test_default_tools_registers_get_time() {
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
|
|||||||
@ -1,40 +1,17 @@
|
|||||||
use super::{
|
use super::GatewayState;
|
||||||
GatewayState,
|
use crate::agent::AgentError;
|
||||||
execution::{AgentExecutionService, MessageExecutionRequest, should_display_message_to_user},
|
use crate::bus::InboundMessage;
|
||||||
session::Session,
|
|
||||||
};
|
|
||||||
use crate::agent::{AgentError, EmittedMessageHandler};
|
|
||||||
use crate::bus::message::{OutboundEventKind, ToolMessageState, format_tool_call_content};
|
|
||||||
use crate::bus::{ChatMessage, OutboundMessage};
|
|
||||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||||
use async_trait::async_trait;
|
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
const CLI_CHANNEL_NAME: &str = "cli";
|
const CLI_CHANNEL_NAME: &str = "cli";
|
||||||
|
|
||||||
struct WsToolCallEmitter {
|
|
||||||
sender: mpsc::Sender<WsOutbound>,
|
|
||||||
show_tool_results: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmittedMessageHandler for WsToolCallEmitter {
|
|
||||||
async fn handle(&self, message: ChatMessage) {
|
|
||||||
if !should_display_message_to_user(self.show_tool_results, &message) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for outbound in ws_outbound_from_chat_message(&message) {
|
|
||||||
let _ = self.sender.send(outbound).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||||
ws.on_upgrade(|socket| async {
|
ws.on_upgrade(|socket| async {
|
||||||
handle_socket(socket, state).await;
|
handle_socket(socket, state).await;
|
||||||
@ -44,14 +21,6 @@ pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewaySta
|
|||||||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
||||||
|
|
||||||
let provider_config = match state.config.get_provider_config("default") {
|
|
||||||
Ok(cfg) => cfg,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(error = %e, "Failed to get provider config");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let cli_sessions = state.session_manager.cli_sessions();
|
let cli_sessions = state.session_manager.cli_sessions();
|
||||||
let initial_record = match cli_sessions.create(None) {
|
let initial_record = match cli_sessions.create(None) {
|
||||||
Ok(record) => record,
|
Ok(record) => record,
|
||||||
@ -61,39 +30,20 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let channel_name = CLI_CHANNEL_NAME.to_string();
|
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
// 创建 CLI session
|
|
||||||
let session = match Session::new(
|
|
||||||
channel_name.clone(),
|
|
||||||
provider_config,
|
|
||||||
sender,
|
|
||||||
state.session_manager.tools(),
|
|
||||||
state.session_manager.skills(),
|
|
||||||
state.session_manager.store(),
|
|
||||||
state.config.gateway.agent_prompt_reinject_every,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(s) => Arc::new(Mutex::new(s)),
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(error = %e, "Failed to create session");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
|
|
||||||
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let runtime_session_id = session.lock().await.id.to_string();
|
|
||||||
let mut current_session_id = initial_record.id.clone();
|
let mut current_session_id = initial_record.id.clone();
|
||||||
|
state
|
||||||
|
.channel_manager
|
||||||
|
.cli_channel()
|
||||||
|
.register_connection(
|
||||||
|
current_session_id.clone(),
|
||||||
|
runtime_session_id.clone(),
|
||||||
|
sender.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
||||||
|
|
||||||
let _ = session
|
let _ = sender
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.send(WsOutbound::SessionEstablished {
|
.send(WsOutbound::SessionEstablished {
|
||||||
session_id: current_session_id.clone(),
|
session_id: current_session_id.clone(),
|
||||||
})
|
})
|
||||||
@ -123,7 +73,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
Ok(inbound) => {
|
Ok(inbound) => {
|
||||||
if let Err(e) = handle_inbound(
|
if let Err(e) = handle_inbound(
|
||||||
&state,
|
&state,
|
||||||
&session,
|
&sender,
|
||||||
&runtime_session_id,
|
&runtime_session_id,
|
||||||
&mut current_session_id,
|
&mut current_session_id,
|
||||||
inbound,
|
inbound,
|
||||||
@ -131,9 +81,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||||||
let _ = session
|
let _ = sender
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.send(WsOutbound::Error {
|
.send(WsOutbound::Error {
|
||||||
code: "SESSION_ERROR".to_string(),
|
code: "SESSION_ERROR".to_string(),
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
@ -143,9 +91,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||||
let _ = session
|
let _ = sender
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.send(WsOutbound::Error {
|
.send(WsOutbound::Error {
|
||||||
code: "PARSE_ERROR".to_string(),
|
code: "PARSE_ERROR".to_string(),
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
@ -163,6 +109,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state
|
||||||
|
.channel_manager
|
||||||
|
.cli_channel()
|
||||||
|
.unregister_connection(&runtime_session_id)
|
||||||
|
.await;
|
||||||
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,115 +129,9 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|
||||||
match message.role.as_str() {
|
|
||||||
"assistant" => {
|
|
||||||
if let Some(tool_calls) = &message.tool_calls {
|
|
||||||
let mut outbound = Vec::new();
|
|
||||||
if !message.content.trim().is_empty() {
|
|
||||||
outbound.push(WsOutbound::AssistantResponse {
|
|
||||||
id: message.id.clone(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
|
||||||
id: message.id.clone(),
|
|
||||||
tool_call_id: tool_call.id.clone(),
|
|
||||||
tool_name: tool_call.name.clone(),
|
|
||||||
arguments: tool_call.arguments.clone(),
|
|
||||||
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}));
|
|
||||||
outbound
|
|
||||||
} else {
|
|
||||||
vec![WsOutbound::AssistantResponse {
|
|
||||||
id: message.id.clone(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"tool" => match message
|
|
||||||
.tool_state
|
|
||||||
.as_ref()
|
|
||||||
.unwrap_or(&ToolMessageState::Completed)
|
|
||||||
{
|
|
||||||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
|
||||||
id: message.id.clone(),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}],
|
|
||||||
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
|
||||||
id: message.id.clone(),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
|
||||||
}],
|
|
||||||
},
|
|
||||||
_ => Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
|
|
||||||
match message.event_kind {
|
|
||||||
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
|
|
||||||
vec![WsOutbound::AssistantResponse {
|
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
|
|
||||||
id: message
|
|
||||||
.tool_call_id
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
arguments: message
|
|
||||||
.tool_arguments
|
|
||||||
.clone()
|
|
||||||
.unwrap_or(serde_json::Value::Null),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}],
|
|
||||||
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
|
|
||||||
id: message
|
|
||||||
.tool_call_id
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}],
|
|
||||||
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
|
|
||||||
id: message
|
|
||||||
.tool_call_id
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
|
||||||
}],
|
|
||||||
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
|
|
||||||
code: "AGENT_ERROR".to_string(),
|
|
||||||
message: message.content.clone(),
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_inbound(
|
async fn handle_inbound(
|
||||||
state: &Arc<GatewayState>,
|
state: &Arc<GatewayState>,
|
||||||
session: &Arc<Mutex<Session>>,
|
sender: &mpsc::Sender<WsOutbound>,
|
||||||
runtime_session_id: &str,
|
runtime_session_id: &str,
|
||||||
current_session_id: &mut String,
|
current_session_id: &mut String,
|
||||||
inbound: WsInbound,
|
inbound: WsInbound,
|
||||||
@ -300,43 +145,31 @@ async fn handle_inbound(
|
|||||||
} => {
|
} => {
|
||||||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||||||
let user_tx = session.lock().await.user_tx.clone();
|
|
||||||
let live_emitter = Arc::new(WsToolCallEmitter {
|
|
||||||
sender: user_tx.clone(),
|
|
||||||
show_tool_results: state.config.gateway.show_tool_results,
|
|
||||||
});
|
|
||||||
|
|
||||||
match AgentExecutionService::new(state.config.gateway.show_tool_results)
|
state
|
||||||
.prepare_and_execute_message(MessageExecutionRequest {
|
.channel_manager
|
||||||
session: session.clone(),
|
.cli_channel()
|
||||||
channel_name: CLI_CHANNEL_NAME,
|
.register_connection(
|
||||||
sender_id: &sender_id,
|
chat_id.clone(),
|
||||||
chat_id: &chat_id,
|
runtime_session_id.to_string(),
|
||||||
content: &content,
|
sender.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
state
|
||||||
|
.bus
|
||||||
|
.publish_inbound(InboundMessage {
|
||||||
|
channel: CLI_CHANNEL_NAME.to_string(),
|
||||||
|
sender_id,
|
||||||
|
chat_id,
|
||||||
|
content,
|
||||||
|
timestamp: current_timestamp(),
|
||||||
media: Vec::new(),
|
media: Vec::new(),
|
||||||
live_emitter: Some(live_emitter),
|
metadata: HashMap::new(),
|
||||||
|
forwarded_metadata: HashMap::new(),
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
{
|
.map_err(|error| AgentError::Other(error.to_string()))?;
|
||||||
Ok(outbound_messages) => {
|
|
||||||
for outbound in outbound_messages
|
|
||||||
.iter()
|
|
||||||
.flat_map(ws_outbound_from_outbound_message)
|
|
||||||
{
|
|
||||||
let _ = user_tx.send(outbound).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(AgentError::LlmError(error)) => {
|
|
||||||
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
|
||||||
let _ = user_tx
|
|
||||||
.send(WsOutbound::Error {
|
|
||||||
code: "LLM_ERROR".to_string(),
|
|
||||||
message: error,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
Err(error) => return Err(error),
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -352,9 +185,11 @@ async fn handle_inbound(
|
|||||||
.cli_sessions()
|
.cli_sessions()
|
||||||
.clear_messages(&target)?;
|
.clear_messages(&target)?;
|
||||||
|
|
||||||
let mut session_guard = session.lock().await;
|
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
||||||
session_guard.remove_history(&target);
|
session.lock().await.remove_history(&target);
|
||||||
let _ = session_guard
|
}
|
||||||
|
|
||||||
|
let _ = sender
|
||||||
.send(WsOutbound::HistoryCleared { session_id: target })
|
.send(WsOutbound::HistoryCleared { session_id: target })
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -366,9 +201,16 @@ async fn handle_inbound(
|
|||||||
.create(title.as_deref())?;
|
.create(title.as_deref())?;
|
||||||
*current_session_id = record.id.clone();
|
*current_session_id = record.id.clone();
|
||||||
|
|
||||||
let mut session_guard = session.lock().await;
|
state
|
||||||
session_guard.ensure_chat_loaded(&record.id)?;
|
.channel_manager
|
||||||
let _ = session_guard
|
.cli_channel()
|
||||||
|
.register_connection(
|
||||||
|
record.id.clone(),
|
||||||
|
runtime_session_id.to_string(),
|
||||||
|
sender.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let _ = sender
|
||||||
.send(WsOutbound::SessionCreated {
|
.send(WsOutbound::SessionCreated {
|
||||||
session_id: record.id,
|
session_id: record.id,
|
||||||
title: record.title,
|
title: record.title,
|
||||||
@ -383,8 +225,7 @@ async fn handle_inbound(
|
|||||||
.list(include_archived)?;
|
.list(include_archived)?;
|
||||||
let summaries = records.into_iter().map(to_session_summary).collect();
|
let summaries = records.into_iter().map(to_session_summary).collect();
|
||||||
|
|
||||||
let session_guard = session.lock().await;
|
let _ = sender
|
||||||
let _ = session_guard
|
|
||||||
.send(WsOutbound::SessionList {
|
.send(WsOutbound::SessionList {
|
||||||
sessions: summaries,
|
sessions: summaries,
|
||||||
current_session_id: Some(current_session_id.clone()),
|
current_session_id: Some(current_session_id.clone()),
|
||||||
@ -394,8 +235,7 @@ async fn handle_inbound(
|
|||||||
}
|
}
|
||||||
WsInbound::LoadSession { session_id } => {
|
WsInbound::LoadSession { session_id } => {
|
||||||
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
|
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
|
||||||
let session_guard = session.lock().await;
|
let _ = sender
|
||||||
let _ = session_guard
|
|
||||||
.send(WsOutbound::Error {
|
.send(WsOutbound::Error {
|
||||||
code: "SESSION_NOT_FOUND".to_string(),
|
code: "SESSION_NOT_FOUND".to_string(),
|
||||||
message: format!("Session not found: {}", session_id),
|
message: format!("Session not found: {}", session_id),
|
||||||
@ -405,9 +245,16 @@ async fn handle_inbound(
|
|||||||
};
|
};
|
||||||
|
|
||||||
*current_session_id = record.id.clone();
|
*current_session_id = record.id.clone();
|
||||||
let mut session_guard = session.lock().await;
|
state
|
||||||
session_guard.ensure_chat_loaded(&record.id)?;
|
.channel_manager
|
||||||
let _ = session_guard
|
.cli_channel()
|
||||||
|
.register_connection(
|
||||||
|
record.id.clone(),
|
||||||
|
runtime_session_id.to_string(),
|
||||||
|
sender.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let _ = sender
|
||||||
.send(WsOutbound::SessionLoaded {
|
.send(WsOutbound::SessionLoaded {
|
||||||
session_id: record.id,
|
session_id: record.id,
|
||||||
title: record.title,
|
title: record.title,
|
||||||
@ -422,8 +269,7 @@ async fn handle_inbound(
|
|||||||
.session_manager
|
.session_manager
|
||||||
.cli_sessions()
|
.cli_sessions()
|
||||||
.rename(&target, &title)?;
|
.rename(&target, &title)?;
|
||||||
let session_guard = session.lock().await;
|
let _ = sender
|
||||||
let _ = session_guard
|
|
||||||
.send(WsOutbound::SessionRenamed {
|
.send(WsOutbound::SessionRenamed {
|
||||||
session_id: target,
|
session_id: target,
|
||||||
title,
|
title,
|
||||||
@ -434,8 +280,7 @@ async fn handle_inbound(
|
|||||||
WsInbound::ArchiveSession { session_id } => {
|
WsInbound::ArchiveSession { session_id } => {
|
||||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
state.session_manager.cli_sessions().archive(&target)?;
|
state.session_manager.cli_sessions().archive(&target)?;
|
||||||
let session_guard = session.lock().await;
|
let _ = sender
|
||||||
let _ = session_guard
|
|
||||||
.send(WsOutbound::SessionArchived { session_id: target })
|
.send(WsOutbound::SessionArchived { session_id: target })
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -450,9 +295,11 @@ async fn handle_inbound(
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut session_guard = session.lock().await;
|
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
||||||
session_guard.remove_history(&target);
|
session.lock().await.remove_history(&target);
|
||||||
let _ = session_guard
|
}
|
||||||
|
|
||||||
|
let _ = sender
|
||||||
.send(WsOutbound::SessionDeleted {
|
.send(WsOutbound::SessionDeleted {
|
||||||
session_id: target.clone(),
|
session_id: target.clone(),
|
||||||
})
|
})
|
||||||
@ -460,8 +307,16 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
if let Some(record) = replacement {
|
if let Some(record) = replacement {
|
||||||
*current_session_id = record.id.clone();
|
*current_session_id = record.id.clone();
|
||||||
session_guard.ensure_chat_loaded(&record.id)?;
|
state
|
||||||
let _ = session_guard
|
.channel_manager
|
||||||
|
.cli_channel()
|
||||||
|
.register_connection(
|
||||||
|
record.id.clone(),
|
||||||
|
runtime_session_id.to_string(),
|
||||||
|
sender.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let _ = sender
|
||||||
.send(WsOutbound::SessionCreated {
|
.send(WsOutbound::SessionCreated {
|
||||||
session_id: record.id,
|
session_id: record.id,
|
||||||
title: record.title,
|
title: record.title,
|
||||||
@ -472,13 +327,19 @@ async fn handle_inbound(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
WsInbound::Ping => {
|
WsInbound::Ping => {
|
||||||
let session_guard = session.lock().await;
|
let _ = sender.send(WsOutbound::Pong).await;
|
||||||
let _ = session_guard.send(WsOutbound::Pong).await;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn current_timestamp() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
||||||
sender_id
|
sender_id
|
||||||
.map(str::trim)
|
.map(str::trim)
|
||||||
@ -489,138 +350,7 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::resolve_ws_sender_id;
|
||||||
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
|
|
||||||
ws_outbound_from_chat_message,
|
|
||||||
};
|
|
||||||
use crate::agent::EmittedMessageHandler;
|
|
||||||
use crate::bus::message::ToolMessageState;
|
|
||||||
use crate::bus::{ChatMessage, OutboundMessage};
|
|
||||||
use crate::protocol::WsOutbound;
|
|
||||||
use crate::providers::ToolCall;
|
|
||||||
use serde_json::json;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
|
|
||||||
let message = ChatMessage::assistant_with_tool_calls(
|
|
||||||
"",
|
|
||||||
vec![ToolCall {
|
|
||||||
id: "call-1".to_string(),
|
|
||||||
name: "calculator".to_string(),
|
|
||||||
arguments: json!({"expression": "1 + 1"}),
|
|
||||||
}],
|
|
||||||
);
|
|
||||||
|
|
||||||
let outbound = ws_outbound_from_chat_message(&message);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
|
||||||
match &outbound[0] {
|
|
||||||
WsOutbound::ToolCall {
|
|
||||||
tool_call_id,
|
|
||||||
tool_name,
|
|
||||||
arguments,
|
|
||||||
content,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
assert_eq!(tool_call_id, "call-1");
|
|
||||||
assert_eq!(tool_name, "calculator");
|
|
||||||
assert_eq!(arguments["expression"], "1 + 1");
|
|
||||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
|
||||||
}
|
|
||||||
other => panic!("unexpected outbound variant: {:?}", other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
|
|
||||||
let message = ChatMessage::assistant_with_tool_calls(
|
|
||||||
"日报已整理完成。",
|
|
||||||
vec![ToolCall {
|
|
||||||
id: "call-1".to_string(),
|
|
||||||
name: "memory_manage".to_string(),
|
|
||||||
arguments: json!({"action": "put"}),
|
|
||||||
}],
|
|
||||||
);
|
|
||||||
|
|
||||||
let outbound = ws_outbound_from_chat_message(&message);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 2);
|
|
||||||
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
|
|
||||||
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ws_outbound_from_chat_message_includes_tool_results() {
|
|
||||||
let message = ChatMessage::tool("call-1", "calculator", "2");
|
|
||||||
|
|
||||||
let outbound = ws_outbound_from_chat_message(&message);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
|
||||||
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
|
|
||||||
let message = ChatMessage::tool_with_state(
|
|
||||||
"call-1",
|
|
||||||
"bash",
|
|
||||||
"等待你完成授权后再继续。",
|
|
||||||
ToolMessageState::PendingUserAction,
|
|
||||||
);
|
|
||||||
|
|
||||||
let outbound = ws_outbound_from_chat_message(&message);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
|
||||||
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
|
||||||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
|
||||||
let pending = ChatMessage::tool_with_state(
|
|
||||||
"call-2",
|
|
||||||
"bash",
|
|
||||||
"waiting",
|
|
||||||
ToolMessageState::PendingUserAction,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(!should_display_message_to_user(false, &completed));
|
|
||||||
assert!(should_display_message_to_user(false, &pending));
|
|
||||||
assert!(should_display_message_to_user(true, &completed));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
|
|
||||||
let message = OutboundMessage::tool_call(
|
|
||||||
"cli",
|
|
||||||
"session-1",
|
|
||||||
"call-1",
|
|
||||||
"calculator",
|
|
||||||
json!({"expression": "1 + 1"}),
|
|
||||||
None,
|
|
||||||
Default::default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let outbound = super::ws_outbound_from_outbound_message(&message);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
|
||||||
match &outbound[0] {
|
|
||||||
WsOutbound::ToolCall {
|
|
||||||
tool_call_id,
|
|
||||||
tool_name,
|
|
||||||
arguments,
|
|
||||||
content,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
assert_eq!(tool_call_id, "call-1");
|
|
||||||
assert_eq!(tool_name, "calculator");
|
|
||||||
assert_eq!(arguments["expression"], "1 + 1");
|
|
||||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
|
||||||
}
|
|
||||||
other => panic!("unexpected outbound variant: {:?}", other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||||||
@ -639,23 +369,4 @@ mod tests {
|
|||||||
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
||||||
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() {
|
|
||||||
let (sender, mut receiver) = mpsc::channel(4);
|
|
||||||
let emitter = WsToolCallEmitter {
|
|
||||||
sender,
|
|
||||||
show_tool_results: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
emitter
|
|
||||||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
assert!(
|
|
||||||
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
|
||||||
.await
|
|
||||||
.is_err()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
229
src/gateway/ws_adapter.rs
Normal file
229
src/gateway/ws_adapter.rs
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
use crate::bus::OutboundMessage;
|
||||||
|
use crate::bus::message::OutboundEventKind;
|
||||||
|
#[cfg(test)]
|
||||||
|
use crate::bus::message::{ToolMessageState, format_tool_call_content};
|
||||||
|
use crate::protocol::WsOutbound;
|
||||||
|
|
||||||
|
const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。";
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||||
|
match message.role.as_str() {
|
||||||
|
"assistant" => {
|
||||||
|
if let Some(tool_calls) = &message.tool_calls {
|
||||||
|
let mut outbound = Vec::new();
|
||||||
|
if !message.content.trim().is_empty() {
|
||||||
|
outbound.push(WsOutbound::AssistantResponse {
|
||||||
|
id: message.id.clone(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
||||||
|
id: message.id.clone(),
|
||||||
|
tool_call_id: tool_call.id.clone(),
|
||||||
|
tool_name: tool_call.name.clone(),
|
||||||
|
arguments: tool_call.arguments.clone(),
|
||||||
|
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||||
|
role: message.role.clone(),
|
||||||
|
}));
|
||||||
|
outbound
|
||||||
|
} else {
|
||||||
|
vec![WsOutbound::AssistantResponse {
|
||||||
|
id: message.id.clone(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"tool" => match message
|
||||||
|
.tool_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&ToolMessageState::Completed)
|
||||||
|
{
|
||||||
|
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
||||||
|
id: message.id.clone(),
|
||||||
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
}],
|
||||||
|
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
||||||
|
id: message.id.clone(),
|
||||||
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
_ => Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
|
||||||
|
match message.event_kind {
|
||||||
|
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
|
||||||
|
vec![WsOutbound::AssistantResponse {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
|
||||||
|
id: message
|
||||||
|
.tool_call_id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||||
|
arguments: message
|
||||||
|
.tool_arguments
|
||||||
|
.clone()
|
||||||
|
.unwrap_or(serde_json::Value::Null),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
}],
|
||||||
|
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
|
||||||
|
id: message
|
||||||
|
.tool_call_id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
}],
|
||||||
|
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
|
||||||
|
id: message
|
||||||
|
.tool_call_id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
role: message.role.clone(),
|
||||||
|
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
|
||||||
|
}],
|
||||||
|
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
|
||||||
|
code: "AGENT_ERROR".to_string(),
|
||||||
|
message: message.content.clone(),
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
|
||||||
|
let message = ChatMessage::assistant_with_tool_calls(
|
||||||
|
"",
|
||||||
|
vec![ToolCall {
|
||||||
|
id: "call-1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: json!({"expression": "1 + 1"}),
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
let outbound = ws_outbound_from_chat_message(&message);
|
||||||
|
|
||||||
|
assert_eq!(outbound.len(), 1);
|
||||||
|
match &outbound[0] {
|
||||||
|
WsOutbound::ToolCall {
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tool_call_id, "call-1");
|
||||||
|
assert_eq!(tool_name, "calculator");
|
||||||
|
assert_eq!(arguments["expression"], "1 + 1");
|
||||||
|
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||||
|
}
|
||||||
|
other => panic!("unexpected outbound variant: {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
|
||||||
|
let message = ChatMessage::assistant_with_tool_calls(
|
||||||
|
"日报已整理完成。",
|
||||||
|
vec![ToolCall {
|
||||||
|
id: "call-1".to_string(),
|
||||||
|
name: "memory_manage".to_string(),
|
||||||
|
arguments: json!({"action": "put"}),
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
let outbound = ws_outbound_from_chat_message(&message);
|
||||||
|
|
||||||
|
assert_eq!(outbound.len(), 2);
|
||||||
|
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
|
||||||
|
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ws_outbound_from_chat_message_includes_tool_results() {
|
||||||
|
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
|
|
||||||
|
let outbound = ws_outbound_from_chat_message(&message);
|
||||||
|
|
||||||
|
assert_eq!(outbound.len(), 1);
|
||||||
|
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
|
||||||
|
let message = ChatMessage::tool_with_state(
|
||||||
|
"call-1",
|
||||||
|
"bash",
|
||||||
|
"等待你完成授权后再继续。",
|
||||||
|
ToolMessageState::PendingUserAction,
|
||||||
|
);
|
||||||
|
|
||||||
|
let outbound = ws_outbound_from_chat_message(&message);
|
||||||
|
|
||||||
|
assert_eq!(outbound.len(), 1);
|
||||||
|
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
|
||||||
|
let message = OutboundMessage::tool_call(
|
||||||
|
"cli",
|
||||||
|
"session-1",
|
||||||
|
"call-1",
|
||||||
|
"calculator",
|
||||||
|
json!({"expression": "1 + 1"}),
|
||||||
|
None,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let outbound = ws_outbound_from_outbound_message(&message);
|
||||||
|
|
||||||
|
assert_eq!(outbound.len(), 1);
|
||||||
|
match &outbound[0] {
|
||||||
|
WsOutbound::ToolCall {
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tool_call_id, "call-1");
|
||||||
|
assert_eq!(tool_name, "calculator");
|
||||||
|
assert_eq!(arguments["expression"], "1 + 1");
|
||||||
|
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||||
|
}
|
||||||
|
other => panic!("unexpected outbound variant: {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -11,8 +11,8 @@ use crate::config::{
|
|||||||
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
|
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
|
||||||
SchedulerMisfirePolicy, SchedulerSchedule,
|
SchedulerMisfirePolicy, SchedulerSchedule,
|
||||||
};
|
};
|
||||||
|
use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
||||||
use crate::gateway::session::ScheduledAgentTaskOptions;
|
use crate::gateway::session::ScheduledAgentTaskOptions;
|
||||||
use crate::gateway::session::SessionManager;
|
|
||||||
use crate::storage::{
|
use crate::storage::{
|
||||||
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
|
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
|
||||||
};
|
};
|
||||||
@ -22,7 +22,8 @@ pub struct Scheduler {
|
|||||||
config: SchedulerConfig,
|
config: SchedulerConfig,
|
||||||
timezone: Tz,
|
timezone: Tz,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
session_manager: SessionManager,
|
agent_task_executor: AgentTaskExecutor,
|
||||||
|
maintenance_service: SchedulerMaintenanceService,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Scheduler {
|
impl Scheduler {
|
||||||
@ -31,14 +32,16 @@ impl Scheduler {
|
|||||||
config: SchedulerConfig,
|
config: SchedulerConfig,
|
||||||
timezone: Tz,
|
timezone: Tz,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
session_manager: SessionManager,
|
agent_task_executor: AgentTaskExecutor,
|
||||||
|
maintenance_service: SchedulerMaintenanceService,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
bus,
|
bus,
|
||||||
config,
|
config,
|
||||||
timezone,
|
timezone,
|
||||||
store,
|
store,
|
||||||
session_manager,
|
agent_task_executor,
|
||||||
|
maintenance_service,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,11 +171,11 @@ impl Scheduler {
|
|||||||
self.bus.publish_outbound(message).await?;
|
self.bus.publish_outbound(message).await?;
|
||||||
}
|
}
|
||||||
SchedulerJobKind::InternalEvent => {
|
SchedulerJobKind::InternalEvent => {
|
||||||
execute_internal_event(&self.session_manager, job).await?;
|
execute_internal_event(&self.maintenance_service, job).await?;
|
||||||
}
|
}
|
||||||
SchedulerJobKind::AgentTask => {
|
SchedulerJobKind::AgentTask => {
|
||||||
let outbound_messages = execute_agent_task(
|
let outbound_messages = execute_agent_task(
|
||||||
&self.session_manager,
|
&self.agent_task_executor,
|
||||||
job,
|
job,
|
||||||
required_notification_chat_id(job, "agent_task")?,
|
required_notification_chat_id(job, "agent_task")?,
|
||||||
)
|
)
|
||||||
@ -184,7 +187,7 @@ impl Scheduler {
|
|||||||
SchedulerJobKind::SilentAgentTask => {
|
SchedulerJobKind::SilentAgentTask => {
|
||||||
let execution_chat_id = resolve_execution_chat_id(job)?;
|
let execution_chat_id = resolve_execution_chat_id(job)?;
|
||||||
if let Err(error) =
|
if let Err(error) =
|
||||||
execute_agent_task(&self.session_manager, job, &execution_chat_id).await
|
execute_agent_task(&self.agent_task_executor, job, &execution_chat_id).await
|
||||||
{
|
{
|
||||||
if let Err(notify_error) =
|
if let Err(notify_error) =
|
||||||
self.notify_silent_agent_task_failure(job, &error).await
|
self.notify_silent_agent_task_failure(job, &error).await
|
||||||
@ -587,7 +590,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_internal_event(
|
async fn execute_internal_event(
|
||||||
session_manager: &SessionManager,
|
maintenance_service: &SchedulerMaintenanceService,
|
||||||
job: &RuntimeJob,
|
job: &RuntimeJob,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let event = job
|
let event = job
|
||||||
@ -598,12 +601,12 @@ async fn execute_internal_event(
|
|||||||
|
|
||||||
match event {
|
match event {
|
||||||
"session_cleanup" => {
|
"session_cleanup" => {
|
||||||
let removed = session_manager.cleanup_expired_sessions().await;
|
let removed = maintenance_service.cleanup_expired_sessions().await;
|
||||||
tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed");
|
tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
"memory_maintenance" => {
|
"memory_maintenance" => {
|
||||||
let results = session_manager
|
let results = maintenance_service
|
||||||
.run_memory_maintenance_for_all_scopes(job.last_fired_at)
|
.run_memory_maintenance_for_all_scopes(job.last_fired_at)
|
||||||
.await?;
|
.await?;
|
||||||
for result in &results {
|
for result in &results {
|
||||||
@ -627,7 +630,7 @@ async fn execute_internal_event(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_agent_task(
|
async fn execute_agent_task(
|
||||||
session_manager: &SessionManager,
|
agent_task_executor: &AgentTaskExecutor,
|
||||||
job: &RuntimeJob,
|
job: &RuntimeJob,
|
||||||
execution_chat_id: &str,
|
execution_chat_id: &str,
|
||||||
) -> anyhow::Result<Vec<OutboundMessage>> {
|
) -> anyhow::Result<Vec<OutboundMessage>> {
|
||||||
@ -643,8 +646,8 @@ async fn execute_agent_task(
|
|||||||
.ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?;
|
.ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?;
|
||||||
let options = parse_scheduled_agent_task_options(job)?;
|
let options = parse_scheduled_agent_task_options(job)?;
|
||||||
|
|
||||||
session_manager
|
agent_task_executor
|
||||||
.run_scheduled_agent_task(channel_name, execution_chat_id, prompt, options)
|
.execute(channel_name, execution_chat_id, prompt, options)
|
||||||
.await
|
.await
|
||||||
.map_err(|error| anyhow::anyhow!(error.to_string()))
|
.map_err(|error| anyhow::anyhow!(error.to_string()))
|
||||||
}
|
}
|
||||||
@ -964,6 +967,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::bus::MessageBus;
|
use crate::bus::MessageBus;
|
||||||
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
|
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
|
||||||
|
use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
||||||
use crate::gateway::session::SessionManager;
|
use crate::gateway::session::SessionManager;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
||||||
@ -1002,6 +1006,14 @@ mod tests {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_scheduler_services() -> (AgentTaskExecutor, SchedulerMaintenanceService) {
|
||||||
|
let session_manager = test_session_manager();
|
||||||
|
(
|
||||||
|
AgentTaskExecutor::new(session_manager.clone()),
|
||||||
|
SchedulerMaintenanceService::new(session_manager),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn runtime_job_skip_policy_advances_from_now() {
|
fn runtime_job_skip_policy_advances_from_now() {
|
||||||
let now = Utc
|
let now = Utc
|
||||||
@ -1129,7 +1141,7 @@ mod tests {
|
|||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let session_manager = test_session_manager();
|
let (agent_task_executor, maintenance_service) = test_scheduler_services();
|
||||||
let scheduler = Scheduler::new(
|
let scheduler = Scheduler::new(
|
||||||
MessageBus::new(8),
|
MessageBus::new(8),
|
||||||
SchedulerConfig {
|
SchedulerConfig {
|
||||||
@ -1141,7 +1153,8 @@ mod tests {
|
|||||||
},
|
},
|
||||||
chrono_tz::Asia::Shanghai,
|
chrono_tz::Asia::Shanghai,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
session_manager,
|
agent_task_executor,
|
||||||
|
maintenance_service,
|
||||||
);
|
);
|
||||||
|
|
||||||
scheduler.process_tick().await.unwrap();
|
scheduler.process_tick().await.unwrap();
|
||||||
@ -1159,13 +1172,14 @@ mod tests {
|
|||||||
fn sync_config_jobs_persists_builtin_memory_maintenance_job() {
|
fn sync_config_jobs_persists_builtin_memory_maintenance_job() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
|
||||||
let session_manager = test_session_manager();
|
let (agent_task_executor, maintenance_service) = test_scheduler_services();
|
||||||
let scheduler = Scheduler::new(
|
let scheduler = Scheduler::new(
|
||||||
MessageBus::new(8),
|
MessageBus::new(8),
|
||||||
SchedulerConfig::default(),
|
SchedulerConfig::default(),
|
||||||
chrono_tz::Asia::Shanghai,
|
chrono_tz::Asia::Shanghai,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
session_manager,
|
agent_task_executor,
|
||||||
|
maintenance_service,
|
||||||
);
|
);
|
||||||
|
|
||||||
scheduler.sync_config_jobs().unwrap();
|
scheduler.sync_config_jobs().unwrap();
|
||||||
@ -1204,6 +1218,7 @@ mod tests {
|
|||||||
async fn silent_agent_task_failure_notifies_primary_chat() {
|
async fn silent_agent_task_failure_notifies_primary_chat() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let bus = MessageBus::new(8);
|
let bus = MessageBus::new(8);
|
||||||
|
let (agent_task_executor, maintenance_service) = test_scheduler_services();
|
||||||
let scheduler = Scheduler::new(
|
let scheduler = Scheduler::new(
|
||||||
bus.clone(),
|
bus.clone(),
|
||||||
SchedulerConfig {
|
SchedulerConfig {
|
||||||
@ -1215,7 +1230,8 @@ mod tests {
|
|||||||
},
|
},
|
||||||
chrono_tz::Asia::Shanghai,
|
chrono_tz::Asia::Shanghai,
|
||||||
store,
|
store,
|
||||||
test_session_manager(),
|
agent_task_executor,
|
||||||
|
maintenance_service,
|
||||||
);
|
);
|
||||||
|
|
||||||
let job = RuntimeJob {
|
let job = RuntimeJob {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user