Compare commits

..

No commits in common. "04fc2c0710401e3baf00e539fbd31262cf994aad" and "fa3354db9ccb9844bbf155410911a4f9e00e653c" have entirely different histories.

55 changed files with 1842 additions and 3438 deletions

View File

@ -18,7 +18,7 @@ PicoBot 是一个用 Rust 构建的多通道 Agent 网关。它把消息接入
PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个可持续运行的 Agent 基础设施:
- 消息从不同渠道进入统一总线
- SessionManager 负责会话路由和运行时服务编排AgentExecutionService 负责上下文准备、AgentLoop 执行、结果持久化和回复生成
- SessionManager 负责会话路由、上下文恢复、工具执行和回复生成
- SQLite 作为事实来源保存跨重启状态
- Agent 在每轮推理时可以读取文件、执行命令、发 HTTP 请求、读写记忆、管理技能和调度任务
@ -30,20 +30,20 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot而是一个
1. Channel 接收外部消息
2. MessageBus 将消息送入统一的 inbound 队列
3. Gateway 启动的 InboundProcessor 调用 SessionManager 定位目标 Session
4. AgentExecutionService 准备上下文、运行 AgentLoop、执行工具调用并收集结果
5. 生成的 user / assistant / tool / system 消息按真实顺序写入 SQLite
3. Gateway 启动的 inbound processor 调用 SessionManager 处理消息
4. SessionManager 加载持久化历史、注入系统提示、运行 AgentLoop、执行工具调用
5. 生成的 assistant / tool / system 消息写入 SQLite
6. OutboundDispatcher 将结果投递到目标通道
主要模块如下:
- src/gateway网关入口、HTTP 健康检查、WebSocket 控制面、Session 池、CLI 会话服务、OutboundDispatcher 与 Agent 执行编排
- src/bus消息总线队列与消息结构定义,不包含渠道投递逻辑
- src/gateway网关入口、HTTP 健康检查、WebSocket 服务、Session 管理
- src/bus消息总线与消息结构定义
- src/agentAgentLoop 与上下文压缩器
- src/providers不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic
- src/tools内置工具集合
- src/storageSQLite 持久化实现
- src/channels渠道适配层当前已有 CLI 与飞书通道
- src/channels渠道适配层当前已有飞书通道
- src/scheduler数据库驱动的计划任务调度器
- src/skills技能发现、加载与运行时管理
- src/client / src/cli本地 CLI 客户端和交互命令
@ -549,8 +549,7 @@ CLI 中已实现的交互命令包括:
"models": {
"default": {
"model_id": "<OPENAI_MODEL_NAME>",
"temperature": 0.2,
"context_window_tokens": 128000
"temperature": 0.2
}
},
"agents": {
@ -632,11 +631,11 @@ PicoBot/
├── src/
│ ├── agent/ # AgentLoop、上下文压缩
│ ├── bus/ # 消息总线与消息结构
│ ├── channels/ # CLI / 飞书等渠道适配
│ ├── channels/ # 渠道适配
│ ├── cli/ # CLI 输入命令
│ ├── client/ # WebSocket CLI 客户端
│ ├── config/ # 配置解析
│ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面
│ ├── gateway/ # Gateway、SessionManager、WS/HTTP
│ ├── providers/ # OpenAI / Anthropic Provider
│ ├── scheduler/ # 定时任务系统
│ ├── skills/ # 技能运行时
@ -656,7 +655,7 @@ PicoBot/
建议维护时重点关注:
- docs/PERSISTENCE.md持久化结构是否与代码一致
- src/gateway/session.rs会话状态、会话路由和运行时服务编排
- src/gateway/session.rs消息流、工具注册、记忆维护、会话恢复主逻辑
- src/storage/mod.rsSQLite schema 变更
- src/config/mod.rs配置项变更是否同步到 README

View File

@ -10,8 +10,7 @@
"models": {
"default": {
"model_id": "<OPENAI_MODEL_NAME>",
"temperature": 0.7,
"context_window_tokens": 128000
"temperature": 0.2
}
},
"agents": {

View File

@ -1,11 +1,13 @@
use crate::bus::ChatMessage;
use crate::bus::message::ContentBlock;
use crate::bus::message::ToolMessageState;
use crate::config::LLMProviderConfig;
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::observability::{
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait;
@ -295,7 +297,9 @@ pub struct AgentLoop {
provider_config: LLMProviderConfig,
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
skills: Arc<SkillRuntime>,
skill_event_store: Option<Arc<SessionStore>>,
skill_event_session_id: Option<String>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
@ -313,19 +317,6 @@ pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
}
pub trait SkillProvider: Send + Sync + 'static {
fn system_index_prompt(&self) -> Option<String>;
}
#[derive(Default)]
struct EmptySkillProvider;
impl SkillProvider for EmptySkillProvider {
fn system_index_prompt(&self) -> Option<String> {
None
}
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
@ -336,7 +327,9 @@ impl AgentLoop {
provider_config,
provider,
tools: Arc::new(ToolRegistry::new()),
skills: Arc::new(EmptySkillProvider),
skills: Arc::new(SkillRuntime::default()),
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
@ -356,7 +349,9 @@ impl AgentLoop {
provider_config,
provider,
tools,
skills: Arc::new(EmptySkillProvider),
skills: Arc::new(SkillRuntime::default()),
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
@ -364,10 +359,10 @@ impl AgentLoop {
})
}
pub fn with_tools_and_skill_provider(
pub fn with_tools_and_skills(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
skills: Arc<SkillRuntime>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config.clone())
@ -378,6 +373,8 @@ impl AgentLoop {
provider,
tools,
skills,
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
@ -385,6 +382,12 @@ impl AgentLoop {
})
}
pub fn with_skill_event_store(mut self, store: Arc<SessionStore>, session_id: String) -> Self {
self.skill_event_store = Some(store);
self.skill_event_session_id = Some(session_id);
self
}
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
self.tool_context = context;
self
@ -440,7 +443,10 @@ impl AgentLoop {
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
// Build request
let tool_defs = self.tools.get_definitions();
let mut tool_defs = self.tools.get_definitions();
if let Some(skill_tool) = self.skills.skill_tool_definition() {
tool_defs.push(skill_tool);
}
let tools = if tool_defs.is_empty() {
None
} else {
@ -776,6 +782,46 @@ impl AgentLoop {
async fn execute_tool_internal(&self, tool_call: &ToolCall) -> ToolExecutionOutcome {
let normalized_arguments = normalize_tool_arguments(&tool_call.arguments);
if tool_call.name == "skill_activate" {
let skill_name = match normalized_arguments.get("name").and_then(|v| v.as_str()) {
Some(name) if !name.trim().is_empty() => name,
_ => {
self.record_skill_event(
"activation_failed",
None,
serde_json::json!({
"reason": "missing_name",
"arguments": normalized_arguments,
}),
);
return ToolExecutionOutcome::failure(
"Error: Missing required parameter: name".to_string(),
Some("Missing required parameter: name".to_string()),
);
}
};
return match self.skills.activation_payload(skill_name) {
Ok(output) => {
if let Ok(payload) = self.skills.activation_event_payload(skill_name) {
self.record_skill_event("activated", Some(skill_name), payload);
}
ToolExecutionOutcome::success(output)
}
Err(err) => {
self.record_skill_event(
"activation_failed",
Some(skill_name),
serde_json::json!({
"reason": err,
"arguments": normalized_arguments,
}),
);
ToolExecutionOutcome::failure(format!("Error: {}", err), Some(err))
}
};
}
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
@ -824,6 +870,26 @@ impl AgentLoop {
}
}
}
fn record_skill_event(
&self,
event_type: &str,
skill_name: Option<&str>,
payload: serde_json::Value,
) {
let (Some(store), Some(session_id)) = (
self.skill_event_store.as_ref(),
self.skill_event_session_id.as_ref(),
) else {
return;
};
if let Err(err) =
store.append_skill_event(Some(session_id), event_type, skill_name, &payload)
{
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
}
}
}
#[cfg(test)]

View File

@ -1,7 +1,5 @@
pub mod agent_loop;
pub mod context_compressor;
pub use agent_loop::{
AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler, SkillProvider,
};
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler};
pub use context_compressor::ContextCompressor;

View File

@ -1,12 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
/// Consumes outbound messages from MessageBus and dispatches them to channels.
/// OutboundDispatcher consumes outbound messages from the MessageBus
/// and dispatches them to the appropriate Channel
pub struct OutboundDispatcher {
bus: Arc<MessageBus>,
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
@ -20,6 +20,7 @@ impl OutboundDispatcher {
}
}
/// Register a channel with the dispatcher
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels
.write()
@ -27,6 +28,7 @@ impl OutboundDispatcher {
.insert(name.to_string(), channel);
}
/// Run the dispatcher loop - consumes from bus and dispatches to channels
pub async fn run(&self) {
tracing::info!("OutboundDispatcher started");
@ -45,8 +47,8 @@ impl OutboundDispatcher {
match channel {
Some(ch) => {
if let Err(error) = self.send_with_retry(&*ch, msg).await {
tracing::error!(channel = %channel_name, error = %error, "Failed to send message after retries");
if let Err(e) = self.send_with_retry(&*ch, msg).await {
tracing::error!(channel = %channel_name, error = %e, "Failed to send message after retries");
}
}
None => {
@ -56,6 +58,7 @@ impl OutboundDispatcher {
}
}
/// Send a message with exponential retry
async fn send_with_retry(
&self,
channel: &dyn Channel,
@ -63,22 +66,21 @@ impl OutboundDispatcher {
) -> Result<(), ChannelError> {
const DELAYS: [u64; 3] = [1, 2, 4];
for (attempt_index, delay) in DELAYS.iter().enumerate() {
for (i, delay) in DELAYS.iter().enumerate() {
match channel.send(msg.clone()).await {
Ok(()) => return Ok(()),
Err(error) if attempt_index < DELAYS.len() - 1 => {
Err(e) if i < DELAYS.len() - 1 => {
tracing::warn!(
attempt = attempt_index + 1,
attempt = i + 1,
delay = delay,
error = %error,
error = %e,
"Send failed, retrying"
);
tokio::time::sleep(tokio::time::Duration::from_secs(*delay)).await;
}
Err(error) => return Err(error),
Err(e) => return Err(e),
}
}
unreachable!()
}
}

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::domain::messages::ToolCall;
use crate::providers::ToolCall;
pub const SYSTEM_CONTEXT_AGENT_PROMPT: &str = "agent_prompt";
pub const SYSTEM_CONTEXT_SCHEDULED_PROMPT: &str = "scheduled_system_prompt";
@ -14,6 +14,38 @@ pub enum ToolMessageState {
PendingUserAction,
}
// ============================================================================
// ContentBlock - Multimodal content representation (OpenAI-style)
// ============================================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlBlock },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlBlock {
pub url: String,
}
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text {
text: content.into(),
}
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrlBlock { url: url.into() },
}
}
}
// ============================================================================
// MediaItem - Media metadata for messages
// ============================================================================
@ -534,7 +566,7 @@ fn current_timestamp() -> i64 {
#[cfg(test)]
mod tests {
use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState};
use crate::domain::messages::ToolCall;
use crate::providers::ToolCall;
use serde_json::json;
use std::collections::HashMap;

View File

@ -1,16 +1,18 @@
pub mod dispatcher;
pub mod message;
pub use crate::domain::messages::ContentBlock;
pub use dispatcher::OutboundDispatcher;
pub use message::{
ChatMessage, InboundMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_HISTORY_COMPACTION, SYSTEM_CONTEXT_SCHEDULED_PROMPT,
ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage,
SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
// ============================================================================
// MessageBus - async inbound/outbound queues
// MessageBus - Async message queue for Channel <-> Agent communication
// ============================================================================
pub struct MessageBus {
@ -33,7 +35,7 @@ impl MessageBus {
})
}
/// Publish a message to the inbound queue
/// Publish an inbound message (Channel -> Bus)
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
#[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, content_len = %msg.content.len(), media_count = %msg.media.len(), "Bus: publishing inbound message");
@ -43,7 +45,7 @@ impl MessageBus {
.map_err(|_| BusError::Closed)
}
/// Consume a message from the inbound queue
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self
.inbound_rx
@ -57,7 +59,7 @@ impl MessageBus {
msg
}
/// Publish a message to the outbound queue
/// Publish an outbound message (Agent -> Bus)
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
#[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, chat_id = %msg.chat_id, content_len = %msg.content.len(), "Bus: publishing outbound message");
@ -67,7 +69,7 @@ impl MessageBus {
.map_err(|_| BusError::Closed)
}
/// Consume an outbound message from the outbound queue
/// Consume an outbound message (Dispatcher -> Bus)
pub async fn consume_outbound(&self) -> OutboundMessage {
self.outbound_rx
.lock()

View File

@ -43,7 +43,7 @@ pub trait Channel: Send + Sync + 'static {
/// Stop the channel
async fn stop(&self) -> Result<(), ChannelError>;
/// Send a message to the channel (called by gateway outbound dispatch)
/// Send a message to the channel (called by OutboundDispatcher)
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError>;
/// Send a streaming delta (optional, for channels that support it)

View File

@ -1,155 +0,0 @@
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use crate::bus::{MessageBus, OutboundMessage};
use crate::protocol::WsOutbound;
use crate::protocol::ws_adapter::ws_outbound_from_outbound_message;
use super::base::{Channel, ChannelError};
#[derive(Clone)]
struct CliConnection {
connection_id: String,
sender: mpsc::Sender<WsOutbound>,
}
#[derive(Clone)]
pub struct CliChannel {
connections: Arc<RwLock<HashMap<String, CliConnection>>>,
}
impl CliChannel {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_connection(
&self,
session_id: impl Into<String>,
connection_id: impl Into<String>,
sender: mpsc::Sender<WsOutbound>,
) {
let session_id = session_id.into();
let connection_id = connection_id.into();
let previous = self.connections.write().await.insert(
session_id.clone(),
CliConnection {
connection_id: connection_id.clone(),
sender,
},
);
if previous.is_some() {
tracing::info!(session_id = %session_id, connection_id = %connection_id, "CLI session sender replaced");
}
}
pub async fn unregister_connection(&self, connection_id: &str) {
self.connections
.write()
.await
.retain(|_, connection| connection.connection_id != connection_id);
}
}
impl Default for CliChannel {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Channel for CliChannel {
fn name(&self) -> &str {
"cli"
}
fn is_running(&self) -> bool {
true
}
async fn start(&self, _bus: Arc<MessageBus>) -> Result<(), ChannelError> {
Ok(())
}
async fn stop(&self) -> Result<(), ChannelError> {
self.connections.write().await.clear();
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let connection = self.connections.read().await.get(&msg.chat_id).cloned();
let Some(connection) = connection else {
return Err(ChannelError::SendError(format!(
"No active CLI connection for session {}",
msg.chat_id
)));
};
for outbound in ws_outbound_from_outbound_message(&msg) {
connection
.sender
.send(outbound)
.await
.map_err(|_| ChannelError::SendError("CLI websocket sender closed".to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::OutboundMessage;
#[tokio::test]
async fn test_cli_channel_sends_to_registered_session() {
let channel = CliChannel::new();
let (sender, mut receiver) = mpsc::channel(4);
channel
.register_connection("session-1", "conn-1", sender)
.await;
channel
.send(OutboundMessage::assistant(
"cli",
"session-1",
"hello",
None,
HashMap::new(),
))
.await
.unwrap();
let outbound = receiver.recv().await.unwrap();
assert!(matches!(outbound, WsOutbound::AssistantResponse { .. }));
}
#[tokio::test]
async fn test_cli_channel_unregisters_connection_sessions() {
let channel = CliChannel::new();
let (sender, _receiver) = mpsc::channel(4);
channel
.register_connection("session-1", "conn-1", sender)
.await;
channel.unregister_connection("conn-1").await;
let error = channel
.send(OutboundMessage::assistant(
"cli",
"session-1",
"hello",
None,
HashMap::new(),
))
.await
.unwrap_err();
assert!(error.to_string().contains("No active CLI connection"));
}
}

View File

@ -4,7 +4,6 @@ use tokio::sync::RwLock;
use crate::bus::MessageBus;
use crate::channels::base::{Channel, ChannelError};
use crate::channels::cli::CliChannel;
use crate::channels::feishu::FeishuChannel;
use crate::config::Config;
@ -13,19 +12,13 @@ use crate::config::Config;
pub struct ChannelManager {
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
bus: Arc<MessageBus>,
cli_channel: Arc<CliChannel>,
}
impl ChannelManager {
pub fn new() -> Self {
let cli_channel = Arc::new(CliChannel::new());
let mut channels: HashMap<String, Arc<dyn Channel + Send + Sync>> = HashMap::new();
channels.insert("cli".to_string(), cli_channel.clone());
Self {
channels: Arc::new(RwLock::new(channels)),
channels: Arc::new(RwLock::new(HashMap::new())),
bus: MessageBus::new(100),
cli_channel,
}
}
@ -34,10 +27,6 @@ impl ChannelManager {
self.bus.clone()
}
pub fn cli_channel(&self) -> Arc<CliChannel> {
self.cli_channel.clone()
}
/// Initialize all Channel instances from config
pub async fn init(
&self,

View File

@ -1,9 +1,7 @@
pub mod base;
pub mod cli;
pub mod feishu;
pub mod manager;
pub use base::{Channel, ChannelError};
pub use cli::CliChannel;
pub use feishu::FeishuChannel;
pub use manager::ChannelManager;

View File

@ -1,36 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlBlock },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlBlock {
pub url: String,
}
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text {
text: content.into(),
}
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrlBlock { url: url.into() },
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}

View File

@ -1,2 +0,0 @@
pub mod messages;
pub mod tools;

View File

@ -1,15 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}

View File

@ -1,45 +0,0 @@
use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop, SkillProvider};
use crate::config::LLMProviderConfig;
use crate::storage::persistent_session_id;
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Clone)]
pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>,
skills: Arc<dyn SkillProvider>,
}
pub(crate) struct AgentBuildRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) sender_id: Option<&'a str>,
pub(crate) message_id: Option<&'a str>,
pub(crate) provider_config: LLMProviderConfig,
}
impl AgentFactory {
pub(crate) fn new(tools: Arc<ToolRegistry>, skills: Arc<dyn SkillProvider>) -> Self {
Self { tools, skills }
}
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
let session_id = persistent_session_id(request.channel_name, request.chat_id);
AgentLoop::with_tools_and_skill_provider(
request.provider_config,
self.tools.clone(),
self.skills.clone(),
)
.map(|agent| {
agent.with_tool_context(ToolContext {
channel_name: Some(request.channel_name.to_string()),
sender_id: request.sender_id.map(str::to_string),
chat_id: Some(request.chat_id.to_string()),
session_id: Some(session_id),
message_id: request.message_id.map(str::to_string),
message_seq: None,
})
})
}
}

View File

@ -1,102 +0,0 @@
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use crate::scheduler::{
AgentTaskExecutor as SchedulerAgentTaskExecutor, MaintenanceExecutor, MaintenanceRunSummary,
ScheduledAgentTaskOptions,
};
use async_trait::async_trait;
use super::memory_maintenance::MemoryMaintenanceScopeResult;
use super::session::SessionManager;
#[derive(Clone)]
pub struct AgentTaskExecutor {
session_manager: SessionManager,
}
impl AgentTaskExecutor {
pub fn new(session_manager: SessionManager) -> Self {
Self { session_manager }
}
async fn execute_agent_task(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
self.session_manager
.run_scheduled_agent_task(channel_name, chat_id, prompt, options)
.await
}
}
#[async_trait]
impl SchedulerAgentTaskExecutor for AgentTaskExecutor {
async fn execute(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> anyhow::Result<Vec<OutboundMessage>> {
self.execute_agent_task(channel_name, chat_id, prompt, options)
.await
.map_err(|error| anyhow::anyhow!(error.to_string()))
}
}
#[derive(Clone)]
pub struct SchedulerMaintenanceService {
session_manager: SessionManager,
}
impl SchedulerMaintenanceService {
pub fn new(session_manager: SessionManager) -> Self {
Self { session_manager }
}
async fn cleanup_sessions(&self) -> usize {
self.session_manager.cleanup_expired_sessions().await
}
async fn run_memory_maintenance(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
self.session_manager
.run_memory_maintenance_for_all_scopes(updated_since)
.await
}
}
#[async_trait]
impl MaintenanceExecutor for SchedulerMaintenanceService {
async fn cleanup_expired_sessions(&self) -> usize {
self.cleanup_sessions().await
}
async fn run_memory_maintenance_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> anyhow::Result<Vec<MaintenanceRunSummary>> {
self.run_memory_maintenance(updated_since)
.await
.map(|results| {
results
.into_iter()
.map(|result| MaintenanceRunSummary {
scope_key: result.scope_key,
user_facts: result.output.user_facts.len(),
preferences: result.output.preferences.len(),
behavior_patterns: result.output.behavior_patterns.len(),
merges: result.output.merges.len(),
conflicts: result.output.conflicts.len(),
low_value: result.output.low_value_ids.len(),
})
.collect()
})
.map_err(|error| anyhow::anyhow!(error.to_string()))
}
}

View File

@ -1,57 +0,0 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::{SessionRecord, SessionStore};
#[derive(Clone)]
pub(crate) struct CliSessionService {
store: Arc<SessionStore>,
}
impl CliSessionService {
pub(crate) fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
pub(crate) fn create(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub(crate) fn list(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub(crate) fn rename(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub(crate) fn archive(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub(crate) fn delete(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub(crate) fn clear_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
}

View File

@ -1,159 +0,0 @@
use crate::agent::AgentError;
use super::session::Session;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InChatCommand {
FreshConversation,
}
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
match content.trim() {
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
_ => None,
}
}
pub(crate) fn handle_in_chat_command(
session: &mut Session,
chat_id: &str,
content: &str,
) -> Result<Option<String>, AgentError> {
match parse_in_chat_command(content) {
Some(InChatCommand::FreshConversation) => {
session.reset_chat_context(chat_id)?;
Ok(Some("Started a fresh conversation.".to_string()))
}
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_parse_in_chat_command_aliases() {
assert_eq!(
parse_in_chat_command("/new"),
Some(InChatCommand::FreshConversation)
);
assert_eq!(
parse_in_chat_command(" /reset \n"),
Some(InChatCommand::FreshConversation)
);
assert_eq!(parse_in_chat_command("/new planning"), None);
assert_eq!(parse_in_chat_command("please /reset"), None);
}
#[tokio::test]
async fn test_handle_in_chat_command_resets_active_history_only() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
.unwrap()
.unwrap();
assert_eq!(response, "Started a fresh conversation.");
assert!(session.get_history("chat-1").unwrap().is_empty());
assert!(
store
.load_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.is_empty()
);
assert_eq!(
store
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
2,
);
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
#[tokio::test]
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store,
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
}

View File

@ -1,105 +0,0 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::AgentError;
use super::session::Session;
pub(crate) async fn schedule_background_history_compaction(
session: Arc<Mutex<Session>>,
chat_id: impl Into<String>,
) -> Result<(), AgentError> {
let chat_id = chat_id.into();
let snapshot = {
let mut session_guard = session.lock().await;
let session_record = session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
let compressor = session_guard.compressor().clone();
if !compressor.should_compress(&history) {
return Ok(());
}
if !session_guard.try_start_background_compaction(&chat_id) {
return Ok(());
}
(
session_guard.store(),
session_guard.persistent_session_id(&chat_id),
session_record.reset_cutoff_seq,
session_record.message_count,
history,
compressor,
session_guard.provider_config().clone(),
)
};
let (
store,
session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
history,
compressor,
provider_config,
) = snapshot;
let session_for_task = session.clone();
let chat_id_for_task = chat_id.clone();
tokio::spawn(async move {
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction");
let compaction_result = compressor
.build_compaction_plan(&history, &provider_config)
.await;
let mut committed = false;
match compaction_result {
Ok(Some(plan)) => match store.compact_active_history(
&session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
&plan.preserved_system_messages,
&plan.summary_message,
&plan.preserved_messages,
) {
Ok(true) => {
committed = true;
tracing::info!(
chat_id = %chat_id_for_task,
snapshot_end_seq,
compressed_turns = plan.compressed_turns,
preserved_turns = plan.preserved_turns,
"Background history compaction committed"
);
}
Ok(false) => {
tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Background history compaction skipped due to stale snapshot");
}
Err(error) => {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction commit failed");
}
},
Ok(None) => {
tracing::debug!(chat_id = %chat_id_for_task, "Background history compaction not needed after snapshot analysis");
}
Err(error) => {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Background history compaction build failed");
}
}
let mut session_guard = session_for_task.lock().await;
if committed {
if let Err(error) = session_guard.reload_chat_history(&chat_id_for_task) {
tracing::warn!(chat_id = %chat_id_for_task, error = %error, "Failed to reload history after background compaction");
}
}
session_guard.finish_background_compaction(&chat_id_for_task);
});
Ok(())
}

View File

@ -1,16 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::{AgentError, AgentProcessResult, EmittedMessageHandler};
use crate::agent::{AgentError, AgentProcessResult};
use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDULED_PROMPT};
use crate::bus::{ChatMessage, OutboundMessage};
use crate::config::LLMProviderConfig;
use tokio::sync::Mutex;
use super::command::handle_in_chat_command;
use super::compaction::schedule_background_history_compaction;
use super::message_prepare::enrich_user_content_with_media_refs;
use super::session::Session;
use super::session::{Session, schedule_background_history_compaction};
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
@ -27,6 +24,19 @@ pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>)
}
}
pub(crate) fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) struct AgentExecutionService {
show_tool_results: bool,
}
@ -46,28 +56,6 @@ pub(crate) struct FinalizedAgentResult {
pub(crate) should_schedule_compaction: bool,
}
pub(crate) struct MessageExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) content: &'a str,
pub(crate) media: Vec<MediaItem>,
pub(crate) live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
}
pub(crate) struct ScheduledExecutionRequest<'a> {
pub(crate) session: Arc<Mutex<Session>>,
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) prompt: &'a str,
pub(crate) sender_id: &'a str,
pub(crate) provider_config: LLMProviderConfig,
pub(crate) fresh_session: bool,
pub(crate) system_prompt: Option<&'a str>,
pub(crate) metadata: &'a HashMap<String, String>,
}
impl AgentExecutionService {
pub(crate) fn new(show_tool_results: bool) -> Self {
Self { show_tool_results }
@ -127,136 +115,6 @@ impl AgentExecutionService {
})
}
pub(crate) async fn prepare_and_execute_message(
&self,
request: MessageExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
session_guard.ensure_chat_loaded(request.chat_id)?;
if let Some(command_response) =
handle_in_chat_command(&mut session_guard, request.chat_id, request.content)?
{
return Ok(vec![OutboundMessage::assistant(
request.channel_name.to_string(),
request.chat_id.to_string(),
command_response,
None,
HashMap::new(),
)]);
}
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let media_refs: Vec<String> = request
.media
.iter()
.map(|media| media.path.clone())
.collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %request.media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let enriched_content =
enrich_user_content_with_media_refs(request.content, &media_refs)?;
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let mut agent = session_guard.create_agent(
request.chat_id,
Some(request.sender_id),
Some(&user_message.id),
)?;
if let Some(handler) = request.live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
(history, agent, user_message)
};
let result = agent.process(history).await?;
let metadata = HashMap::new();
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: &metadata,
suppress_live_tool_calls: request.live_emitter.is_some(),
execution_kind: "message",
},
)
.await
}
pub(crate) async fn prepare_and_execute_scheduled_task(
&self,
request: ScheduledExecutionRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let (history, agent, user_message) = {
let mut session_guard = request.session.lock().await;
session_guard.ensure_persistent_session(request.chat_id)?;
if request.fresh_session {
session_guard.reset_chat_context(request.chat_id)?;
}
session_guard.ensure_chat_loaded(request.chat_id)?;
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let scheduled_system_prompt =
compose_scheduled_task_system_prompt(request.system_prompt);
session_guard.append_persisted_message(
request.chat_id,
ChatMessage::system_with_context(
&scheduled_system_prompt,
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
),
)?;
let user_message = session_guard.create_user_message(request.prompt, Vec::new());
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
let history = session_guard.get_or_create_history(request.chat_id).clone();
session_guard.record_skill_offer(request.chat_id)?;
let agent = session_guard.create_agent_with_provider_config(
request.chat_id,
Some(request.sender_id),
Some(&user_message.id),
request.provider_config.clone(),
)?;
(history, agent, user_message)
};
let result = agent.process(history).await?;
self.finalize_result_and_schedule_compaction(
request.session.clone(),
FinalizeAgentResultRequest {
channel_name: request.channel_name,
chat_id: request.chat_id,
user_message: &user_message,
result,
metadata: request.metadata,
suppress_live_tool_calls: false,
execution_kind: "scheduled_task",
},
)
.await
}
pub(crate) async fn finalize_result_and_schedule_compaction(
&self,
session: Arc<Mutex<Session>>,
@ -312,6 +170,50 @@ mod tests {
use super::*;
use crate::bus::ChatMessage;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected =
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected =
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
#[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));

View File

@ -238,18 +238,8 @@ impl MemoryMaintenanceService {
let mut results = Vec::new();
for scope_key in scope_keys {
let output = match self.run_for_scope(&scope_key).await {
Ok(Some(output)) => output,
Ok(None) => continue,
Err(error) if is_recoverable_maintenance_scope_error(&error) => {
tracing::warn!(
scope_key = %scope_key,
error = %error,
"Memory maintenance skipped scope after recoverable model failure"
);
continue;
}
Err(error) => return Err(error),
let Some(output) = self.run_for_scope(&scope_key).await? else {
continue;
};
results.push(MemoryMaintenanceScopeResult { scope_key, output });
@ -329,10 +319,6 @@ pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|| normalized.contains("timeout")
}
fn is_recoverable_maintenance_scope_error(error: &AgentError) -> bool {
is_recoverable_maintenance_llm_error(&error.to_string())
}
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {

View File

@ -1,46 +0,0 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::SessionStore;
use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
};
use super::provider_config_service::ProviderConfigService;
#[derive(Clone)]
pub(crate) struct MemoryMaintenanceCoordinator {
store: Arc<SessionStore>,
provider_configs: ProviderConfigService,
}
impl MemoryMaintenanceCoordinator {
pub(crate) fn new(store: Arc<SessionStore>, provider_configs: ProviderConfigService) -> Self {
Self {
store,
provider_configs,
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn summarize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
self.service()?.summarize_for_scope(scope_key).await
}
pub(crate) async fn run_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
self.service()?.run_for_all_scopes(updated_since).await
}
fn service(&self) -> Result<MemoryMaintenanceService, AgentError> {
Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_configs.default_provider_config(),
))
}
}

View File

@ -1,39 +0,0 @@
use crate::agent::AgentError;
pub(crate) fn enrich_user_content_with_media_refs(
content: &str,
media_refs: &[String],
) -> Result<String, AgentError> {
if media_refs.is_empty() {
return Ok(content.to_string());
}
let media_refs_json = serde_json::to_string(media_refs)
.map_err(|err| AgentError::Other(format!("serialize media refs error: {}", err)))?;
Ok(format!("{content}\n\nmedia_refs_json: {media_refs_json}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
assert_eq!(
enriched,
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
);
}
#[test]
fn test_enrich_user_content_with_media_refs_keeps_plain_text_without_media() {
let enriched = enrich_user_content_with_media_refs("hello", &[]).unwrap();
assert_eq!(enriched, "hello");
}
}

View File

@ -1,26 +1,9 @@
pub mod agent_factory;
pub mod agent_task_executor;
pub mod cli_session;
pub mod command;
pub mod compaction;
pub mod execution;
pub mod http;
pub mod memory_maintenance;
pub mod memory_maintenance_coordinator;
pub mod message_prepare;
pub mod outbound_dispatcher;
pub mod processor;
pub mod prompt;
pub mod prompt_injector;
pub mod provider_config_service;
pub mod scheduled_agent_task_service;
pub mod session;
pub mod session_factory;
pub mod session_history;
pub mod session_lifecycle;
pub mod session_message_service;
pub mod session_pool;
pub mod tool_registry_factory;
pub mod ws;
use axum::{Router, routing};
@ -28,15 +11,13 @@ use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use crate::bus::MessageBus;
use crate::bus::{MessageBus, OutboundDispatcher};
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
use outbound_dispatcher::OutboundDispatcher;
use processor::InboundProcessor;
use session::SessionManager;
@ -138,8 +119,7 @@ pub async fn run(
state.config.scheduler.clone(),
timezone,
state.session_manager.store(),
AgentTaskExecutor::new(state.session_manager.clone()),
SchedulerMaintenanceService::new(state.session_manager.clone()),
state.session_manager.clone(),
);
tokio::spawn(async move {

View File

@ -1,6 +1,6 @@
use std::sync::Arc;
use crate::bus::{MessageBus, OutboundMessage};
use crate::bus::MessageBus;
use super::session::{BusToolCallEmitter, SessionManager};
@ -70,21 +70,6 @@ impl InboundProcessor {
}
Err(error) => {
tracing::error!(error = %error, "Failed to handle message");
let mut metadata = inbound.forwarded_metadata.clone();
metadata.insert("error_kind".to_string(), "agent_execution".to_string());
if let Err(publish_error) = self
.bus
.publish_outbound(OutboundMessage::error_notification(
inbound.channel,
inbound.chat_id,
error.to_string(),
None,
metadata,
))
.await
{
tracing::error!(error = %publish_error, "Failed to publish execution error outbound");
}
}
}
}

View File

@ -1,86 +0,0 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::{ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT};
use crate::storage::PromptInjectionRepository;
use super::prompt::load_agent_prompt;
#[derive(Clone)]
pub(crate) struct PromptInjector {
repository: Arc<dyn PromptInjectionRepository>,
reinject_every: i64,
}
impl PromptInjector {
pub(crate) fn new(repository: Arc<dyn PromptInjectionRepository>, reinject_every: u64) -> Self {
Self {
repository,
reinject_every: reinject_every as i64,
}
}
pub(crate) fn ensure_initial_prompt<F>(
&self,
history_is_empty: bool,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
}
Ok(())
}
pub(crate) fn ensure_reinjected_prompt<F>(
&self,
session_id: &str,
mut append_message: F,
) -> Result<(), AgentError>
where
F: FnMut(ChatMessage) -> Result<(), AgentError>,
{
let session_record = self
.repository
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = self
.repository
.count_active_user_messages(session_id)
.map_err(|err| {
AgentError::Other(format!("count active user messages error: {}", err))
})?;
if self.reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
append_message(Self::agent_prompt_message(agent_prompt))?;
self.repository
.mark_agent_prompt_reinjected(session_id)
.map_err(|err| {
AgentError::Other(format!("mark agent prompt reinjection error: {}", err))
})?;
}
}
Ok(())
}
fn agent_prompt_message(agent_prompt: String) -> ChatMessage {
ChatMessage::system_with_context(
agent_prompt,
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
)
}
}

