Fix CLI session routing and dialog controls
This commit is contained in:
parent
0d66536e90
commit
c6f4392e63
359
docs/CODE_QUALITY_ANALYSIS.md
Normal file
359
docs/CODE_QUALITY_ANALYSIS.md
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# PicoBot 代码质量分析报告
|
||||||
|
|
||||||
|
审查日期:2026-06-15
|
||||||
|
|
||||||
|
## 结论摘要
|
||||||
|
|
||||||
|
PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只做收发,MessageBus 解耦输入输出,SessionManager 管理会话,AgentLoop 保持无状态并执行工具,Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
|
||||||
|
|
||||||
|
当前主要质量风险集中在三类:
|
||||||
|
|
||||||
|
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
|
||||||
|
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
|
||||||
|
3. 工具和后台任务的资源边界偏弱,文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
|
||||||
|
|
||||||
|
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
|
||||||
|
|
||||||
|
## 修复状态
|
||||||
|
|
||||||
|
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
||||||
|
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
||||||
|
- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
|
||||||
|
|
||||||
|
## 主要发现
|
||||||
|
|
||||||
|
### 已修复:CLI 会话路由会破坏会话连续性和多客户端隔离
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/channels/cli_chat.rs:113-126`
|
||||||
|
- `src/channels/cli_chat.rs:160-164`
|
||||||
|
- `src/channels/cli_chat.rs:225-249`
|
||||||
|
- `src/channels/cli_chat.rs:479-494`
|
||||||
|
- `src/session/session.rs:1305-1310`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
`Client.current_session_id` 存的是完整 session id,但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID,而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id,但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`。
|
||||||
|
|
||||||
|
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client,没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- CLI 多轮对话可能落入不同 chat scope。
|
||||||
|
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
|
||||||
|
- 多个 CLI 客户端同时连接时存在串话。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 将 client 状态拆成 `chat_id` 和 `current_session_id`,不要混用。
|
||||||
|
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
|
||||||
|
- `send()` 按 `OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
|
||||||
|
- `LoadSession` 应直接切换到指定 session,或通过 `SwitchDialog` 使用其中的 `dialog_id`。
|
||||||
|
- 为 CLI WebSocket 增加多客户端路由测试。
|
||||||
|
|
||||||
|
### 已修复:Dialog 控制接口与协议承诺不一致
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/session/session.rs:996-997`
|
||||||
|
- `src/session/session.rs:1305-1310`
|
||||||
|
- `src/session/session.rs:1329-1349`
|
||||||
|
- `src/session/session.rs:1378-1384`
|
||||||
|
- `src/channels/cli_chat.rs:128-158`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
|
||||||
|
|
||||||
|
- `/delete` 只创建新 session,并没有删除当前 session。
|
||||||
|
- `get_current_dialog()` 固定返回 `Ok(None)`。
|
||||||
|
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`。
|
||||||
|
- `archive_dialog()` 是空操作。
|
||||||
|
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
|
||||||
|
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
|
||||||
|
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`。
|
||||||
|
- `list_dialogs()` 返回真实 current dialog,并补上 archived 模型或移除 archived 参数。
|
||||||
|
|
||||||
|
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/tools/mod.rs:56-62`
|
||||||
|
- `src/tools/path_utils.rs:3-23`
|
||||||
|
- `src/tools/bash.rs:146-185`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
文件工具默认通过 `FileReadTool::new()`、`FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist,也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize,不能防御 `..`、符号链接等路径逃逸。
|
||||||
|
|
||||||
|
`bash` 默认工作目录是 `"."`,Gateway 启动时切到 workspace,这对相对路径有效,但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP,风险会放大。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
|
||||||
|
- `resolve_path()` 使用 `std::fs::canonicalize` 或 `path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
|
||||||
|
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
||||||
|
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
||||||
|
|
||||||
|
### 中高优先级:Session 锁内执行过多异步操作
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/session/session.rs:1001-1018`
|
||||||
|
- `src/session/session.rs:1604-1711`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
|
||||||
|
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
||||||
|
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 锁内只做内存状态快照和必要的状态标记。
|
||||||
|
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
|
||||||
|
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
|
||||||
|
|
||||||
|
### 中优先级:Bash 超时不会显式终止子进程
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/tools/bash.rs:150-174`
|
||||||
|
- `src/tools/bash.rs:180-207`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。
|
||||||
|
- 超时分支显式 kill child 并 wait。
|
||||||
|
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
|
||||||
|
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
|
||||||
|
|
||||||
|
### 中优先级:文件读取对大二进制文件没有输出上限
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/tools/file_read.rs:121-131`
|
||||||
|
- `src/tools/file_read.rs:214-229`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
`file_read` 先 `std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 先检查 metadata size,超过阈值直接返回提示。
|
||||||
|
- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。
|
||||||
|
- 对文本读取也改成流式按行读取,而不是整文件读入。
|
||||||
|
|
||||||
|
### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/tools/http_request.rs:31-59`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。
|
||||||
|
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
|
||||||
|
- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。
|
||||||
|
|
||||||
|
### 中优先级:后台任务和主循环缺少监督与优雅关闭
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/bus/mod.rs:51-99`
|
||||||
|
- `src/gateway/mod.rs:187-244`
|
||||||
|
- `src/gateway/mod.rs:247-266`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle,也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 某个后台 loop 异常退出后,Gateway 不一定能发现。
|
||||||
|
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
||||||
|
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。
|
||||||
|
- 用 `CancellationToken` 贯穿 Gateway 子任务。
|
||||||
|
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
|
||||||
|
|
||||||
|
### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/scheduler/mod.rs:18-40`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)` 或 `upcoming(tz)` 的当前时间。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
|
||||||
|
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
|
||||||
|
- 增加固定时间输入的单元测试,避免受系统时间影响。
|
||||||
|
|
||||||
|
### 中低优先级:存在未接入或半接入代码,增加维护噪音
|
||||||
|
|
||||||
|
位置:
|
||||||
|
|
||||||
|
- `src/tools/pty.rs`
|
||||||
|
- `src/tools/mod.rs:1-20`
|
||||||
|
- `src/tools/mod.rs:49-88`
|
||||||
|
|
||||||
|
问题:
|
||||||
|
|
||||||
|
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty`,`create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
|
||||||
|
- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。
|
||||||
|
|
||||||
|
## 架构评价
|
||||||
|
|
||||||
|
### 做得好的地方
|
||||||
|
|
||||||
|
- 模块分层方向清楚:Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
|
||||||
|
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
|
||||||
|
- Provider 抽象简单直接,OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
|
||||||
|
- Storage 集中初始化 schema,便于部署单二进制应用。
|
||||||
|
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
|
||||||
|
|
||||||
|
### 主要架构债务
|
||||||
|
|
||||||
|
- SessionManager 承担过多职责:会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
|
||||||
|
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
|
||||||
|
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
|
||||||
|
- 后台任务生命周期分散:gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn,缺少统一管理。
|
||||||
|
|
||||||
|
## 模块级分析
|
||||||
|
|
||||||
|
### gateway
|
||||||
|
|
||||||
|
`GatewayState::new()` 是清晰的装配中心:配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
|
||||||
|
|
||||||
|
### channels
|
||||||
|
|
||||||
|
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel,核心问题是会话身份混用和广播投递。
|
||||||
|
|
||||||
|
### bus
|
||||||
|
|
||||||
|
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
|
||||||
|
|
||||||
|
### session
|
||||||
|
|
||||||
|
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
|
||||||
|
|
||||||
|
- `manager.rs`:SessionManager 状态和 dialog 生命周期
|
||||||
|
- `worker.rs`:per-session agent worker 和 cancellation
|
||||||
|
- `commands.rs`:slash command 执行
|
||||||
|
- `outbound.rs`:OutboundMessenger 实现
|
||||||
|
- `restore.rs`:storage 恢复与 tool call chain repair
|
||||||
|
|
||||||
|
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
|
||||||
|
|
||||||
|
### agent
|
||||||
|
|
||||||
|
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义:`read_only()` 目前是工具自己声明,副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard,不应替代工具层的资源限制。
|
||||||
|
|
||||||
|
### providers
|
||||||
|
|
||||||
|
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response,可能包含用户隐私、API 返回内容和文件内容。
|
||||||
|
|
||||||
|
### tools
|
||||||
|
|
||||||
|
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里,默认值容易失控。
|
||||||
|
|
||||||
|
### storage
|
||||||
|
|
||||||
|
Storage schema 初始化实用,但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”,适合早期迭代,不适合长期演进。建议引入 schema version 表或 sqlx migrations,至少把每次迁移记录下来。
|
||||||
|
|
||||||
|
### skills
|
||||||
|
|
||||||
|
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
|
||||||
|
|
||||||
|
## 建议修复路线
|
||||||
|
|
||||||
|
### P0:先修会话正确性
|
||||||
|
|
||||||
|
1. 修正 CLI `chat_id/current_session_id` 数据模型。
|
||||||
|
2. 修正 CLI 出站按 client/chat_id 投递。
|
||||||
|
3. 实现 `get_current_dialog()`、`list_dialogs()` current 返回。
|
||||||
|
4. 修正 `/delete`、`clear_history`、`archive` 的真实行为或从协议移除。
|
||||||
|
5. 增加 WebSocket session lifecycle 测试。
|
||||||
|
|
||||||
|
### P1:收紧工具和资源边界
|
||||||
|
|
||||||
|
1. 文件工具默认限制 workspace,路径 canonicalize。
|
||||||
|
2. bash 超时杀进程,必要时引入进程组。
|
||||||
|
3. file_read 增加文件大小上限和二进制输出上限。
|
||||||
|
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
|
||||||
|
5. 明确高危工具的配置开关。
|
||||||
|
|
||||||
|
### P2:降低架构复杂度
|
||||||
|
|
||||||
|
1. 拆分 `session.rs`、`feishu.rs`、`storage/mod.rs`、`browser.rs`。
|
||||||
|
2. 引入任务 supervisor 和统一 shutdown token。
|
||||||
|
3. 引入正式数据库迁移。
|
||||||
|
4. 增加工具注册快照测试,避免死代码和文档漂移。
|
||||||
|
|
||||||
|
## 建议测试补充
|
||||||
|
|
||||||
|
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
|
||||||
|
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
|
||||||
|
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
|
||||||
|
- `/delete` 后旧 session 软删除、新 session 成为 current。
|
||||||
|
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
|
||||||
|
- bash timeout 后检查子进程不存在。
|
||||||
|
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
|
||||||
|
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。
|
||||||
@ -1,10 +1,10 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{Mutex, mpsc};
|
||||||
|
|
||||||
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
||||||
|
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
|
||||||
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
||||||
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
|
|
||||||
|
|
||||||
use super::base::{Channel, ChannelError};
|
use super::base::{Channel, ChannelError};
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ use super::base::{Channel, ChannelError};
|
|||||||
|
|
||||||
pub(crate) struct Client {
|
pub(crate) struct Client {
|
||||||
sender: mpsc::Sender<WsOutbound>,
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
|
chat_id: String,
|
||||||
current_session_id: Mutex<Option<String>>,
|
current_session_id: Mutex<Option<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,23 +42,28 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Register a new client connection, returns (session_id, client)
|
/// Register a new client connection, returns (session_id, client)
|
||||||
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
|
pub(crate) async fn register_client(
|
||||||
// Generate connection ID (used as chat_id) - use short ID
|
&self,
|
||||||
let connection_id = crate::util::short_id();
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
|
) -> (String, Arc<Client>) {
|
||||||
|
// Each WebSocket connection gets a stable chat scope. All user input and
|
||||||
|
// dialog controls for this client stay inside that scope unless the
|
||||||
|
// protocol explicitly carries a full session id.
|
||||||
|
let chat_id = crate::util::short_id();
|
||||||
|
|
||||||
let client = Arc::new(Client {
|
let client = Arc::new(Client {
|
||||||
sender,
|
sender,
|
||||||
|
chat_id: chat_id.clone(),
|
||||||
current_session_id: Mutex::new(None),
|
current_session_id: Mutex::new(None),
|
||||||
});
|
});
|
||||||
self.clients.lock().await.push(client.clone());
|
self.clients.lock().await.push(client.clone());
|
||||||
|
|
||||||
// Create initial session via control message
|
// Create initial session via control message
|
||||||
let session_id = match self.create_session_via_control(&connection_id, None).await {
|
let session_id = match self.create_session_via_control(&chat_id, None).await {
|
||||||
Ok(id) => id,
|
Ok((id, _title)) => id,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "Failed to create initial session");
|
tracing::error!(error = %e, "Failed to create initial session");
|
||||||
// Fall back to old format for backward compatibility
|
UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
|
||||||
connection_id.clone()
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -73,8 +79,7 @@ impl CliChatChannel {
|
|||||||
/// Handle an inbound message from a client
|
/// Handle an inbound message from a client
|
||||||
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
|
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
|
||||||
match parse_inbound(raw_msg) {
|
match parse_inbound(raw_msg) {
|
||||||
Ok(inbound) => {
|
Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
|
||||||
match self.handle_ws_inbound(client.clone(), inbound).await {
|
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to handle inbound message");
|
tracing::warn!(error = %e, "Failed to handle inbound message");
|
||||||
@ -86,8 +91,7 @@ impl CliChatChannel {
|
|||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||||
let _ = client
|
let _ = client
|
||||||
@ -101,22 +105,30 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> {
|
async fn handle_ws_inbound(
|
||||||
|
&self,
|
||||||
|
client: Arc<Client>,
|
||||||
|
inbound: WsInbound,
|
||||||
|
) -> Result<(), ChannelError> {
|
||||||
let bus = {
|
let bus = {
|
||||||
let guard = self.bus.lock().unwrap();
|
let guard = self.bus.lock().unwrap();
|
||||||
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
guard
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut current_session_guard = client.current_session_id.lock().await;
|
let mut current_session_guard = client.current_session_id.lock().await;
|
||||||
|
|
||||||
match inbound {
|
match inbound {
|
||||||
WsInbound::UserInput { content, chat_id, .. } => {
|
WsInbound::UserInput {
|
||||||
|
content, chat_id, ..
|
||||||
|
} => {
|
||||||
// All messages (including slash commands) go through the normal inbound flow
|
// All messages (including slash commands) go through the normal inbound flow
|
||||||
// SessionManager handles session creation/reuse internally
|
// SessionManager handles session creation/reuse internally
|
||||||
let msg = InboundMessage {
|
let msg = InboundMessage {
|
||||||
channel: self.name().to_string(),
|
channel: self.name().to_string(),
|
||||||
sender_id: "cli".to_string(),
|
sender_id: "cli".to_string(),
|
||||||
chat_id: chat_id.unwrap_or_else(crate::util::short_id),
|
chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
|
||||||
content,
|
content,
|
||||||
timestamp: crate::bus::message::current_timestamp(),
|
timestamp: crate::bus::message::current_timestamp(),
|
||||||
media: Vec::new(),
|
media: Vec::new(),
|
||||||
@ -125,19 +137,56 @@ impl CliChatChannel {
|
|||||||
};
|
};
|
||||||
bus.publish_inbound(msg).await?;
|
bus.publish_inbound(msg).await?;
|
||||||
}
|
}
|
||||||
WsInbound::ClearHistory { chat_id, session_id } => {
|
WsInbound::ClearHistory {
|
||||||
let target = session_id
|
chat_id,
|
||||||
.or(chat_id)
|
session_id,
|
||||||
.or(current_session_guard.clone())
|
} => {
|
||||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let session_id = UnifiedSessionId::parse(&target)
|
let session_id = if let Some(session_id) = session_id {
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
UnifiedSessionId::parse(&session_id).ok_or_else(|| {
|
||||||
|
ChannelError::Other("Invalid session ID format".to_string())
|
||||||
|
})?
|
||||||
|
} else if let Some(chat_id) = chat_id {
|
||||||
|
let (current_tx, mut current_rx) = mpsc::channel(1);
|
||||||
|
bus.publish_control(ControlMessage {
|
||||||
|
op: SessionCommand::GetCurrentDialog {
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id,
|
||||||
|
},
|
||||||
|
reply_tx: current_tx,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
match current_rx.recv().await {
|
||||||
|
Some(Ok(SessionEvent::CurrentDialog {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
})) => session_id,
|
||||||
|
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
|
||||||
|
return Err(ChannelError::Other("No active session".to_string()));
|
||||||
|
}
|
||||||
|
Some(Ok(_)) => {
|
||||||
|
return Err(ChannelError::Other(
|
||||||
|
"Unexpected response type".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Some(Err(e)) => return Err(e),
|
||||||
|
None => {
|
||||||
|
return Err(ChannelError::Other("Control channel closed".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let target = current_session_guard
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||||
|
UnifiedSessionId::parse(&target).ok_or_else(|| {
|
||||||
|
ChannelError::Other("Invalid session ID format".to_string())
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
let target = session_id.to_string();
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::ClearHistory { session_id },
|
op: SessionCommand::ClearHistory { session_id },
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
|
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
|
||||||
@ -158,24 +207,21 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::CreateSession { title } => {
|
WsInbound::CreateSession { title } => {
|
||||||
// Use current session's chat_id if available, otherwise generate new one
|
let (new_id, created_title) = self
|
||||||
let chat_id = current_session_guard.clone()
|
.create_session_via_control(&client.chat_id, title.as_deref())
|
||||||
.unwrap_or_else(crate::util::short_id);
|
.await?;
|
||||||
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
|
|
||||||
*current_session_guard = Some(new_id.clone());
|
*current_session_guard = Some(new_id.clone());
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionCreated {
|
.send(WsOutbound::SessionCreated {
|
||||||
session_id: new_id,
|
session_id: new_id,
|
||||||
title: title.unwrap_or_default(),
|
title: created_title,
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
WsInbound::ListSessions { include_archived } => {
|
WsInbound::ListSessions { include_archived } => {
|
||||||
// List dialogs for the current chat
|
// List dialogs for the current chat
|
||||||
let chat_id = current_session_guard.clone()
|
let chat_id = client.chat_id.clone();
|
||||||
.unwrap_or_else(|| "".to_string());
|
|
||||||
let chat_id_for_response = chat_id.clone();
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::ListDialogs {
|
op: SessionCommand::ListDialogs {
|
||||||
@ -184,13 +230,18 @@ impl CliChatChannel {
|
|||||||
include_archived,
|
include_archived,
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => {
|
Some(Ok(SessionEvent::DialogList {
|
||||||
|
dialogs,
|
||||||
|
current_dialog_id,
|
||||||
|
})) => {
|
||||||
// Convert DialogInfo to SessionSummary for backward compatibility
|
// Convert DialogInfo to SessionSummary for backward compatibility
|
||||||
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| {
|
let sessions: Vec<crate::protocol::SessionSummary> = dialogs
|
||||||
crate::protocol::SessionSummary {
|
.into_iter()
|
||||||
|
.map(|d| crate::protocol::SessionSummary {
|
||||||
session_id: d.session_id.to_string(),
|
session_id: d.session_id.to_string(),
|
||||||
title: d.title,
|
title: d.title,
|
||||||
channel_name: d.session_id.channel.clone(),
|
channel_name: d.session_id.channel.clone(),
|
||||||
@ -198,11 +249,14 @@ impl CliChatChannel {
|
|||||||
message_count: d.message_count,
|
message_count: d.message_count,
|
||||||
last_active_at: d.last_active_at,
|
last_active_at: d.last_active_at,
|
||||||
archived_at: d.archived_at,
|
archived_at: d.archived_at,
|
||||||
}
|
})
|
||||||
}).collect();
|
.collect();
|
||||||
let current_session_id = current_dialog_id.map(|did| {
|
let current_session_id = current_dialog_id.map(|did| {
|
||||||
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string()
|
UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
|
||||||
});
|
});
|
||||||
|
if let Some(ref session_id) = current_session_id {
|
||||||
|
*current_session_guard = Some(session_id.clone());
|
||||||
|
}
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionList {
|
.send(WsOutbound::SessionList {
|
||||||
@ -223,39 +277,35 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::LoadSession { session_id } => {
|
WsInbound::LoadSession { session_id } => {
|
||||||
// LoadSession: parse the session_id and get current dialog info
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&session_id)
|
let unified_id = UnifiedSessionId::parse(&session_id)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
|
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
|
||||||
|
return Err(ChannelError::Other(
|
||||||
|
"Session does not belong to this client".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::GetCurrentDialog {
|
op: SessionCommand::SwitchDialog {
|
||||||
channel: unified_id.channel.clone(),
|
channel: unified_id.channel.clone(),
|
||||||
chat_id: unified_id.chat_id.clone(),
|
chat_id: unified_id.chat_id.clone(),
|
||||||
|
dialog_id: unified_id.dialog_id.clone(),
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => {
|
Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
|
||||||
if let Some(current_session_id) = current_session_id_opt {
|
*current_session_guard = Some(session_id.to_string());
|
||||||
*current_session_guard = Some(current_session_id.to_string());
|
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionLoaded {
|
.send(WsOutbound::SessionLoaded {
|
||||||
session_id: current_session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
title: "Session".to_string(), // TODO: get actual title
|
title: "Session".to_string(),
|
||||||
message_count: 0, // TODO: get actual count
|
message_count: 0,
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
} else {
|
|
||||||
let _ = client
|
|
||||||
.sender
|
|
||||||
.send(WsOutbound::Error {
|
|
||||||
code: "NO_CURRENT_DIALOG".to_string(),
|
|
||||||
message: "No current dialog".to_string(),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {
|
Some(Ok(_)) => {
|
||||||
// Unexpected response type
|
// Unexpected response type
|
||||||
@ -275,23 +325,30 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::RenameSession { session_id, title } => {
|
WsInbound::RenameSession { session_id, title } => {
|
||||||
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
let target = session_id
|
||||||
ChannelError::Other("No active session".to_string())
|
.or(current_session_guard.clone())
|
||||||
})?;
|
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&target)
|
let unified_id = UnifiedSessionId::parse(&target)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() },
|
op: SessionCommand::RenameDialog {
|
||||||
|
session_id: unified_id,
|
||||||
|
title: title.clone(),
|
||||||
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
|
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title })
|
.send(WsOutbound::SessionRenamed {
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
title,
|
||||||
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {
|
Some(Ok(_)) => {
|
||||||
@ -306,24 +363,43 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::ArchiveSession { session_id } => {
|
WsInbound::ArchiveSession { session_id } => {
|
||||||
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
let target = session_id
|
||||||
ChannelError::Other("No active session".to_string())
|
.or(current_session_guard.clone())
|
||||||
})?;
|
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||||
|
let was_current = current_session_guard.as_deref() == Some(&target);
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&target)
|
let unified_id = UnifiedSessionId::parse(&target)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::ArchiveDialog { session_id: unified_id },
|
op: SessionCommand::ArchiveDialog {
|
||||||
|
session_id: unified_id,
|
||||||
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
|
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() })
|
.send(WsOutbound::SessionArchived {
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
})
|
||||||
.await;
|
.await;
|
||||||
|
if was_current {
|
||||||
|
let (new_id, title) = self
|
||||||
|
.create_session_via_control(&client.chat_id, None)
|
||||||
|
.await?;
|
||||||
|
*current_session_guard = Some(new_id.clone());
|
||||||
|
let _ = client
|
||||||
|
.sender
|
||||||
|
.send(WsOutbound::SessionCreated {
|
||||||
|
session_id: new_id,
|
||||||
|
title,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {
|
Some(Ok(_)) => {
|
||||||
// Unexpected response type
|
// Unexpected response type
|
||||||
@ -337,35 +413,42 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::DeleteSession { session_id } => {
|
WsInbound::DeleteSession { session_id } => {
|
||||||
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
let target = session_id
|
||||||
ChannelError::Other("No active session".to_string())
|
.or(current_session_guard.clone())
|
||||||
})?;
|
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&target)
|
let unified_id = UnifiedSessionId::parse(&target)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::DeleteDialog { session_id: unified_id },
|
op: SessionCommand::DeleteDialog {
|
||||||
|
session_id: unified_id,
|
||||||
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
|
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() })
|
.send(WsOutbound::SessionDeleted {
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// If deleting current session, create a new one
|
// If deleting current session, create a new one
|
||||||
if current_session_guard.as_deref() == Some(&target) {
|
if current_session_guard.as_deref() == Some(&target) {
|
||||||
drop(reply_rx);
|
drop(reply_rx);
|
||||||
if let Ok(new_id) = self.create_session_via_control(&target, None).await {
|
if let Ok((new_id, title)) =
|
||||||
|
self.create_session_via_control(&client.chat_id, None).await
|
||||||
|
{
|
||||||
*current_session_guard = Some(new_id.clone());
|
*current_session_guard = Some(new_id.clone());
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionCreated {
|
.send(WsOutbound::SessionCreated {
|
||||||
session_id: new_id,
|
session_id: new_id,
|
||||||
title: String::new(),
|
title,
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
@ -388,32 +471,45 @@ impl CliChatChannel {
|
|||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::GetSlashCommands {
|
op: SessionCommand::GetSlashCommands {
|
||||||
channel: "cli_chat".to_string(),
|
channel: "cli_chat".to_string(),
|
||||||
chat_id: "".to_string(),
|
chat_id: client.chat_id.clone(),
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
if let Some(result) = reply_rx.recv().await {
|
if let Some(result) = reply_rx.recv().await {
|
||||||
match result {
|
match result {
|
||||||
Ok(SessionEvent::SlashCommandsList { commands }) => {
|
Ok(SessionEvent::SlashCommandsList { commands }) => {
|
||||||
// Convert to SlashCommand to SlashCommandInfo
|
// Convert to SlashCommand to SlashCommandInfo
|
||||||
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| {
|
let command_infos: Vec<SlashCommandInfo> = commands
|
||||||
SlashCommandInfo {
|
.into_iter()
|
||||||
|
.map(|cmd| SlashCommandInfo {
|
||||||
name: cmd.name.to_string(),
|
name: cmd.name.to_string(),
|
||||||
description: cmd.description.to_string(),
|
description: cmd.description.to_string(),
|
||||||
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
|
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
|
||||||
}
|
})
|
||||||
}).collect();
|
.collect();
|
||||||
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await;
|
let _ = client
|
||||||
|
.sender
|
||||||
|
.send(WsOutbound::SlashCommandsList {
|
||||||
|
commands: command_infos,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
Ok(SessionEvent::Error { code, message }) => {
|
Ok(SessionEvent::Error { code, message }) => {
|
||||||
let _ = client.sender.send(WsOutbound::Error { code, message }).await;
|
let _ = client
|
||||||
|
.sender
|
||||||
|
.send(WsOutbound::Error { code, message })
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let _ = client.sender.send(WsOutbound::Error {
|
let _ = client
|
||||||
|
.sender
|
||||||
|
.send(WsOutbound::Error {
|
||||||
code: "GET_COMMANDS_ERROR".to_string(),
|
code: "GET_COMMANDS_ERROR".to_string(),
|
||||||
message: e.to_string()
|
message: e.to_string(),
|
||||||
}).await;
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
@ -427,29 +523,34 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a session via control message and return the session_id
|
/// Create a session via control message and return the session_id
|
||||||
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> {
|
async fn create_session_via_control(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
) -> Result<(String, String), ChannelError> {
|
||||||
let bus = {
|
let bus = {
|
||||||
let guard = self.bus.lock().unwrap();
|
let guard = self.bus.lock().unwrap();
|
||||||
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
guard
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||||
};
|
};
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::CreateDialog {
|
op: SessionCommand::CreateDialog {
|
||||||
channel: "cli_chat".to_string(),
|
channel: "cli_chat".to_string(),
|
||||||
chat_id: connection_id.to_string(),
|
chat_id: chat_id.to_string(),
|
||||||
title: title.map(String::from),
|
title: title.map(String::from),
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => {
|
Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
|
||||||
Ok(session_id.to_string())
|
Ok((session_id.to_string(), title))
|
||||||
}
|
|
||||||
Some(Ok(_)) => {
|
|
||||||
Err(ChannelError::Other("Unexpected response type".to_string()))
|
|
||||||
}
|
}
|
||||||
|
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
|
||||||
Some(Err(e)) => Err(e),
|
Some(Err(e)) => Err(e),
|
||||||
None => Err(ChannelError::Other("Control channel closed".to_string())),
|
None => Err(ChannelError::Other("Control channel closed".to_string())),
|
||||||
}
|
}
|
||||||
@ -479,7 +580,11 @@ impl Channel for CliChatChannel {
|
|||||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
let clients = self.clients.lock().await.clone();
|
let clients = self.clients.lock().await.clone();
|
||||||
for client in clients {
|
for client in clients {
|
||||||
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
|
if client.chat_id != msg.chat_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
|
||||||
|
{
|
||||||
WsOutbound::SystemNotification {
|
WsOutbound::SystemNotification {
|
||||||
content: msg.content.clone(),
|
content: msg.content.clone(),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,19 +1,19 @@
|
|||||||
pub mod http;
|
pub mod http;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
|
use axum::{Router, routing};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use axum::{routing, Router};
|
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
||||||
use crate::channels::{ChannelManager, CliChatChannel};
|
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
use crate::channels::{ChannelManager, CliChatChannel};
|
||||||
|
use crate::config::{Config, ensure_workspace_dir, expand_path};
|
||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::mcp;
|
use crate::mcp;
|
||||||
use crate::memory::MemoryManager;
|
use crate::memory::MemoryManager;
|
||||||
use crate::session::SessionManager;
|
|
||||||
use crate::scheduler::Scheduler;
|
use crate::scheduler::Scheduler;
|
||||||
|
use crate::session::SessionManager;
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
@ -32,8 +32,13 @@ impl GatewayState {
|
|||||||
let workspace_path = ensure_workspace_dir(&workspace_path)?;
|
let workspace_path = ensure_workspace_dir(&workspace_path)?;
|
||||||
|
|
||||||
// Switch current working directory to workspace
|
// Switch current working directory to workspace
|
||||||
std::env::set_current_dir(&workspace_path)
|
std::env::set_current_dir(&workspace_path).map_err(|e| {
|
||||||
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?;
|
format!(
|
||||||
|
"Failed to switch to workspace directory {}: {}",
|
||||||
|
workspace_path.display(),
|
||||||
|
e
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
tracing::info!("Using workspace directory: {}", workspace_path.display());
|
tracing::info!("Using workspace directory: {}", workspace_path.display());
|
||||||
|
|
||||||
@ -52,8 +57,9 @@ impl GatewayState {
|
|||||||
workspace_path.join("picobot.db")
|
workspace_path.join("picobot.db")
|
||||||
};
|
};
|
||||||
let storage = Arc::new(
|
let storage = Arc::new(
|
||||||
crate::storage::Storage::new(&db_path).await
|
crate::storage::Storage::new(&db_path)
|
||||||
.map_err(|e| format!("failed to initialize session storage: {}", e))?
|
.await
|
||||||
|
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
|
||||||
);
|
);
|
||||||
tracing::info!("Session storage: {}", db_path.display());
|
tracing::info!("Session storage: {}", db_path.display());
|
||||||
|
|
||||||
@ -98,7 +104,9 @@ impl GatewayState {
|
|||||||
// Create ChannelManager and init channels
|
// Create ChannelManager and init channels
|
||||||
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
||||||
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
||||||
channel_manager.init(&config, workspace_path.clone()).await
|
channel_manager
|
||||||
|
.init(&config, workspace_path.clone())
|
||||||
|
.await
|
||||||
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
||||||
|
|
||||||
// Register send_message tool with available channel names
|
// Register send_message tool with available channel names
|
||||||
@ -107,9 +115,12 @@ impl GatewayState {
|
|||||||
session_manager.register_outbound_tool(available_channels);
|
session_manager.register_outbound_tool(available_channels);
|
||||||
|
|
||||||
// Register chat_manager tool
|
// Register chat_manager tool
|
||||||
session_manager.tools().register(
|
session_manager
|
||||||
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
|
.tools()
|
||||||
);
|
.register(crate::tools::ChatManagerTool::new(
|
||||||
|
storage.clone(),
|
||||||
|
valid_channels.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
// Initialize MCP servers — connect and register discovered tools
|
// Initialize MCP servers — connect and register discovered tools
|
||||||
if !config.mcp.servers.is_empty() {
|
if !config.mcp.servers.is_empty() {
|
||||||
@ -130,24 +141,27 @@ impl GatewayState {
|
|||||||
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
||||||
if scheduler_config.enabled {
|
if scheduler_config.enabled {
|
||||||
// Register cron tools
|
// Register cron tools
|
||||||
session_manager.tools().register(
|
session_manager
|
||||||
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
|
.tools()
|
||||||
);
|
.register(crate::tools::cron::CronAddTool::new(
|
||||||
session_manager.tools().register(
|
storage.clone(),
|
||||||
crate::tools::cron::CronListTool::new(storage.clone()),
|
valid_channels,
|
||||||
);
|
));
|
||||||
session_manager.tools().register(
|
session_manager
|
||||||
crate::tools::cron::CronRemoveTool::new(storage.clone()),
|
.tools()
|
||||||
);
|
.register(crate::tools::cron::CronListTool::new(storage.clone()));
|
||||||
session_manager.tools().register(
|
session_manager
|
||||||
crate::tools::cron::CronEnableTool::new(storage.clone()),
|
.tools()
|
||||||
);
|
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
|
||||||
session_manager.tools().register(
|
session_manager
|
||||||
crate::tools::cron::CronDisableTool::new(storage.clone()),
|
.tools()
|
||||||
);
|
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
|
||||||
session_manager.tools().register(
|
session_manager
|
||||||
crate::tools::cron::CronUpdateTool::new(storage.clone()),
|
.tools()
|
||||||
);
|
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
|
||||||
|
session_manager
|
||||||
|
.tools()
|
||||||
|
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
|
||||||
tracing::info!("Cron tools registered");
|
tracing::info!("Cron tools registered");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,71 +282,103 @@ impl GatewayState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle control messages (session management operations)
|
/// Handle control messages (session management operations)
|
||||||
async fn handle_control_message(
|
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
|
||||||
session_manager: &SessionManager,
|
|
||||||
msg: ControlMessage,
|
|
||||||
) {
|
|
||||||
use crate::session::{SessionCommand::*, SessionEvent};
|
use crate::session::{SessionCommand::*, SessionEvent};
|
||||||
|
|
||||||
let reply_tx = msg.reply_tx;
|
let reply_tx = msg.reply_tx;
|
||||||
let result: Result<SessionEvent, ChannelError> = match msg.op {
|
let result: Result<SessionEvent, ChannelError> = match msg.op {
|
||||||
CreateDialog { channel, chat_id, title } => {
|
CreateDialog {
|
||||||
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await
|
channel,
|
||||||
|
chat_id,
|
||||||
|
title,
|
||||||
|
} => session_manager
|
||||||
|
.create_dialog(&channel, &chat_id, title.as_deref())
|
||||||
|
.await
|
||||||
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
ListDialogs {
|
||||||
ListDialogs { channel, chat_id, include_archived } => {
|
channel,
|
||||||
session_manager.list_dialogs(&channel, &chat_id, include_archived).await
|
chat_id,
|
||||||
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id })
|
include_archived,
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
} => session_manager
|
||||||
}
|
.list_dialogs(&channel, &chat_id, include_archived)
|
||||||
GetCurrentDialog { channel, chat_id } => {
|
.await
|
||||||
session_manager.get_current_dialog(&channel, &chat_id).await
|
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
|
||||||
|
dialogs,
|
||||||
|
current_dialog_id,
|
||||||
|
})
|
||||||
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
|
GetCurrentDialog { channel, chat_id } => session_manager
|
||||||
|
.get_current_dialog(&channel, &chat_id)
|
||||||
|
.await
|
||||||
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
SwitchDialog {
|
||||||
SwitchDialog { channel, chat_id, dialog_id } => {
|
channel,
|
||||||
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await
|
chat_id,
|
||||||
|
dialog_id,
|
||||||
|
} => session_manager
|
||||||
|
.switch_dialog(&channel, &chat_id, &dialog_id)
|
||||||
|
.await
|
||||||
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
RenameDialog { session_id, title } => session_manager
|
||||||
RenameDialog { session_id, title } => {
|
.rename_dialog(&session_id, &title)
|
||||||
session_manager.rename_dialog(&session_id, &title).await
|
.await
|
||||||
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
ArchiveDialog { session_id } => session_manager
|
||||||
ArchiveDialog { session_id } => {
|
.archive_dialog(&session_id)
|
||||||
session_manager.archive_dialog(&session_id)
|
.await
|
||||||
.map(|()| SessionEvent::DialogArchived { session_id })
|
.map(|()| SessionEvent::DialogArchived { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
DeleteDialog { session_id } => session_manager
|
||||||
DeleteDialog { session_id } => {
|
.delete_dialog(&session_id)
|
||||||
session_manager.delete_dialog(&session_id).await
|
.await
|
||||||
.map(|()| SessionEvent::DialogDeleted { session_id })
|
.map(|()| SessionEvent::DialogDeleted { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
ClearHistory { session_id } => session_manager
|
||||||
ClearHistory { session_id } => {
|
.clear_dialog_history(&session_id)
|
||||||
session_manager.clear_dialog_history(&session_id)
|
.await
|
||||||
.map(|()| SessionEvent::HistoryCleared { session_id })
|
.map(|()| SessionEvent::HistoryCleared { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
}
|
GetSlashCommands {
|
||||||
GetSlashCommands { channel: _, chat_id: _ } => {
|
channel: _,
|
||||||
|
chat_id: _,
|
||||||
|
} => {
|
||||||
let commands = session_manager.get_slash_commands().to_vec();
|
let commands = session_manager.get_slash_commands().to_vec();
|
||||||
Ok(SessionEvent::SlashCommandsList { commands })
|
Ok(SessionEvent::SlashCommandsList { commands })
|
||||||
}
|
}
|
||||||
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => {
|
ExecuteSlashCommand {
|
||||||
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref())
|
command,
|
||||||
|
args,
|
||||||
|
channel,
|
||||||
|
chat_id,
|
||||||
|
current_session_id,
|
||||||
|
} => session_manager
|
||||||
|
.execute_slash_command(
|
||||||
|
&command,
|
||||||
|
args.as_deref(),
|
||||||
|
&channel,
|
||||||
|
&chat_id,
|
||||||
|
current_session_id.as_ref(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg })
|
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
new_session_id: new_id,
|
||||||
}
|
message: msg,
|
||||||
|
})
|
||||||
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||||
};
|
};
|
||||||
|
|
||||||
let _ = reply_tx.send(result).await;
|
let _ = reply_tx.send(result).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(
|
||||||
|
host: Option<String>,
|
||||||
|
port: Option<u16>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Initialize logging
|
// Initialize logging
|
||||||
logging::init_logging();
|
logging::init_logging();
|
||||||
tracing::info!("Starting PicoBot Gateway");
|
tracing::info!("Starting PicoBot Gateway");
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,17 +1,17 @@
|
|||||||
|
pub mod background_task;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod message;
|
pub mod message;
|
||||||
pub mod background_task;
|
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
|
||||||
pub use error::StorageError;
|
|
||||||
pub use background_task::BackgroundTask;
|
pub use background_task::BackgroundTask;
|
||||||
|
pub use error::StorageError;
|
||||||
pub use scheduler::{JobRun, ScheduledJob};
|
pub use scheduler::{JobRun, ScheduledJob};
|
||||||
|
|
||||||
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
||||||
use tokio::time::{sleep, Duration};
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use tokio::time::{Duration, sleep};
|
||||||
|
|
||||||
pub struct Storage {
|
pub struct Storage {
|
||||||
pub(crate) pool: Pool<Sqlite>,
|
pub(crate) pool: Pool<Sqlite>,
|
||||||
@ -42,6 +42,7 @@ impl Storage {
|
|||||||
last_active_at INTEGER NOT NULL,
|
last_active_at INTEGER NOT NULL,
|
||||||
message_count INTEGER DEFAULT 0,
|
message_count INTEGER DEFAULT 0,
|
||||||
routing_info TEXT,
|
routing_info TEXT,
|
||||||
|
archived_at INTEGER,
|
||||||
deleted_at INTEGER,
|
deleted_at INTEGER,
|
||||||
last_consolidated_at INTEGER,
|
last_consolidated_at INTEGER,
|
||||||
last_compressed_message_at INTEGER,
|
last_compressed_message_at INTEGER,
|
||||||
@ -92,17 +93,13 @@ impl Storage {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Migration: add source column if upgrading from older schema
|
// Migration: add source column if upgrading from older schema
|
||||||
sqlx::query(
|
sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
|
||||||
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
|
|
||||||
)
|
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
// Migration: add reasoning_content column if upgrading from older schema
|
// Migration: add reasoning_content column if upgrading from older schema
|
||||||
sqlx::query(
|
sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
|
||||||
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#,
|
|
||||||
)
|
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
@ -216,12 +213,20 @@ impl Storage {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Rebuild FTS5 index for any existing records
|
// Rebuild FTS5 index for any existing records
|
||||||
sqlx::query(
|
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
|
||||||
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
|
|
||||||
)
|
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Migration: add last_consolidated_at column if not exists
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
|
||||||
// Migration: add last_consolidated_at column if not exists
|
// Migration: add last_consolidated_at column if not exists
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
@ -260,7 +265,10 @@ impl Storage {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
|
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
|
||||||
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
|
tracing::warn!(
|
||||||
|
"Failed to init scheduler schema (tables may already exist): {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -374,16 +382,20 @@ impl Storage {
|
|||||||
&self.pool
|
&self.pool
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
|
pub async fn upsert_session(
|
||||||
|
&self,
|
||||||
|
meta: &crate::storage::session::SessionMeta,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at)
|
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT(id) DO UPDATE SET
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
title = excluded.title,
|
title = excluded.title,
|
||||||
last_active_at = excluded.last_active_at,
|
last_active_at = excluded.last_active_at,
|
||||||
message_count = excluded.message_count,
|
message_count = excluded.message_count,
|
||||||
routing_info = excluded.routing_info,
|
routing_info = excluded.routing_info,
|
||||||
|
archived_at = excluded.archived_at,
|
||||||
deleted_at = excluded.deleted_at,
|
deleted_at = excluded.deleted_at,
|
||||||
last_consolidated_at = excluded.last_consolidated_at,
|
last_consolidated_at = excluded.last_consolidated_at,
|
||||||
last_compressed_message_at = excluded.last_compressed_message_at
|
last_compressed_message_at = excluded.last_compressed_message_at
|
||||||
@ -398,6 +410,7 @@ impl Storage {
|
|||||||
.bind(meta.last_active_at)
|
.bind(meta.last_active_at)
|
||||||
.bind(meta.message_count)
|
.bind(meta.message_count)
|
||||||
.bind(&meta.routing_info)
|
.bind(&meta.routing_info)
|
||||||
|
.bind(meta.archived_at)
|
||||||
.bind(meta.deleted_at)
|
.bind(meta.deleted_at)
|
||||||
.bind(meta.last_consolidated_at)
|
.bind(meta.last_consolidated_at)
|
||||||
.bind(meta.last_compressed_message_at)
|
.bind(meta.last_compressed_message_at)
|
||||||
@ -407,10 +420,13 @@ impl Storage {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
pub async fn get_session(
|
||||||
|
&self,
|
||||||
|
id: &str,
|
||||||
|
) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
||||||
let row = sqlx::query(
|
let row = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
@ -429,6 +445,7 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
|
archived_at: row.get("archived_at"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -440,18 +457,21 @@ impl Storage {
|
|||||||
channel: &str,
|
channel: &str,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
limit: i64,
|
limit: i64,
|
||||||
|
include_archived: bool,
|
||||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
let rows = sqlx::query(
|
let rows = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||||
|
AND (? OR archived_at IS NULL)
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(channel)
|
.bind(channel)
|
||||||
.bind(chat_id)
|
.bind(chat_id)
|
||||||
|
.bind(include_archived)
|
||||||
.bind(limit)
|
.bind(limit)
|
||||||
.fetch_all(self.pool())
|
.fetch_all(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
@ -468,6 +488,7 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
|
archived_at: row.get("archived_at"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -498,9 +519,18 @@ impl Storage {
|
|||||||
|
|
||||||
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
sqlx::query(
|
sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
|
||||||
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
|
.bind(now)
|
||||||
)
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
|
||||||
.bind(now)
|
.bind(now)
|
||||||
.bind(id)
|
.bind(id)
|
||||||
.execute(self.pool())
|
.execute(self.pool())
|
||||||
@ -516,9 +546,9 @@ impl Storage {
|
|||||||
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
|
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
let row = sqlx::query(
|
let row = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"#,
|
"#,
|
||||||
@ -539,6 +569,7 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
|
archived_at: row.get("archived_at"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -547,7 +578,11 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
|
pub async fn append_message(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
msg: &crate::storage::message::MessageMeta,
|
||||||
|
) -> Result<i64, StorageError> {
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
||||||
@ -674,16 +709,15 @@ impl Storage {
|
|||||||
offset: i64,
|
offset: i64,
|
||||||
limit: i64,
|
limit: i64,
|
||||||
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
|
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
|
||||||
let count_row = sqlx::query(
|
let count_row =
|
||||||
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL",
|
sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
|
||||||
)
|
|
||||||
.fetch_one(self.pool())
|
.fetch_one(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
let total: i64 = count_row.get("total");
|
let total: i64 = count_row.get("total");
|
||||||
|
|
||||||
let rows = sqlx::query(
|
let rows = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
@ -707,6 +741,7 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
|
archived_at: row.get("archived_at"),
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -772,7 +807,10 @@ impl Storage {
|
|||||||
where_extra.push_str(" AND created_at > ?");
|
where_extra.push_str(" AND created_at > ?");
|
||||||
}
|
}
|
||||||
|
|
||||||
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
|
let count_sql = format!(
|
||||||
|
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
|
||||||
|
where_extra
|
||||||
|
);
|
||||||
let select_sql = format!(
|
let select_sql = format!(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||||
@ -1030,6 +1068,7 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1066,6 +1105,7 @@ mod tests {
|
|||||||
last_active_at: i as i64 * 1000,
|
last_active_at: i as i64 * 1000,
|
||||||
message_count: i,
|
message_count: i,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1073,7 +1113,10 @@ mod tests {
|
|||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
|
let sessions = storage
|
||||||
|
.list_sessions("cli_chat", "sid123", 10, false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert_eq!(sessions.len(), 5);
|
assert_eq!(sessions.len(), 5);
|
||||||
// 按 last_active_at DESC 排序
|
// 按 last_active_at DESC 排序
|
||||||
assert_eq!(sessions[0].dialog_id, "dialog4");
|
assert_eq!(sessions[0].dialog_id, "dialog4");
|
||||||
@ -1093,6 +1136,7 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1120,6 +1164,7 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1141,7 +1186,10 @@ mod tests {
|
|||||||
created_at: 1000,
|
created_at: 1000,
|
||||||
};
|
};
|
||||||
|
|
||||||
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
|
let seq = storage
|
||||||
|
.append_message(&session_meta.id, &msg)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert_eq!(seq, 1);
|
assert_eq!(seq, 1);
|
||||||
|
|
||||||
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
||||||
@ -1163,6 +1211,7 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ pub struct SessionMeta {
|
|||||||
pub last_active_at: i64,
|
pub last_active_at: i64,
|
||||||
pub message_count: i64,
|
pub message_count: i64,
|
||||||
pub routing_info: Option<String>,
|
pub routing_info: Option<String>,
|
||||||
|
pub archived_at: Option<i64>,
|
||||||
pub deleted_at: Option<i64>,
|
pub deleted_at: Option<i64>,
|
||||||
pub last_consolidated_at: Option<i64>,
|
pub last_consolidated_at: Option<i64>,
|
||||||
pub last_compressed_message_at: Option<i64>,
|
pub last_compressed_message_at: Option<i64>,
|
||||||
|
|||||||
@ -126,7 +126,10 @@ impl ChatManagerTool {
|
|||||||
let start_num = offset + 1;
|
let start_num = offset + 1;
|
||||||
let end_num = offset + sessions.len() as i64;
|
let end_num = offset + sessions.len() as i64;
|
||||||
|
|
||||||
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num);
|
let mut output = format!(
|
||||||
|
"全部会话 (共 {} 个,第 {}-{} 个):\n",
|
||||||
|
total, start_num, end_num
|
||||||
|
);
|
||||||
|
|
||||||
for s in &sessions {
|
for s in &sessions {
|
||||||
let ago = format_duration_ago(now_ms - s.last_active_at);
|
let ago = format_duration_ago(now_ms - s.last_active_at);
|
||||||
@ -300,6 +303,7 @@ mod tests {
|
|||||||
last_active_at: now - i * 3600_000,
|
last_active_at: now - i * 3600_000,
|
||||||
message_count: i * 5,
|
message_count: i * 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -335,6 +339,7 @@ mod tests {
|
|||||||
last_active_at: now,
|
last_active_at: now,
|
||||||
message_count: 3,
|
message_count: 3,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -346,7 +351,11 @@ mod tests {
|
|||||||
id: format!("msg{}", i),
|
id: format!("msg{}", i),
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
seq: i as i64 + 1,
|
seq: i as i64 + 1,
|
||||||
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
|
role: if i == 0 {
|
||||||
|
"user".to_string()
|
||||||
|
} else {
|
||||||
|
"assistant".to_string()
|
||||||
|
},
|
||||||
content: format!("消息内容 {}", i),
|
content: format!("消息内容 {}", i),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
media_refs: None,
|
media_refs: None,
|
||||||
@ -392,6 +401,7 @@ mod tests {
|
|||||||
last_active_at: now,
|
last_active_at: now,
|
||||||
message_count: 5,
|
message_count: 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -403,7 +413,11 @@ mod tests {
|
|||||||
id: format!("msg{}", i),
|
id: format!("msg{}", i),
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
seq: i as i64 + 1,
|
seq: i as i64 + 1,
|
||||||
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
|
role: if i % 2 == 0 {
|
||||||
|
"user".to_string()
|
||||||
|
} else {
|
||||||
|
"assistant".to_string()
|
||||||
|
},
|
||||||
content: format!("消息内容 {}", i),
|
content: format!("消息内容 {}", i),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
media_refs: None,
|
media_refs: None,
|
||||||
@ -447,6 +461,7 @@ mod tests {
|
|||||||
last_active_at: now,
|
last_active_at: now,
|
||||||
message_count: 5,
|
message_count: 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
|
archived_at: None,
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -492,10 +507,7 @@ mod tests {
|
|||||||
let (storage, _dir) = create_test_storage().await;
|
let (storage, _dir) = create_test_storage().await;
|
||||||
let tool = ChatManagerTool::new(storage, vec![]);
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
let result = tool
|
let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
|
||||||
.execute(json!({ "action": "unknown" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Unknown action"));
|
assert!(result.error.unwrap().contains("Unknown action"));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use picobot::providers::{ChatCompletionRequest, Message};
|
|
||||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||||
|
use picobot::providers::{ChatCompletionRequest, Message};
|
||||||
|
|
||||||
/// Test that message with special characters is properly escaped
|
/// Test that message with special characters is properly escaped
|
||||||
#[test]
|
#[test]
|
||||||
@ -19,7 +19,9 @@ fn test_message_special_characters() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_multiline_system_prompt() {
|
fn test_multiline_system_prompt() {
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
|
Message::system(
|
||||||
|
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
|
||||||
|
),
|
||||||
Message::user("Hi"),
|
Message::user("Hi"),
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_chat_request_serialization() {
|
fn test_chat_request_serialization() {
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
|
||||||
Message::system("You are helpful"),
|
|
||||||
Message::user("Hello"),
|
|
||||||
],
|
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
tools: None,
|
tools: None,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user