Fix CLI session routing and dialog controls

This commit is contained in:
xiaoski 2026-06-15 23:44:41 +08:00
parent 0d66536e90
commit c6f4392e63
8 changed files with 1391 additions and 545 deletions

View File

@ -0,0 +1,359 @@
# PicoBot 代码质量分析报告
审查日期2026-06-15
## 结论摘要
PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只做收发MessageBus 解耦输入输出SessionManager 管理会话AgentLoop 保持无状态并执行工具Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
当前主要质量风险集中在三类:
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
3. 工具和后台任务的资源边界偏弱文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
## 修复状态
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
## 主要发现
### 已修复CLI 会话路由会破坏会话连续性和多客户端隔离
位置:
- `src/channels/cli_chat.rs:113-126`
- `src/channels/cli_chat.rs:160-164`
- `src/channels/cli_chat.rs:225-249`
- `src/channels/cli_chat.rs:479-494`
- `src/session/session.rs:1305-1310`
问题:
`Client.current_session_id` 存的是完整 session id但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
影响:
- CLI 多轮对话可能落入不同 chat scope。
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
- 多个 CLI 客户端同时连接时存在串话。
建议:
- 将 client 状态拆成 `chat_id``current_session_id`,不要混用。
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
- `send()``OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
- `LoadSession` 应直接切换到指定 session或通过 `SwitchDialog` 使用其中的 `dialog_id`
- 为 CLI WebSocket 增加多客户端路由测试。
### 已修复Dialog 控制接口与协议承诺不一致
位置:
- `src/session/session.rs:996-997`
- `src/session/session.rs:1305-1310`
- `src/session/session.rs:1329-1349`
- `src/session/session.rs:1378-1384`
- `src/channels/cli_chat.rs:128-158`
问题:
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
- `/delete` 只创建新 session并没有删除当前 session。
- `get_current_dialog()` 固定返回 `Ok(None)`
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`
- `archive_dialog()` 是空操作。
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`
影响:
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
建议:
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`
- `list_dialogs()` 返回真实 current dialog并补上 archived 模型或移除 archived 参数。
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
位置:
- `src/tools/mod.rs:56-62`
- `src/tools/path_utils.rs:3-23`
- `src/tools/bash.rs:146-185`
问题:
文件工具默认通过 `FileReadTool::new()``FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize不能防御 `..`、符号链接等路径逃逸。
`bash` 默认工作目录是 `"."`Gateway 启动时切到 workspace这对相对路径有效但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
影响:
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP风险会放大。
建议:
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
- `resolve_path()` 使用 `std::fs::canonicalize``path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 中高优先级Session 锁内执行过多异步操作
位置:
- `src/session/session.rs:1001-1018`
- `src/session/session.rs:1604-1711`
问题:
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
影响:
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
建议:
- 锁内只做内存状态快照和必要的状态标记。
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
### 中优先级Bash 超时不会显式终止子进程
位置:
- `src/tools/bash.rs:150-174`
- `src/tools/bash.rs:180-207`
问题:
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
影响:
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
建议:
- 使用 `tokio::process::Child``kill_on_drop(true)`
- 超时分支显式 kill child 并 wait。
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
### 中优先级:文件读取对大二进制文件没有输出上限
位置:
- `src/tools/file_read.rs:121-131`
- `src/tools/file_read.rs:214-229`
问题:
`file_read``std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
影响:
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
建议:
- 先检查 metadata size超过阈值直接返回提示。
- 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。
- 对文本读取也改成流式按行读取,而不是整文件读入。
### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置:
- `src/tools/http_request.rs:31-59`
问题:
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
影响:
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
建议:
- 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
- 对 `http_request``web_fetch` 复用同一套 URL 安全策略。
### 中优先级:后台任务和主循环缺少监督与优雅关闭
位置:
- `src/bus/mod.rs:51-99`
- `src/gateway/mod.rs:187-244`
- `src/gateway/mod.rs:247-266`
问题:
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
影响:
- 某个后台 loop 异常退出后Gateway 不一定能发现。
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。
建议:
- 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。
- 用 `CancellationToken` 贯穿 Gateway 子任务。
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置:
- `src/scheduler/mod.rs:18-40`
问题:
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)``upcoming(tz)` 的当前时间。
影响:
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
建议:
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
- 增加固定时间输入的单元测试,避免受系统时间影响。
### 中低优先级:存在未接入或半接入代码,增加维护噪音
位置:
- `src/tools/pty.rs`
- `src/tools/mod.rs:1-20`
- `src/tools/mod.rs:49-88`
问题:
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty``create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
影响:
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
建议:
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
- 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。
## 架构评价
### 做得好的地方
- 模块分层方向清楚Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
- Provider 抽象简单直接OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
- Storage 集中初始化 schema便于部署单二进制应用。
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
### 主要架构债务
- SessionManager 承担过多职责会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
- 后台任务生命周期分散gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn缺少统一管理。
## 模块级分析
### gateway
`GatewayState::new()` 是清晰的装配中心配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
### channels
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel核心问题是会话身份混用和广播投递。
### bus
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
### session
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
- `manager.rs`SessionManager 状态和 dialog 生命周期
- `worker.rs`per-session agent worker 和 cancellation
- `commands.rs`slash command 执行
- `outbound.rs`OutboundMessenger 实现
- `restore.rs`storage 恢复与 tool call chain repair
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
### agent
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义`read_only()` 目前是工具自己声明副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard不应替代工具层的资源限制。
### providers
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response可能包含用户隐私、API 返回内容和文件内容。
### tools
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里默认值容易失控。
### storage
Storage schema 初始化实用但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”适合早期迭代不适合长期演进。建议引入 schema version 表或 sqlx migrations至少把每次迁移记录下来。
### skills
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
## 建议修复路线
### P0先修会话正确性
1. 修正 CLI `chat_id/current_session_id` 数据模型。
2. 修正 CLI 出站按 client/chat_id 投递。
3. 实现 `get_current_dialog()``list_dialogs()` current 返回。
4. 修正 `/delete``clear_history``archive` 的真实行为或从协议移除。
5. 增加 WebSocket session lifecycle 测试。
### P1收紧工具和资源边界
1. 文件工具默认限制 workspace路径 canonicalize。
2. bash 超时杀进程,必要时引入进程组。
3. file_read 增加文件大小上限和二进制输出上限。
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
5. 明确高危工具的配置开关。
### P2降低架构复杂度
1. 拆分 `session.rs``feishu.rs``storage/mod.rs``browser.rs`
2. 引入任务 supervisor 和统一 shutdown token。
3. 引入正式数据库迁移。
4. 增加工具注册快照测试,避免死代码和文档漂移。
## 建议测试补充
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
- `/delete` 后旧 session 软删除、新 session 成为 current。
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
- bash timeout 后检查子进程不存在。
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::{mpsc, Mutex};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
use super::base::{Channel, ChannelError};
@ -14,6 +14,7 @@ use super::base::{Channel, ChannelError};
pub(crate) struct Client {
sender: mpsc::Sender<WsOutbound>,
chat_id: String,
current_session_id: Mutex<Option<String>>,
}
@ -41,23 +42,28 @@ impl CliChatChannel {
}
/// Register a new client connection, returns (session_id, client)
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
// Generate connection ID (used as chat_id) - use short ID
let connection_id = crate::util::short_id();
pub(crate) async fn register_client(
&self,
sender: mpsc::Sender<WsOutbound>,
) -> (String, Arc<Client>) {
// Each WebSocket connection gets a stable chat scope. All user input and
// dialog controls for this client stay inside that scope unless the
// protocol explicitly carries a full session id.
let chat_id = crate::util::short_id();
let client = Arc::new(Client {
sender,
chat_id: chat_id.clone(),
current_session_id: Mutex::new(None),
});
self.clients.lock().await.push(client.clone());
// Create initial session via control message
let session_id = match self.create_session_via_control(&connection_id, None).await {
Ok(id) => id,
let session_id = match self.create_session_via_control(&chat_id, None).await {
Ok((id, _title)) => id,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial session");
// Fall back to old format for backward compatibility
connection_id.clone()
UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
}
};
@ -73,21 +79,19 @@ impl CliChatChannel {
/// Handle an inbound message from a client
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
match parse_inbound(raw_msg) {
Ok(inbound) => {
match self.handle_ws_inbound(client.clone(), inbound).await {
Ok(()) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to handle inbound message");
let _ = client
.sender
.send(WsOutbound::Error {
code: "INTERNAL_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
Ok(()) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to handle inbound message");
let _ = client
.sender
.send(WsOutbound::Error {
code: "INTERNAL_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
},
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = client
@ -101,22 +105,30 @@ impl CliChatChannel {
}
}
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> {
async fn handle_ws_inbound(
&self,
client: Arc<Client>,
inbound: WsInbound,
) -> Result<(), ChannelError> {
let bus = {
let guard = self.bus.lock().unwrap();
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
};
let mut current_session_guard = client.current_session_id.lock().await;
match inbound {
WsInbound::UserInput { content, chat_id, .. } => {
WsInbound::UserInput {
content, chat_id, ..
} => {
// All messages (including slash commands) go through the normal inbound flow
// SessionManager handles session creation/reuse internally
let msg = InboundMessage {
channel: self.name().to_string(),
sender_id: "cli".to_string(),
chat_id: chat_id.unwrap_or_else(crate::util::short_id),
chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
content,
timestamp: crate::bus::message::current_timestamp(),
media: Vec::new(),
@ -125,19 +137,56 @@ impl CliChatChannel {
};
bus.publish_inbound(msg).await?;
}
WsInbound::ClearHistory { chat_id, session_id } => {
let target = session_id
.or(chat_id)
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
WsInbound::ClearHistory {
chat_id,
session_id,
} => {
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let session_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
let session_id = if let Some(session_id) = session_id {
UnifiedSessionId::parse(&session_id).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
} else if let Some(chat_id) = chat_id {
let (current_tx, mut current_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog {
channel: "cli_chat".to_string(),
chat_id,
},
reply_tx: current_tx,
})
.await?;
match current_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog {
session_id: Some(session_id),
})) => session_id,
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
return Err(ChannelError::Other("No active session".to_string()));
}
Some(Ok(_)) => {
return Err(ChannelError::Other(
"Unexpected response type".to_string(),
));
}
Some(Err(e)) => return Err(e),
None => {
return Err(ChannelError::Other("Control channel closed".to_string()));
}
}
} else {
let target = current_session_guard
.clone()
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
UnifiedSessionId::parse(&target).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
};
let target = session_id.to_string();
bus.publish_control(ControlMessage {
op: SessionCommand::ClearHistory { session_id },
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
@ -158,24 +207,21 @@ impl CliChatChannel {
}
}
WsInbound::CreateSession { title } => {
// Use current session's chat_id if available, otherwise generate new one
let chat_id = current_session_guard.clone()
.unwrap_or_else(crate::util::short_id);
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
let (new_id, created_title) = self
.create_session_via_control(&client.chat_id, title.as_deref())
.await?;
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title: title.unwrap_or_default(),
title: created_title,
})
.await;
}
WsInbound::ListSessions { include_archived } => {
// List dialogs for the current chat
let chat_id = current_session_guard.clone()
.unwrap_or_else(|| "".to_string());
let chat_id_for_response = chat_id.clone();
let chat_id = client.chat_id.clone();
let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::ListDialogs {
@ -184,13 +230,18 @@ impl CliChatChannel {
include_archived,
},
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => {
Some(Ok(SessionEvent::DialogList {
dialogs,
current_dialog_id,
})) => {
// Convert DialogInfo to SessionSummary for backward compatibility
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| {
crate::protocol::SessionSummary {
let sessions: Vec<crate::protocol::SessionSummary> = dialogs
.into_iter()
.map(|d| crate::protocol::SessionSummary {
session_id: d.session_id.to_string(),
title: d.title,
channel_name: d.session_id.channel.clone(),
@ -198,11 +249,14 @@ impl CliChatChannel {
message_count: d.message_count,
last_active_at: d.last_active_at,
archived_at: d.archived_at,
}
}).collect();
})
.collect();
let current_session_id = current_dialog_id.map(|did| {
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string()
UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
});
if let Some(ref session_id) = current_session_id {
*current_session_guard = Some(session_id.clone());
}
let _ = client
.sender
.send(WsOutbound::SessionList {
@ -223,39 +277,35 @@ impl CliChatChannel {
}
}
WsInbound::LoadSession { session_id } => {
// LoadSession: parse the session_id and get current dialog info
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&session_id)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
return Err(ChannelError::Other(
"Session does not belong to this client".to_string(),
));
}
bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog {
op: SessionCommand::SwitchDialog {
channel: unified_id.channel.clone(),
chat_id: unified_id.chat_id.clone(),
dialog_id: unified_id.dialog_id.clone(),
},
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => {
if let Some(current_session_id) = current_session_id_opt {
*current_session_guard = Some(current_session_id.to_string());
let _ = client
.sender
.send(WsOutbound::SessionLoaded {
session_id: current_session_id.to_string(),
title: "Session".to_string(), // TODO: get actual title
message_count: 0, // TODO: get actual count
})
.await;
} else {
let _ = client
.sender
.send(WsOutbound::Error {
code: "NO_CURRENT_DIALOG".to_string(),
message: "No current dialog".to_string(),
})
.await;
}
Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
*current_session_guard = Some(session_id.to_string());
let _ = client
.sender
.send(WsOutbound::SessionLoaded {
session_id: session_id.to_string(),
title: "Session".to_string(),
message_count: 0,
})
.await;
}
Some(Ok(_)) => {
// Unexpected response type
@ -275,23 +325,30 @@ impl CliChatChannel {
}
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
ChannelError::Other("No active session".to_string())
})?;
let target = session_id
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() },
op: SessionCommand::RenameDialog {
session_id: unified_id,
title: title.clone(),
},
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
let _ = client
.sender
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title })
.send(WsOutbound::SessionRenamed {
session_id: session_id.to_string(),
title,
})
.await;
}
Some(Ok(_)) => {
@ -306,24 +363,43 @@ impl CliChatChannel {
}
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
ChannelError::Other("No active session".to_string())
})?;
let target = session_id
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let was_current = current_session_guard.as_deref() == Some(&target);
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::ArchiveDialog { session_id: unified_id },
op: SessionCommand::ArchiveDialog {
session_id: unified_id,
},
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
let _ = client
.sender
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() })
.send(WsOutbound::SessionArchived {
session_id: session_id.to_string(),
})
.await;
if was_current {
let (new_id, title) = self
.create_session_via_control(&client.chat_id, None)
.await?;
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title,
})
.await;
}
}
Some(Ok(_)) => {
// Unexpected response type
@ -337,35 +413,42 @@ impl CliChatChannel {
}
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
ChannelError::Other("No active session".to_string())
})?;
let target = session_id
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::DeleteDialog { session_id: unified_id },
op: SessionCommand::DeleteDialog {
session_id: unified_id,
},
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
let _ = client
.sender
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() })
.send(WsOutbound::SessionDeleted {
session_id: session_id.to_string(),
})
.await;
// If deleting current session, create a new one
if current_session_guard.as_deref() == Some(&target) {
drop(reply_rx);
if let Ok(new_id) = self.create_session_via_control(&target, None).await {
if let Ok((new_id, title)) =
self.create_session_via_control(&client.chat_id, None).await
{
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title: String::new(),
title,
})
.await;
}
@ -388,32 +471,45 @@ impl CliChatChannel {
bus.publish_control(ControlMessage {
op: SessionCommand::GetSlashCommands {
channel: "cli_chat".to_string(),
chat_id: "".to_string(),
chat_id: client.chat_id.clone(),
},
reply_tx,
}).await?;
})
.await?;
if let Some(result) = reply_rx.recv().await {
match result {
Ok(SessionEvent::SlashCommandsList { commands }) => {
// Convert to SlashCommand to SlashCommandInfo
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| {
SlashCommandInfo {
let command_infos: Vec<SlashCommandInfo> = commands
.into_iter()
.map(|cmd| SlashCommandInfo {
name: cmd.name.to_string(),
description: cmd.description.to_string(),
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
}
}).collect();
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await;
})
.collect();
let _ = client
.sender
.send(WsOutbound::SlashCommandsList {
commands: command_infos,
})
.await;
}
Ok(SessionEvent::Error { code, message }) => {
let _ = client.sender.send(WsOutbound::Error { code, message }).await;
let _ = client
.sender
.send(WsOutbound::Error { code, message })
.await;
}
Err(e) => {
let _ = client.sender.send(WsOutbound::Error {
code: "GET_COMMANDS_ERROR".to_string(),
message: e.to_string()
}).await;
let _ = client
.sender
.send(WsOutbound::Error {
code: "GET_COMMANDS_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
_ => {}
}
@ -427,29 +523,34 @@ impl CliChatChannel {
}
/// Create a session via control message and return the session_id
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> {
async fn create_session_via_control(
&self,
chat_id: &str,
title: Option<&str>,
) -> Result<(String, String), ChannelError> {
let bus = {
let guard = self.bus.lock().unwrap();
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
};
let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::CreateDialog {
channel: "cli_chat".to_string(),
chat_id: connection_id.to_string(),
chat_id: chat_id.to_string(),
title: title.map(String::from),
},
reply_tx,
}).await?;
})
.await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => {
Ok(session_id.to_string())
}
Some(Ok(_)) => {
Err(ChannelError::Other("Unexpected response type".to_string()))
Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
Ok((session_id.to_string(), title))
}
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
Some(Err(e)) => Err(e),
None => Err(ChannelError::Other("Control channel closed".to_string())),
}
@ -479,7 +580,11 @@ impl Channel for CliChatChannel {
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let clients = self.clients.lock().await.clone();
for client in clients {
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
if client.chat_id != msg.chat_id {
continue;
}
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
{
WsOutbound::SystemNotification {
content: msg.content.clone(),
}

View File

@ -1,19 +1,19 @@
pub mod http;
pub mod ws;
use axum::{Router, routing};
use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::channels::base::{Channel, ChannelError};
use crate::config::{Config, expand_path, ensure_workspace_dir};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::config::{Config, ensure_workspace_dir, expand_path};
use crate::logging;
use crate::mcp;
use crate::memory::MemoryManager;
use crate::session::SessionManager;
use crate::scheduler::Scheduler;
use crate::session::SessionManager;
pub struct GatewayState {
pub config: Config,
@ -32,8 +32,13 @@ impl GatewayState {
let workspace_path = ensure_workspace_dir(&workspace_path)?;
// Switch current working directory to workspace
std::env::set_current_dir(&workspace_path)
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?;
std::env::set_current_dir(&workspace_path).map_err(|e| {
format!(
"Failed to switch to workspace directory {}: {}",
workspace_path.display(),
e
)
})?;
tracing::info!("Using workspace directory: {}", workspace_path.display());
@ -52,8 +57,9 @@ impl GatewayState {
workspace_path.join("picobot.db")
};
let storage = Arc::new(
crate::storage::Storage::new(&db_path).await
.map_err(|e| format!("failed to initialize session storage: {}", e))?
crate::storage::Storage::new(&db_path)
.await
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
);
tracing::info!("Session storage: {}", db_path.display());
@ -98,7 +104,9 @@ impl GatewayState {
// Create ChannelManager and init channels
let cli_chat_channel = Arc::new(CliChatChannel::new());
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
channel_manager.init(&config, workspace_path.clone()).await
channel_manager
.init(&config, workspace_path.clone())
.await
.map_err(|e| format!("Failed to init channels: {}", e))?;
// Register send_message tool with available channel names
@ -107,9 +115,12 @@ impl GatewayState {
session_manager.register_outbound_tool(available_channels);
// Register chat_manager tool
session_manager.tools().register(
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
);
session_manager
.tools()
.register(crate::tools::ChatManagerTool::new(
storage.clone(),
valid_channels.clone(),
));
// Initialize MCP servers — connect and register discovered tools
if !config.mcp.servers.is_empty() {
@ -130,24 +141,27 @@ impl GatewayState {
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled {
// Register cron tools
session_manager.tools().register(
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
);
session_manager.tools().register(
crate::tools::cron::CronListTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronRemoveTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronEnableTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronDisableTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronUpdateTool::new(storage.clone()),
);
session_manager
.tools()
.register(crate::tools::cron::CronAddTool::new(
storage.clone(),
valid_channels,
));
session_manager
.tools()
.register(crate::tools::cron::CronListTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
tracing::info!("Cron tools registered");
}
@ -268,71 +282,103 @@ impl GatewayState {
}
/// Handle control messages (session management operations)
async fn handle_control_message(
session_manager: &SessionManager,
msg: ControlMessage,
) {
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
use crate::session::{SessionCommand::*, SessionEvent};
let reply_tx = msg.reply_tx;
let result: Result<SessionEvent, ChannelError> = match msg.op {
CreateDialog { channel, chat_id, title } => {
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ListDialogs { channel, chat_id, include_archived } => {
session_manager.list_dialogs(&channel, &chat_id, include_archived).await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
GetCurrentDialog { channel, chat_id } => {
session_manager.get_current_dialog(&channel, &chat_id).await
.map(|session_id| SessionEvent::CurrentDialog { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
SwitchDialog { channel, chat_id, dialog_id } => {
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await
.map(|session_id| SessionEvent::DialogSwitched { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
RenameDialog { session_id, title } => {
session_manager.rename_dialog(&session_id, &title).await
.map(|()| SessionEvent::DialogRenamed { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ArchiveDialog { session_id } => {
session_manager.archive_dialog(&session_id)
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
DeleteDialog { session_id } => {
session_manager.delete_dialog(&session_id).await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ClearHistory { session_id } => {
session_manager.clear_dialog_history(&session_id)
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
GetSlashCommands { channel: _, chat_id: _ } => {
CreateDialog {
channel,
chat_id,
title,
} => session_manager
.create_dialog(&channel, &chat_id, title.as_deref())
.await
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())),
ListDialogs {
channel,
chat_id,
include_archived,
} => session_manager
.list_dialogs(&channel, &chat_id, include_archived)
.await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
dialogs,
current_dialog_id,
})
.map_err(|e| ChannelError::Other(e.to_string())),
GetCurrentDialog { channel, chat_id } => session_manager
.get_current_dialog(&channel, &chat_id)
.await
.map(|session_id| SessionEvent::CurrentDialog { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
SwitchDialog {
channel,
chat_id,
dialog_id,
} => session_manager
.switch_dialog(&channel, &chat_id, &dialog_id)
.await
.map(|session_id| SessionEvent::DialogSwitched { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
RenameDialog { session_id, title } => session_manager
.rename_dialog(&session_id, &title)
.await
.map(|()| SessionEvent::DialogRenamed { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())),
ArchiveDialog { session_id } => session_manager
.archive_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
DeleteDialog { session_id } => session_manager
.delete_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => session_manager
.clear_dialog_history(&session_id)
.await
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands {
channel: _,
chat_id: _,
} => {
let commands = session_manager.get_slash_commands().to_vec();
Ok(SessionEvent::SlashCommandsList { commands })
}
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => {
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref())
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ExecuteSlashCommand {
command,
args,
channel,
chat_id,
current_session_id,
} => session_manager
.execute_slash_command(
&command,
args.as_deref(),
&channel,
&chat_id,
current_session_id.as_ref(),
)
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
new_session_id: new_id,
message: msg,
})
.map_err(|e| ChannelError::Other(e.to_string())),
};
let _ = reply_tx.send(result).await;
}
}
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
logging::init_logging();
tracing::info!("Starting PicoBot Gateway");

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,17 @@
pub mod background_task;
pub mod error;
pub mod memory;
pub mod message;
pub mod background_task;
pub mod scheduler;
pub mod session;
pub use error::StorageError;
pub use background_task::BackgroundTask;
pub use error::StorageError;
pub use scheduler::{JobRun, ScheduledJob};
use sqlx::{Pool, Row, Sqlite, SqlitePool};
use tokio::time::{sleep, Duration};
use std::path::Path;
use tokio::time::{Duration, sleep};
pub struct Storage {
pub(crate) pool: Pool<Sqlite>,
@ -42,6 +42,7 @@ impl Storage {
last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0,
routing_info TEXT,
archived_at INTEGER,
deleted_at INTEGER,
last_consolidated_at INTEGER,
last_compressed_message_at INTEGER,
@ -92,20 +93,16 @@ impl Storage {
.await?;
// Migration: add source column if upgrading from older schema
sqlx::query(
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
)
.execute(&self.pool)
.await
.ok();
sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
.execute(&self.pool)
.await
.ok();
// Migration: add reasoning_content column if upgrading from older schema
sqlx::query(
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#,
)
.execute(&self.pool)
.await
.ok();
sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
.execute(&self.pool)
.await
.ok();
// Background tasks table — for async sub-agent tasks.
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
@ -216,11 +213,19 @@ impl Storage {
.await?;
// Rebuild FTS5 index for any existing records
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
.execute(&self.pool)
.await?;
// Migration: add last_consolidated_at column if not exists
sqlx::query(
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
r#"
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
"#,
)
.execute(&self.pool)
.await?;
.await
.ok();
// Migration: add last_consolidated_at column if not exists
sqlx::query(
@ -260,7 +265,10 @@ impl Storage {
.await?;
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
tracing::warn!(
"Failed to init scheduler schema (tables may already exist): {}",
e
);
}
Ok(())
@ -374,16 +382,20 @@ impl Storage {
&self.pool
}
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
pub async fn upsert_session(
&self,
meta: &crate::storage::session::SessionMeta,
) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
title = excluded.title,
last_active_at = excluded.last_active_at,
message_count = excluded.message_count,
routing_info = excluded.routing_info,
archived_at = excluded.archived_at,
deleted_at = excluded.deleted_at,
last_consolidated_at = excluded.last_consolidated_at,
last_compressed_message_at = excluded.last_compressed_message_at
@ -398,6 +410,7 @@ impl Storage {
.bind(meta.last_active_at)
.bind(meta.message_count)
.bind(&meta.routing_info)
.bind(meta.archived_at)
.bind(meta.deleted_at)
.bind(meta.last_consolidated_at)
.bind(meta.last_compressed_message_at)
@ -407,10 +420,13 @@ impl Storage {
Ok(())
}
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
pub async fn get_session(
&self,
id: &str,
) -> Result<crate::storage::session::SessionMeta, StorageError> {
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions WHERE id = ? AND deleted_at IS NULL
"#,
)
@ -429,6 +445,7 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -440,18 +457,21 @@ impl Storage {
channel: &str,
chat_id: &str,
limit: i64,
include_archived: bool,
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
AND (? OR archived_at IS NULL)
ORDER BY last_active_at DESC
LIMIT ?
"#,
)
.bind(channel)
.bind(chat_id)
.bind(include_archived)
.bind(limit)
.fetch_all(self.pool())
.await?;
@ -468,6 +488,7 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -498,13 +519,22 @@ impl Storage {
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
@ -516,9 +546,9 @@ impl Storage {
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
ORDER BY last_active_at DESC
LIMIT 1
"#,
@ -539,6 +569,7 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -547,7 +578,11 @@ impl Storage {
}
}
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
pub async fn append_message(
&self,
session_id: &str,
msg: &crate::storage::message::MessageMeta,
) -> Result<i64, StorageError> {
sqlx::query(
r#"
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
@ -674,16 +709,15 @@ impl Storage {
offset: i64,
limit: i64,
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
let count_row = sqlx::query(
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL",
)
.fetch_one(self.pool())
.await?;
let count_row =
sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
.fetch_one(self.pool())
.await?;
let total: i64 = count_row.get("total");
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions
WHERE deleted_at IS NULL
ORDER BY last_active_at DESC
@ -707,6 +741,7 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -772,7 +807,10 @@ impl Storage {
where_extra.push_str(" AND created_at > ?");
}
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
let count_sql = format!(
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
where_extra
);
let select_sql = format!(
r#"
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
@ -1030,6 +1068,7 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -1066,14 +1105,18 @@ mod tests {
last_active_at: i as i64 * 1000,
message_count: i,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
let sessions = storage
.list_sessions("cli_chat", "sid123", 10, false)
.await
.unwrap();
assert_eq!(sessions.len(), 5);
// 按 last_active_at DESC 排序
assert_eq!(sessions[0].dialog_id, "dialog4");
@ -1093,6 +1136,7 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -1120,6 +1164,7 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -1141,7 +1186,10 @@ mod tests {
created_at: 1000,
};
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
let seq = storage
.append_message(&session_meta.id, &msg)
.await
.unwrap();
assert_eq!(seq, 1);
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
@ -1163,6 +1211,7 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,

View File

@ -11,6 +11,7 @@ pub struct SessionMeta {
pub last_active_at: i64,
pub message_count: i64,
pub routing_info: Option<String>,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub last_consolidated_at: Option<i64>,
pub last_compressed_message_at: Option<i64>,

View File

@ -126,7 +126,10 @@ impl ChatManagerTool {
let start_num = offset + 1;
let end_num = offset + sessions.len() as i64;
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num);
let mut output = format!(
"全部会话 (共 {} 个,第 {}-{} 个):\n",
total, start_num, end_num
);
for s in &sessions {
let ago = format_duration_ago(now_ms - s.last_active_at);
@ -300,9 +303,10 @@ mod tests {
last_active_at: now - i * 3600_000,
message_count: i * 5,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
@ -335,6 +339,7 @@ mod tests {
last_active_at: now,
message_count: 3,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -346,7 +351,11 @@ mod tests {
id: format!("msg{}", i),
session_id: session_id.to_string(),
seq: i as i64 + 1,
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
role: if i == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
content: format!("消息内容 {}", i),
reasoning_content: None,
media_refs: None,
@ -392,6 +401,7 @@ mod tests {
last_active_at: now,
message_count: 5,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -403,7 +413,11 @@ mod tests {
id: format!("msg{}", i),
session_id: session_id.to_string(),
seq: i as i64 + 1,
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
role: if i % 2 == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
content: format!("消息内容 {}", i),
reasoning_content: None,
media_refs: None,
@ -447,6 +461,7 @@ mod tests {
last_active_at: now,
message_count: 5,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -492,10 +507,7 @@ mod tests {
let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec![]);
let result = tool
.execute(json!({ "action": "unknown" }))
.await
.unwrap();
let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}

View File

@ -1,5 +1,5 @@
use picobot::providers::{ChatCompletionRequest, Message};
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
/// Test that message with special characters is properly escaped
#[test]
@ -19,7 +19,9 @@ fn test_message_special_characters() {
#[test]
fn test_multiline_system_prompt() {
let messages = vec![
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
Message::system(
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
),
Message::user("Hi"),
];
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
#[test]
fn test_chat_request_serialization() {
let request = ChatCompletionRequest {
messages: vec![
Message::system("You are helpful"),
Message::user("Hello"),
],
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
temperature: Some(0.7),
max_tokens: Some(100),
tools: None,