View File

@ -1,90 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
#[derive(Clone)]
pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
}
impl ProviderConfigService {
pub(crate) fn new(
default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
) -> Self {
Self {
default_provider_config,
provider_configs: Arc::new(provider_configs),
}
}
pub(crate) fn select(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(self.default_provider_config.clone()),
Some(agent_name) => self
.provider_configs
.get(agent_name)
.cloned()
.ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(
default_provider,
HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]),
);
let selected = service.select(Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(default_provider, HashMap::new());
let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
}

View File

@ -1,57 +0,0 @@
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use crate::scheduler::ScheduledAgentTaskOptions;
use super::execution::{AgentExecutionService, ScheduledExecutionRequest};
use super::provider_config_service::ProviderConfigService;
use super::session_lifecycle::SessionLifecycleService;
#[derive(Clone)]
pub(crate) struct ScheduledAgentTaskService {
lifecycle: SessionLifecycleService,
provider_configs: ProviderConfigService,
show_tool_results: bool,
}
impl ScheduledAgentTaskService {
pub(crate) fn new(
lifecycle: SessionLifecycleService,
provider_configs: ProviderConfigService,
show_tool_results: bool,
) -> Self {
Self {
lifecycle,
provider_configs,
show_tool_results,
}
}
pub(crate) async fn run(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
let session = self.lifecycle.active_session(channel_name).await?;
let sender_id = options
.sender_id
.clone()
.unwrap_or_else(|| "scheduler".to_string());
let provider_config = self.provider_configs.select(options.agent.as_deref())?;
AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_scheduled_task(ScheduledExecutionRequest {
session,
channel_name,
chat_id,
prompt,
sender_id: &sender_id,
provider_config,
fresh_session: options.fresh_session,
system_prompt: options.system_prompt.as_deref(),
metadata: &options.metadata,
})
.await
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +0,0 @@
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, SkillEventRepository};
use super::agent_factory::AgentFactory;
use super::prompt_injector::PromptInjector;
use super::session::Session;
#[derive(Clone)]
pub(crate) struct SessionFactory {
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
}
impl SessionFactory {
pub(crate) fn new(
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
agent_factory: AgentFactory,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self {
Self {
provider_config,
skills,
agent_factory,
prompt_injector,
conversations,
skill_events,
}
}
pub(crate) async fn create(
&self,
channel_name: impl Into<String>,
user_tx: mpsc::Sender<WsOutbound>,
) -> Result<Session, AgentError> {
Session::with_factories(
channel_name.into(),
self.provider_config.clone(),
user_tx,
self.skills.clone(),
self.agent_factory.clone(),
self.prompt_injector.clone(),
self.conversations.clone(),
self.skill_events.clone(),
)
.await
}
}

View File

@ -1,267 +0,0 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::agent::AgentError;
use crate::bus::ChatMessage;
use crate::storage::{
ConversationRepository, SessionRecord, SkillEventRepository, persistent_session_id,
};
use super::prompt_injector::PromptInjector;
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
preview.push_str("...");
}
preview.replace('\n', "\\n")
}
pub(crate) struct SessionHistory {
channel_name: String,
chat_histories: HashMap<String, Vec<ChatMessage>>,
compression_in_flight: HashSet<String>,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
}
impl SessionHistory {
pub(crate) fn new(
channel_name: impl Into<String>,
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
) -> Self {
Self {
channel_name: channel_name.into(),
chat_histories: HashMap::new(),
compression_in_flight: HashSet::new(),
prompt_injector,
conversations,
skill_events,
}
}
pub(crate) fn persistent_session_id(&self, chat_id: &str) -> String {
persistent_session_id(&self.channel_name, chat_id)
}
pub(crate) fn ensure_persistent_session(
&self,
chat_id: &str,
) -> Result<SessionRecord, AgentError> {
self.conversations
.ensure_channel_session(&self.channel_name, chat_id)
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
}
pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
if self.chat_histories.contains_key(chat_id) {
return self.ensure_initial_agent_prompt(chat_id);
}
let history = self
.conversations
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
self.ensure_initial_agent_prompt(chat_id)?;
Ok(())
}
pub(crate) fn ensure_agent_prompt_before_user_message(
&mut self,
chat_id: &str,
) -> Result<(), AgentError> {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
self.append_persisted_message(chat_id, message)
})
}
pub(crate) fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
self.chat_histories.entry(chat_id.to_string()).or_default()
}
pub(crate) fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
self.chat_histories.get(chat_id)
}
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
self.get_or_create_history(chat_id).push(message);
}
pub(crate) fn remove_history(&mut self, chat_id: &str) {
self.chat_histories.remove(chat_id);
self.compression_in_flight.remove(chat_id);
}
pub(crate) fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
}
self.conversations
.clear_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
}
pub(crate) fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
}
self.conversations
.reset_session(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
}
pub(crate) fn append_persisted_message(
&mut self,
chat_id: &str,
message: ChatMessage,
) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id);
self.conversations
.append_message(&session_id, &message)
.map_err(|err| {
AgentError::Other(format!("append message persistence error: {}", err))
})?;
self.add_message(chat_id, message);
Ok(())
}
pub(crate) fn append_persisted_messages<I>(
&mut self,
chat_id: &str,
messages: I,
) -> Result<(), AgentError>
where
I: IntoIterator<Item = ChatMessage>,
{
for message in messages {
self.append_persisted_message(chat_id, message)?;
}
Ok(())
}
pub(crate) fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
self.get_history(chat_id)
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
}
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
self.latest_user_message(chat_id)
.map(|current| {
current.id == message.id
|| (current.content == message.content
&& current.timestamp == message.timestamp
&& current.media_refs == message.media_refs)
})
.unwrap_or(false)
}
pub(crate) fn stale_result_diagnostics(
&self,
chat_id: &str,
) -> (Option<&str>, Option<String>, bool, usize) {
let latest_user = self.latest_user_message(chat_id);
let latest_user_id = latest_user.map(|message| message.id.as_str());
let latest_user_preview = latest_user.map(|message| preview_text(&message.content, 80));
let compression_in_flight = self.compression_in_flight.contains(chat_id);
let history_len = self
.get_history(chat_id)
.map(|history| history.len())
.unwrap_or(0);
(
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
)
}
pub(crate) fn clear_all_history(&mut self) -> Result<(), AgentError> {
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear();
self.compression_in_flight.clear();
#[cfg(debug_assertions)]
tracing::debug!(previous_total = total, "All chat histories cleared");
for chat_id in chat_ids {
self.conversations
.clear_messages(&self.persistent_session_id(&chat_id))
.map_err(|err| {
AgentError::Other(format!("clear history persistence error: {}", err))
})?;
}
Ok(())
}
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
self.compression_in_flight.insert(chat_id.to_string())
}
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
self.compression_in_flight.remove(chat_id);
}
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history = self
.conversations
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
Ok(())
}
pub(crate) fn conversations(&self) -> Arc<dyn ConversationRepository> {
self.conversations.clone()
}
pub(crate) fn append_skill_event(
&self,
chat_id: &str,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), AgentError> {
self.skill_events
.append_skill_event(
Some(&self.persistent_session_id(chat_id)),
event_type,
skill_name,
payload,
)
.map_err(|err| AgentError::Other(format!("append skill event error: {}", err)))
}
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history_is_empty = self
.get_history(chat_id)
.map(|history| history.is_empty())
.unwrap_or(true);
if !history_is_empty {
return Ok(());
}
let prompt_injector = self.prompt_injector.clone();
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
self.append_persisted_message(chat_id, message)
})
}
}

