Compare commits

..

16 Commits

Author SHA1 Message Date
04fc2c0710 feat: 添加记忆维护的错误处理逻辑,优化传输错误的上下文信息
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 18:10:23 +08:00
891830779f feat: 重构存储逻辑,使用 ConversationRepository 和 PromptInjectionRepository 替代 SessionStore,优化会话和提示注入管理
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 15:55:27 +08:00
f48b132bb9 feat: 重构调度器存储逻辑,使用 SchedulerJobRepository 替代 SessionStore,添加更新调度作业运行时的方法 2026-04-28 15:40:50 +08:00
90e44950cb feat: 重构技能事件处理逻辑,移除 SkillEventSink,添加 SkillActivateTool 模块以优化技能激活流程
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 15:31:56 +08:00
396504dffb Refactor agent and storage components to introduce SkillProvider and repository patterns
- Introduced `SkillProvider` trait to abstract skill-related functionalities.
- Replaced `SkillRuntime` with `EmptySkillProvider` in `AgentLoop` for default behavior.
- Updated `AgentFactory` to accept `SkillProvider` instead of `SkillRuntime`.
- Created `SessionHistory` struct to manage chat histories and interactions.
- Added `MemoryRepository`, `SchedulerJobRepository`, and `SkillEventRepository` traits for better storage abstraction.
- Refactored tools to use new repository traits for memory and scheduler management.
- Cleaned up session management logic by consolidating chat history handling into `SessionHistory`.

Co-authored-by: Copilot <copilot@github.com>
2026-04-28 15:12:45 +08:00
6756a3d0ae feat: 添加 OutboundDispatcher 模块,重构消息分发逻辑,优化渠道消息处理
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 14:52:33 +08:00
c547b88a12 feat: 添加持久化技能事件处理逻辑,重构技能事件记录机制
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 14:49:31 +08:00
e5e2b37246 feat: 重构会话管理逻辑,添加多个服务以优化会话和任务调度
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 14:43:46 +08:00
acc8f63da0 feat: 添加 SessionLifecycleService 模块,重构会话管理逻辑以优化会话生命周期处理
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 14:32:14 +08:00
8f27bd2735 feat: 重构工具和协议模块,添加工具注册和会话管理逻辑,优化消息处理 2026-04-28 14:16:30 +08:00
af7860f2fd feat: 重构消息模块,添加 ContentBlock 和 ToolCall 结构,优化消息处理逻辑
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 14:04:59 +08:00
c65921b5e8 feat: 添加 AgentFactory 和 PromptInjector,重构工具注册逻辑以优化会话管理 2026-04-28 13:06:00 +08:00
008aba91ac feat: 重构调度器以使用 AgentTaskExecutor 和 SchedulerMaintenanceService
- 更新调度器,将 SessionManager 替换为 AgentTaskExecutor 和 SchedulerMaintenanceService。
- 修改作业执行逻辑,使用新服务处理代理任务和内部事件。
- 添加新的 CliChannel 以处理 CLI 连接,并包括适当的注册和注销逻辑。
- 引入 AgentTaskExecutor 和 SchedulerMaintenanceService,用于管理代理任务和会话维护。
- 实现聊天命令处理,用于重置会话上下文。
- 添加后台历史压缩功能,以优化会话存储。
- 创建实用函数,用于准备通过 WebSocket 通信的出站消息。
- 为新功能添加测试,并确保现有测试通过。

Co-authored-by: Copilot <copilot@github.com>
2026-04-28 12:55:30 +08:00
62b38eac73 feat: 添加 SessionFactory 模块,重构 SessionPool 以优化会话创建逻辑
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 12:08:34 +08:00
65bcf34b75 feat: 添加 CLI 会话服务和会话池,重构 SessionManager 以优化会话管理逻辑
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 11:55:55 +08:00
14476bb101 feat: 添加 context_window_tokens 配置,调整模型温度并重构消息执行逻辑 2026-04-28 11:45:36 +08:00
55 changed files with 3438 additions and 1842 deletions

View File

@ -18,7 +18,7 @@ PicoBot 是一个用 Rust 构建的多通道 Agent 网关。它把消息接入
PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个可持续运行的 Agent 基础设施: PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个可持续运行的 Agent 基础设施:
- 消息从不同渠道进入统一总线 - 消息从不同渠道进入统一总线
- SessionManager 负责会话路由、上下文恢复、工具执行和回复生成 - SessionManager 负责会话路由和运行时服务编排AgentExecutionService 负责上下文准备、AgentLoop 执行、结果持久化和回复生成
- SQLite 作为事实来源保存跨重启状态 - SQLite 作为事实来源保存跨重启状态
- Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务 - Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务
@ -30,20 +30,20 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个
1. Channel 接收外部消息 1. Channel 接收外部消息
2. MessageBus 将消息送入统一的 inbound 队列 2. MessageBus 将消息送入统一的 inbound 队列
3. Gateway 启动的 inbound processor 调用 SessionManager 处理消息 3. Gateway 启动的 InboundProcessor 调用 SessionManager 定位目标 Session
4. SessionManager 加载持久化历史、注入系统提示、运行 AgentLoop、执行工具调用 4. AgentExecutionService 准备上下文、运行 AgentLoop、执行工具调用并收集结果
5. 生成的 assistant / tool / system 消息写入 SQLite 5. 生成的 user / assistant / tool / system 消息按真实顺序写入 SQLite
6. OutboundDispatcher 将结果投递到目标通道 6. OutboundDispatcher 将结果投递到目标通道
主要模块如下: 主要模块如下:
- src/gateway网关入口、HTTP 健康检查、WebSocket 服务、Session 管理 - src/gateway网关入口、HTTP 健康检查、WebSocket 控制面、Session 池、CLI 会话服务、OutboundDispatcher 与 Agent 执行编排
- src/bus消息总线与消息结构定义 - src/bus消息总线队列与消息结构定义,不包含渠道投递逻辑
- src/agentAgentLoop 与上下文压缩器 - src/agentAgentLoop 与上下文压缩器
- src/providers不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic - src/providers不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic
- src/tools内置工具集合 - src/tools内置工具集合
- src/storageSQLite 持久化实现 - src/storageSQLite 持久化实现
- src/channels渠道适配层当前已有飞书通道 - src/channels渠道适配层当前已有 CLI 与飞书通道
- src/scheduler数据库驱动的计划任务调度器 - src/scheduler数据库驱动的计划任务调度器
- src/skills技能发现、加载与运行时管理 - src/skills技能发现、加载与运行时管理
- src/client / src/cli本地 CLI 客户端和交互命令 - src/client / src/cli本地 CLI 客户端和交互命令
@ -549,7 +549,8 @@ CLI 中已实现的交互命令包括:
"models": { "models": {
"default": { "default": {
"model_id": "<OPENAI_MODEL_NAME>", "model_id": "<OPENAI_MODEL_NAME>",
"temperature": 0.2 "temperature": 0.2,
"context_window_tokens": 128000
} }
}, },
"agents": { "agents": {
@ -631,11 +632,11 @@ PicoBot/
├── src/ ├── src/
│ ├── agent/ # AgentLoop、上下文压缩 │ ├── agent/ # AgentLoop、上下文压缩
│ ├── bus/ # 消息总线与消息结构 │ ├── bus/ # 消息总线与消息结构
│ ├── channels/ # 渠道适配 │ ├── channels/ # CLI / 飞书等渠道适配
│ ├── cli/ # CLI 输入命令 │ ├── cli/ # CLI 输入命令
│ ├── client/ # WebSocket CLI 客户端 │ ├── client/ # WebSocket CLI 客户端
│ ├── config/ # 配置解析 │ ├── config/ # 配置解析
│ ├── gateway/ # Gateway、SessionManager、WS/HTTP │ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面
│ ├── providers/ # OpenAI / Anthropic Provider │ ├── providers/ # OpenAI / Anthropic Provider
│ ├── scheduler/ # 定时任务系统 │ ├── scheduler/ # 定时任务系统
│ ├── skills/ # 技能运行时 │ ├── skills/ # 技能运行时
@ -655,7 +656,7 @@ PicoBot/
建议维护时重点关注: 建议维护时重点关注:
- docs/PERSISTENCE.md持久化结构是否与代码一致 - docs/PERSISTENCE.md持久化结构是否与代码一致
- src/gateway/session.rs消息流、工具注册、记忆维护、会话恢复主逻辑 - src/gateway/session.rs会话状态、会话路由和运行时服务编排
- src/storage/mod.rsSQLite schema 变更 - src/storage/mod.rsSQLite schema 变更
- src/config/mod.rs配置项变更是否同步到 README - src/config/mod.rs配置项变更是否同步到 README

View File

@ -10,7 +10,8 @@
"models": { "models": {
"default": { "default": {
"model_id": "<OPENAI_MODEL_NAME>", "model_id": "<OPENAI_MODEL_NAME>",
"temperature": 0.2 "temperature": 0.7,
"context_window_tokens": 128000
} }
}, },
"agents": { "agents": {

View File

@ -1,13 +1,11 @@
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::bus::message::ContentBlock;
use crate::bus::message::ToolMessageState; use crate::bus::message::ToolMessageState;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::observability::{ use crate::observability::{
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
}; };
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider}; use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry}; use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait; use async_trait::async_trait;
@ -297,9 +295,7 @@ pub struct AgentLoop {
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
provider: Box<dyn LLMProvider>, provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<dyn SkillProvider>,
skill_event_store: Option<Arc<SessionStore>>,
skill_event_session_id: Option<String>,
tool_context: ToolContext, tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>, observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>, emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
@ -317,6 +313,19 @@ pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage); async fn handle(&self, message: ChatMessage);
} }
pub trait SkillProvider: Send + Sync + 'static {
fn system_index_prompt(&self) -> Option<String>;
}
#[derive(Default)]
struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
}
impl AgentLoop { impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> { pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
@ -327,9 +336,7 @@ impl AgentLoop {
provider_config, provider_config,
provider, provider,
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
skills: Arc::new(SkillRuntime::default()), skills: Arc::new(EmptySkillProvider),
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(), tool_context: ToolContext::default(),
observer: None, observer: None,
emitted_message_handler: None, emitted_message_handler: None,
@ -349,9 +356,7 @@ impl AgentLoop {
provider_config, provider_config,
provider, provider,
tools, tools,
skills: Arc::new(SkillRuntime::default()), skills: Arc::new(EmptySkillProvider),
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(), tool_context: ToolContext::default(),
observer: None, observer: None,
emitted_message_handler: None, emitted_message_handler: None,
@ -359,10 +364,10 @@ impl AgentLoop {
}) })
} }
pub fn with_tools_and_skills( pub fn with_tools_and_skill_provider(
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<dyn SkillProvider>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config.clone()) let provider = create_provider(provider_config.clone())
@ -373,8 +378,6 @@ impl AgentLoop {
provider, provider,
tools, tools,
skills, skills,
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(), tool_context: ToolContext::default(),
observer: None, observer: None,
emitted_message_handler: None, emitted_message_handler: None,
@ -382,12 +385,6 @@ impl AgentLoop {
}) })
} }
pub fn with_skill_event_store(mut self, store: Arc<SessionStore>, session_id: String) -> Self {
self.skill_event_store = Some(store);
self.skill_event_session_id = Some(session_id);
self
}
pub fn with_tool_context(mut self, context: ToolContext) -> Self { pub fn with_tool_context(mut self, context: ToolContext) -> Self {
self.tool_context = context; self.tool_context = context;
self self
@ -443,10 +440,7 @@ impl AgentLoop {
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message)); messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
// Build request // Build request
let mut tool_defs = self.tools.get_definitions(); let tool_defs = self.tools.get_definitions();
if let Some(skill_tool) = self.skills.skill_tool_definition() {
tool_defs.push(skill_tool);
}
let tools = if tool_defs.is_empty() { let tools = if tool_defs.is_empty() {
None None
} else { } else {
@ -782,46 +776,6 @@ impl AgentLoop {
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome { async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let normalized_arguments = normalize_tool_arguments(&tool_call.arguments); let normalized_arguments = normalize_tool_arguments(&tool_call.arguments);
if tool_call.name == "skill_activate" {
let skill_name = match normalized_arguments.get("name").and_then(|v| v.as_str()) {
Some(name) if !name.trim().is_empty() => name,
_ => {
self.record_skill_event(
"activation_failed",
None,
serde_json::json!({
"reason": "missing_name",
"arguments": normalized_arguments,
}),
);
return ToolExecutionOutcome::failure(
"Error: Missing required parameter: name".to_string(),
Some("Missing required parameter: name".to_string()),
);
}
};
return match self.skills.activation_payload(skill_name) {
Ok(output) => {
if let Ok(payload) = self.skills.activation_event_payload(skill_name) {
self.record_skill_event("activated", Some(skill_name), payload);
}
ToolExecutionOutcome::success(output)
}
Err(err) => {
self.record_skill_event(
"activation_failed",
Some(skill_name),
serde_json::json!({
"reason": err,
"arguments": normalized_arguments,
}),
);
ToolExecutionOutcome::failure(format!("Error: {}", err), Some(err))
}
};
}
let tool = match self.tools.get(&tool_call.name) { let tool = match self.tools.get(&tool_call.name) {
Some(t) => t, Some(t) => t,
None => { None => {
@ -870,26 +824,6 @@ impl AgentLoop {
} }
} }
} }
fn record_skill_event(
&self,
event_type: &str,
skill_name: Option<&str>,
payload: serde_json::Value,
) {
let (Some(store), Some(session_id)) = (
self.skill_event_store.as_ref(),
self.skill_event_session_id.as_ref(),
) else {
return;
};
if let Err(err) =
store.append_skill_event(Some(session_id), event_type, skill_name, &payload)
{
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
}
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,5 +1,7 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor; pub mod context_compressor;
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler}; pub use agent_loop::{
AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler, SkillProvider,
};
pub use context_compressor::ContextCompressor; pub use context_compressor::ContextCompressor;

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use crate::providers::ToolCall; use crate::domain::messages::ToolCall;
pub const SYSTEM_CONTEXT_AGENT_PROMPT: &str = "agent_prompt"; pub const SYSTEM_CONTEXT_AGENT_PROMPT: &str = "agent_prompt";
pub const SYSTEM_CONTEXT_SCHEDULED_PROMPT: &str = "scheduled_system_prompt"; pub const SYSTEM_CONTEXT_SCHEDULED_PROMPT: &str = "scheduled_system_prompt";
@ -14,38 +14,6 @@ pub enum ToolMessageState {
PendingUserAction, PendingUserAction,
} }
// ============================================================================
// ContentBlock - Multimodal content representation (OpenAI-style)
// ============================================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlBlock },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlBlock {
pub url: String,
}
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text {
text: content.into(),
}
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrlBlock { url: url.into() },
}
}
}
// ============================================================================ // ============================================================================
// MediaItem - Media metadata for messages // MediaItem - Media metadata for messages
// ============================================================================ // ============================================================================
@ -566,7 +534,7 @@ fn current_timestamp() -> i64 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState}; use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState};
use crate::providers::ToolCall; use crate::domain::messages::ToolCall;
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;

