Fix CLI session routing and dialog controls

This commit is contained in:
xiaoski 2026-06-15 23:44:41 +08:00
parent 0d66536e90
commit c6f4392e63
8 changed files with 1391 additions and 545 deletions

View File

@ -0,0 +1,359 @@
# PicoBot 代码质量分析报告
审查日期2026-06-15
## 结论摘要
PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只做收发MessageBus 解耦输入输出SessionManager 管理会话AgentLoop 保持无状态并执行工具Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
当前主要质量风险集中在三类:
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
3. 工具和后台任务的资源边界偏弱文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
## 修复状态
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
## 主要发现
### 已修复CLI 会话路由会破坏会话连续性和多客户端隔离
位置:
- `src/channels/cli_chat.rs:113-126`
- `src/channels/cli_chat.rs:160-164`
- `src/channels/cli_chat.rs:225-249`
- `src/channels/cli_chat.rs:479-494`
- `src/session/session.rs:1305-1310`
问题:
`Client.current_session_id` 存的是完整 session id但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
影响:
- CLI 多轮对话可能落入不同 chat scope。
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
- 多个 CLI 客户端同时连接时存在串话。
建议:
- 将 client 状态拆成 `chat_id``current_session_id`,不要混用。
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
- `send()``OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
- `LoadSession` 应直接切换到指定 session或通过 `SwitchDialog` 使用其中的 `dialog_id`
- 为 CLI WebSocket 增加多客户端路由测试。
### 已修复Dialog 控制接口与协议承诺不一致
位置:
- `src/session/session.rs:996-997`
- `src/session/session.rs:1305-1310`
- `src/session/session.rs:1329-1349`
- `src/session/session.rs:1378-1384`
- `src/channels/cli_chat.rs:128-158`
问题:
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
- `/delete` 只创建新 session并没有删除当前 session。
- `get_current_dialog()` 固定返回 `Ok(None)`
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`
- `archive_dialog()` 是空操作。
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`
影响:
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
建议:
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`
- `list_dialogs()` 返回真实 current dialog并补上 archived 模型或移除 archived 参数。
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
位置:
- `src/tools/mod.rs:56-62`
- `src/tools/path_utils.rs:3-23`
- `src/tools/bash.rs:146-185`
问题:
文件工具默认通过 `FileReadTool::new()``FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize不能防御 `..`、符号链接等路径逃逸。
`bash` 默认工作目录是 `"."`Gateway 启动时切到 workspace这对相对路径有效但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
影响:
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP风险会放大。
建议:
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
- `resolve_path()` 使用 `std::fs::canonicalize``path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 中高优先级Session 锁内执行过多异步操作
位置:
- `src/session/session.rs:1001-1018`
- `src/session/session.rs:1604-1711`
问题:
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
影响:
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
建议:
- 锁内只做内存状态快照和必要的状态标记。
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
### 中优先级Bash 超时不会显式终止子进程
位置:
- `src/tools/bash.rs:150-174`
- `src/tools/bash.rs:180-207`
问题:
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
影响:
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
建议:
- 使用 `tokio::process::Child``kill_on_drop(true)`
- 超时分支显式 kill child 并 wait。
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
### 中优先级:文件读取对大二进制文件没有输出上限
位置:
- `src/tools/file_read.rs:121-131`
- `src/tools/file_read.rs:214-229`
问题:
`file_read``std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
影响:
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
建议:
- 先检查 metadata size超过阈值直接返回提示。
- 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。
- 对文本读取也改成流式按行读取,而不是整文件读入。
### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置:
- `src/tools/http_request.rs:31-59`
问题:
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
影响:
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
建议:
- 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
- 对 `http_request``web_fetch` 复用同一套 URL 安全策略。
### 中优先级:后台任务和主循环缺少监督与优雅关闭
位置:
- `src/bus/mod.rs:51-99`
- `src/gateway/mod.rs:187-244`
- `src/gateway/mod.rs:247-266`
问题:
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
影响:
- 某个后台 loop 异常退出后Gateway 不一定能发现。
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。
建议:
- 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。
- 用 `CancellationToken` 贯穿 Gateway 子任务。
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置:
- `src/scheduler/mod.rs:18-40`
问题:
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)``upcoming(tz)` 的当前时间。
影响:
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
建议:
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
- 增加固定时间输入的单元测试,避免受系统时间影响。
### 中低优先级:存在未接入或半接入代码,增加维护噪音
位置:
- `src/tools/pty.rs`
- `src/tools/mod.rs:1-20`
- `src/tools/mod.rs:49-88`
问题:
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty``create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
影响:
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
建议:
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
- 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。
## 架构评价
### 做得好的地方
- 模块分层方向清楚Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
- Provider 抽象简单直接OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
- Storage 集中初始化 schema便于部署单二进制应用。
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
### 主要架构债务
- SessionManager 承担过多职责会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
- 后台任务生命周期分散gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn缺少统一管理。
## 模块级分析
### gateway
`GatewayState::new()` 是清晰的装配中心配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
### channels
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel核心问题是会话身份混用和广播投递。
### bus
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
### session
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
- `manager.rs`SessionManager 状态和 dialog 生命周期
- `worker.rs`per-session agent worker 和 cancellation
- `commands.rs`slash command 执行
- `outbound.rs`OutboundMessenger 实现
- `restore.rs`storage 恢复与 tool call chain repair
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
### agent
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义`read_only()` 目前是工具自己声明副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard不应替代工具层的资源限制。
### providers
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response可能包含用户隐私、API 返回内容和文件内容。
### tools
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里默认值容易失控。
### storage
Storage schema 初始化实用但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”适合早期迭代不适合长期演进。建议引入 schema version 表或 sqlx migrations至少把每次迁移记录下来。
### skills
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
## 建议修复路线
### P0先修会话正确性
1. 修正 CLI `chat_id/current_session_id` 数据模型。
2. 修正 CLI 出站按 client/chat_id 投递。
3. 实现 `get_current_dialog()``list_dialogs()` current 返回。
4. 修正 `/delete``clear_history``archive` 的真实行为或从协议移除。
5. 增加 WebSocket session lifecycle 测试。
### P1收紧工具和资源边界
1. 文件工具默认限制 workspace路径 canonicalize。
2. bash 超时杀进程,必要时引入进程组。
3. file_read 增加文件大小上限和二进制输出上限。
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
5. 明确高危工具的配置开关。
### P2降低架构复杂度
1. 拆分 `session.rs``feishu.rs``storage/mod.rs``browser.rs`
2. 引入任务 supervisor 和统一 shutdown token。
3. 引入正式数据库迁移。
4. 增加工具注册快照测试,避免死代码和文档漂移。
## 建议测试补充
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
- `/delete` 后旧 session 软删除、新 session 成为 current。
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
- bash timeout 后检查子进程不存在。
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::sync::{mpsc, Mutex}; use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage}; use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId}; use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
use super::base::{Channel, ChannelError}; use super::base::{Channel, ChannelError};
@ -14,6 +14,7 @@ use super::base::{Channel, ChannelError};
pub(crate) struct Client { pub(crate) struct Client {
sender: mpsc::Sender<WsOutbound>, sender: mpsc::Sender<WsOutbound>,
chat_id: String,
current_session_id: Mutex<Option<String>>, current_session_id: Mutex<Option<String>>,
} }
@ -41,23 +42,28 @@ impl CliChatChannel {
} }
/// Register a new client connection, returns (session_id, client) /// Register a new client connection, returns (session_id, client)
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) { pub(crate) async fn register_client(
// Generate connection ID (used as chat_id) - use short ID &self,
let connection_id = crate::util::short_id(); sender: mpsc::Sender<WsOutbound>,
) -> (String, Arc<Client>) {
// Each WebSocket connection gets a stable chat scope. All user input and
// dialog controls for this client stay inside that scope unless the
// protocol explicitly carries a full session id.
let chat_id = crate::util::short_id();
let client = Arc::new(Client { let client = Arc::new(Client {
sender, sender,
chat_id: chat_id.clone(),
current_session_id: Mutex::new(None), current_session_id: Mutex::new(None),
}); });
self.clients.lock().await.push(client.clone()); self.clients.lock().await.push(client.clone());
// Create initial session via control message // Create initial session via control message
let session_id = match self.create_session_via_control(&connection_id, None).await { let session_id = match self.create_session_via_control(&chat_id, None).await {
Ok(id) => id, Ok((id, _title)) => id,
Err(e) => { Err(e) => {
tracing::error!(error = %e, "Failed to create initial session"); tracing::error!(error = %e, "Failed to create initial session");
// Fall back to old format for backward compatibility UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
connection_id.clone()
} }
}; };
@ -73,21 +79,19 @@ impl CliChatChannel {
/// Handle an inbound message from a client /// Handle an inbound message from a client
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) { pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
match parse_inbound(raw_msg) { match parse_inbound(raw_msg) {
Ok(inbound) => { Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
match self.handle_ws_inbound(client.clone(), inbound).await { Ok(()) => {}
Ok(()) => {} Err(e) => {
Err(e) => { tracing::warn!(error = %e, "Failed to handle inbound message");
tracing::warn!(error = %e, "Failed to handle inbound message"); let _ = client
let _ = client .sender
.sender .send(WsOutbound::Error {
.send(WsOutbound::Error { code: "INTERNAL_ERROR".to_string(),
code: "INTERNAL_ERROR".to_string(), message: e.to_string(),
message: e.to_string(), })
}) .await;
.await;
}
} }
} },
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message"); tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = client let _ = client
@ -101,22 +105,30 @@ impl CliChatChannel {
} }
} }
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> { async fn handle_ws_inbound(
&self,
client: Arc<Client>,
inbound: WsInbound,
) -> Result<(), ChannelError> {
let bus = { let bus = {
let guard = self.bus.lock().unwrap(); let guard = self.bus.lock().unwrap();
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
}; };
let mut current_session_guard = client.current_session_id.lock().await; let mut current_session_guard = client.current_session_id.lock().await;
match inbound { match inbound {
WsInbound::UserInput { content, chat_id, .. } => { WsInbound::UserInput {
content, chat_id, ..
} => {
// All messages (including slash commands) go through the normal inbound flow // All messages (including slash commands) go through the normal inbound flow
// SessionManager handles session creation/reuse internally // SessionManager handles session creation/reuse internally
let msg = InboundMessage { let msg = InboundMessage {
channel: self.name().to_string(), channel: self.name().to_string(),
sender_id: "cli".to_string(), sender_id: "cli".to_string(),
chat_id: chat_id.unwrap_or_else(crate::util::short_id), chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
content, content,
timestamp: crate::bus::message::current_timestamp(), timestamp: crate::bus::message::current_timestamp(),
media: Vec::new(), media: Vec::new(),
@ -125,19 +137,56 @@ impl CliChatChannel {
}; };
bus.publish_inbound(msg).await?; bus.publish_inbound(msg).await?;
} }
WsInbound::ClearHistory { chat_id, session_id } => { WsInbound::ClearHistory {
let target = session_id chat_id,
.or(chat_id) session_id,
.or(current_session_guard.clone()) } => {
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let session_id = UnifiedSessionId::parse(&target) let session_id = if let Some(session_id) = session_id {
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; UnifiedSessionId::parse(&session_id).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
} else if let Some(chat_id) = chat_id {
let (current_tx, mut current_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog {
channel: "cli_chat".to_string(),
chat_id,
},
reply_tx: current_tx,
})
.await?;
match current_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog {
session_id: Some(session_id),
})) => session_id,
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
return Err(ChannelError::Other("No active session".to_string()));
}
Some(Ok(_)) => {
return Err(ChannelError::Other(
"Unexpected response type".to_string(),
));
}
Some(Err(e)) => return Err(e),
None => {
return Err(ChannelError::Other("Control channel closed".to_string()));
}
}
} else {
let target = current_session_guard
.clone()
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
UnifiedSessionId::parse(&target).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
};
let target = session_id.to_string();
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::ClearHistory { session_id }, op: SessionCommand::ClearHistory { session_id },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::HistoryCleared { .. })) => { Some(Ok(SessionEvent::HistoryCleared { .. })) => {
@ -158,24 +207,21 @@ impl CliChatChannel {
} }
} }
WsInbound::CreateSession { title } => { WsInbound::CreateSession { title } => {
// Use current session's chat_id if available, otherwise generate new one let (new_id, created_title) = self
let chat_id = current_session_guard.clone() .create_session_via_control(&client.chat_id, title.as_deref())
.unwrap_or_else(crate::util::short_id); .await?;
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
*current_session_guard = Some(new_id.clone()); *current_session_guard = Some(new_id.clone());
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionCreated { .send(WsOutbound::SessionCreated {
session_id: new_id, session_id: new_id,
title: title.unwrap_or_default(), title: created_title,
}) })
.await; .await;
} }
WsInbound::ListSessions { include_archived } => { WsInbound::ListSessions { include_archived } => {
// List dialogs for the current chat // List dialogs for the current chat
let chat_id = current_session_guard.clone() let chat_id = client.chat_id.clone();
.unwrap_or_else(|| "".to_string());
let chat_id_for_response = chat_id.clone();
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::ListDialogs { op: SessionCommand::ListDialogs {
@ -184,13 +230,18 @@ impl CliChatChannel {
include_archived, include_archived,
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => { Some(Ok(SessionEvent::DialogList {
dialogs,
current_dialog_id,
})) => {
// Convert DialogInfo to SessionSummary for backward compatibility // Convert DialogInfo to SessionSummary for backward compatibility
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| { let sessions: Vec<crate::protocol::SessionSummary> = dialogs
crate::protocol::SessionSummary { .into_iter()
.map(|d| crate::protocol::SessionSummary {
session_id: d.session_id.to_string(), session_id: d.session_id.to_string(),
title: d.title, title: d.title,
channel_name: d.session_id.channel.clone(), channel_name: d.session_id.channel.clone(),
@ -198,11 +249,14 @@ impl CliChatChannel {
message_count: d.message_count, message_count: d.message_count,
last_active_at: d.last_active_at, last_active_at: d.last_active_at,
archived_at: d.archived_at, archived_at: d.archived_at,
} })
}).collect(); .collect();
let current_session_id = current_dialog_id.map(|did| { let current_session_id = current_dialog_id.map(|did| {
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string() UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
}); });
if let Some(ref session_id) = current_session_id {
*current_session_guard = Some(session_id.clone());
}
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionList { .send(WsOutbound::SessionList {
@ -223,39 +277,35 @@ impl CliChatChannel {
} }
} }
WsInbound::LoadSession { session_id } => { WsInbound::LoadSession { session_id } => {
// LoadSession: parse the session_id and get current dialog info
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&session_id) let unified_id = UnifiedSessionId::parse(&session_id)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
return Err(ChannelError::Other(
"Session does not belong to this client".to_string(),
));
}
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog { op: SessionCommand::SwitchDialog {
channel: unified_id.channel.clone(), channel: unified_id.channel.clone(),
chat_id: unified_id.chat_id.clone(), chat_id: unified_id.chat_id.clone(),
dialog_id: unified_id.dialog_id.clone(),
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => { Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
if let Some(current_session_id) = current_session_id_opt { *current_session_guard = Some(session_id.to_string());
*current_session_guard = Some(current_session_id.to_string()); let _ = client
let _ = client .sender
.sender .send(WsOutbound::SessionLoaded {
.send(WsOutbound::SessionLoaded { session_id: session_id.to_string(),
session_id: current_session_id.to_string(), title: "Session".to_string(),
title: "Session".to_string(), // TODO: get actual title message_count: 0,
message_count: 0, // TODO: get actual count })
}) .await;
.await;
} else {
let _ = client
.sender
.send(WsOutbound::Error {
code: "NO_CURRENT_DIALOG".to_string(),
message: "No current dialog".to_string(),
})
.await;
}
} }
Some(Ok(_)) => { Some(Ok(_)) => {
// Unexpected response type // Unexpected response type
@ -275,23 +325,30 @@ impl CliChatChannel {
} }
} }
WsInbound::RenameSession { session_id, title } => { WsInbound::RenameSession { session_id, title } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { let target = session_id
ChannelError::Other("No active session".to_string()) .or(current_session_guard.clone())
})?; .ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target) let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() }, op: SessionCommand::RenameDialog {
session_id: unified_id,
title: title.clone(),
},
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => { Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title }) .send(WsOutbound::SessionRenamed {
session_id: session_id.to_string(),
title,
})
.await; .await;
} }
Some(Ok(_)) => { Some(Ok(_)) => {
@ -306,24 +363,43 @@ impl CliChatChannel {
} }
} }
WsInbound::ArchiveSession { session_id } => { WsInbound::ArchiveSession { session_id } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { let target = session_id
ChannelError::Other("No active session".to_string()) .or(current_session_guard.clone())
})?; .ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let was_current = current_session_guard.as_deref() == Some(&target);
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target) let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::ArchiveDialog { session_id: unified_id }, op: SessionCommand::ArchiveDialog {
session_id: unified_id,
},
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogArchived { session_id })) => { Some(Ok(SessionEvent::DialogArchived { session_id })) => {
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() }) .send(WsOutbound::SessionArchived {
session_id: session_id.to_string(),
})
.await; .await;
if was_current {
let (new_id, title) = self
.create_session_via_control(&client.chat_id, None)
.await?;
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title,
})
.await;
}
} }
Some(Ok(_)) => { Some(Ok(_)) => {
// Unexpected response type // Unexpected response type
@ -337,35 +413,42 @@ impl CliChatChannel {
} }
} }
WsInbound::DeleteSession { session_id } => { WsInbound::DeleteSession { session_id } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { let target = session_id
ChannelError::Other("No active session".to_string()) .or(current_session_guard.clone())
})?; .ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target) let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::DeleteDialog { session_id: unified_id }, op: SessionCommand::DeleteDialog {
session_id: unified_id,
},
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogDeleted { session_id })) => { Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() }) .send(WsOutbound::SessionDeleted {
session_id: session_id.to_string(),
})
.await; .await;
// If deleting current session, create a new one // If deleting current session, create a new one
if current_session_guard.as_deref() == Some(&target) { if current_session_guard.as_deref() == Some(&target) {
drop(reply_rx); drop(reply_rx);
if let Ok(new_id) = self.create_session_via_control(&target, None).await { if let Ok((new_id, title)) =
self.create_session_via_control(&client.chat_id, None).await
{
*current_session_guard = Some(new_id.clone()); *current_session_guard = Some(new_id.clone());
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionCreated { .send(WsOutbound::SessionCreated {
session_id: new_id, session_id: new_id,
title: String::new(), title,
}) })
.await; .await;
} }
@ -388,32 +471,45 @@ impl CliChatChannel {
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::GetSlashCommands { op: SessionCommand::GetSlashCommands {
channel: "cli_chat".to_string(), channel: "cli_chat".to_string(),
chat_id: "".to_string(), chat_id: client.chat_id.clone(),
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
if let Some(result) = reply_rx.recv().await { if let Some(result) = reply_rx.recv().await {
match result { match result {
Ok(SessionEvent::SlashCommandsList { commands }) => { Ok(SessionEvent::SlashCommandsList { commands }) => {
// Convert to SlashCommand to SlashCommandInfo // Convert to SlashCommand to SlashCommandInfo
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| { let command_infos: Vec<SlashCommandInfo> = commands
SlashCommandInfo { .into_iter()
.map(|cmd| SlashCommandInfo {
name: cmd.name.to_string(), name: cmd.name.to_string(),
description: cmd.description.to_string(), description: cmd.description.to_string(),
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(), aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
} })
}).collect(); .collect();
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await; let _ = client
.sender
.send(WsOutbound::SlashCommandsList {
commands: command_infos,
})
.await;
} }
Ok(SessionEvent::Error { code, message }) => { Ok(SessionEvent::Error { code, message }) => {
let _ = client.sender.send(WsOutbound::Error { code, message }).await; let _ = client
.sender
.send(WsOutbound::Error { code, message })
.await;
} }
Err(e) => { Err(e) => {
let _ = client.sender.send(WsOutbound::Error { let _ = client
code: "GET_COMMANDS_ERROR".to_string(), .sender
message: e.to_string() .send(WsOutbound::Error {
}).await; code: "GET_COMMANDS_ERROR".to_string(),
message: e.to_string(),
})
.await;
} }
_ => {} _ => {}
} }
@ -427,29 +523,34 @@ impl CliChatChannel {
} }
/// Create a session via control message and return the session_id /// Create a session via control message and return the session_id
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> { async fn create_session_via_control(
&self,
chat_id: &str,
title: Option<&str>,
) -> Result<(String, String), ChannelError> {
let bus = { let bus = {
let guard = self.bus.lock().unwrap(); let guard = self.bus.lock().unwrap();
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
}; };
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::CreateDialog { op: SessionCommand::CreateDialog {
channel: "cli_chat".to_string(), channel: "cli_chat".to_string(),
chat_id: connection_id.to_string(), chat_id: chat_id.to_string(),
title: title.map(String::from), title: title.map(String::from),
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => { Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
Ok(session_id.to_string()) Ok((session_id.to_string(), title))
}
Some(Ok(_)) => {
Err(ChannelError::Other("Unexpected response type".to_string()))
} }
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
Some(Err(e)) => Err(e), Some(Err(e)) => Err(e),
None => Err(ChannelError::Other("Control channel closed".to_string())), None => Err(ChannelError::Other("Control channel closed".to_string())),
} }
@ -479,7 +580,11 @@ impl Channel for CliChatChannel {
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let clients = self.clients.lock().await.clone(); let clients = self.clients.lock().await.clone();
for client in clients { for client in clients {
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") { if client.chat_id != msg.chat_id {
continue;
}
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
{
WsOutbound::SystemNotification { WsOutbound::SystemNotification {
content: msg.content.clone(), content: msg.content.clone(),
} }

View File

@ -1,19 +1,19 @@
pub mod http; pub mod http;
pub mod ws; pub mod ws;
use axum::{Router, routing};
use std::sync::Arc; use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher}; use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::channels::base::{Channel, ChannelError}; use crate::channels::base::{Channel, ChannelError};
use crate::config::{Config, expand_path, ensure_workspace_dir}; use crate::channels::{ChannelManager, CliChatChannel};
use crate::config::{Config, ensure_workspace_dir, expand_path};
use crate::logging; use crate::logging;
use crate::mcp; use crate::mcp;
use crate::memory::MemoryManager; use crate::memory::MemoryManager;
use crate::session::SessionManager;
use crate::scheduler::Scheduler; use crate::scheduler::Scheduler;
use crate::session::SessionManager;
pub struct GatewayState { pub struct GatewayState {
pub config: Config, pub config: Config,
@ -32,8 +32,13 @@ impl GatewayState {
let workspace_path = ensure_workspace_dir(&workspace_path)?; let workspace_path = ensure_workspace_dir(&workspace_path)?;
// Switch current working directory to workspace // Switch current working directory to workspace
std::env::set_current_dir(&workspace_path) std::env::set_current_dir(&workspace_path).map_err(|e| {
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?; format!(
"Failed to switch to workspace directory {}: {}",
workspace_path.display(),
e
)
})?;
tracing::info!("Using workspace directory: {}", workspace_path.display()); tracing::info!("Using workspace directory: {}", workspace_path.display());
@ -52,8 +57,9 @@ impl GatewayState {
workspace_path.join("picobot.db") workspace_path.join("picobot.db")
}; };
let storage = Arc::new( let storage = Arc::new(
crate::storage::Storage::new(&db_path).await crate::storage::Storage::new(&db_path)
.map_err(|e| format!("failed to initialize session storage: {}", e))? .await
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
); );
tracing::info!("Session storage: {}", db_path.display()); tracing::info!("Session storage: {}", db_path.display());
@ -98,7 +104,9 @@ impl GatewayState {
// Create ChannelManager and init channels // Create ChannelManager and init channels
let cli_chat_channel = Arc::new(CliChatChannel::new()); let cli_chat_channel = Arc::new(CliChatChannel::new());
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus); let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
channel_manager.init(&config, workspace_path.clone()).await channel_manager
.init(&config, workspace_path.clone())
.await
.map_err(|e| format!("Failed to init channels: {}", e))?; .map_err(|e| format!("Failed to init channels: {}", e))?;
// Register send_message tool with available channel names // Register send_message tool with available channel names
@ -107,9 +115,12 @@ impl GatewayState {
session_manager.register_outbound_tool(available_channels); session_manager.register_outbound_tool(available_channels);
// Register chat_manager tool // Register chat_manager tool
session_manager.tools().register( session_manager
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()), .tools()
); .register(crate::tools::ChatManagerTool::new(
storage.clone(),
valid_channels.clone(),
));
// Initialize MCP servers — connect and register discovered tools // Initialize MCP servers — connect and register discovered tools
if !config.mcp.servers.is_empty() { if !config.mcp.servers.is_empty() {
@ -130,24 +141,27 @@ impl GatewayState {
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default(); let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled { if scheduler_config.enabled {
// Register cron tools // Register cron tools
session_manager.tools().register( session_manager
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels), .tools()
); .register(crate::tools::cron::CronAddTool::new(
session_manager.tools().register( storage.clone(),
crate::tools::cron::CronListTool::new(storage.clone()), valid_channels,
); ));
session_manager.tools().register( session_manager
crate::tools::cron::CronRemoveTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronListTool::new(storage.clone()));
session_manager.tools().register( session_manager
crate::tools::cron::CronEnableTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
session_manager.tools().register( session_manager
crate::tools::cron::CronDisableTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronEnableTool::new(storage.clone()));
session_manager.tools().register( session_manager
crate::tools::cron::CronUpdateTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronDisableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
tracing::info!("Cron tools registered"); tracing::info!("Cron tools registered");
} }
@ -268,71 +282,103 @@ impl GatewayState {
} }
/// Handle control messages (session management operations) /// Handle control messages (session management operations)
async fn handle_control_message( async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
session_manager: &SessionManager,
msg: ControlMessage,
) {
use crate::session::{SessionCommand::*, SessionEvent}; use crate::session::{SessionCommand::*, SessionEvent};
let reply_tx = msg.reply_tx; let reply_tx = msg.reply_tx;
let result: Result<SessionEvent, ChannelError> = match msg.op { let result: Result<SessionEvent, ChannelError> = match msg.op {
CreateDialog { channel, chat_id, title } => { CreateDialog {
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await channel,
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title }) chat_id,
.map_err(|e| ChannelError::Other(e.to_string())) title,
} } => session_manager
ListDialogs { channel, chat_id, include_archived } => { .create_dialog(&channel, &chat_id, title.as_deref())
session_manager.list_dialogs(&channel, &chat_id, include_archived).await .await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id }) .map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())) .map_err(|e| ChannelError::Other(e.to_string())),
} ListDialogs {
GetCurrentDialog { channel, chat_id } => { channel,
session_manager.get_current_dialog(&channel, &chat_id).await chat_id,
.map(|session_id| SessionEvent::CurrentDialog { session_id }) include_archived,
.map_err(|e| ChannelError::Other(e.to_string())) } => session_manager
} .list_dialogs(&channel, &chat_id, include_archived)
SwitchDialog { channel, chat_id, dialog_id } => { .await
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await .map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
.map(|session_id| SessionEvent::DialogSwitched { session_id }) dialogs,
.map_err(|e| ChannelError::Other(e.to_string())) current_dialog_id,
} })
RenameDialog { session_id, title } => { .map_err(|e| ChannelError::Other(e.to_string())),
session_manager.rename_dialog(&session_id, &title).await GetCurrentDialog { channel, chat_id } => session_manager
.map(|()| SessionEvent::DialogRenamed { session_id, title }) .get_current_dialog(&channel, &chat_id)
.map_err(|e| ChannelError::Other(e.to_string())) .await
} .map(|session_id| SessionEvent::CurrentDialog { session_id })
ArchiveDialog { session_id } => { .map_err(|e| ChannelError::Other(e.to_string())),
session_manager.archive_dialog(&session_id) SwitchDialog {
.map(|()| SessionEvent::DialogArchived { session_id }) channel,
.map_err(|e| ChannelError::Other(e.to_string())) chat_id,
} dialog_id,
DeleteDialog { session_id } => { } => session_manager
session_manager.delete_dialog(&session_id).await .switch_dialog(&channel, &chat_id, &dialog_id)
.map(|()| SessionEvent::DialogDeleted { session_id }) .await
.map_err(|e| ChannelError::Other(e.to_string())) .map(|session_id| SessionEvent::DialogSwitched { session_id })
} .map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => { RenameDialog { session_id, title } => session_manager
session_manager.clear_dialog_history(&session_id) .rename_dialog(&session_id, &title)
.map(|()| SessionEvent::HistoryCleared { session_id }) .await
.map_err(|e| ChannelError::Other(e.to_string())) .map(|()| SessionEvent::DialogRenamed { session_id, title })
} .map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands { channel: _, chat_id: _ } => { ArchiveDialog { session_id } => session_manager
.archive_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
DeleteDialog { session_id } => session_manager
.delete_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => session_manager
.clear_dialog_history(&session_id)
.await
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands {
channel: _,
chat_id: _,
} => {
let commands = session_manager.get_slash_commands().to_vec(); let commands = session_manager.get_slash_commands().to_vec();
Ok(SessionEvent::SlashCommandsList { commands }) Ok(SessionEvent::SlashCommandsList { commands })
} }
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => { ExecuteSlashCommand {
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref()) command,
.await args,
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg }) channel,
.map_err(|e| ChannelError::Other(e.to_string())) chat_id,
} current_session_id,
} => session_manager
.execute_slash_command(
&command,
args.as_deref(),
&channel,
&chat_id,
current_session_id.as_ref(),
)
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
new_session_id: new_id,
message: msg,
})
.map_err(|e| ChannelError::Other(e.to_string())),
}; };
let _ = reply_tx.send(result).await; let _ = reply_tx.send(result).await;
} }
} }
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> { pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging // Initialize logging
logging::init_logging(); logging::init_logging();
tracing::info!("Starting PicoBot Gateway"); tracing::info!("Starting PicoBot Gateway");

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,17 @@
pub mod background_task;
pub mod error; pub mod error;
pub mod memory; pub mod memory;
pub mod message; pub mod message;
pub mod background_task;
pub mod scheduler; pub mod scheduler;
pub mod session; pub mod session;
pub use error::StorageError;
pub use background_task::BackgroundTask; pub use background_task::BackgroundTask;
pub use error::StorageError;
pub use scheduler::{JobRun, ScheduledJob}; pub use scheduler::{JobRun, ScheduledJob};
use sqlx::{Pool, Row, Sqlite, SqlitePool}; use sqlx::{Pool, Row, Sqlite, SqlitePool};
use tokio::time::{sleep, Duration};
use std::path::Path; use std::path::Path;
use tokio::time::{Duration, sleep};
pub struct Storage { pub struct Storage {
pub(crate) pool: Pool<Sqlite>, pub(crate) pool: Pool<Sqlite>,
@ -42,6 +42,7 @@ impl Storage {
last_active_at INTEGER NOT NULL, last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0, message_count INTEGER DEFAULT 0,
routing_info TEXT, routing_info TEXT,
archived_at INTEGER,
deleted_at INTEGER, deleted_at INTEGER,
last_consolidated_at INTEGER, last_consolidated_at INTEGER,
last_compressed_message_at INTEGER, last_compressed_message_at INTEGER,
@ -92,20 +93,16 @@ impl Storage {
.await?; .await?;
// Migration: add source column if upgrading from older schema // Migration: add source column if upgrading from older schema
sqlx::query( sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
r#"ALTER TABLE messages ADD COLUMN source TEXT"#, .execute(&self.pool)
) .await
.execute(&self.pool) .ok();
.await
.ok();
// Migration: add reasoning_content column if upgrading from older schema // Migration: add reasoning_content column if upgrading from older schema
sqlx::query( sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#, .execute(&self.pool)
) .await
.execute(&self.pool) .ok();
.await
.ok();
// Background tasks table — for async sub-agent tasks. // Background tasks table — for async sub-agent tasks.
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL). // Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
@ -216,11 +213,19 @@ impl Storage {
.await?; .await?;
// Rebuild FTS5 index for any existing records // Rebuild FTS5 index for any existing records
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
.execute(&self.pool)
.await?;
// Migration: add last_consolidated_at column if not exists
sqlx::query( sqlx::query(
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')", r#"
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
"#,
) )
.execute(&self.pool) .execute(&self.pool)
.await?; .await
.ok();
// Migration: add last_consolidated_at column if not exists // Migration: add last_consolidated_at column if not exists
sqlx::query( sqlx::query(
@ -260,7 +265,10 @@ impl Storage {
.await?; .await?;
if let Err(e) = Self::init_scheduler_schema(&self.pool).await { if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e); tracing::warn!(
"Failed to init scheduler schema (tables may already exist): {}",
e
);
} }
Ok(()) Ok(())
@ -374,16 +382,20 @@ impl Storage {
&self.pool &self.pool
} }
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> { pub async fn upsert_session(
&self,
meta: &crate::storage::session::SessionMeta,
) -> Result<(), StorageError> {
sqlx::query( sqlx::query(
r#" r#"
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at) INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET ON CONFLICT(id) DO UPDATE SET
title = excluded.title, title = excluded.title,
last_active_at = excluded.last_active_at, last_active_at = excluded.last_active_at,
message_count = excluded.message_count, message_count = excluded.message_count,
routing_info = excluded.routing_info, routing_info = excluded.routing_info,
archived_at = excluded.archived_at,
deleted_at = excluded.deleted_at, deleted_at = excluded.deleted_at,
last_consolidated_at = excluded.last_consolidated_at, last_consolidated_at = excluded.last_consolidated_at,
last_compressed_message_at = excluded.last_compressed_message_at last_compressed_message_at = excluded.last_compressed_message_at
@ -398,6 +410,7 @@ impl Storage {
.bind(meta.last_active_at) .bind(meta.last_active_at)
.bind(meta.message_count) .bind(meta.message_count)
.bind(&meta.routing_info) .bind(&meta.routing_info)
.bind(meta.archived_at)
.bind(meta.deleted_at) .bind(meta.deleted_at)
.bind(meta.last_consolidated_at) .bind(meta.last_consolidated_at)
.bind(meta.last_compressed_message_at) .bind(meta.last_compressed_message_at)
@ -407,10 +420,13 @@ impl Storage {
Ok(()) Ok(())
} }
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> { pub async fn get_session(
&self,
id: &str,
) -> Result<crate::storage::session::SessionMeta, StorageError> {
let row = sqlx::query( let row = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions WHERE id = ? AND deleted_at IS NULL FROM sessions WHERE id = ? AND deleted_at IS NULL
"#, "#,
) )
@ -429,6 +445,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -440,18 +457,21 @@ impl Storage {
channel: &str, channel: &str,
chat_id: &str, chat_id: &str,
limit: i64, limit: i64,
include_archived: bool,
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> { ) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
let rows = sqlx::query( let rows = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
AND (? OR archived_at IS NULL)
ORDER BY last_active_at DESC ORDER BY last_active_at DESC
LIMIT ? LIMIT ?
"#, "#,
) )
.bind(channel) .bind(channel)
.bind(chat_id) .bind(chat_id)
.bind(include_archived)
.bind(limit) .bind(limit)
.fetch_all(self.pool()) .fetch_all(self.pool())
.await?; .await?;
@ -468,6 +488,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -498,13 +519,22 @@ impl Storage {
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> { pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis(); let now = chrono::Utc::now().timestamp_millis();
sqlx::query( sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#, .bind(now)
) .bind(id)
.bind(now) .execute(self.pool())
.bind(id) .await?;
.execute(self.pool())
.await?; Ok(())
}
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(()) Ok(())
} }
@ -516,9 +546,9 @@ impl Storage {
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> { ) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
let row = sqlx::query( let row = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
ORDER BY last_active_at DESC ORDER BY last_active_at DESC
LIMIT 1 LIMIT 1
"#, "#,
@ -539,6 +569,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -547,7 +578,11 @@ impl Storage {
} }
} }
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> { pub async fn append_message(
&self,
session_id: &str,
msg: &crate::storage::message::MessageMeta,
) -> Result<i64, StorageError> {
sqlx::query( sqlx::query(
r#" r#"
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at) INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
@ -674,16 +709,15 @@ impl Storage {
offset: i64, offset: i64,
limit: i64, limit: i64,
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> { ) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
let count_row = sqlx::query( let count_row =
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL", sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
) .fetch_one(self.pool())
.fetch_one(self.pool()) .await?;
.await?;
let total: i64 = count_row.get("total"); let total: i64 = count_row.get("total");
let rows = sqlx::query( let rows = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions FROM sessions
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
ORDER BY last_active_at DESC ORDER BY last_active_at DESC
@ -707,6 +741,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -772,7 +807,10 @@ impl Storage {
where_extra.push_str(" AND created_at > ?"); where_extra.push_str(" AND created_at > ?");
} }
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra); let count_sql = format!(
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
where_extra
);
let select_sql = format!( let select_sql = format!(
r#" r#"
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
@ -1030,6 +1068,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: Some(r#"{"type":"cli"}"#.to_string()), routing_info: Some(r#"{"type":"cli"}"#.to_string()),
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -1066,14 +1105,18 @@ mod tests {
last_active_at: i as i64 * 1000, last_active_at: i as i64 * 1000,
message_count: i, message_count: i,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
}; };
storage.upsert_session(&meta).await.unwrap(); storage.upsert_session(&meta).await.unwrap();
} }
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap(); let sessions = storage
.list_sessions("cli_chat", "sid123", 10, false)
.await
.unwrap();
assert_eq!(sessions.len(), 5); assert_eq!(sessions.len(), 5);
// 按 last_active_at DESC 排序 // 按 last_active_at DESC 排序
assert_eq!(sessions[0].dialog_id, "dialog4"); assert_eq!(sessions[0].dialog_id, "dialog4");
@ -1093,6 +1136,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -1120,6 +1164,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -1141,7 +1186,10 @@ mod tests {
created_at: 1000, created_at: 1000,
}; };
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap(); let seq = storage
.append_message(&session_meta.id, &msg)
.await
.unwrap();
assert_eq!(seq, 1); assert_eq!(seq, 1);
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap(); let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
@ -1163,6 +1211,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,

View File

@ -11,6 +11,7 @@ pub struct SessionMeta {
pub last_active_at: i64, pub last_active_at: i64,
pub message_count: i64, pub message_count: i64,
pub routing_info: Option<String>, pub routing_info: Option<String>,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>, pub deleted_at: Option<i64>,
pub last_consolidated_at: Option<i64>, pub last_consolidated_at: Option<i64>,
pub last_compressed_message_at: Option<i64>, pub last_compressed_message_at: Option<i64>,

View File

@ -126,7 +126,10 @@ impl ChatManagerTool {
let start_num = offset + 1; let start_num = offset + 1;
let end_num = offset + sessions.len() as i64; let end_num = offset + sessions.len() as i64;
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num); let mut output = format!(
"全部会话 (共 {} 个,第 {}-{} 个):\n",
total, start_num, end_num
);
for s in &sessions { for s in &sessions {
let ago = format_duration_ago(now_ms - s.last_active_at); let ago = format_duration_ago(now_ms - s.last_active_at);
@ -300,9 +303,10 @@ mod tests {
last_active_at: now - i * 3600_000, last_active_at: now - i * 3600_000,
message_count: i * 5, message_count: i * 5,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
}; };
storage.upsert_session(&meta).await.unwrap(); storage.upsert_session(&meta).await.unwrap();
} }
@ -335,6 +339,7 @@ mod tests {
last_active_at: now, last_active_at: now,
message_count: 3, message_count: 3,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -346,7 +351,11 @@ mod tests {
id: format!("msg{}", i), id: format!("msg{}", i),
session_id: session_id.to_string(), session_id: session_id.to_string(),
seq: i as i64 + 1, seq: i as i64 + 1,
role: if i == 0 { "user".to_string() } else { "assistant".to_string() }, role: if i == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
content: format!("消息内容 {}", i), content: format!("消息内容 {}", i),
reasoning_content: None, reasoning_content: None,
media_refs: None, media_refs: None,
@ -392,6 +401,7 @@ mod tests {
last_active_at: now, last_active_at: now,
message_count: 5, message_count: 5,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -403,7 +413,11 @@ mod tests {
id: format!("msg{}", i), id: format!("msg{}", i),
session_id: session_id.to_string(), session_id: session_id.to_string(),
seq: i as i64 + 1, seq: i as i64 + 1,
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() }, role: if i % 2 == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
content: format!("消息内容 {}", i), content: format!("消息内容 {}", i),
reasoning_content: None, reasoning_content: None,
media_refs: None, media_refs: None,
@ -447,6 +461,7 @@ mod tests {
last_active_at: now, last_active_at: now,
message_count: 5, message_count: 5,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -492,10 +507,7 @@ mod tests {
let (storage, _dir) = create_test_storage().await; let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec![]); let tool = ChatManagerTool::new(storage, vec![]);
let result = tool let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
.execute(json!({ "action": "unknown" }))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action")); assert!(result.error.unwrap().contains("Unknown action"));
} }

View File

@ -1,5 +1,5 @@
use picobot::providers::{ChatCompletionRequest, Message};
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
/// Test that message with special characters is properly escaped /// Test that message with special characters is properly escaped
#[test] #[test]
@ -19,7 +19,9 @@ fn test_message_special_characters() {
#[test] #[test]
fn test_multiline_system_prompt() { fn test_multiline_system_prompt() {
let messages = vec![ let messages = vec![
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"), Message::system(
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
),
Message::user("Hi"), Message::user("Hi"),
]; ];
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
#[test] #[test]
fn test_chat_request_serialization() { fn test_chat_request_serialization() {
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![ messages: vec![Message::system("You are helpful"), Message::user("Hello")],
Message::system("You are helpful"),
Message::user("Hello"),
],
temperature: Some(0.7), temperature: Some(0.7),
max_tokens: Some(100), max_tokens: Some(100),
tools: None, tools: None,