View File

@ -1,49 +0,0 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::AgentError;
use super::session::Session;
use super::session_factory::SessionFactory;
use super::session_pool::SessionPool;
#[derive(Clone)]
pub(crate) struct SessionLifecycleService {
session_pool: SessionPool,
}
impl SessionLifecycleService {
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
Self {
session_pool: SessionPool::new(session_ttl_hours, session_factory),
}
}
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
self.session_pool.ensure_session(channel_name).await
}
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.session_pool.get(channel_name).await
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.session_pool.touch(channel_name).await;
}
pub(crate) async fn active_session(
&self,
channel_name: &str,
) -> Result<Arc<Mutex<Session>>, AgentError> {
self.ensure_session(channel_name).await?;
self.touch(channel_name).await;
self.get(channel_name)
.await
.ok_or_else(|| AgentError::Other("Session not found".to_string()))
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
self.session_pool.cleanup_expired_sessions().await
}
}

View File

@ -1,69 +0,0 @@
use std::sync::Arc;
use crate::agent::{AgentError, EmittedMessageHandler};
use crate::bus::{MediaItem, OutboundMessage};
use super::execution::{AgentExecutionService, MessageExecutionRequest};
use super::session_lifecycle::SessionLifecycleService;
#[derive(Clone)]
pub(crate) struct SessionMessageService {
lifecycle: SessionLifecycleService,
show_tool_results: bool,
}
impl SessionMessageService {
pub(crate) fn new(lifecycle: SessionLifecycleService, show_tool_results: bool) -> Self {
Self {
lifecycle,
show_tool_results,
}
}
pub(crate) async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
media_count = %media.len(),
"Routing message to agent"
);
for (i, m) in media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
}
}
let session = self.lifecycle.active_session(channel_name).await?;
let outbound_messages = AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_message(MessageExecutionRequest {
session,
channel_name,
sender_id,
chat_id,
content,
media,
live_emitter,
})
.await?;
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
outbound_count = outbound_messages.len(),
"Agent response sequence received"
);
Ok(outbound_messages)
}
}