View File

@ -1,18 +1,16 @@
pub mod dispatcher;
pub mod message; pub mod message;
pub use dispatcher::OutboundDispatcher; pub use crate::domain::messages::ContentBlock;
pub use message::{ pub use message::{
ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage, ChatMessage, InboundMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION, SYSTEM_CONTEXT_HISTORY_COMPACTION, SYSTEM_CONTEXT_SCHEDULED_PROMPT,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
}; };
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Mutex, mpsc}; use tokio::sync::{Mutex, mpsc};
// ============================================================================ // ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication // MessageBus - async inbound/outbound queues
// ============================================================================ // ============================================================================
pub struct MessageBus { pub struct MessageBus {
@ -35,7 +33,7 @@ impl MessageBus {
}) })
} }
/// Publish an inbound message (Channel -> Bus) /// Publish a message to the inbound queue
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> { pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, content_len = %msg.content.len(), media_count = %msg.media.len(), "Bus: publishing inbound message"); tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, content_len = %msg.content.len(), media_count = %msg.media.len(), "Bus: publishing inbound message");
@ -45,7 +43,7 @@ impl MessageBus {
.map_err(|_| BusError::Closed) .map_err(|_| BusError::Closed)
} }
/// Consume an inbound message (Agent -> Bus) /// Consume a message from the inbound queue
pub async fn consume_inbound(&self) -> InboundMessage { pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self let msg = self
.inbound_rx .inbound_rx
@ -59,7 +57,7 @@ impl MessageBus {
msg msg
} }
/// Publish an outbound message (Agent -> Bus) /// Publish a message to the outbound queue
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> { pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, chat_id = %msg.chat_id, content_len = %msg.content.len(), "Bus: publishing outbound message"); tracing::debug!(channel = %msg.channel, chat_id = %msg.chat_id, content_len = %msg.content.len(), "Bus: publishing outbound message");
@ -69,7 +67,7 @@ impl MessageBus {
.map_err(|_| BusError::Closed) .map_err(|_| BusError::Closed)
} }
/// Consume an outbound message (Dispatcher -> Bus) /// Consume an outbound message from the outbound queue
pub async fn consume_outbound(&self) -> OutboundMessage { pub async fn consume_outbound(&self) -> OutboundMessage {
self.outbound_rx self.outbound_rx
.lock() .lock()

View File

@ -43,7 +43,7 @@ pub trait Channel: Send + Sync + 'static {
/// Stop the channel /// Stop the channel
async fn stop(&self) -> Result<(), ChannelError>; async fn stop(&self) -> Result<(), ChannelError>;
/// Send a message to the channel (called by OutboundDispatcher) /// Send a message to the channel (called by gateway outbound dispatch)
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError>; async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError>;
/// Send a streaming delta (optional, for channels that support it) /// Send a streaming delta (optional, for channels that support it)

155
src/channels/cli.rs Normal file
View File

@ -0,0 +1,155 @@
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use crate::bus::{MessageBus, OutboundMessage};
use crate::protocol::WsOutbound;
use crate::protocol::ws_adapter::ws_outbound_from_outbound_message;
use super::base::{Channel, ChannelError};
#[derive(Clone)]
struct CliConnection {
connection_id: String,
sender: mpsc::Sender<WsOutbound>,
}
#[derive(Clone)]
pub struct CliChannel {
connections: Arc<RwLock<HashMap<String, CliConnection>>>,
}
impl CliChannel {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_connection(
&self,
session_id: impl Into<String>,
connection_id: impl Into<String>,
sender: mpsc::Sender<WsOutbound>,
) {
let session_id = session_id.into();
let connection_id = connection_id.into();
let previous = self.connections.write().await.insert(
session_id.clone(),
CliConnection {
connection_id: connection_id.clone(),
sender,
},
);
if previous.is_some() {
tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced");
}
}
pub async fn unregister_connection(&self, connection_id: &str) {
self.connections
.write()
.await
.retain(|_, connection| connection.connection_id != connection_id);
}
}
impl Default for CliChannel {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Channel for CliChannel {
fn name(&self) -> &str {
"cli"
}
fn is_running(&self) -> bool {
true
}
async fn start(&self, _bus: Arc<MessageBus>) -> Result<(), ChannelError> {
Ok(())
}
async fn stop(&self) -> Result<(), ChannelError> {
self.connections.write().await.clear();
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let connection = self.connections.read().await.get(&msg.chat_id).cloned();
let Some(connection) = connection else {
return Err(ChannelError::SendError(format!(
"No active CLI connection for session {}",
msg.chat_id
)));
};
for outbound in ws_outbound_from_outbound_message(&msg) {
connection
.sender
.send(outbound)
.await
.map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::OutboundMessage;
#[tokio::test]
async fn test_cli_channel_sends_to_registered_session() {
let channel = CliChannel::new();
let (sender, mut receiver) = mpsc::channel(4);
channel
.register_connection("session-1", "conn-1", sender)
.await;
channel
.send(OutboundMessage::assistant(
"cli",
"session-1",
"hello",
None,
HashMap::new(),
))
.await
.unwrap();
let outbound = receiver.recv().await.unwrap();
assert!(matches!(outbound, WsOutbound::AssistantResponse { .. }));
}
#[tokio::test]
async fn test_cli_channel_unregisters_connection_sessions() {
let channel = CliChannel::new();
let (sender, _receiver) = mpsc::channel(4);
channel
.register_connection("session-1", "conn-1", sender)
.await;
channel.unregister_connection("conn-1").await;
let error = channel
.send(OutboundMessage::assistant(
"cli",
"session-1",
"hello",
None,
HashMap::new(),
))
.await
.unwrap_err();
assert!(error.to_string().contains("No active CLI connection"));
}
}

View File

@ -4,6 +4,7 @@ use tokio::sync::RwLock;
use crate::bus::MessageBus; use crate::bus::MessageBus;
use crate::channels::base::{Channel, ChannelError}; use crate::channels::base::{Channel, ChannelError};
use crate::channels::cli::CliChannel;
use crate::channels::feishu::FeishuChannel; use crate::channels::feishu::FeishuChannel;
use crate::config::Config; use crate::config::Config;
@ -12,13 +13,19 @@ use crate::config::Config;
pub struct ChannelManager { pub struct ChannelManager {
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>, channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
cli_channel: Arc<CliChannel>,
} }
impl ChannelManager { impl ChannelManager {
pub fn new() -> Self { pub fn new() -> Self {
let cli_channel = Arc::new(CliChannel::new());
let mut channels: HashMap<String, Arc<dyn Channel + Send + Sync>> = HashMap::new();
channels.insert("cli".to_string(), cli_channel.clone());
Self { Self {
channels: Arc::new(RwLock::new(HashMap::new())), channels: Arc::new(RwLock::new(channels)),
bus: MessageBus::new(100), bus: MessageBus::new(100),
cli_channel,
} }
} }
@ -27,6 +34,10 @@ impl ChannelManager {
self.bus.clone() self.bus.clone()
} }
pub fn cli_channel(&self) -> Arc<CliChannel> {
self.cli_channel.clone()
}
/// Initialize all Channel instances from config /// Initialize all Channel instances from config
pub async fn init( pub async fn init(
&self, &self,

View File

@ -1,7 +1,9 @@
pub mod base; pub mod base;
pub mod cli;
pub mod feishu; pub mod feishu;
pub mod manager; pub mod manager;
pub use base::{Channel, ChannelError}; pub use base::{Channel, ChannelError};
pub use cli::CliChannel;
pub use feishu::FeishuChannel; pub use feishu::FeishuChannel;
pub use manager::ChannelManager; pub use manager::ChannelManager;

36
src/domain/messages.rs Normal file
View File

@ -0,0 +1,36 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlBlock },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlBlock {
pub url: String,
}
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text {
text: content.into(),
}
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrlBlock { url: url.into() },
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}

2
src/domain/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod messages;
pub mod tools;

15
src/domain/tools.rs Normal file
View File

@ -0,0 +1,15 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}

View File

@ -0,0 +1,45 @@
use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop, SkillProvider};
use crate::config::LLMProviderConfig;
use crate::storage::persistent_session_id;
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Clone)]
pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
}
pub(crate) struct AgentBuildRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) sender_id: Option<&'a str>,
pub(crate) message_id: Option<&'a str>,
pub(crate) provider_config: LLMProviderConfig,
}
impl AgentFactory {
pub(crate) fn new(tools: Arc<ToolRegistry>, skills: Arc<dyn SkillProvider>) -> Self {
Self { tools, skills }
}
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
let session_id = persistent_session_id(request.channel_name, request.chat_id);
AgentLoop::with_tools_and_skill_provider(
request.provider_config,
self.tools.clone(),
self.skills.clone(),
)
.map(|agent| {
agent.with_tool_context(ToolContext {
channel_name: Some(request.channel_name.to_string()),
sender_id: request.sender_id.map(str::to_string),
chat_id: Some(request.chat_id.to_string()),
session_id: Some(session_id),
message_id: request.message_id.map(str::to_string),
message_seq: None,
})
})
}
}

View File