View File

@ -1,109 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use crate::agent::AgentError;
use crate::protocol::WsOutbound;
use super::session::Session;
use super::session_factory::SessionFactory;
#[derive(Clone)]
pub(crate) struct SessionPool {
inner: Arc<Mutex<SessionPoolInner>>,
session_factory: SessionFactory,
}
struct SessionPoolInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
impl SessionPool {
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionPoolInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
session_factory,
}
}
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
inner.sessions.remove(channel_name);
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = self
.session_factory
.create(channel_name.to_string(), user_tx)
.await?;
inner
.sessions
.insert(channel_name.to_string(), Arc::new(Mutex::new(session)));
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
}
pub(crate) async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
self.inner.lock().await.sessions.get(channel_name).cloned()
}
pub(crate) async fn touch(&self, channel_name: &str) {
self.inner
.lock()
.await
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
}
}

View File

@ -1,63 +0,0 @@
use std::collections::HashSet;
use std::sync::Arc;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillActivateTool, SkillListTool,
SkillManageTool, TimeTool, ToolRegistry, WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
}
impl ToolRegistryFactory {
pub(crate) fn new(
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
known_agents: HashSet<String>,
default_timezone: String,
) -> Self {
Self {
skills,
store,
known_agents,
default_timezone,
}
}
pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(TimeTool::new(self.default_timezone.clone()));
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(MemorySearchTool::new(self.store.clone()));
registry.register(MemoryManageTool::new(self.store.clone()));
registry.register(SchedulerManageTool::new(
self.store.clone(),
self.known_agents.clone(),
));
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.store.clone(),
));
registry.register(SkillListTool::new(self.skills.clone()));
registry.register(SkillManageTool::new(self.skills.clone()));
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry
}
}

View File

@ -1,16 +1,36 @@
use super::GatewayState;
use crate::agent::AgentError;
use crate::bus::InboundMessage;
use super::{
GatewayState,
session::{Session, handle_in_chat_command, schedule_background_history_compaction},
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use async_trait::async_trait;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::{Mutex, mpsc};
const CLI_CHANNEL_NAME: &str = "cli";
struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>,
show_tool_results: bool,
}
#[async_trait]
impl EmittedMessageHandler for WsToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
if !should_display_message_to_user(self.show_tool_results, &message) {
return;
}
for outbound in ws_outbound_from_chat_message(&message) {
let _ = self.sender.send(outbound).await;
}
}
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
@ -21,8 +41,15 @@ pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewaySta
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let cli_sessions = state.session_manager.cli_sessions();
let initial_record = match cli_sessions.create(None) {
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
return;
}
};
let initial_record = match state.session_manager.create_cli_session(None) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
@ -30,20 +57,39 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
};
let runtime_session_id = uuid::Uuid::new_v4().to_string();
let channel_name = "cli".to_string();
// 创建 CLI session
let session = match Session::new(
channel_name.clone(),
provider_config,
sender,
state.session_manager.tools(),
state.session_manager.skills(),
state.session_manager.store(),
state.config.gateway.agent_prompt_reinject_every,
)
.await
{
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
return;
}
};
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
return;
}
let runtime_session_id = session.lock().await.id.to_string();
let mut current_session_id = initial_record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
current_session_id.clone(),
runtime_session_id.clone(),
sender.clone(),
)
.await;
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = sender
let _ = session
.lock()
.await
.send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(),
})
@ -73,7 +119,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
Ok(inbound) => {
if let Err(e) = handle_inbound(
&state,
&sender,
&session,
&runtime_session_id,
&mut current_session_id,
inbound,
@ -81,7 +127,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
.await
{
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = sender
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(),
message: e.to_string(),
@ -91,7 +139,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = sender
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
@ -109,11 +159,6 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
}
state
.channel_manager
.cli_channel()
.unregister_connection(&runtime_session_id)
.await;
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
@ -129,9 +174,79 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
}
}
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
let mut outbound = Vec::new();
if !message.content.trim().is_empty() {
outbound.push(WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
});
}
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
}],
},
_ => Vec::new(),
}
}
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
async fn handle_inbound(
state: &Arc<GatewayState>,
sender: &mpsc::Sender<WsOutbound>,
session: &Arc<Mutex<Session>>,
runtime_session_id: &str,
current_session_id: &mut String,
inbound: WsInbound,
@ -145,31 +260,84 @@ async fn handle_inbound(
} => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
let (history, agent, user_tx) = {
let mut session_guard = session.lock().await;
state
.channel_manager
.cli_channel()
.register_connection(
chat_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
state
.bus
.publish_inbound(InboundMessage {
channel: CLI_CHANNEL_NAME.to_string(),
sender_id,
chat_id,
content,
timestamp: current_timestamp(),
media: Vec::new(),
metadata: HashMap::new(),
forwarded_metadata: HashMap::new(),
})
.await
.map_err(|error| AgentError::Other(error.to_string()))?;
if let Some(command_response) =
handle_in_chat_command(&mut session_guard, &chat_id, &content)?
{
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: command_response,
role: "assistant".to_string(),
})
.await;
return Ok(());
}
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
session_guard.record_skill_offer(&chat_id)?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
show_tool_results: state.config.gateway.show_tool_results,
});
let agent = session_guard
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
(history, agent, session_guard.user_tx.clone())
};
match agent.process(history).await {
Ok(result) => {
let mut session_guard = session.lock().await;
session_guard
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result
.emitted_messages
.iter()
.filter(|message| {
!message.is_assistant_tool_call_message()
&& should_display_message_to_user(
state.config.gateway.show_tool_results,
message,
)
})
.flat_map(ws_outbound_from_chat_message)
{
let _ = session_guard.send(outbound).await;
}
drop(session_guard);
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone())
.await
{
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
}
}
Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
let _ = user_tx
.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: error.to_string(),
})
.await;
}
}
Ok(())
}
@ -180,37 +348,22 @@ async fn handle_inbound(
let target = session_id
.or(chat_id)
.unwrap_or_else(|| current_session_id.clone());
state
.session_manager
.cli_sessions()
.clear_messages(&target)?;
state.session_manager.clear_session_messages(&target)?;
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
session.lock().await.remove_history(&target);
}
let _ = sender
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::HistoryCleared { session_id: target })
.await;
Ok(())
}
WsInbound::CreateSession { title } => {
let record = state
.session_manager
.cli_sessions()
.create(title.as_deref())?;
let record = state.session_manager.create_cli_session(title.as_deref())?;
*current_session_id = record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
@ -219,13 +372,11 @@ async fn handle_inbound(
Ok(())
}
WsInbound::ListSessions { include_archived } => {
let records = state
.session_manager
.cli_sessions()
.list(include_archived)?;
let records = state.session_manager.list_cli_sessions(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect();
let _ = sender
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionList {
sessions: summaries,
current_session_id: Some(current_session_id.clone()),
@ -234,8 +385,9 @@ async fn handle_inbound(
Ok(())
}
WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
let _ = sender
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::Error {
code: "SESSION_NOT_FOUND".to_string(),
message: format!("Session not found: {}", session_id),
@ -245,16 +397,9 @@ async fn handle_inbound(
};
*current_session_id = record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionLoaded {
session_id: record.id,
title: record.title,
@ -265,11 +410,9 @@ async fn handle_inbound(
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state
.session_manager
.cli_sessions()
.rename(&target, &title)?;
let _ = sender
state.session_manager.rename_session(&target, &title)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionRenamed {
session_id: target,
title,
@ -279,27 +422,26 @@ async fn handle_inbound(
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.cli_sessions().archive(&target)?;
let _ = sender
state.session_manager.archive_session(&target)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target })
.await;
Ok(())
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.cli_sessions().delete(&target)?;
state.session_manager.delete_session(&target)?;
let replacement = if target == *current_session_id {
Some(state.session_manager.cli_sessions().create(None)?)
Some(state.session_manager.create_cli_session(None)?)
} else {
None
};
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
session.lock().await.remove_history(&target);
}
let _ = sender
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::SessionDeleted {
session_id: target.clone(),
})
@ -307,16 +449,8 @@ async fn handle_inbound(
if let Some(record) = replacement {
*current_session_id = record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
@ -327,19 +461,13 @@ async fn handle_inbound(
Ok(())
}
WsInbound::Ping => {
let _ = sender.send(WsOutbound::Pong).await;
let session_guard = session.lock().await;
let _ = session_guard.send(WsOutbound::Pong).await;
Ok(())
}
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
sender_id
.map(str::trim)
@ -350,7 +478,106 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
#[cfg(test)]
mod tests {
use super::resolve_ws_sender_id;
use super::{
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
ws_outbound_from_chat_message,
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::protocol::WsOutbound;
use crate::providers::ToolCall;
use serde_json::json;
use tokio::sync::mpsc;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
let message = ChatMessage::assistant_with_tool_calls(
"日报已整理完成。",
vec![ToolCall {
id: "call-1".to_string(),
name: "memory_manage".to_string(),
arguments: json!({"action": "put"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 2);
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
let completed = ChatMessage::tool("call-1", "calculator", "2");
let pending = ChatMessage::tool_with_state(
"call-2",
"bash",
"waiting",
ToolMessageState::PendingUserAction,
);
assert!(!should_display_message_to_user(false, &completed));
assert!(should_display_message_to_user(false, &pending));
assert!(should_display_message_to_user(true, &completed));
}
#[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
@ -369,4 +596,23 @@ mod tests {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
}
#[tokio::test]
async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() {
let (sender, mut receiver) = mpsc::channel(4);
let emitter = WsToolCallEmitter {
sender,
show_tool_results: false,
};
emitter
.handle(ChatMessage::tool("call-1", "calculator", "2"))
.await;
assert!(
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
.await
.is_err()
);
}
}

View File

@ -4,7 +4,6 @@ pub mod channels;
pub mod cli;
pub mod client;
pub mod config;
pub mod domain;
pub mod gateway;
pub mod logging;
pub mod observability;

View File

@ -1,5 +1,3 @@
pub mod ws_adapter;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@ -1,230 +0,0 @@
#[cfg(test)]
use crate::bus::ChatMessage;
use crate::bus::OutboundMessage;
use crate::bus::message::OutboundEventKind;
#[cfg(test)]
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use super::WsOutbound;
const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。";
#[cfg(test)]
pub(crate) fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
let mut outbound = Vec::new();
if !message.content.trim().is_empty() {
outbound.push(WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
});
}
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
}],
},
_ => Vec::new(),
}
}
pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Vec<WsOutbound> {
match message.event_kind {
OutboundEventKind::AssistantResponse | OutboundEventKind::SchedulerNotification => {
vec![WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
OutboundEventKind::ToolCall => vec![WsOutbound::ToolCall {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
arguments: message
.tool_arguments
.clone()
.unwrap_or(serde_json::Value::Null),
content: message.content.clone(),
role: message.role.clone(),
}],
OutboundEventKind::ToolResult => vec![WsOutbound::ToolResult {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
OutboundEventKind::ToolPending => vec![WsOutbound::ToolPending {
id: message
.tool_call_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: TOOL_PENDING_RESUME_HINT.to_string(),
}],
OutboundEventKind::ErrorNotification => vec![WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: message.content.clone(),
}],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::messages::ToolCall;
use serde_json::json;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
let message = ChatMessage::assistant_with_tool_calls(
"日报已整理完成。",
vec![ToolCall {
id: "call-1".to_string(),
name: "memory_manage".to_string(),
arguments: json!({"action": "put"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 2);
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
}
#[test]
fn test_ws_outbound_from_outbound_message_maps_tool_call() {
let message = OutboundMessage::tool_call(
"cli",
"session-1",
"call-1",
"calculator",
json!({"expression": "1 + 1"}),
None,
Default::default(),
);
let outbound = ws_outbound_from_outbound_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
}

View File

@ -6,7 +6,7 @@ use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::domain::messages::ContentBlock;
use crate::bus::message::ContentBlock;
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];

View File

@ -6,9 +6,10 @@ pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider;
use crate::config::LLMProviderConfig;
pub use crate::domain::messages::ToolCall;
pub use crate::domain::tools::{Tool, ToolFunction};
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Usage};
pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
ToolFunction, Usage,
};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() {

View File

@ -7,7 +7,7 @@ use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::domain::messages::ContentBlock;
use crate::bus::message::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
@ -23,23 +23,6 @@ fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
details.join("\ncaused by: ")
}
fn format_transport_error_context(
provider_name: &str,
model_id: &str,
url: &str,
timeout_secs: u64,
error: &(dyn std::error::Error + 'static),
) -> String {
format!(
"transport error: provider={} model={} url={} timeout_secs={} details={}",
provider_name,
model_id,
url,
timeout_secs,
format_error_chain(error)
)
}
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] {
@ -311,25 +294,7 @@ impl LLMProvider for OpenAIProvider {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await.map_err(|err| {
let error_context = format_transport_error_context(
&self.name,
&self.model_id,
&url,
self.llm_timeout_secs,
&err,
);
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
base_url = %self.base_url,
timeout_secs = self.llm_timeout_secs,
error = %error_context,
"OpenAI-compatible API transport request failed"
);
error_context
})?;
let resp = req_builder.json(&body).send().await?;
let status = resp.status();
let text = resp.text().await?;

View File

@ -1,5 +1,4 @@
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::domain::tools::Tool;
use crate::bus::message::ContentBlock;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
@ -78,6 +77,27 @@ impl Message {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub messages: Vec<Message>,

View File

@ -2,7 +2,6 @@ use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Duration as ChronoDuration, TimeZone, Utc};
use chrono_tz::Tz;
use tokio::sync::watch;
@ -12,81 +11,34 @@ use crate::config::{
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
SchedulerMisfirePolicy, SchedulerSchedule,
};
use crate::gateway::session::ScheduledAgentTaskOptions;
use crate::gateway::session::SessionManager;
use crate::storage::{
SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert,
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
};
#[derive(Debug, Clone, Default)]
pub struct ScheduledAgentTaskOptions {
pub sender_id: Option<String>,
pub fresh_session: bool,
pub system_prompt: Option<String>,
pub metadata: HashMap<String, String>,
pub agent: Option<String>,
}
#[derive(Debug, Clone)]
pub struct MaintenanceRunSummary {
pub scope_key: String,
pub user_facts: usize,
pub preferences: usize,
pub behavior_patterns: usize,
pub merges: usize,
pub conflicts: usize,
pub low_value: usize,
}
#[async_trait]
pub trait AgentTaskExecutor: Send + Sync {
async fn execute(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> anyhow::Result<Vec<OutboundMessage>>;
}
#[async_trait]
pub trait MaintenanceExecutor: Send + Sync {
async fn cleanup_expired_sessions(&self) -> usize;
async fn run_memory_maintenance_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> anyhow::Result<Vec<MaintenanceRunSummary>>;
}
pub struct Scheduler {
bus: Arc<MessageBus>,
config: SchedulerConfig,
timezone: Tz,
jobs: Arc<dyn SchedulerJobRepository>,
agent_task_executor: Arc<dyn AgentTaskExecutor>,
maintenance_executor: Arc<dyn MaintenanceExecutor>,
store: Arc<SessionStore>,
session_manager: SessionManager,
}
impl Scheduler {
pub fn new<A, M>(
pub fn new(
bus: Arc<MessageBus>,
config: SchedulerConfig,
timezone: Tz,
jobs: Arc<dyn SchedulerJobRepository>,
agent_task_executor: A,
maintenance_executor: M,
) -> Self
where
A: AgentTaskExecutor + 'static,
M: MaintenanceExecutor + 'static,
{
store: Arc<SessionStore>,
session_manager: SessionManager,
) -> Self {
Self {
bus,
config,
timezone,
jobs,
agent_task_executor: Arc::new(agent_task_executor),
maintenance_executor: Arc::new(maintenance_executor),
store,
session_manager,
}
}
@ -129,14 +81,14 @@ impl Scheduler {
}) {
let runtime =
RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
self.jobs.upsert_scheduler_job(&runtime.to_upsert())?;
self.store.upsert_scheduler_job(&runtime.to_upsert())?;
}
Ok(())
}
async fn process_tick(&self) -> anyhow::Result<()> {
let now = Utc::now();
let jobs = self.jobs.list_scheduler_jobs(true)?;
let jobs = self.store.list_scheduler_jobs(true)?;
for record in jobs {
let Some(mut job) =
@ -146,7 +98,7 @@ impl Scheduler {
};
if record.next_fire_at.is_none() && job.next_fire_at.is_some() {
self.jobs.update_scheduler_job_runtime(
self.store.update_scheduler_job_runtime(
&job.id,
job.state.clone(),
job.last_status.clone(),
@ -163,7 +115,7 @@ impl Scheduler {
continue;
}
self.jobs.update_scheduler_job_runtime(
self.store.update_scheduler_job_runtime(
&job.id,
SchedulerJobState::Running,
job.last_status.clone(),
@ -193,7 +145,7 @@ impl Scheduler {
tracing::error!(job_id = %job.id, error = %error, "Scheduler job failed");
}
self.jobs.update_scheduler_job_runtime(
self.store.update_scheduler_job_runtime(
&job.id,
job.state.clone(),
status,
@ -216,11 +168,11 @@ impl Scheduler {
self.bus.publish_outbound(message).await?;
}
SchedulerJobKind::InternalEvent => {
execute_internal_event(self.maintenance_executor.as_ref(), job).await?;
execute_internal_event(&self.session_manager, job).await?;
}
SchedulerJobKind::AgentTask => {
let outbound_messages = execute_agent_task(
self.agent_task_executor.as_ref(),
&self.session_manager,
job,
required_notification_chat_id(job, "agent_task")?,
)
@ -232,8 +184,7 @@ impl Scheduler {
SchedulerJobKind::SilentAgentTask => {
let execution_chat_id = resolve_execution_chat_id(job)?;
if let Err(error) =
execute_agent_task(self.agent_task_executor.as_ref(), job, &execution_chat_id)
.await
execute_agent_task(&self.session_manager, job, &execution_chat_id).await
{
if let Err(notify_error) =
self.notify_silent_agent_task_failure(job, &error).await
@ -636,7 +587,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
}
async fn execute_internal_event(
maintenance_executor: &dyn MaintenanceExecutor,
session_manager: &SessionManager,
job: &RuntimeJob,
) -> anyhow::Result<()> {
let event = job
@ -647,24 +598,24 @@ async fn execute_internal_event(
match event {
"session_cleanup" => {
let removed = maintenance_executor.cleanup_expired_sessions().await;
let removed = session_manager.cleanup_expired_sessions().await;
tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed");
Ok(())
}
"memory_maintenance" => {
let results = maintenance_executor
let results = session_manager
.run_memory_maintenance_for_all_scopes(job.last_fired_at)
.await?;
for result in &results {
tracing::info!(
job_id = %job.id,
scope_key = %result.scope_key,
user_facts = result.user_facts,
preferences = result.preferences,
behavior_patterns = result.behavior_patterns,
merges = result.merges,
conflicts = result.conflicts,
low_value = result.low_value,
user_facts = result.output.user_facts.len(),
preferences = result.output.preferences.len(),
behavior_patterns = result.output.behavior_patterns.len(),
merges = result.output.merges.len(),
conflicts = result.output.conflicts.len(),
low_value = result.output.low_value_ids.len(),
"Scheduler completed memory maintenance model run"
);
}
@ -676,7 +627,7 @@ async fn execute_internal_event(
}
async fn execute_agent_task(
agent_task_executor: &dyn AgentTaskExecutor,
session_manager: &SessionManager,
job: &RuntimeJob,
execution_chat_id: &str,
) -> anyhow::Result<Vec<OutboundMessage>> {
@ -692,9 +643,10 @@ async fn execute_agent_task(
.ok_or_else(|| anyhow::anyhow!("agent_task payload.prompt must be a string"))?;
let options = parse_scheduled_agent_task_options(job)?;
agent_task_executor
.execute(channel_name, execution_chat_id, prompt, options)
session_manager
.run_scheduled_agent_task(channel_name, execution_chat_id, prompt, options)
.await
.map_err(|error| anyhow::anyhow!(error.to_string()))
}
fn required_notification_chat_id<'a>(
@ -1011,44 +963,43 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
mod tests {
use super::*;
use crate::bus::MessageBus;
use crate::config::BUILTIN_MEMORY_MAINTENANCE_JOB_ID;
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
use crate::gateway::session::SessionManager;
use crate::skills::SkillRuntime;
use crate::storage::{SchedulerJobUpsert, SessionStore};
use std::collections::HashMap;
#[derive(Clone)]
struct TestAgentTaskExecutor;
#[async_trait::async_trait]
impl AgentTaskExecutor for TestAgentTaskExecutor {
async fn execute(
&self,
_channel_name: &str,
_chat_id: &str,
_prompt: &str,
_options: ScheduledAgentTaskOptions,
) -> anyhow::Result<Vec<OutboundMessage>> {
Ok(Vec::new())
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "default".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 30,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: None,
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 4,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[derive(Clone)]
struct TestMaintenanceExecutor;
#[async_trait::async_trait]
impl MaintenanceExecutor for TestMaintenanceExecutor {
async fn cleanup_expired_sessions(&self) -> usize {
0
}
async fn run_memory_maintenance_for_all_scopes(
&self,
_updated_since: Option<i64>,
) -> anyhow::Result<Vec<MaintenanceRunSummary>> {
Ok(Vec::new())
}
}
fn test_scheduler_services() -> (TestAgentTaskExecutor, TestMaintenanceExecutor) {
(TestAgentTaskExecutor, TestMaintenanceExecutor)
fn test_session_manager() -> SessionManager {
let provider_config = test_provider_config();
SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
)
.unwrap()
}
#[test]
@ -1178,7 +1129,7 @@ mod tests {
})
.unwrap();
let (agent_task_executor, maintenance_service) = test_scheduler_services();
let session_manager = test_session_manager();
let scheduler = Scheduler::new(
MessageBus::new(8),
SchedulerConfig {
@ -1190,8 +1141,7 @@ mod tests {
},
chrono_tz::Asia::Shanghai,
store.clone(),
agent_task_executor,
maintenance_service,
session_manager,
);
scheduler.process_tick().await.unwrap();
@ -1209,14 +1159,13 @@ mod tests {
fn sync_config_jobs_persists_builtin_memory_maintenance_job() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (agent_task_executor, maintenance_service) = test_scheduler_services();
let session_manager = test_session_manager();
let scheduler = Scheduler::new(
MessageBus::new(8),
SchedulerConfig::default(),
chrono_tz::Asia::Shanghai,
store.clone(),
agent_task_executor,
maintenance_service,
session_manager,
);
scheduler.sync_config_jobs().unwrap();
@ -1255,7 +1204,6 @@ mod tests {
async fn silent_agent_task_failure_notifies_primary_chat() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let bus = MessageBus::new(8);
let (agent_task_executor, maintenance_service) = test_scheduler_services();
let scheduler = Scheduler::new(
bus.clone(),
SchedulerConfig {
@ -1267,8 +1215,7 @@ mod tests {
},
chrono_tz::Asia::Shanghai,
store,
agent_task_executor,
maintenance_service,
test_session_manager(),
);
let job = RuntimeJob {

View File

@ -6,6 +6,7 @@ use std::path::{Path, PathBuf};
use std::sync::RwLock;
use crate::config::SkillsConfig;
use crate::providers::{Tool, ToolFunction};
#[derive(Debug, Clone)]
pub struct Skill {
@ -119,6 +120,13 @@ impl SkillRuntime {
.offered_event_payload()
}
pub fn skill_tool_definition(&self) -> Option<Tool> {
self.catalog
.read()
.expect("skills rwlock poisoned")
.skill_tool_definition()
}
pub fn activation_payload(&self, name: &str) -> Result<String, String> {
self.catalog
.read()
@ -222,12 +230,6 @@ impl SkillRuntime {
}
}
impl crate::agent::SkillProvider for SkillRuntime {
fn system_index_prompt(&self) -> Option<String> {
SkillRuntime::system_index_prompt(self)
}
}
impl SkillSource {
fn as_str(&self) -> &'static str {
match self {
@ -342,6 +344,30 @@ impl SkillCatalog {
self.catalog_event_payload()
}
pub fn skill_tool_definition(&self) -> Option<Tool> {
if self.skills.is_empty() {
return None;
}
Some(Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "skill_activate".to_string(),
description: "Load detailed instructions for a named skill discovered from SKILL.md files. Use when a task matches a listed skill description.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Skill name from the available skills list"
}
},
"required": ["name"]
}),
},
})
}
pub fn activation_payload(&self, name: &str) -> Result<String, String> {
let skill = self
.find_skill(name)
@ -653,6 +679,31 @@ mod tests {
assert!(err.contains("description"));
}
#[test]
fn test_skill_tool_definition_exists_when_skills_present() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().join(".picobot").join("skills").join("demo");
fs::create_dir_all(&root).unwrap();
fs::write(
root.join("SKILL.md"),
"---\ndescription: demo skill\n---\nDo demo",
)
.unwrap();
let skills = load_skills_from_root(
&dir.path().join(".picobot").join("skills"),
SkillSource::Project,
);
let catalog = SkillCatalog {
skills,
max_index_chars: 4000,
max_listed_skills: 10,
};
let tool = catalog.skill_tool_definition().unwrap();
assert_eq!(tool.function.name, "skill_activate");
}
#[test]
fn test_activation_payload_contains_body() {
let dir = tempfile::tempdir().unwrap();

View File

@ -1,9 +0,0 @@
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}

View File

@ -3,28 +3,193 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use rusqlite::{Connection, OptionalExtension, params};
use serde::{Deserialize, Serialize};
use crate::bus::ChatMessage;
pub mod error;
pub mod ports;
pub mod records;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillEventRecord {
pub id: String,
pub session_id: Option<String>,
pub event_type: String,
pub skill_name: Option<String>,
pub payload: serde_json::Value,
pub created_at: i64,
}
pub use error::StorageError;
pub use ports::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SkillEventRepository,
};
pub use records::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SessionRecord, SkillEventRecord,
};
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRecord {
pub id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub summary: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub last_active_at: i64,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: String,
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct MemoryUpsert {
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobState {
Scheduled,
Running,
Paused,
Completed,
}
impl SchedulerJobState {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobState::Scheduled => "scheduled",
SchedulerJobState::Running => "running",
SchedulerJobState::Paused => "paused",
SchedulerJobState::Completed => "completed",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"scheduled" => Some(Self::Scheduled),
"running" => Some(Self::Running),
"paused" => Some(Self::Paused),
"completed" => Some(Self::Completed),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobStatus {
Ok,
Error,
Skipped,
}
impl SchedulerJobStatus {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobStatus::Ok => "ok",
SchedulerJobStatus::Error => "error",
SchedulerJobStatus::Skipped => "skipped",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"ok" => Some(Self::Ok),
"error" => Some(Self::Error),
"skipped" => Some(Self::Skipped),
_ => None,
}
}
}
impl Default for SchedulerJobState {
fn default() -> Self {
Self::Scheduled
}
}
#[derive(Clone)]
pub struct SessionStore {
conn: Arc<Mutex<Connection>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerJobRecord {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct SchedulerJobUpsert {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
}
impl SessionStore {
#[cfg(test)]
pub fn new() -> Result<Self, StorageError> {
@ -1637,7 +1802,7 @@ fn quote_fts_or_query(queries: &[String]) -> String {
mod tests {
use super::*;
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
use crate::domain::messages::ToolCall;
use crate::providers::ToolCall;
#[test]
fn test_persistent_session_id_for_cli_and_channel() {

View File

@ -1,304 +0,0 @@
use super::{
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
SchedulerJobUpsert, SessionRecord, SkillEventRecord, StorageError,
};
use crate::bus::ChatMessage;
pub trait ConversationRepository: Send + Sync + 'static {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError>;
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>;
fn reset_session(&self, session_id: &str) -> Result<(), StorageError>;
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError>;
}
pub trait PromptInjectionRepository: Send + Sync + 'static {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError>;
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError>;
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError>;
}
pub trait MemoryRepository: Send + Sync + 'static {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>;
fn update_memory(&self, input: &MemoryUpsert) -> Result<Option<MemoryRecord>, StorageError>;
fn delete_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<bool, StorageError>;
fn get_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<Option<MemoryRecord>, StorageError>;
fn list_memories(
&self,
scope_kind: &str,
scope_key: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError>;
fn search_memories_any(
&self,
scope_kind: &str,
scope_key: &str,
queries: &[String],
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError>;
}
pub trait SchedulerJobRepository: Send + Sync + 'static {
fn upsert_scheduler_job(
&self,
input: &SchedulerJobUpsert,
) -> Result<SchedulerJobRecord, StorageError>;
fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError>;
fn list_scheduler_jobs(
&self,
enabled_only: bool,
) -> Result<Vec<SchedulerJobRecord>, StorageError>;
fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError>;
fn update_scheduler_job_runtime(
&self,
job_id: &str,
state: SchedulerJobState,
last_status: Option<SchedulerJobStatus>,
last_error: Option<&str>,
run_count: i64,
last_fired_at: Option<i64>,
next_fire_at: Option<i64>,
paused_at: Option<i64>,
completed_at: Option<i64>,
) -> Result<(), StorageError>;
}
pub trait SkillEventRepository: Send + Sync + 'static {
fn append_skill_event(
&self,
session_id: Option<&str>,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), StorageError>;
fn list_skill_events(
&self,
session_id: Option<&str>,
) -> Result<Vec<SkillEventRecord>, StorageError>;
}
impl ConversationRepository for super::SessionStore {
fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
) -> Result<SessionRecord, StorageError> {
super::SessionStore::ensure_channel_session(self, channel_name, chat_id)
}
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
super::SessionStore::load_messages(self, session_id)
}
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
super::SessionStore::append_message(self, session_id, message)
}
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::clear_messages(self, session_id)
}
fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::reset_session(self, session_id)
}
fn compact_active_history(
&self,
session_id: &str,
expected_reset_cutoff_seq: i64,
snapshot_end_seq: i64,
preserved_system_messages: &[ChatMessage],
summary_message: &ChatMessage,
preserved_messages: &[ChatMessage],
) -> Result<bool, StorageError> {
super::SessionStore::compact_active_history(
self,
session_id,
expected_reset_cutoff_seq,
snapshot_end_seq,
preserved_system_messages,
summary_message,
preserved_messages,
)
}
}
impl PromptInjectionRepository for super::SessionStore {
fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
super::SessionStore::get_session(self, session_id)
}
fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
super::SessionStore::count_active_user_messages(self, session_id)
}
fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
super::SessionStore::mark_agent_prompt_reinjected(self, session_id)
}
}
impl MemoryRepository for super::SessionStore {
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
super::SessionStore::put_memory(self, input)
}
fn update_memory(&self, input: &MemoryUpsert) -> Result<Option<MemoryRecord>, StorageError> {
super::SessionStore::update_memory(self, input)
}
fn delete_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<bool, StorageError> {
super::SessionStore::delete_memory(self, scope_kind, scope_key, namespace, memory_key)
}
fn get_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<Option<MemoryRecord>, StorageError> {
super::SessionStore::get_memory(self, scope_kind, scope_key, namespace, memory_key)
}
fn list_memories(
&self,
scope_kind: &str,
scope_key: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
super::SessionStore::list_memories(self, scope_kind, scope_key, namespace, limit)
}
fn search_memories_any(
&self,
scope_kind: &str,
scope_key: &str,
queries: &[String],
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
super::SessionStore::search_memories_any(
self, scope_kind, scope_key, queries, namespace, limit,
)
}
}
impl SchedulerJobRepository for super::SessionStore {
fn upsert_scheduler_job(
&self,
input: &SchedulerJobUpsert,
) -> Result<SchedulerJobRecord, StorageError> {
super::SessionStore::upsert_scheduler_job(self, input)
}
fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError> {
super::SessionStore::get_scheduler_job(self, job_id)
}
fn list_scheduler_jobs(
&self,
enabled_only: bool,
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
super::SessionStore::list_scheduler_jobs(self, enabled_only)
}
fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError> {
super::SessionStore::delete_scheduler_job(self, job_id)
}
fn update_scheduler_job_runtime(
&self,
job_id: &str,
state: SchedulerJobState,
last_status: Option<SchedulerJobStatus>,
last_error: Option<&str>,
run_count: i64,
last_fired_at: Option<i64>,
next_fire_at: Option<i64>,
paused_at: Option<i64>,
completed_at: Option<i64>,
) -> Result<(), StorageError> {
super::SessionStore::update_scheduler_job_runtime(
self,
job_id,
state,
last_status,
last_error,
run_count,
last_fired_at,
next_fire_at,
paused_at,
completed_at,
)
}
}
impl SkillEventRepository for super::SessionStore {
fn append_skill_event(
&self,
session_id: Option<&str>,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) -> Result<(), StorageError> {
super::SessionStore::append_skill_event(self, session_id, event_type, skill_name, payload)
}
fn list_skill_events(
&self,
session_id: Option<&str>,
) -> Result<Vec<SkillEventRecord>, StorageError> {
super::SessionStore::list_skill_events(self, session_id)
}
}

View File

@ -1,169 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillEventRecord {
pub id: String,
pub session_id: Option<String>,
pub event_type: String,
pub skill_name: Option<String>,
pub payload: serde_json::Value,
pub created_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRecord {
pub id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub summary: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub last_active_at: i64,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: String,
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct MemoryUpsert {
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobState {
Scheduled,
Running,
Paused,
Completed,
}
impl SchedulerJobState {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobState::Scheduled => "scheduled",
SchedulerJobState::Running => "running",
SchedulerJobState::Paused => "paused",
SchedulerJobState::Completed => "completed",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"scheduled" => Some(Self::Scheduled),
"running" => Some(Self::Running),
"paused" => Some(Self::Paused),
"completed" => Some(Self::Completed),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobStatus {
Ok,
Error,
Skipped,
}
impl SchedulerJobStatus {
pub fn as_str(&self) -> &'static str {
match self {
SchedulerJobStatus::Ok => "ok",
SchedulerJobStatus::Error => "error",
SchedulerJobStatus::Skipped => "skipped",
}
}
pub fn from_str(value: &str) -> Option<Self> {
match value {
"ok" => Some(Self::Ok),
"error" => Some(Self::Error),
"skipped" => Some(Self::Skipped),
_ => None,
}
}
}
impl Default for SchedulerJobState {
fn default() -> Self {
Self::Scheduled
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerJobRecord {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct SchedulerJobUpsert {
pub id: String,
pub kind: String,
pub schedule: serde_json::Value,
pub interval_secs: i64,
pub startup_delay_secs: i64,
pub target: serde_json::Value,
pub payload: serde_json::Value,
pub enabled: bool,
pub state: SchedulerJobState,
pub last_status: Option<SchedulerJobStatus>,
pub last_error: Option<String>,
pub run_count: i64,
pub max_runs: Option<i64>,
pub last_fired_at: Option<i64>,
pub next_fire_at: Option<i64>,
pub paused_at: Option<i64>,
pub completed_at: Option<i64>,
}

View File

@ -3,16 +3,16 @@ use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::storage::{MemoryRecord, MemoryRepository, MemoryUpsert};
use crate::storage::{MemoryRecord, MemoryUpsert, SessionStore};
use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct MemoryManageTool {
memories: Arc<dyn MemoryRepository>,
store: Arc<SessionStore>,
}
impl MemoryManageTool {
pub fn new(memories: Arc<dyn MemoryRepository>) -> Self {
Self { memories }
pub fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
}
@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
}
fn description(&self) -> &str {
"Create, update, or delete long-term user memories in the configured memory repository. Supports actions: put, update, delete. Use memory_search as the default retrieval path before answering most requests, and use memory_search for all retrieval actions including search, get, and list. Only call this tool when you have determined that a high-value long-term memory should be created, overwritten, updated, or deleted. Memories are scoped to the current channel and sender, and record the originating session/message when available."
"Create, update, or delete long-term user memories stored in SQLite. Supports actions: put, update, delete. Use memory_search as the default retrieval path before answering most requests, and use memory_search for all retrieval actions including search, get, and list. Only call this tool when you have determined that a high-value long-term memory should be created, overwritten, updated, or deleted. Memories are scoped to the current channel and sender, and record the originating session/message when available."
}
fn parameters_schema(&self) -> serde_json::Value {
@ -80,7 +80,7 @@ impl Tool for MemoryManageTool {
Ok(input) => input,
Err(result) => return Ok(result),
};
memory_to_json(self.memories.put_memory(&input)?)
memory_to_json(self.store.put_memory(&input)?)
}
"update" => {
let input = match build_memory_upsert(context, &scope_key, &args, false) {
@ -88,7 +88,7 @@ impl Tool for MemoryManageTool {
Err(result) => return Ok(result),
};
match self.memories.update_memory(&input)? {
match self.store.update_memory(&input)? {
Some(memory) => memory_to_json(memory),
None => {
return Ok(error_result(&format!(
@ -109,7 +109,7 @@ impl Tool for MemoryManageTool {
};
let deleted = self
.memories
.store
.delete_memory("user", &scope_key, namespace, key)?;
if !deleted {
return Ok(error_result(&format!(
@ -219,7 +219,6 @@ fn error_result(message: &str) -> ToolResult {
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
#[tokio::test]
async fn test_memory_manage_put_returns_saved_memory() {

View File

@ -3,16 +3,16 @@ use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::storage::{MemoryRecord, MemoryRepository};
use crate::storage::{MemoryRecord, SessionStore};
use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct MemorySearchTool {
memories: Arc<dyn MemoryRepository>,
store: Arc<SessionStore>,
}
impl MemorySearchTool {
pub fn new(memories: Arc<dyn MemoryRepository>) -> Self {
Self { memories }
pub fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
}
@ -23,7 +23,7 @@ impl Tool for MemorySearchTool {
}
fn description(&self) -> &str {
"Search and read long-term user memories from the configured memory repository. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory."
"Search and read long-term user memories stored in SQLite. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory."
}
fn parameters_schema(&self) -> serde_json::Value {
@ -91,7 +91,7 @@ impl Tool for MemorySearchTool {
.and_then(|value| value.as_u64())
.unwrap_or(10) as usize;
let memories = self
.memories
.store
.list_memories("user", &scope_key, namespace, limit)?;
json!({
"count": memories.len(),
@ -117,7 +117,7 @@ impl Tool for MemorySearchTool {
.and_then(|value| value.as_u64())
.unwrap_or(10) as usize;
let memories = self
.memories
.store
.search_memories_any("user", &scope_key, &queries, namespace, limit)?;
json!({
"queries": queries,
@ -135,10 +135,7 @@ impl Tool for MemorySearchTool {
None => return Ok(error_result("Missing required parameter: key")),
};
match self
.memories
.get_memory("user", &scope_key, namespace, key)?
{
match self.store.get_memory("user", &scope_key, namespace, key)? {
Some(memory) => memory_to_json(memory),
None => {
return Ok(error_result(&format!(
@ -205,7 +202,6 @@ fn error_result(message: &str) -> ToolResult {
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
#[tokio::test]
async fn test_memory_search_search_and_get() {

View File

@ -9,7 +9,6 @@ pub mod memory_search;
pub mod registry;
pub mod scheduler_manage;
pub mod schema;
pub mod skill_activate;
pub mod skill_manage;
pub mod time;
pub mod traits;
@ -26,7 +25,6 @@ pub use memory_search::MemorySearchTool;
pub use registry::ToolRegistry;
pub use scheduler_manage::SchedulerManageTool;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use skill_activate::SkillActivateTool;
pub use skill_manage::{SkillListTool, SkillManageTool};
pub use time::TimeTool;
pub use traits::{Tool, ToolContext, ToolResult};

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use crate::domain::tools::{Tool, ToolFunction};
use crate::providers::{Tool, ToolFunction};
use super::traits::Tool as ToolTrait;

View File

@ -5,20 +5,18 @@ use async_trait::async_trait;
use serde_json::json;
use crate::config::SchedulerSchedule;
use crate::storage::{
SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobUpsert,
};
use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore};
use crate::tools::traits::{Tool, ToolResult};
pub struct SchedulerManageTool {
jobs: Arc<dyn SchedulerJobRepository>,
store: Arc<SessionStore>,
known_agents: Arc<HashSet<String>>,
}
impl SchedulerManageTool {
pub fn new(jobs: Arc<dyn SchedulerJobRepository>, known_agents: HashSet<String>) -> Self {
pub fn new(store: Arc<SessionStore>, known_agents: HashSet<String>) -> Self {
Self {
jobs,
store,
known_agents: Arc::new(known_agents),
}
}
@ -31,7 +29,7 @@ impl Tool for SchedulerManageTool {
}
fn description(&self) -> &str {
"Manage repository-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing."
"Manage DB-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs persist in SQLite and are executed by the scheduler runtime. When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing."
}
fn parameters_schema(&self) -> serde_json::Value {
@ -118,30 +116,30 @@ impl Tool for SchedulerManageTool {
.get("enabled_only")
.and_then(|value| value.as_bool())
.unwrap_or(false);
let jobs = self.jobs.list_scheduler_jobs(enabled_only)?;
let jobs = self.store.list_scheduler_jobs(enabled_only)?;
json!(jobs.iter().map(record_to_json).collect::<Vec<_>>())
}
"get" => {
let id = require_str(&args, "id")?;
match self.jobs.get_scheduler_job(id)? {
match self.store.get_scheduler_job(id)? {
Some(record) => record_to_json(&record),
None => return Ok(error_result(&format!("scheduler job '{}' not found", id))),
}
}
"put" => {
let input = build_upsert(context, &args, &self.known_agents)?;
let record = self.jobs.upsert_scheduler_job(&input)?;
let record = self.store.upsert_scheduler_job(&input)?;
record_to_json(&record)
}
"delete" => {
let id = require_str(&args, "id")?;
self.jobs.delete_scheduler_job(id)?;
self.store.delete_scheduler_job(id)?;
json!({"status": "deleted", "id": id})
}
"pause" => {
let id = require_str(&args, "id")?;
let record = self
.jobs
.store
.get_scheduler_job(id)?
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
let mut input = record_to_upsert(&record);
@ -149,13 +147,13 @@ impl Tool for SchedulerManageTool {
input.state = SchedulerJobState::Paused;
input.paused_at = Some(current_timestamp());
input.next_fire_at = None;
let saved = self.jobs.upsert_scheduler_job(&input)?;
let saved = self.store.upsert_scheduler_job(&input)?;
record_to_json(&saved)
}
"resume" => {
let id = require_str(&args, "id")?;
let record = self
.jobs
.store
.get_scheduler_job(id)?
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
let mut input = record_to_upsert(&record);
@ -164,7 +162,7 @@ impl Tool for SchedulerManageTool {
input.paused_at = None;
input.completed_at = None;
input.next_fire_at = None;
let saved = self.jobs.upsert_scheduler_job(&input)?;
let saved = self.store.upsert_scheduler_job(&input)?;
record_to_json(&saved)
}
_ => return Ok(error_result("Unsupported action")),
@ -433,7 +431,6 @@ fn current_timestamp() -> i64 {
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
#[tokio::test]
async fn test_scheduler_manage_put_and_get() {

View File

@ -1,151 +0,0 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::skills::SkillRuntime;
use crate::storage::SkillEventRepository;
use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct SkillActivateTool {
skills: Arc<SkillRuntime>,
events: Arc<dyn SkillEventRepository>,
}
impl SkillActivateTool {
pub fn new(skills: Arc<SkillRuntime>, events: Arc<dyn SkillEventRepository>) -> Self {
Self { skills, events }
}
fn record_event(
&self,
context: &ToolContext,
event_type: &str,
skill_name: Option<&str>,
payload: &serde_json::Value,
) {
if let Err(err) = self.events.append_skill_event(
context.session_id.as_deref(),
event_type,
skill_name,
payload,
) {
tracing::warn!(error = %err, event_type, skill_name, "Failed to record skill activation event");
}
}
}
#[async_trait]
impl Tool for SkillActivateTool {
fn name(&self) -> &str {
"skill_activate"
}
fn description(&self) -> &str {
"Load detailed instructions for a named skill discovered from SKILL.md files. Use when a task matches a listed skill description."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Skill name from the available skills list"
}
},
"required": ["name"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
self.execute_with_context(&ToolContext::default(), args)
.await
}
async fn execute_with_context(
&self,
context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
let skill_name = match args.get("name").and_then(|value| value.as_str()) {
Some(name) if !name.trim().is_empty() => name,
_ => {
self.record_event(
context,
"activation_failed",
None,
&json!({
"reason": "missing_name",
"arguments": args,
}),
);
return Ok(error_result("Missing required parameter: name"));
}
};
match self.skills.activation_payload(skill_name) {
Ok(output) => {
if let Ok(payload) = self.skills.activation_event_payload(skill_name) {
self.record_event(context, "activated", Some(skill_name), &payload);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(err) => {
self.record_event(
context,
"activation_failed",
Some(skill_name),
&json!({
"reason": err,
"arguments": args,
}),
);
Ok(error_result(&err))
}
}
}
}
fn error_result(message: &str) -> ToolResult {
ToolResult {
success: false,
output: String::new(),
error: Some(message.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
#[tokio::test]
async fn test_skill_activate_records_failed_activation_event() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
store.ensure_channel_session("feishu", "chat-1").unwrap();
let tool = SkillActivateTool::new(skills, store.clone());
let context = ToolContext {
session_id: Some("feishu:chat-1".to_string()),
..ToolContext::default()
};
let result = tool
.execute_with_context(&context, json!({ "name": "demo" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
let events = store.list_skill_events(Some("feishu:chat-1")).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "activation_failed");
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));
}
}