@ -0,0 +1,102 @@
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use crate::scheduler::{
AgentTaskExecutor as SchedulerAgentTaskExecutor, MaintenanceExecutor, MaintenanceRunSummary,
ScheduledAgentTaskOptions,
};
use async_trait::async_trait;
use super::memory_maintenance::MemoryMaintenanceScopeResult;
use super::session::SessionManager;
#[derive(Clone)]
pub struct AgentTaskExecutor {
session_manager: SessionManager,
}
impl AgentTaskExecutor {
pub fn new(session_manager: SessionManager) -> Self {
Self { session_manager }
}
async fn execute_agent_task(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
self.session_manager
.run_scheduled_agent_task(channel_name, chat_id, prompt, options)
.await
}
}
#[async_trait]
impl SchedulerAgentTaskExecutor for AgentTaskExecutor {
async fn execute(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> anyhow::Result<Vec<OutboundMessage>> {
self.execute_agent_task(channel_name, chat_id, prompt, options)
.await
.map_err(|error| anyhow::anyhow!(error.to_string()))
}
}
#[derive(Clone)]
pub struct SchedulerMaintenanceService {
session_manager: SessionManager,
}
impl SchedulerMaintenanceService {
pub fn new(session_manager: SessionManager) -> Self {
Self { session_manager }
}
async fn cleanup_sessions(&self) -> usize {
self.session_manager.cleanup_expired_sessions().await
}
async fn run_memory_maintenance(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
self.session_manager
.run_memory_maintenance_for_all_scopes(updated_since)
.await
}
}
#[async_trait]
impl MaintenanceExecutor for SchedulerMaintenanceService {
async fn cleanup_expired_sessions(&self) -> usize {
self.cleanup_sessions().await
}
async fn run_memory_maintenance_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> anyhow::Result<Vec<MaintenanceRunSummary>> {
self.run_memory_maintenance(updated_since)
.await
.map(|results| {
results
.into_iter()
.map(|result| MaintenanceRunSummary {
scope_key: result.scope_key,
user_facts: result.output.user_facts.len(),
preferences: result.output.preferences.len(),
behavior_patterns: result.output.behavior_patterns.len(),
merges: result.output.merges.len(),
conflicts: result.output.conflicts.len(),
low_value: result.output.low_value_ids.len(),
})
.collect()
})
.map_err(|error| anyhow::anyhow!(error.to_string()))
}
}

View File

@ -0,0 +1,57 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::{SessionRecord, SessionStore};
#[derive(Clone)]
pub(crate) struct CliSessionService {
store: Arc<SessionStore>,
}
impl CliSessionService {
pub(crate) fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
pub(crate) fn create(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub(crate) fn list(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub(crate) fn rename(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub(crate) fn archive(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub(crate) fn delete(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub(crate) fn clear_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
}

159
src/gateway/command.rs Normal file
View File

@ -0,0 +1,159 @@
use crate::agent::AgentError;
use super::session::Session;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InChatCommand {
FreshConversation,
}
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
match content.trim() {
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
_ => None,
}
}
pub(crate) fn handle_in_chat_command(
session: &mut Session,
chat_id: &str,
content: &str,
) -> Result<Option<String>, AgentError> {
match parse_in_chat_command(content) {
Some(InChatCommand::FreshConversation) => {
session.reset_chat_context(chat_id)?;
Ok(Some("Started a fresh conversation.".to_string()))
}
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_parse_in_chat_command_aliases() {
assert_eq!(
parse_in_chat_command("/new"),
Some(InChatCommand::FreshConversation)
);
assert_eq!(
parse_in_chat_command(" /reset \n"),
Some(InChatCommand::FreshConversation)
);
assert_eq!(parse_in_chat_command("/new planning"), None);
assert_eq!(parse_in_chat_command("please /reset"), None);
}
#[tokio::test]
async fn test_handle_in_chat_command_resets_active_history_only() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
.unwrap()
.unwrap();
assert_eq!(response, "Started a fresh conversation.");
assert!(session.get_history("chat-1").unwrap().is_empty());
assert!(
store
.load_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.is_empty()
);
assert_eq!(
store
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
2,
);
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
#[tokio::test]
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store,
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
}

105
src/gateway/compaction.rs Normal file
View File

@ -0,0 +1,105 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::AgentError;
use super::session::Session;
pub(crate) async fn schedule_background_history_compaction(
session: Arc<Mutex<Session>>,
chat_id: impl Into<String>,
) -> Result<(), AgentError> {
let chat_id = chat_id.into();
let snapshot = {
let mut session_guard = session.lock().await;
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
let compressor = session_guard.compressor().clone();
if !compressor.should_compress(&history) {
return Ok(());
}
if !session_guard.try_start_background_compaction(&chat_id) {
return Ok(());
}
(
session_guard.store(),
session_guard.persistent_session_id(&chat_id),
session_record.reset_cutoff_seq,
session_record.message_count,
history,
compressor,
session_guard.provider_config().clone(),
)
};
let (
store,
session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
history,
compressor,
provider_config,
) = snapshot;
let session_for_task = session.clone();
let chat_id_for_task = chat_id.clone();
tokio::spawn(async move {
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
let compaction_result = compressor
.build_compaction_plan(&history, &provider_config)
.await;
let mut committed = false;
match compaction_result {
Ok(Some(plan)) => match store.compact_active_history(
&session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
&plan.preserved_system_messages,
&plan.summary_message,
&plan.preserved_messages,
) {
Ok(true) => {
committed = true;
tracing::info!(
chat_id = %chat_id_for_task,
snapshot_end_seq,
compressed_turns = plan.compressed_turns,
preserved_turns = plan.preserved_turns,
"Background history compaction committed"
);
}
Ok(false) => {
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
}
Err(error) => {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
}
},
Ok(None) => {
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
}
Err(error) => {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
}
}
let mut session_guard = session_for_task.lock().await;
if committed {
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
}
}
session_guard.finish_background_compaction(&chat_id_for_task);
});
Ok(())
}

View File

@ -1,13 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::agent::{AgentError, AgentProcessResult}; use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler};
use crate::bus::message::ToolMessageState; use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, OutboundMessage}; use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDULED_PROMPT};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use super::session::{Session, schedule_background_history_compaction}; use super::command::handle_in_chat_command;
use super::compaction::schedule_background_history_compaction;
use super::message_prepare::enrich_user_content_with_media_refs;
use super::session::Session;
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。"; const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
@ -24,19 +27,6 @@ pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>)
} }
} }
pub(crate) fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) struct AgentExecutionService { pub(crate) struct AgentExecutionService {
show_tool_results: bool, show_tool_results: bool,
} }
@ -56,6 +46,28 @@ pub(crate) struct FinalizedAgentResult {
pub(crate) should_schedule_compaction: bool, pub(crate) should_schedule_compaction: bool,
} }
pub(crate) struct MessageExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) content: &'a str,
pub(crate) media: Vec<MediaItem>,
pub(crate) live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
}
pub(crate) struct ScheduledExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) prompt: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) provider_config: LLMProviderConfig,
pub(crate) fresh_session: bool,
pub(crate) system_prompt: Option<&'a str>,
pub(crate) metadata: &'a HashMap<String, String>,
}
impl AgentExecutionService { impl AgentExecutionService {
pub(crate) fn new(show_tool_results: bool) -> Self { pub(crate) fn new(show_tool_results: bool) -> Self {
Self { show_tool_results } Self { show_tool_results }
@ -115,6 +127,136 @@ impl AgentExecutionService {
}) })
} }
pub(crate) async fn prepare_and_execute_message(
&self,
request: MessageExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
session_guard.ensure_chat_loaded(request.chat_id)?;
if let Some(command_response) =
handle_in_chat_command(&mut session_guard, request.chat_id, request.content)?
{
return Ok(vec![OutboundMessage::assistant(
request.channel_name.to_string(),
request.chat_id.to_string(),
command_response,
None,
HashMap::new(),
)]);
}
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let media_refs: Vec<String> = request
.media
.iter()
.map(|media| media.path.clone())
.collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %request.media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let enriched_content =
enrich_user_content_with_media_refs(request.content, &media_refs)?;
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let mut agent = session_guard.create_agent(
request.chat_id,
Some(request.sender_id),
Some(&user_message.id),
)?;
if let Some(handler) = request.live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
(history, agent, user_message)
};
let result = agent.process(history).await?;
let metadata = HashMap::new();
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: &metadata,
suppress_live_tool_calls: request.live_emitter.is_some(),
execution_kind: "message",
},
)
.await
}
pub(crate) async fn prepare_and_execute_scheduled_task(
&self,
request: ScheduledExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
if request.fresh_session {
session_guard.reset_chat_context(request.chat_id)?;
}
session_guard.ensure_chat_loaded(request.chat_id)?;
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let scheduled_system_prompt =
compose_scheduled_task_system_prompt(request.system_prompt);
session_guard.append_persisted_message(
request.chat_id,
ChatMessage::system_with_context(
&scheduled_system_prompt,
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
),
)?;
let user_message = session_guard.create_user_message(request.prompt, Vec::new());
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let agent = session_guard.create_agent_with_provider_config(
request.chat_id,
Some(request.sender_id),
Some(&user_message.id),
request.provider_config.clone(),
)?;
(history, agent, user_message)
};
let result = agent.process(history).await?;
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: request.metadata,
suppress_live_tool_calls: false,
execution_kind: "scheduled_task",
},
)
.await
}
pub(crate) async fn finalize_result_and_schedule_compaction( pub(crate) async fn finalize_result_and_schedule_compaction(
&self, &self,
session: Arc<Mutex<Session>>, session: Arc<Mutex<Session>>,
@ -170,50 +312,6 @@ mod tests {
use super::*; use super::*;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected =
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected =
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
#[test] #[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() { fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 ")); let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));

View File

@ -238,8 +238,18 @@ impl MemoryMaintenanceService {
let mut results = Vec::new(); let mut results = Vec::new();
for scope_key in scope_keys { for scope_key in scope_keys {
let Some(output) = self.run_for_scope(&scope_key).await? else { let output = match self.run_for_scope(&scope_key).await {
Ok(Some(output)) => output,
Ok(None) => continue,
Err(error) if is_recoverable_maintenance_scope_error(&error) => {
tracing::warn!(
scope_key = %scope_key,
error = %error,
"Memory maintenance skipped scope after recoverable model failure"
);
continue; continue;
}
Err(error) => return Err(error),
}; };
results.push(MemoryMaintenanceScopeResult { scope_key, output }); results.push(MemoryMaintenanceScopeResult { scope_key, output });
@ -319,6 +329,10 @@ pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|| normalized.contains("timeout") || normalized.contains("timeout")
} }
fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool {
is_recoverable_maintenance_llm_error(&error.to_string())
}
pub(crate) fn strip_json_code_fence(content: &str) -> &str { pub(crate) fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim(); let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") { if let Some(rest) = trimmed.strip_prefix("```json") {

View File

@ -0,0 +1,46 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::SessionStore;
use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
};
use super::provider_config_service::ProviderConfigService;
#[derive(Clone)]
pub(crate) struct MemoryMaintenanceCoordinator {
store: Arc<SessionStore>,
provider_configs: ProviderConfigService,
}
impl MemoryMaintenanceCoordinator {
pub(crate) fn new(store: Arc<SessionStore>, provider_configs: ProviderConfigService) -> Self {
Self {
store,
provider_configs,
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn summarize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
self.service()?.summarize_for_scope(scope_key).await
}
pub(crate) async fn run_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
self.service()?.run_for_all_scopes(updated_since).await
}
fn service(&self) -> Result<MemoryMaintenanceService, AgentError> {
Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_configs.default_provider_config(),
))
}
}

View File

@ -0,0 +1,39 @@
use crate::agent::AgentError;
pub(crate) fn enrich_user_content_with_media_refs(
content: &str,
media_refs: &[String],
) -> Result<String, AgentError> {
if media_refs.is_empty() {
return Ok(content.to_string());
}
let media_refs_json = serde_json::to_string(media_refs)
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
assert_eq!(
enriched,
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
);
}
#[test]
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
assert_eq!(enriched, "hello");
}
}

View File

@ -1,9 +1,26 @@
pub mod agent_factory;
pub mod agent_task_executor;
pub mod cli_session;
pub mod command;
pub mod compaction;
pub mod execution; pub mod execution;
pub mod http; pub mod http;
pub mod memory_maintenance; pub mod memory_maintenance;
pub mod memory_maintenance_coordinator;
pub mod message_prepare;
pub mod outbound_dispatcher;
pub mod processor; pub mod processor;
pub mod prompt; pub mod prompt;
pub mod prompt_injector;
pub mod provider_config_service;
pub mod scheduled_agent_task_service;
pub mod session; pub mod session;
pub mod session_factory;
pub mod session_history;
pub mod session_lifecycle;
pub mod session_message_service;
pub mod session_pool;
pub mod tool_registry_factory;
pub mod ws; pub mod ws;
use axum::{Router, routing}; use axum::{Router, routing};
@ -11,13 +28,15 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use crate::bus::{MessageBus, OutboundDispatcher}; use crate::bus::MessageBus;
use crate::channels::ChannelManager; use crate::channels::ChannelManager;
use crate::config::Config; use crate::config::Config;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::logging; use crate::logging;
use crate::scheduler::Scheduler; use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
use outbound_dispatcher::OutboundDispatcher;
use processor::InboundProcessor; use processor::InboundProcessor;
use session::SessionManager; use session::SessionManager;
@ -119,7 +138,8 @@ pub async fn run(
state.config.scheduler.clone(), state.config.scheduler.clone(),
timezone, timezone,
state.session_manager.store(), state.session_manager.store(),
state.session_manager.clone(), AgentTaskExecutor::new(state.session_manager.clone()),
SchedulerMaintenanceService::new(state.session_manager.clone()),
); );
tokio::spawn(async move { tokio::spawn(async move {

View File

@ -1,12 +1,12 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::bus::{MessageBus, OutboundMessage}; use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError}; use crate::channels::base::{Channel, ChannelError};
/// OutboundDispatcher consumes outbound messages from the MessageBus /// Consumes outbound messages from MessageBus and dispatches them to channels.
/// and dispatches them to the appropriate Channel
pub struct OutboundDispatcher { pub struct OutboundDispatcher {
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>, channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
@ -20,7 +20,6 @@ impl OutboundDispatcher {
} }
} }
/// Register a channel with the dispatcher
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) { pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels self.channels
.write() .write()
@ -28,7 +27,6 @@ impl OutboundDispatcher {
.insert(name.to_string(), channel); .insert(name.to_string(), channel);
} }
/// Run the dispatcher loop - consumes from bus and dispatches to channels
pub async fn run(&self) { pub async fn run(&self) {
tracing::info!("OutboundDispatcher started"); tracing::info!("OutboundDispatcher started");
@ -47,8 +45,8 @@ impl OutboundDispatcher {
match channel { match channel {
Some(ch) => { Some(ch) => {
if let Err(e) = self.send_with_retry(&*ch, msg).await { if let Err(error) = self.send_with_retry(&*ch, msg).await {
tracing::error!(channel = %channel_name, error = %e, "Failed to send message after retries"); tracing::error!(channel = %channel_name, error = %error, "Failed to send message after retries");
} }
} }
None => { None => {
@ -58,7 +56,6 @@ impl OutboundDispatcher {
} }
} }
/// Send a message with exponential retry
async fn send_with_retry( async fn send_with_retry(
&self, &self,
channel: &dyn Channel, channel: &dyn Channel,
@ -66,21 +63,22 @@ impl OutboundDispatcher {
) -> Result<(), ChannelError> { ) -> Result<(), ChannelError> {
const DELAYS: [u64; 3] = [1, 2, 4]; const DELAYS: [u64; 3] = [1, 2, 4];
for (i, delay) in DELAYS.iter().enumerate() { for (attempt_index, delay) in DELAYS.iter().enumerate() {
match channel.send(msg.clone()).await { match channel.send(msg.clone()).await {
Ok(()) => return Ok(()), Ok(()) => return Ok(()),
Err(e) if i < DELAYS.len() - 1 => { Err(error) if attempt_index < DELAYS.len() - 1 => {
tracing::warn!( tracing::warn!(
attempt = i + 1, attempt = attempt_index + 1,
delay = delay, delay = delay,
error = %e, error = %error,
"Send failed, retrying" "Send failed, retrying"
); );
tokio::time::sleep(tokio::time::Duration::from_secs(*delay)).await; tokio::time::sleep(tokio::time::Duration::from_secs(*delay)).await;
} }
Err(e) => return Err(e), Err(error) => return Err(error),
} }
} }
unreachable!() unreachable!()
} }
} }

View File

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use crate::bus::MessageBus; use crate::bus::{MessageBus, OutboundMessage};
use super::session::{BusToolCallEmitter, SessionManager}; use super::session::{BusToolCallEmitter, SessionManager};
@ -70,6 +70,21 @@ impl InboundProcessor {
} }
Err(error) => { Err(error) => {
tracing::error!(error = %error, "Failed to handle message"); tracing::error!(error = %error, "Failed to handle message");
let mut metadata = inbound.forwarded_metadata.clone();
metadata.insert("error_kind".to_string(), "agent_execution".to_string());
if let Err(publish_error) = self
.bus
.publish_outbound(OutboundMessage::error_notification(
inbound.channel,
inbound.chat_id,
error.to_string(),
None,
metadata,
))
.await
{
tracing::error!(error = %publish_error, "Failed to publish execution error outbound");
}
} }
} }
} }

View File

@ -0,0 +1,86 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::PromptInjectionRepository;
use super::prompt::load_agent_prompt;
#[derive(Clone)]
pub(crate) struct PromptInjector {
repository: Arc<dyn PromptInjectionRepository>,
reinject_every: i64,
}
impl PromptInjector {
pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
Self {
repository,
reinject_every: reinject_every as i64,
}
}
pub(crate) fn ensure_initial_prompt<F>(
&self,
history_is_empty: bool,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
}
Ok(())
}
pub(crate) fn ensure_reinjected_prompt<F>(
&self,
session_id: &str,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
let session_record = self
.repository
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = self
.repository
.count_active_user_messages(session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
})?;
if self.reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
self.repository
.mark_agent_prompt_reinjected(session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
})?;
}
}
Ok(())
}
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
)
}
}

View File

@ -0,0 +1,90 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
#[derive(Clone)]
pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
}
impl ProviderConfigService {
pub(crate) fn new(
default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
) -> Self {
Self {
default_provider_config,
provider_configs: Arc::new(provider_configs),
}
}
pub(crate) fn select(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(self.default_provider_config.clone()),
Some(agent_name) => self
.provider_configs
.get(agent_name)
.cloned()
.ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(
default_provider,
HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]),
);
let selected = service.select(Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(default_provider, HashMap::new());
let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
}

View File

@ -0,0 +1,57 @@
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use crate::scheduler::ScheduledAgentTaskOptions;
use super::execution::{AgentExecutionService, ScheduledExecutionRequest};
use super::provider_config_service::ProviderConfigService;
use super::session_lifecycle::SessionLifecycleService;
#[derive(Clone)]
pub(crate) struct ScheduledAgentTaskService {
lifecycle: SessionLifecycleService,
provider_configs: ProviderConfigService,
show_tool_results: bool,
}
impl ScheduledAgentTaskService {
pub(crate) fn new(
lifecycle: SessionLifecycleService,
provider_configs: ProviderConfigService,
show_tool_results: bool,
) -> Self {
Self {
lifecycle,
provider_configs,
show_tool_results,
}
}
pub(crate) async fn run(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
let session = self.lifecycle.active_session(channel_name).await?;
let sender_id = options
.sender_id
.clone()
.unwrap_or_else(|| "scheduler".to_string());
let provider_config = self.provider_configs.select(options.agent.as_deref())?;
AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_scheduled_task(ScheduledExecutionRequest {
session,
channel_name,
chat_id,
prompt,
sender_id: &sender_id,
provider_config,
fresh_session: options.fresh_session,
system_prompt: options.system_prompt.as_deref(),
metadata: &options.metadata,
})
.await
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,61 @@
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, SkillEventRepository};
use super::agent_factory::AgentFactory;
use super::prompt_injector::PromptInjector;
use super::session::Session;
#[derive(Clone)]
pub(crate) struct SessionFactory {
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
}
impl SessionFactory {
pub(crate) fn new(
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self {
Self {
provider_config,
skills,
agent_factory,
prompt_injector,
conversations,
skill_events,
}
}
pub(crate) async fn create(
&self,
channel_name: impl Into<String>,
user_tx: mpsc::Sender<WsOutbound>,
) -> Result<Session, AgentError> {
Session::with_factories(
channel_name.into(),
self.provider_config.clone(),
user_tx,
self.skills.clone(),
self.agent_factory.clone(),
self.prompt_injector.clone(),
self.conversations.clone(),
self.skill_events.clone(),
)
.await
}
}

View File

@ -0,0 +1,267 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::ChatMessage;
use crate::storage::{
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
};
use super::prompt_injector::PromptInjector;
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
preview.push_str("...");
}
preview.replace('\n', "\\n")
}
pub(crate) struct SessionHistory {
channel_name: String,
chat_histories: HashMap<String, Vec<ChatMessage>>,
compression_in_flight: HashSet<String>,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
}
impl SessionHistory {
pub(crate) fn new(
channel_name: impl Into<String>,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self {
Self {
channel_name: channel_name.into(),
chat_histories: HashMap::new(),
compression_in_flight: HashSet::new(),
prompt_injector,
conversations,
skill_events,
}
}
pub(crate) fn persistent_session_id(&self, chat_id: &str) -> String {
persistent_session_id(&self.channel_name, chat_id)
}
pub(crate) fn ensure_persistent_session(
&self,
chat_id: &str,
) -> Result<SessionRecord, AgentError> {
self.conversations
.ensure_channel_session(&self.channel_name, chat_id)
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
}
pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
if self.chat_histories.contains_key(chat_id) {
return self.ensure_initial_agent_prompt(chat_id);
}
let history = self
.conversations
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
self.ensure_initial_agent_prompt(chat_id)?;
Ok(())
}
pub(crate) fn ensure_agent_prompt_before_user_message(
&mut self,
chat_id: &str,
) -> Result<(), AgentError> {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
self.append_persisted_message(chat_id, message)
})
}
pub(crate) fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
self.chat_histories.entry(chat_id.to_string()).or_default()
}
pub(crate) fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
self.chat_histories.get(chat_id)
}
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
self.get_or_create_history(chat_id).push(message);
}
pub(crate) fn remove_history(&mut self, chat_id: &str) {
self.chat_histories.remove(chat_id);
self.compression_in_flight.remove(chat_id);
}
pub(crate) fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
}
self.conversations
.clear_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
}
pub(crate) fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
}
self.conversations
.reset_session(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
}
pub(crate) fn append_persisted_message(
&mut self,
chat_id: &str,
message: ChatMessage,
) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id);
self.conversations
.append_message(&session_id, &message)
.map_err(|err| {
AgentError::Other(format!("append message persistence error: {}", err))
})?;
self.add_message(chat_id, message);
Ok(())
}
pub(crate) fn append_persisted_messages<I>(
&mut self,
chat_id: &str,
messages: I,
) -> Result<(), AgentError>
where
I: IntoIterator<Item = ChatMessage>,
{
for message in messages {
self.append_persisted_message(chat_id, message)?;
}
Ok(())
}
pub(crate) fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
self.get_history(chat_id)
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
}
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
self.latest_user_message(chat_id)
.map(|current| {
current.id == message.id
|| (current.content == message.content
&& current.timestamp == message.timestamp
&& current.media_refs == message.media_refs)
})
.unwrap_or(false)
}
pub(crate) fn stale_result_diagnostics(
&self,
chat_id: &str,
) -> (Option<&str>, Option<String>, bool, usize) {
let latest_user = self.latest_user_message(chat_id);
let latest_user_id = latest_user.map(|message| message.id.as_str());
let latest_user_preview = latest_user.map(|message| preview_text(&message.content, 80));
let compression_in_flight = self.compression_in_flight.contains(chat_id);
let history_len = self
.get_history(chat_id)
.map(|history| history.len())
.unwrap_or(0);
(
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
)
}
pub(crate) fn clear_all_history(&mut self) -> Result<(), AgentError> {
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear();
self.compression_in_flight.clear();
#[cfg(debug_assertions)]
tracing::debug!(previous_total = total, "All chat histories cleared");
for chat_id in chat_ids {
self.conversations
.clear_messages(&self.persistent_session_id(&chat_id))
.map_err(|err| {
AgentError::Other(format!("clear history persistence error: {}", err))
})?;
}
Ok(())
}
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
self.compression_in_flight.insert(chat_id.to_string())
}
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
self.compression_in_flight.remove(chat_id);
}
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history = self
.conversations
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
Ok(())
}
pub(crate) fn conversations(&self) -> Arc<dyn ConversationRepository> {
self.conversations.clone()
}
pub(crate) fn append_skill_event(
&self,
chat_id: &str,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), AgentError> {
self.skill_events
.append_skill_event(
Some(&self.persistent_session_id(chat_id)),
event_type,
skill_name,
payload,
)
.map_err(|err| AgentError::Other(format!("append skill event error: {}", err)))
}
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history_is_empty = self
.get_history(chat_id)
.map(|history| history.is_empty())
.unwrap_or(true);
if !history_is_empty {
return Ok(());
}
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
self.append_persisted_message(chat_id, message)
})
}
}

View File

@ -0,0 +1,49 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::AgentError;
use super::session::Session;
use super::session_factory::SessionFactory;
use super::session_pool::SessionPool;
#[derive(Clone)]
pub(crate) struct SessionLifecycleService {
session_pool: SessionPool,
}
impl SessionLifecycleService {
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
Self {
session_pool: SessionPool::new(session_ttl_hours, session_factory),
}
}
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
self.session_pool.ensure_session(channel_name).await
}
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.session_pool.get(channel_name).await
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.session_pool.touch(channel_name).await;
}
pub(crate) async fn active_session(
&self,
channel_name: &str,
) -> Result<Arc<Mutex<Session>>, AgentError> {
self.ensure_session(channel_name).await?;
self.touch(channel_name).await;
self.get(channel_name)
.await
.ok_or_else(|| AgentError::Other("Session not found".to_string()))
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
self.session_pool.cleanup_expired_sessions().await
}
}

View File

@ -0,0 +1,69 @@
use std::sync::Arc;
use crate::agent::{AgentError, EmittedMessageHandler};
use crate::bus::{MediaItem, OutboundMessage};
use super::execution::{AgentExecutionService, MessageExecutionRequest};
use super::session_lifecycle::SessionLifecycleService;
#[derive(Clone)]
pub(crate) struct SessionMessageService {
lifecycle: SessionLifecycleService,
show_tool_results: bool,
}
impl SessionMessageService {
pub(crate) fn new(lifecycle: SessionLifecycleService, show_tool_results: bool) -> Self {
Self {
lifecycle,
show_tool_results,
}
}
pub(crate) async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
media_count = %media.len(),
"Routing message to agent"
);
for (i, m) in media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
}
}
let session = self.lifecycle.active_session(channel_name).await?;
let outbound_messages = AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_message(MessageExecutionRequest {
session,
channel_name,
sender_id,
chat_id,
content,
media,
live_emitter,
})
.await?;
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
outbound_count = outbound_messages.len(),
"Agent response sequence received"
);
Ok(outbound_messages)
}
}

109
src/gateway/session_pool.rs Normal file
View File

@ -0,0 +1,109 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use crate::agent::AgentError;
use crate::protocol::WsOutbound;
use super::session::Session;
use super::session_factory::SessionFactory;
#[derive(Clone)]
pub(crate) struct SessionPool {
inner: Arc<Mutex<SessionPoolInner>>,
session_factory: SessionFactory,
}
struct SessionPoolInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
impl SessionPool {
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionPoolInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
session_factory,
}
}
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
inner.sessions.remove(channel_name);
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = self
.session_factory
.create(channel_name.to_string(), user_tx)
.await?;
inner
.sessions
.insert(channel_name.to_string(), Arc::new(Mutex::new(session)));
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
}
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.inner.lock().await.sessions.get(channel_name).cloned()
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.inner
.lock()
.await
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
}
}

View File

@ -0,0 +1,63 @@
use std::collections::HashSet;
use std::sync::Arc;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillActivateTool, SkillListTool,
SkillManageTool, TimeTool, ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
) -> Self {
Self {
skills,
store,
known_agents,
default_timezone,
}
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(TimeTool::new(self.default_timezone.clone()));
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(MemorySearchTool::new(self.store.clone()));
registry.register(MemoryManageTool::new(self.store.clone()));
registry.register(SchedulerManageTool::new(
self.store.clone(),
self.known_agents.clone(),
));
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.store.clone(),
));
registry.register(SkillListTool::new(self.skills.clone()));
registry.register(SkillManageTool::new(self.skills.clone()));
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry
}
}

View File

@ -1,36 +1,16 @@
use super::{ use super::GatewayState;
GatewayState, use crate::agent::AgentError;
session::{Session, handle_in_chat_command, schedule_background_history_compaction}, use crate::bus::InboundMessage;
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use async_trait::async_trait;
use axum::extract::State; use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response; use axum::response::Response;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Mutex, mpsc}; use tokio::sync::mpsc;
struct WsToolCallEmitter { const CLI_CHANNEL_NAME: &str = "cli";
sender: mpsc::Sender<WsOutbound>,
show_tool_results: bool,
}
#[async_trait]
impl EmittedMessageHandler for WsToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
if !should_display_message_to_user(self.show_tool_results, &message) {
return;
}
for outbound in ws_outbound_from_chat_message(&message) {
let _ = self.sender.send(outbound).await;
}
}
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async { ws.on_upgrade(|socket| async {
@ -41,15 +21,8 @@ pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewaySta
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) { async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100); let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let provider_config = match state.config.get_provider_config("default") { let cli_sessions = state.session_manager.cli_sessions();
Ok(cfg) => cfg, let initial_record = match cli_sessions.create(None) {
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
return;
}
};
let initial_record = match state.session_manager.create_cli_session(None) {
Ok(record) => record, Ok(record) => record,
Err(e) => { Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session"); tracing::error!(error = %e, "Failed to create initial CLI session");
@ -57,39 +30,20 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
}; };
let channel_name = "cli".to_string(); let runtime_session_id = uuid::Uuid::new_v4().to_string();
// 创建 CLI session
let session = match Session::new(
channel_name.clone(),
provider_config,
sender,
state.session_manager.tools(),
state.session_manager.skills(),
state.session_manager.store(),
state.config.gateway.agent_prompt_reinject_every,
)
.await
{
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
return;
}
};
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
return;
}
let runtime_session_id = session.lock().await.id.to_string();
let mut current_session_id = initial_record.id.clone(); let mut current_session_id = initial_record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
current_session_id.clone(),
runtime_session_id.clone(),
sender.clone(),
)
.await;
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established"); tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = session let _ = sender
.lock()
.await
.send(WsOutbound::SessionEstablished { .send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(), session_id: current_session_id.clone(),
}) })
@ -119,7 +73,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
Ok(inbound) => { Ok(inbound) => {
if let Err(e) = handle_inbound( if let Err(e) = handle_inbound(
&state, &state,
&session, &sender,
&runtime_session_id, &runtime_session_id,
&mut current_session_id, &mut current_session_id,
inbound, inbound,
@ -127,9 +81,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
.await .await
{ {
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = session let _ = sender
.lock()
.await
.send(WsOutbound::Error { .send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(), code: "SESSION_ERROR".to_string(),
message: e.to_string(), message: e.to_string(),
@ -139,9 +91,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message"); tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = session let _ = sender
.lock()
.await
.send(WsOutbound::Error { .send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(), code: "PARSE_ERROR".to_string(),
message: e.to_string(), message: e.to_string(),
@ -159,6 +109,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
} }
} }
state
.channel_manager
.cli_channel()
.unregister_connection(&runtime_session_id)
.await;
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended"); tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
} }
@ -174,79 +129,9 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
} }
} }
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
let mut outbound = Vec::new();
if !message.content.trim().is_empty() {
outbound.push(WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
});
}
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
}],
},
_ => Vec::new(),
}
}
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
async fn handle_inbound( async fn handle_inbound(
state: &Arc<GatewayState>, state: &Arc<GatewayState>,
session: &Arc<Mutex<Session>>, sender: &mpsc::Sender<WsOutbound>,
runtime_session_id: &str, runtime_session_id: &str,
current_session_id: &mut String, current_session_id: &mut String,
inbound: WsInbound, inbound: WsInbound,
@ -260,84 +145,31 @@ async fn handle_inbound(
} => { } => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
let (history, agent, user_tx) = {
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(&chat_id)?; state
session_guard.ensure_chat_loaded(&chat_id)?; .channel_manager
.cli_channel()
if let Some(command_response) = .register_connection(
handle_in_chat_command(&mut session_guard, &chat_id, &content)? chat_id.clone(),
{ runtime_session_id.to_string(),
let _ = session_guard sender.clone(),
.send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: command_response,
role: "assistant".to_string(),
})
.await;
return Ok(());
}
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
session_guard.record_skill_offer(&chat_id)?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
show_tool_results: state.config.gateway.show_tool_results,
});
let agent = session_guard
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
(history, agent, session_guard.user_tx.clone())
};
match agent.process(history).await {
Ok(result) => {
let mut session_guard = session.lock().await;
session_guard
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result
.emitted_messages
.iter()
.filter(|message| {
!message.is_assistant_tool_call_message()
&& should_display_message_to_user(
state.config.gateway.show_tool_results,
message,
) )
})
.flat_map(ws_outbound_from_chat_message)
{
let _ = session_guard.send(outbound).await;
}
drop(session_guard);
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone())
.await
{
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
}
}
Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
let _ = user_tx
.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: error.to_string(),
})
.await; .await;
}
} state
.bus
.publish_inbound(InboundMessage {
channel: CLI_CHANNEL_NAME.to_string(),
sender_id,
chat_id,
content,
timestamp: current_timestamp(),
media: Vec::new(),
metadata: HashMap::new(),
forwarded_metadata: HashMap::new(),
})
.await
.map_err(|error| AgentError::Other(error.to_string()))?;
Ok(()) Ok(())
} }
@ -348,22 +180,37 @@ async fn handle_inbound(
let target = session_id let target = session_id
.or(chat_id) .or(chat_id)
.unwrap_or_else(|| current_session_id.clone()); .unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?; state
.session_manager
.cli_sessions()
.clear_messages(&target)?;
let mut session_guard = session.lock().await; if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
session_guard.remove_history(&target); session.lock().await.remove_history(&target);
let _ = session_guard }
let _ = sender
.send(WsOutbound::HistoryCleared { session_id: target }) .send(WsOutbound::HistoryCleared { session_id: target })
.await; .await;
Ok(()) Ok(())
} }
WsInbound::CreateSession { title } => { WsInbound::CreateSession { title } => {
let record = state.session_manager.create_cli_session(title.as_deref())?; let record = state
.session_manager
.cli_sessions()
.create(title.as_deref())?;
*current_session_id = record.id.clone(); *current_session_id = record.id.clone();
let mut session_guard = session.lock().await; state
session_guard.ensure_chat_loaded(&record.id)?; .channel_manager
let _ = session_guard .cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
.send(WsOutbound::SessionCreated { .send(WsOutbound::SessionCreated {
session_id: record.id, session_id: record.id,
title: record.title, title: record.title,
@ -372,11 +219,13 @@ async fn handle_inbound(
Ok(()) Ok(())
} }
WsInbound::ListSessions { include_archived } => { WsInbound::ListSessions { include_archived } => {
let records = state.session_manager.list_cli_sessions(include_archived)?; let records = state
.session_manager
.cli_sessions()
.list(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect(); let summaries = records.into_iter().map(to_session_summary).collect();
let session_guard = session.lock().await; let _ = sender
let _ = session_guard
.send(WsOutbound::SessionList { .send(WsOutbound::SessionList {
sessions: summaries, sessions: summaries,
current_session_id: Some(current_session_id.clone()), current_session_id: Some(current_session_id.clone()),
@ -385,9 +234,8 @@ async fn handle_inbound(
Ok(()) Ok(())
} }
WsInbound::LoadSession { session_id } => { WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.get_session_record(&session_id)? else { let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
let session_guard = session.lock().await; let _ = sender
let _ = session_guard
.send(WsOutbound::Error { .send(WsOutbound::Error {
code: "SESSION_NOT_FOUND".to_string(), code: "SESSION_NOT_FOUND".to_string(),
message: format!("Session not found: {}", session_id), message: format!("Session not found: {}", session_id),
@ -397,9 +245,16 @@ async fn handle_inbound(
}; };
*current_session_id = record.id.clone(); *current_session_id = record.id.clone();
let mut session_guard = session.lock().await; state
session_guard.ensure_chat_loaded(&record.id)?; .channel_manager
let _ = session_guard .cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
.send(WsOutbound::SessionLoaded { .send(WsOutbound::SessionLoaded {
session_id: record.id, session_id: record.id,
title: record.title, title: record.title,
@ -410,9 +265,11 @@ async fn handle_inbound(
} }
WsInbound::RenameSession { session_id, title } => { WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone()); let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.rename_session(&target, &title)?; state
let session_guard = session.lock().await; .session_manager
let _ = session_guard .cli_sessions()
.rename(&target, &title)?;
let _ = sender
.send(WsOutbound::SessionRenamed { .send(WsOutbound::SessionRenamed {
session_id: target, session_id: target,
title, title,
@ -422,26 +279,27 @@ async fn handle_inbound(
} }
WsInbound::ArchiveSession { session_id } => { WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone()); let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.archive_session(&target)?; state.session_manager.cli_sessions().archive(&target)?;
let session_guard = session.lock().await; let _ = sender
let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target }) .send(WsOutbound::SessionArchived { session_id: target })
.await; .await;
Ok(()) Ok(())
} }
WsInbound::DeleteSession { session_id } => { WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone()); let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.delete_session(&target)?; state.session_manager.cli_sessions().delete(&target)?;
let replacement = if target == *current_session_id { let replacement = if target == *current_session_id {
Some(state.session_manager.create_cli_session(None)?) Some(state.session_manager.cli_sessions().create(None)?)
} else { } else {
None None
}; };
let mut session_guard = session.lock().await; if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
session_guard.remove_history(&target); session.lock().await.remove_history(&target);
let _ = session_guard }
let _ = sender
.send(WsOutbound::SessionDeleted { .send(WsOutbound::SessionDeleted {
session_id: target.clone(), session_id: target.clone(),
}) })
@ -449,8 +307,16 @@ async fn handle_inbound(
if let Some(record) = replacement { if let Some(record) = replacement {
*current_session_id = record.id.clone(); *current_session_id = record.id.clone();
session_guard.ensure_chat_loaded(&record.id)?; state
let _ = session_guard .channel_manager
.cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
.send(WsOutbound::SessionCreated { .send(WsOutbound::SessionCreated {
session_id: record.id, session_id: record.id,
title: record.title, title: record.title,
@ -461,13 +327,19 @@ async fn handle_inbound(
Ok(()) Ok(())
} }
WsInbound::Ping => { WsInbound::Ping => {
let session_guard = session.lock().await; let _ = sender.send(WsOutbound::Pong).await;
let _ = session_guard.send(WsOutbound::Pong).await;
Ok(()) Ok(())
} }
} }
} }
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String { fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
sender_id sender_id
.map(str::trim) .map(str::trim)
@ -478,106 +350,7 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::resolve_ws_sender_id;
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
ws_outbound_from_chat_message,
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::protocol::WsOutbound;
use crate::providers::ToolCall;
use serde_json::json;
use tokio::sync::mpsc;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
let message = ChatMessage::assistant_with_tool_calls(
"日报已整理完成。",
vec![ToolCall {
id: "call-1".to_string(),
name: "memory_manage".to_string(),
arguments: json!({"action": "put"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 2);
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
let completed = ChatMessage::tool("call-1", "calculator", "2");
let pending = ChatMessage::tool_with_state(
"call-2",
"bash",
"waiting",
ToolMessageState::PendingUserAction,
);
assert!(!should_display_message_to_user(false, &completed));
assert!(should_display_message_to_user(false, &pending));
assert!(should_display_message_to_user(true, &completed));
}
#[test] #[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() { fn test_resolve_ws_sender_id_prefers_inbound_sender() {
@ -596,23 +369,4 @@ mod tests {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1"); assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
} }
#[tokio::test]
async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() {
let (sender, mut receiver) = mpsc::channel(4);
let emitter = WsToolCallEmitter {
sender,
show_tool_results: false,
};
emitter
.handle(ChatMessage::tool("call-1", "calculator", "2"))
.await;
assert!(
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
.await
.is_err()
);
}
} }

View File

@ -4,6 +4,7 @@ pub mod channels;
pub mod cli; pub mod cli;
pub mod client; pub mod client;
pub mod config; pub mod config;
pub mod domain;
pub mod gateway; pub mod gateway;
pub mod logging; pub mod logging;
pub mod observability; pub mod observability;

View File

@ -1,3 +1,5 @@
pub mod ws_adapter;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

230
src/protocol/ws_adapter.rs Normal file
View File

@ -0,0 +1,230 @@
#[cfg(test)]
use crate::bus::ChatMessage;
use crate::bus::OutboundMessage;
use crate::bus::message::OutboundEventKind;
#[cfg(test)]
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use super::WsOutbound;
const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。";
#[cfg(test)]
pub(crate) fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
let mut outbound = Vec::new();
if !message.content.trim().is_empty() {
outbound.push(WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
});
}
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
}],
},
_ => Vec::new(),
}
}
pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
match message.event_kind {
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
vec![WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
arguments: message
.tool_arguments
.clone()
.unwrap_or(serde_json::Value::Null),
content: message.content.clone(),
role: message.role.clone(),
}],
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
}],
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: message.content.clone(),
}],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::messages::ToolCall;
use serde_json::json;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
let message = ChatMessage::assistant_with_tool_calls(
"日报已整理完成。",
vec![ToolCall {
id: "call-1".to_string(),
name: "memory_manage".to_string(),
arguments: json!({"action": "put"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 2);
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
}
#[test]
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
let message = OutboundMessage::tool_call(
"cli",
"session-1",
"call-1",
"calculator",
json!({"expression": "1 + 1"}),
None,
Default::default(),
);
let outbound = ws_outbound_from_outbound_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
}

View File

@ -6,7 +6,7 @@ use std::time::Duration;
use super::traits::Usage; use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::bus::message::ContentBlock; use crate::domain::messages::ContentBlock;
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()]; let mut details = vec![error.to_string()];

View File

@ -6,10 +6,9 @@ pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider; pub use self::openai::OpenAIProvider;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
pub use traits::{ pub use crate::domain::messages::ToolCall;
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, pub use crate::domain::tools::{Tool, ToolFunction};
ToolFunction, Usage, pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Usage};
};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> { pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() { match config.provider_type.as_str() {

View File

@ -7,7 +7,7 @@ use std::time::Duration;
use super::traits::Usage; use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::bus::message::ContentBlock; use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
@ -23,6 +23,23 @@ fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
details.join("\ncaused by: ") details.join("\ncaused by: ")
} }
fn format_transport_error_context(
provider_name: &str,
model_id: &str,
url: &str,
timeout_secs: u64,
error: &(dyn std::error::Error + 'static),
) -> String {
format!(
"transport error: provider={} model={} url={} timeout_secs={} details={}",
provider_name,
model_id,
url,
timeout_secs,
format_error_chain(error)
)
}
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 { if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] { if let ContentBlock::Text { text } = &blocks[0] {
@ -294,7 +311,25 @@ impl LLMProvider for OpenAIProvider {
req_builder = req_builder.header(key.as_str(), value.as_str()); req_builder = req_builder.header(key.as_str(), value.as_str());
} }
let resp = req_builder.json(&body).send().await?; let resp = req_builder.json(&body).send().await.map_err(|err| {
let error_context = format_transport_error_context(
&self.name,
&self.model_id,
&url,
self.llm_timeout_secs,
&err,
);
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
base_url = %self.base_url,
timeout_secs = self.llm_timeout_secs,
error = %error_context,
"OpenAI-compatible API transport request failed"
);
error_context
})?;
let status = resp.status(); let status = resp.status();
let text = resp.text().await?; let text = resp.text().await?;

View File

@ -1,4 +1,5 @@
use crate::bus::message::ContentBlock; use crate::domain::messages::{ContentBlock, ToolCall};
use crate::domain::tools::Tool;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -77,27 +78,6 @@ impl Message {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest { pub struct ChatCompletionRequest {
pub messages: Vec<Message>, pub messages: Vec<Message>,

View File

@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Duration as ChronoDuration, TimeZone, Utc}; use chrono::{DateTime, Duration as ChronoDuration, TimeZone, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
use tokio::sync::watch; use tokio::sync::watch;
@ -11,34 +12,81 @@ use crate::config::{
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
SchedulerMisfirePolicy, SchedulerSchedule, SchedulerMisfirePolicy, SchedulerSchedule,
}; };
use crate::gateway::session::ScheduledAgentTaskOptions;
use crate::gateway::session::SessionManager;
use crate::storage::{ use crate::storage::{
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore, SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert,
}; };
#[derive(Debug, Clone, Default)]
pub struct ScheduledAgentTaskOptions {
pub sender_id: Option<String>,
pub fresh_session: bool,
pub system_prompt: Option<String>,
pub metadata: HashMap<String, String>,
pub agent: Option<String>,
}
#[derive(Debug, Clone)]
pub struct MaintenanceRunSummary {
pub scope_key: String,
pub user_facts: usize,
pub preferences: usize,
pub behavior_patterns: usize,
pub merges: usize,
pub conflicts: usize,
pub low_value: usize,
}
#[async_trait]
pub trait AgentTaskExecutor: Send + Sync {
async fn execute(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> anyhow::Result<Vec<OutboundMessage>>;
}
#[async_trait]
pub trait MaintenanceExecutor: Send + Sync {
async fn cleanup_expired_sessions(&self) -> usize;
async fn run_memory_maintenance_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> anyhow::Result<Vec<MaintenanceRunSummary>>;
}
pub struct Scheduler { pub struct Scheduler {
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
config: SchedulerConfig, config: SchedulerConfig,
timezone: Tz, timezone: Tz,
store: Arc<SessionStore>, jobs: Arc<dyn SchedulerJobRepository>,
session_manager: SessionManager, agent_task_executor: Arc<dyn AgentTaskExecutor>,
maintenance_executor: Arc<dyn MaintenanceExecutor>,
} }
impl Scheduler { impl Scheduler {
pub fn new( pub fn new<A, M>(
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
config: SchedulerConfig, config: SchedulerConfig,
timezone: Tz, timezone: Tz,
store: Arc<SessionStore>, jobs: Arc<dyn SchedulerJobRepository>,
session_manager: SessionManager, agent_task_executor: A,
) -> Self { maintenance_executor: M,
) -> Self
where
A: AgentTaskExecutor + 'static,
M: MaintenanceExecutor + 'static,
{
Self { Self {
bus, bus,
config, config,
timezone, timezone,
store, jobs,
session_manager, agent_task_executor: Arc::new(agent_task_executor),
maintenance_executor: Arc::new(maintenance_executor),
} }
} }
@ -81,14 +129,14 @@ impl Scheduler {
}) { }) {
let runtime = let runtime =
RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?; RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
self.store.upsert_scheduler_job(&runtime.to_upsert())?; self.jobs.upsert_scheduler_job(&runtime.to_upsert())?;
} }
Ok(()) Ok(())
} }
async fn process_tick(&self) -> anyhow::Result<()> { async fn process_tick(&self) -> anyhow::Result<()> {
let now = Utc::now(); let now = Utc::now();
let jobs = self.store.list_scheduler_jobs(true)?; let jobs = self.jobs.list_scheduler_jobs(true)?;
for record in jobs { for record in jobs {
let Some(mut job) = let Some(mut job) =
@ -98,7 +146,7 @@ impl Scheduler {
}; };
if record.next_fire_at.is_none() && job.next_fire_at.is_some() { if record.next_fire_at.is_none() && job.next_fire_at.is_some() {
self.store.update_scheduler_job_runtime( self.jobs.update_scheduler_job_runtime(
&job.id, &job.id,
job.state.clone(), job.state.clone(),
job.last_status.clone(), job.last_status.clone(),
@ -115,7 +163,7 @@ impl Scheduler {
continue; continue;
} }
self.store.update_scheduler_job_runtime( self.jobs.update_scheduler_job_runtime(
&job.id, &job.id,
SchedulerJobState::Running, SchedulerJobState::Running,
job.last_status.clone(), job.last_status.clone(),
@ -145,7 +193,7 @@ impl Scheduler {
tracing::error!(job_id = %job.id, error = %error, "Scheduler job failed"); tracing::error!(job_id = %job.id, error = %error, "Scheduler job failed");
} }
self.store.update_scheduler_job_runtime( self.jobs.update_scheduler_job_runtime(
&job.id, &job.id,
job.state.clone(), job.state.clone(),
status, status,
@ -168,11 +216,11 @@ impl Scheduler {
self.bus.publish_outbound(message).await?; self.bus.publish_outbound(message).await?;
} }
SchedulerJobKind::InternalEvent => { SchedulerJobKind::InternalEvent => {
execute_internal_event(&self.session_manager, job).await?; execute_internal_event(self.maintenance_executor.as_ref(), job).await?;
} }
SchedulerJobKind::AgentTask => { SchedulerJobKind::AgentTask => {
let outbound_messages = execute_agent_task( let outbound_messages = execute_agent_task(
&self.session_manager, self.agent_task_executor.as_ref(),
job, job,
required_notification_chat_id(job, "agent_task")?, required_notification_chat_id(job, "agent_task")?,
) )
@ -184,7 +232,8 @@ impl Scheduler {
SchedulerJobKind::SilentAgentTask => { SchedulerJobKind::SilentAgentTask => {
let execution_chat_id = resolve_execution_chat_id(job)?; let execution_chat_id = resolve_execution_chat_id(job)?;
if let Err(error) = if let Err(error) =
execute_agent_task(&self.session_manager, job, &execution_chat_id).await execute_agent_task(self.agent_task_executor.as_ref(), job, &execution_chat_id)
.await
{ {
if let Err(notify_error) = if let Err(notify_error) =
self.notify_silent_agent_task_failure(job, &error).await self.notify_silent_agent_task_failure(job, &error).await
@ -587,7 +636,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
} }
async fn execute_internal_event( async fn execute_internal_event(
session_manager: &SessionManager, maintenance_executor: &dyn MaintenanceExecutor,
job: &RuntimeJob, job: &RuntimeJob,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let event = job let event = job
@ -598,24 +647,24 @@ async fn execute_internal_event(
match event { match event {
"session_cleanup" => { "session_cleanup" => {
let removed = session_manager.cleanup_expired_sessions().await; let removed = maintenance_executor.cleanup_expired_sessions().await;
tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed"); tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed");
Ok(()) Ok(())
} }
"memory_maintenance" => { "memory_maintenance" => {
let results = session_manager let results = maintenance_executor
.run_memory_maintenance_for_all_scopes(job.last_fired_at) .run_memory_maintenance_for_all_scopes(job.last_fired_at)
.await?; .await?;
for result in &results { for result in &results {
tracing::info!( tracing::info!(
job_id = %job.id, job_id = %job.id,
scope_key = %result.scope_key, scope_key = %result.scope_key,
user_facts = result.output.user_facts.len(), user_facts = result.user_facts,
preferences = result.output.preferences.len(), preferences = result.preferences,
behavior_patterns = result.output.behavior_patterns.len(), behavior_patterns = result.behavior_patterns,
merges = result.output.merges.len(), merges = result.merges,
conflicts = result.output.conflicts.len(), conflicts = result.conflicts,
low_value = result.output.low_value_ids.len(), low_value = result.low_value,
"Scheduler completed memory maintenance model run" "Scheduler completed memory maintenance model run"
); );
} }
@ -627,7 +676,7 @@ async fn execute_internal_event(
} }
async fn execute_agent_task( async fn execute_agent_task(
session_manager: &SessionManager, agent_task_executor: &dyn AgentTaskExecutor,
job: &RuntimeJob, job: &RuntimeJob,
execution_chat_id: &str, execution_chat_id: &str,
) -> anyhow::Result<Vec<OutboundMessage>> { ) -> anyhow::Result<Vec<OutboundMessage>> {
@ -643,10 +692,9 @@ async fn execute_agent_task(
.ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?; .ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?;
let options = parse_scheduled_agent_task_options(job)?; let options = parse_scheduled_agent_task_options(job)?;
session_manager agent_task_executor
.run_scheduled_agent_task(channel_name, execution_chat_id, prompt, options) .execute(channel_name, execution_chat_id, prompt, options)
.await .await
.map_err(|error| anyhow::anyhow!(error.to_string()))
} }
fn required_notification_chat_id<'a>( fn required_notification_chat_id<'a>(
@ -963,43 +1011,44 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
mod tests { mod tests {
use super::*; use super::*;
use crate::bus::MessageBus; use crate::bus::MessageBus;
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig}; use crate::config::BUILTIN_MEMORY_MAINTENANCE_JOB_ID;
use crate::gateway::session::SessionManager;
use crate::skills::SkillRuntime;
use crate::storage::{SchedulerJobUpsert, SessionStore}; use crate::storage::{SchedulerJobUpsert, SessionStore};
use std::collections::HashMap;
fn test_provider_config() -> LLMProviderConfig { #[derive(Clone)]
LLMProviderConfig { struct TestAgentTaskExecutor;
provider_type: "openai".to_string(),
name: "default".to_string(), #[async_trait::async_trait]
base_url: "http://localhost".to_string(), impl AgentTaskExecutor for TestAgentTaskExecutor {
api_key: "test-key".to_string(), async fn execute(
extra_headers: HashMap::new(), &self,
llm_timeout_secs: 30, _channel_name: &str,
model_id: "test-model".to_string(), _chat_id: &str,
temperature: Some(0.0), _prompt: &str,
max_tokens: None, _options: ScheduledAgentTaskOptions,
context_window_tokens: None, ) -> anyhow::Result<Vec<OutboundMessage>> {
model_extra: HashMap::new(), Ok(Vec::new())
max_tool_iterations: 4,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
} }
} }
fn test_session_manager() -> SessionManager { #[derive(Clone)]
let provider_config = test_provider_config(); struct TestMaintenanceExecutor;
SessionManager::new(
4, #[async_trait::async_trait]
100, impl MaintenanceExecutor for TestMaintenanceExecutor {
false, async fn cleanup_expired_sessions(&self) -> usize {
"Asia/Shanghai".to_string(), 0
provider_config.clone(), }
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()), async fn run_memory_maintenance_for_all_scopes(
) &self,
.unwrap() _updated_since: Option<i64>,
) -> anyhow::Result<Vec<MaintenanceRunSummary>> {
Ok(Vec::new())
}
}
fn test_scheduler_services() -> (TestAgentTaskExecutor, TestMaintenanceExecutor) {
(TestAgentTaskExecutor, TestMaintenanceExecutor)
} }
#[test] #[test]
@ -1129,7 +1178,7 @@ mod tests {
}) })
.unwrap(); .unwrap();
let session_manager = test_session_manager(); let (agent_task_executor, maintenance_service) = test_scheduler_services();
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
MessageBus::new(8), MessageBus::new(8),
SchedulerConfig { SchedulerConfig {
@ -1141,7 +1190,8 @@ mod tests {
}, },
chrono_tz::Asia::Shanghai, chrono_tz::Asia::Shanghai,
store.clone(), store.clone(),
session_manager, agent_task_executor,
maintenance_service,
); );
scheduler.process_tick().await.unwrap(); scheduler.process_tick().await.unwrap();
@ -1159,13 +1209,14 @@ mod tests {
fn sync_config_jobs_persists_builtin_memory_maintenance_job() { fn sync_config_jobs_persists_builtin_memory_maintenance_job() {
let store = Arc::new(SessionStore::in_memory().unwrap()); let store = Arc::new(SessionStore::in_memory().unwrap());
let session_manager = test_session_manager(); let (agent_task_executor, maintenance_service) = test_scheduler_services();
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
MessageBus::new(8), MessageBus::new(8),
SchedulerConfig::default(), SchedulerConfig::default(),
chrono_tz::Asia::Shanghai, chrono_tz::Asia::Shanghai,
store.clone(), store.clone(),
session_manager, agent_task_executor,
maintenance_service,
); );
scheduler.sync_config_jobs().unwrap(); scheduler.sync_config_jobs().unwrap();
@ -1204,6 +1255,7 @@ mod tests {
async fn silent_agent_task_failure_notifies_primary_chat() { async fn silent_agent_task_failure_notifies_primary_chat() {
let store = Arc::new(SessionStore::in_memory().unwrap()); let store = Arc::new(SessionStore::in_memory().unwrap());
let bus = MessageBus::new(8); let bus = MessageBus::new(8);
let (agent_task_executor, maintenance_service) = test_scheduler_services();
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
bus.clone(), bus.clone(),
SchedulerConfig { SchedulerConfig {
@ -1215,7 +1267,8 @@ mod tests {
}, },
chrono_tz::Asia::Shanghai, chrono_tz::Asia::Shanghai,
store, store,
test_session_manager(), agent_task_executor,
maintenance_service,
); );
let job = RuntimeJob { let job = RuntimeJob {

View File

@ -6,7 +6,6 @@ use std::path::{Path, PathBuf};
use std::sync::RwLock; use std::sync::RwLock;
use crate::config::SkillsConfig; use crate::config::SkillsConfig;
use crate::providers::{Tool, ToolFunction};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Skill { pub struct Skill {
@ -120,13 +119,6 @@ impl SkillRuntime {
.offered_event_payload() .offered_event_payload()
} }
pub fn skill_tool_definition(&self) -> Option<Tool> {
self.catalog
.read()
.expect("skills rwlock poisoned")
.skill_tool_definition()
}
pub fn activation_payload(&self, name: &str) -> Result<String, String> { pub fn activation_payload(&self, name: &str) -> Result<String, String> {
self.catalog self.catalog
.read() .read()
@ -230,6 +222,12 @@ impl SkillRuntime {
} }
} }
impl crate::agent::SkillProvider for SkillRuntime {
fn system_index_prompt(&self) -> Option<String> {
SkillRuntime::system_index_prompt(self)
}
}
impl SkillSource { impl SkillSource {
fn as_str(&self) -> &'static str { fn as_str(&self) -> &'static str {
match self { match self {
@ -344,30 +342,6 @@ impl SkillCatalog {
self.catalog_event_payload() self.catalog_event_payload()
} }
pub fn skill_tool_definition(&self) -> Option<Tool> {
if self.skills.is_empty() {
return None;
}
Some(Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "skill_activate".to_string(),
description: "Load detailed instructions for a named skill discovered from SKILL.md files. Use when a task matches a listed skill description.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Skill name from the available skills list"
}
},
"required": ["name"]
}),
},
})
}
pub fn activation_payload(&self, name: &str) -> Result<String, String> { pub fn activation_payload(&self, name: &str) -> Result<String, String> {
let skill = self let skill = self
.find_skill(name) .find_skill(name)
@ -679,31 +653,6 @@ mod tests {
assert!(err.contains("description")); assert!(err.contains("description"));
} }
#[test]
fn test_skill_tool_definition_exists_when_skills_present() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().join(".picobot").join("skills").join("demo");
fs::create_dir_all(&root).unwrap();
fs::write(
root.join("SKILL.md"),
"---\ndescription: demo skill\n---\nDo demo",
)
.unwrap();
let skills = load_skills_from_root(
&dir.path().join(".picobot").join("skills"),
SkillSource::Project,
);
let catalog = SkillCatalog {
skills,
max_index_chars: 4000,
max_listed_skills: 10,
};
let tool = catalog.skill_tool_definition().unwrap();
assert_eq!(tool.function.name, "skill_activate");
}
#[test] #[test]
fn test_activation_payload_contains_body() { fn test_activation_payload_contains_body() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();

9
src/storage/error.rs Normal file
View File

@ -0,0 +1,9 @@
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}

View File

@ -3,193 +3,28 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use rusqlite::{Connection, OptionalExtension, params}; use rusqlite::{Connection, OptionalExtension, params};
use serde::{Deserialize, Serialize};
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
#[derive(Debug, Clone, Serialize, Deserialize)] pub mod error;
pub struct SkillEventRecord { pub mod ports;
pub id: String, pub mod records;
pub session_id: Option<String>,
pub event_type: String,
pub skill_name: Option<String>,
pub payload: serde_json::Value,
pub created_at: i64,
}
#[derive(Debug, thiserror::Error)] pub use error::StorageError;
pub enum StorageError { pub use ports::{
#[error("database error: {0}")] ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
Database(#[from] rusqlite::Error), SkillEventRepository,
#[error("io error: {0}")] };
Io(#[from] std::io::Error), pub use records::{
#[error("serialization error: {0}")] MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
Serialization(#[from] serde_json::Error), SchedulerJobUpsert, SessionRecord, SkillEventRecord,
} };
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRecord {
pub id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub summary: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub last_active_at: i64,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: String,
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct MemoryUpsert {
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobState {
Scheduled,
Running,
Paused,
Completed,
}
impl SchedulerJobState {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobState::Scheduled => "scheduled",
SchedulerJobState::Running => "running",
SchedulerJobState::Paused => "paused",
SchedulerJobState::Completed => "completed",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"scheduled" => Some(Self::Scheduled),
"running" => Some(Self::Running),
"paused" => Some(Self::Paused),
"completed" => Some(Self::Completed),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobStatus {
Ok,
Error,
Skipped,
}
impl SchedulerJobStatus {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobStatus::Ok => "ok",
SchedulerJobStatus::Error => "error",
SchedulerJobStatus::Skipped => "skipped",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"ok" => Some(Self::Ok),
"error" => Some(Self::Error),
"skipped" => Some(Self::Skipped),
_ => None,
}
}
}
impl Default for SchedulerJobState {
fn default() -> Self {
Self::Scheduled
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct SessionStore { pub struct SessionStore {
conn: Arc<Mutex<Connection>>, conn: Arc<Mutex<Connection>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerJobRecord {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct SchedulerJobUpsert {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
}
impl SessionStore { impl SessionStore {
#[cfg(test)] #[cfg(test)]
pub fn new() -> Result<Self, StorageError> { pub fn new() -> Result<Self, StorageError> {
@ -1802,7 +1637,7 @@ fn quote_fts_or_query(queries: &[String]) -> String {
mod tests { mod tests {
use super::*; use super::*;
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT; use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
use crate::providers::ToolCall; use crate::domain::messages::ToolCall;
#[test] #[test]
fn test_persistent_session_id_for_cli_and_channel() { fn test_persistent_session_id_for_cli_and_channel() {

304
src/storage/ports.rs Normal file
View File

@ -0,0 +1,304 @@
use super::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError,
};
use crate::bus::ChatMessage;
pub trait ConversationRepository: Send + Sync + 'static {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError>;
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>;
fn reset_session(&self, session_id: &str) -> Result<(), StorageError>;
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError>;
}
pub trait PromptInjectionRepository: Send + Sync + 'static {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError>;
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError>;
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError>;
}
pub trait MemoryRepository: Send + Sync + 'static {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>;
fn update_memory(&self, input: &MemoryUpsert) -> Result<Option<MemoryRecord>, StorageError>;
fn delete_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<bool, StorageError>;
fn get_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<Option<MemoryRecord>, StorageError>;
fn list_memories(
&self,
scope_kind: &str,
scope_key: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError>;
fn search_memories_any(
&self,
scope_kind: &str,
scope_key: &str,
queries: &[String],
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError>;
}
pub trait SchedulerJobRepository: Send + Sync + 'static {
fn upsert_scheduler_job(
&self,
input: &SchedulerJobUpsert,
) -> Result<SchedulerJobRecord, StorageError>;
fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError>;
fn list_scheduler_jobs(
&self,
enabled_only: bool,
) -> Result<Vec<SchedulerJobRecord>, StorageError>;
fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError>;
fn update_scheduler_job_runtime(
&self,
job_id: &str,
state: SchedulerJobState,
last_status: Option<SchedulerJobStatus>,
last_error: Option<&str>,
run_count: i64,
last_fired_at: Option<i64>,
next_fire_at: Option<i64>,
paused_at: Option<i64>,
completed_at: Option<i64>,
) -> Result<(), StorageError>;
}
pub trait SkillEventRepository: Send + Sync + 'static {
fn append_skill_event(
&self,
session_id: Option<&str>,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), StorageError>;
fn list_skill_events(
&self,
session_id: Option<&str>,
) -> Result<Vec<SkillEventRecord>, StorageError>;
}
impl ConversationRepository for super::SessionStore {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError> {
super::SessionStore::ensure_channel_session(self, channel_name, chat_id)
}
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
super::SessionStore::load_messages(self, session_id)
}
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
super::SessionStore::append_message(self, session_id, message)
}
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::clear_messages(self, session_id)
}
fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::reset_session(self, session_id)
}
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError> {
super::SessionStore::compact_active_history(
self,
session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
preserved_system_messages,
summary_message,
preserved_messages,
)
}
}
impl PromptInjectionRepository for super::SessionStore {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
super::SessionStore::get_session(self, session_id)
}
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
super::SessionStore::count_active_user_messages(self, session_id)
}
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::mark_agent_prompt_reinjected(self, session_id)
}
}
impl MemoryRepository for super::SessionStore {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
super::SessionStore::put_memory(self, input)
}
fn update_memory(&self, input: &MemoryUpsert) -> Result<Option<MemoryRecord>, StorageError> {
super::SessionStore::update_memory(self, input)
}
fn delete_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<bool, StorageError> {
super::SessionStore::delete_memory(self, scope_kind, scope_key, namespace, memory_key)
}
fn get_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<Option<MemoryRecord>, StorageError> {
super::SessionStore::get_memory(self, scope_kind, scope_key, namespace, memory_key)
}
fn list_memories(
&self,
scope_kind: &str,
scope_key: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
super::SessionStore::list_memories(self, scope_kind, scope_key, namespace, limit)
}
fn search_memories_any(
&self,
scope_kind: &str,
scope_key: &str,
queries: &[String],
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
super::SessionStore::search_memories_any(
self, scope_kind, scope_key, queries, namespace, limit,
)
}
}
impl SchedulerJobRepository for super::SessionStore {
fn upsert_scheduler_job(
&self,
input: &SchedulerJobUpsert,
) -> Result<SchedulerJobRecord, StorageError> {
super::SessionStore::upsert_scheduler_job(self, input)
}
fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError> {
super::SessionStore::get_scheduler_job(self, job_id)
}
fn list_scheduler_jobs(
&self,
enabled_only: bool,
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
super::SessionStore::list_scheduler_jobs(self, enabled_only)
}
fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError> {
super::SessionStore::delete_scheduler_job(self, job_id)
}
fn update_scheduler_job_runtime(
&self,
job_id: &str,
state: SchedulerJobState,
last_status: Option<SchedulerJobStatus>,
last_error: Option<&str>,
run_count: i64,
last_fired_at: Option<i64>,
next_fire_at: Option<i64>,
paused_at: Option<i64>,
completed_at: Option<i64>,
) -> Result<(), StorageError> {
super::SessionStore::update_scheduler_job_runtime(
self,
job_id,
state,
last_status,
last_error,
run_count,
last_fired_at,
next_fire_at,
paused_at,
completed_at,
)
}
}
impl SkillEventRepository for super::SessionStore {
fn append_skill_event(
&self,
session_id: Option<&str>,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), StorageError> {
super::SessionStore::append_skill_event(self, session_id, event_type, skill_name, payload)
}
fn list_skill_events(
&self,
session_id: Option<&str>,
) -> Result<Vec<SkillEventRecord>, StorageError> {
super::SessionStore::list_skill_events(self, session_id)
}
}

169
src/storage/records.rs Normal file
View File

@ -0,0 +1,169 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillEventRecord {
pub id: String,
pub session_id: Option<String>,
pub event_type: String,
pub skill_name: Option<String>,
pub payload: serde_json::Value,
pub created_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRecord {
pub id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub summary: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub last_active_at: i64,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: String,
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct MemoryUpsert {
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobState {
Scheduled,
Running,
Paused,
Completed,
}
impl SchedulerJobState {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobState::Scheduled => "scheduled",
SchedulerJobState::Running => "running",
SchedulerJobState::Paused => "paused",
SchedulerJobState::Completed => "completed",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"scheduled" => Some(Self::Scheduled),
"running" => Some(Self::Running),
"paused" => Some(Self::Paused),
"completed" => Some(Self::Completed),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobStatus {
Ok,
Error,
Skipped,
}
impl SchedulerJobStatus {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobStatus::Ok => "ok",
SchedulerJobStatus::Error => "error",
SchedulerJobStatus::Skipped => "skipped",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"ok" => Some(Self::Ok),
"error" => Some(Self::Error),
"skipped" => Some(Self::Skipped),
_ => None,
}
}
}
impl Default for SchedulerJobState {
fn default() -> Self {
Self::Scheduled
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerJobRecord {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct SchedulerJobUpsert {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
}

View File

@ -3,16 +3,16 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use crate::storage::{MemoryRecord, MemoryUpsert, SessionStore}; use crate::storage::{MemoryRecord, MemoryRepository, MemoryUpsert};
use crate::tools::traits::{Tool, ToolContext, ToolResult}; use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct MemoryManageTool { pub struct MemoryManageTool {
store: Arc<SessionStore>, memories: Arc<dyn MemoryRepository>,
} }
impl MemoryManageTool { impl MemoryManageTool {
pub fn new(store: Arc<SessionStore>) -> Self { pub fn new(memories: Arc<dyn MemoryRepository>) -> Self {
Self { store } Self { memories }
} }
} }
@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
} }
fn description(&self) -> &str { fn description(&self) -> &str {
"Create, update, or delete long-term user memories stored in SQLite. Supports actions: put, update, delete. Use memory_search as the default retrieval path before answering most requests, and use memory_search for all retrieval actions including search, get, and list. Only call this tool when you have determined that a high-value long-term memory should be created, overwritten, updated, or deleted. Memories are scoped to the current channel and sender, and record the originating session/message when available." "Create, update, or delete long-term user memories in the configured memory repository. Supports actions: put, update, delete. Use memory_search as the default retrieval path before answering most requests, and use memory_search for all retrieval actions including search, get, and list. Only call this tool when you have determined that a high-value long-term memory should be created, overwritten, updated, or deleted. Memories are scoped to the current channel and sender, and record the originating session/message when available."
} }
fn parameters_schema(&self) -> serde_json::Value { fn parameters_schema(&self) -> serde_json::Value {
@ -80,7 +80,7 @@ impl Tool for MemoryManageTool {
Ok(input) => input, Ok(input) => input,
Err(result) => return Ok(result), Err(result) => return Ok(result),
}; };
memory_to_json(self.store.put_memory(&input)?) memory_to_json(self.memories.put_memory(&input)?)
} }
"update" => { "update" => {
let input = match build_memory_upsert(context, &scope_key, &args, false) { let input = match build_memory_upsert(context, &scope_key, &args, false) {
@ -88,7 +88,7 @@ impl Tool for MemoryManageTool {
Err(result) => return Ok(result), Err(result) => return Ok(result),
}; };
match self.store.update_memory(&input)? { match self.memories.update_memory(&input)? {
Some(memory) => memory_to_json(memory), Some(memory) => memory_to_json(memory),
None => { None => {
return Ok(error_result(&format!( return Ok(error_result(&format!(
@ -109,7 +109,7 @@ impl Tool for MemoryManageTool {
}; };
let deleted = self let deleted = self
.store .memories
.delete_memory("user", &scope_key, namespace, key)?; .delete_memory("user", &scope_key, namespace, key)?;
if !deleted { if !deleted {
return Ok(error_result(&format!( return Ok(error_result(&format!(
@ -219,6 +219,7 @@ fn error_result(message: &str) -> ToolResult {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::storage::SessionStore;
#[tokio::test] #[tokio::test]
async fn test_memory_manage_put_returns_saved_memory() { async fn test_memory_manage_put_returns_saved_memory() {

View File

@ -3,16 +3,16 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use crate::storage::{MemoryRecord, SessionStore}; use crate::storage::{MemoryRecord, MemoryRepository};
use crate::tools::traits::{Tool, ToolContext, ToolResult}; use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct MemorySearchTool { pub struct MemorySearchTool {
store: Arc<SessionStore>, memories: Arc<dyn MemoryRepository>,
} }
impl MemorySearchTool { impl MemorySearchTool {
pub fn new(store: Arc<SessionStore>) -> Self { pub fn new(memories: Arc<dyn MemoryRepository>) -> Self {
Self { store } Self { memories }
} }
} }
@ -23,7 +23,7 @@ impl Tool for MemorySearchTool {
} }
fn description(&self) -> &str { fn description(&self) -> &str {
"Search and read long-term user memories stored in SQLite. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory." "Search and read long-term user memories from the configured memory repository. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory."
} }
fn parameters_schema(&self) -> serde_json::Value { fn parameters_schema(&self) -> serde_json::Value {
@ -91,7 +91,7 @@ impl Tool for MemorySearchTool {
.and_then(|value| value.as_u64()) .and_then(|value| value.as_u64())
.unwrap_or(10) as usize; .unwrap_or(10) as usize;
let memories = self let memories = self
.store .memories
.list_memories("user", &scope_key, namespace, limit)?; .list_memories("user", &scope_key, namespace, limit)?;
json!({ json!({
"count": memories.len(), "count": memories.len(),
@ -117,7 +117,7 @@ impl Tool for MemorySearchTool {
.and_then(|value| value.as_u64()) .and_then(|value| value.as_u64())
.unwrap_or(10) as usize; .unwrap_or(10) as usize;
let memories = self let memories = self
.store .memories
.search_memories_any("user", &scope_key, &queries, namespace, limit)?; .search_memories_any("user", &scope_key, &queries, namespace, limit)?;
json!({ json!({
"queries": queries, "queries": queries,
@ -135,7 +135,10 @@ impl Tool for MemorySearchTool {
None => return Ok(error_result("Missing required parameter: key")), None => return Ok(error_result("Missing required parameter: key")),
}; };
match self.store.get_memory("user", &scope_key, namespace, key)? { match self
.memories
.get_memory("user", &scope_key, namespace, key)?
{
Some(memory) => memory_to_json(memory), Some(memory) => memory_to_json(memory),
None => { None => {
return Ok(error_result(&format!( return Ok(error_result(&format!(
@ -202,6 +205,7 @@ fn error_result(message: &str) -> ToolResult {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::storage::SessionStore;
#[tokio::test] #[tokio::test]
async fn test_memory_search_search_and_get() { async fn test_memory_search_search_and_get() {

View File

@ -9,6 +9,7 @@ pub mod memory_search;
pub mod registry; pub mod registry;
pub mod scheduler_manage; pub mod scheduler_manage;
pub mod schema; pub mod schema;
pub mod skill_activate;
pub mod skill_manage; pub mod skill_manage;
pub mod time; pub mod time;
pub mod traits; pub mod traits;
@ -25,6 +26,7 @@ pub use memory_search::MemorySearchTool;
pub use registry::ToolRegistry; pub use registry::ToolRegistry;
pub use scheduler_manage::SchedulerManageTool; pub use scheduler_manage::SchedulerManageTool;
pub use schema::{CleaningStrategy, SchemaCleanr}; pub use schema::{CleaningStrategy, SchemaCleanr};
pub use skill_activate::SkillActivateTool;
pub use skill_manage::{SkillListTool, SkillManageTool}; pub use skill_manage::{SkillListTool, SkillManageTool};
pub use time::TimeTool; pub use time::TimeTool;
pub use traits::{Tool, ToolContext, ToolResult}; pub use traits::{Tool, ToolContext, ToolResult};

View File

@ -1,6 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::providers::{Tool, ToolFunction}; use crate::domain::tools::{Tool, ToolFunction};
use super::traits::Tool as ToolTrait; use super::traits::Tool as ToolTrait;

View File

@ -5,18 +5,20 @@ use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use crate::config::SchedulerSchedule; use crate::config::SchedulerSchedule;
use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore}; use crate::storage::{
SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobUpsert,
};
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
pub struct SchedulerManageTool { pub struct SchedulerManageTool {
store: Arc<SessionStore>, jobs: Arc<dyn SchedulerJobRepository>,
known_agents: Arc<HashSet<String>>, known_agents: Arc<HashSet<String>>,
} }
impl SchedulerManageTool { impl SchedulerManageTool {
pub fn new(store: Arc<SessionStore>, known_agents: HashSet<String>) -> Self { pub fn new(jobs: Arc<dyn SchedulerJobRepository>, known_agents: HashSet<String>) -> Self {
Self { Self {
store, jobs,
known_agents: Arc::new(known_agents), known_agents: Arc::new(known_agents),
} }
} }
@ -29,7 +31,7 @@ impl Tool for SchedulerManageTool {
} }
fn description(&self) -> &str { fn description(&self) -> &str {
"Manage DB-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs persist in SQLite and are executed by the scheduler runtime. When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing." "Manage repository-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing."
} }
fn parameters_schema(&self) -> serde_json::Value { fn parameters_schema(&self) -> serde_json::Value {
@ -116,30 +118,30 @@ impl Tool for SchedulerManageTool {
.get("enabled_only") .get("enabled_only")
.and_then(|value| value.as_bool()) .and_then(|value| value.as_bool())
.unwrap_or(false); .unwrap_or(false);
let jobs = self.store.list_scheduler_jobs(enabled_only)?; let jobs = self.jobs.list_scheduler_jobs(enabled_only)?;
json!(jobs.iter().map(record_to_json).collect::<Vec<_>>()) json!(jobs.iter().map(record_to_json).collect::<Vec<_>>())
} }
"get" => { "get" => {
let id = require_str(&args, "id")?; let id = require_str(&args, "id")?;
match self.store.get_scheduler_job(id)? { match self.jobs.get_scheduler_job(id)? {
Some(record) => record_to_json(&record), Some(record) => record_to_json(&record),
None => return Ok(error_result(&format!("scheduler job '{}' not found", id))), None => return Ok(error_result(&format!("scheduler job '{}' not found", id))),
} }
} }
"put" => { "put" => {
let input = build_upsert(context, &args, &self.known_agents)?; let input = build_upsert(context, &args, &self.known_agents)?;
let record = self.store.upsert_scheduler_job(&input)?; let record = self.jobs.upsert_scheduler_job(&input)?;
record_to_json(&record) record_to_json(&record)
} }
"delete" => { "delete" => {
let id = require_str(&args, "id")?; let id = require_str(&args, "id")?;
self.store.delete_scheduler_job(id)?; self.jobs.delete_scheduler_job(id)?;
json!({"status": "deleted", "id": id}) json!({"status": "deleted", "id": id})
} }
"pause" => { "pause" => {
let id = require_str(&args, "id")?; let id = require_str(&args, "id")?;
let record = self let record = self
.store .jobs
.get_scheduler_job(id)? .get_scheduler_job(id)?
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?; .ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
let mut input = record_to_upsert(&record); let mut input = record_to_upsert(&record);
@ -147,13 +149,13 @@ impl Tool for SchedulerManageTool {
input.state = SchedulerJobState::Paused; input.state = SchedulerJobState::Paused;
input.paused_at = Some(current_timestamp()); input.paused_at = Some(current_timestamp());
input.next_fire_at = None; input.next_fire_at = None;
let saved = self.store.upsert_scheduler_job(&input)?; let saved = self.jobs.upsert_scheduler_job(&input)?;
record_to_json(&saved) record_to_json(&saved)
} }
"resume" => { "resume" => {
let id = require_str(&args, "id")?; let id = require_str(&args, "id")?;
let record = self let record = self
.store .jobs
.get_scheduler_job(id)? .get_scheduler_job(id)?
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?; .ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
let mut input = record_to_upsert(&record); let mut input = record_to_upsert(&record);
@ -162,7 +164,7 @@ impl Tool for SchedulerManageTool {
input.paused_at = None; input.paused_at = None;
input.completed_at = None; input.completed_at = None;
input.next_fire_at = None; input.next_fire_at = None;
let saved = self.store.upsert_scheduler_job(&input)?; let saved = self.jobs.upsert_scheduler_job(&input)?;
record_to_json(&saved) record_to_json(&saved)
} }
_ => return Ok(error_result("Unsupported action")), _ => return Ok(error_result("Unsupported action")),
@ -431,6 +433,7 @@ fn current_timestamp() -> i64 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::storage::SessionStore;
#[tokio::test] #[tokio::test]
async fn test_scheduler_manage_put_and_get() { async fn test_scheduler_manage_put_and_get() {

151
src/tools/skill_activate.rs Normal file
View File

@ -0,0 +1,151 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::skills::SkillRuntime;
use crate::storage::SkillEventRepository;
use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct SkillActivateTool {
skills: Arc<SkillRuntime>,
events: Arc<dyn SkillEventRepository>,
}
impl SkillActivateTool {
pub fn new(skills: Arc<SkillRuntime>, events: Arc<dyn SkillEventRepository>) -> Self {
Self { skills, events }
}
fn record_event(
&self,
context: &ToolContext,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) {
if let Err(err) = self.events.append_skill_event(
context.session_id.as_deref(),
event_type,
skill_name,
payload,
) {
tracing::warn!(error = %err, event_type, skill_name, "Failed to record skill activation event");
}
}
}
#[async_trait]
impl Tool for SkillActivateTool {
fn name(&self) -> &str {
"skill_activate"
}
fn description(&self) -> &str {
"Load detailed instructions for a named skill discovered from SKILL.md files. Use when a task matches a listed skill description."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Skill name from the available skills list"
}
},
"required": ["name"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
self.execute_with_context(&ToolContext::default(), args)
.await
}
async fn execute_with_context(
&self,
context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
let skill_name = match args.get("name").and_then(|value| value.as_str()) {
Some(name) if !name.trim().is_empty() => name,
_ => {
self.record_event(
context,
"activation_failed",
None,
&json!({
"reason": "missing_name",
"arguments": args,
}),
);
return Ok(error_result("Missing required parameter: name"));
}
};
match self.skills.activation_payload(skill_name) {
Ok(output) => {
if let Ok(payload) = self.skills.activation_event_payload(skill_name) {
self.record_event(context, "activated", Some(skill_name), &payload);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(err) => {
self.record_event(
context,
"activation_failed",
Some(skill_name),
&json!({
"reason": err,
"arguments": args,
}),
);
Ok(error_result(&err))
}
}
}
}
fn error_result(message: &str) -> ToolResult {
ToolResult {
success: false,
output: String::new(),
error: Some(message.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
#[tokio::test]
async fn test_skill_activate_records_failed_activation_event() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
store.ensure_channel_session("feishu", "chat-1").unwrap();
let tool = SkillActivateTool::new(skills, store.clone());
let context = ToolContext {
session_id: Some("feishu:chat-1".to_string()),
..ToolContext::default()
};
let result = tool
.execute_with_context(&context, json!({ "name": "demo" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
let events = store.list_skill_events(Some("feishu:chat-1")).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "activation_failed");
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));
}
}