diff --git a/docs/CODE_QUALITY_ANALYSIS.md b/docs/CODE_QUALITY_ANALYSIS.md new file mode 100644 index 0000000..41ef4b7 --- /dev/null +++ b/docs/CODE_QUALITY_ANALYSIS.md @@ -0,0 +1,359 @@ +# PicoBot 代码质量分析报告 + +审查日期:2026-06-15 + +## 结论摘要 + +PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只做收发,MessageBus 解耦输入输出,SessionManager 管理会话,AgentLoop 保持无状态并执行工具,Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。 + +当前主要质量风险集中在三类: + +1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。 +2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。 +3. 工具和后台任务的资源边界偏弱,文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。 + +如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。 + +## 修复状态 + +- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。 +- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。 +- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。 + +## 主要发现 + +### 已修复:CLI 会话路由会破坏会话连续性和多客户端隔离 + +位置: + +- `src/channels/cli_chat.rs:113-126` +- `src/channels/cli_chat.rs:160-164` +- `src/channels/cli_chat.rs:225-249` +- `src/channels/cli_chat.rs:479-494` +- `src/session/session.rs:1305-1310` + +问题: + +`Client.current_session_id` 存的是完整 session id,但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID,而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id,但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`。 + +同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client,没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。 + +影响: + +- CLI 多轮对话可能落入不同 chat scope。 +- 创建/列出/加载会话得到的结果可能不符合 UI 预期。 +- 多个 CLI 客户端同时连接时存在串话。 + +建议: + +- 将 client 状态拆成 `chat_id` 和 `current_session_id`,不要混用。 +- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。 +- `send()` 按 `OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。 +- `LoadSession` 应直接切换到指定 session,或通过 `SwitchDialog` 使用其中的 `dialog_id`。 +- 为 CLI WebSocket 增加多客户端路由测试。 + +### 已修复:Dialog 控制接口与协议承诺不一致 + +位置: + +- `src/session/session.rs:996-997` +- `src/session/session.rs:1305-1310` +- `src/session/session.rs:1329-1349` +- `src/session/session.rs:1378-1384` +- `src/channels/cli_chat.rs:128-158` + +问题: + +后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位: + +- `/delete` 只创建新 session,并没有删除当前 session。 +- `get_current_dialog()` 固定返回 `Ok(None)`。 +- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`。 +- `archive_dialog()` 是空操作。 +- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`。 + +影响: + +用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。 + +建议: + +- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。 +- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。 +- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`。 +- `list_dialogs()` 返回真实 current dialog,并补上 archived 模型或移除 archived 参数。 + +### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束 + +位置: + +- `src/tools/mod.rs:56-62` +- `src/tools/path_utils.rs:3-23` +- `src/tools/bash.rs:146-185` + +问题: + +文件工具默认通过 `FileReadTool::new()`、`FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist,也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize,不能防御 `..`、符号链接等路径逃逸。 + +`bash` 默认工作目录是 `"."`,Gateway 启动时切到 workspace,这对相对路径有效,但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。 + +影响: + +Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP,风险会放大。 + +建议: + +- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。 +- `resolve_path()` 使用 `std::fs::canonicalize` 或 `path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。 +- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。 +- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。 + +### 中高优先级:Session 锁内执行过多异步操作 + +位置: + +- `src/session/session.rs:1001-1018` +- `src/session/session.rs:1604-1711` + +问题: + +`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。 + +影响: + +- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。 +- 当压缩或存储出现抖动时,用户感觉像“卡死”。 +- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。 + +建议: + +- 锁内只做内存状态快照和必要的状态标记。 +- 将 memory recall、压缩、LLM 摘要放到锁外执行。 +- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。 + +### 中优先级:Bash 超时不会显式终止子进程 + +位置: + +- `src/tools/bash.rs:150-174` +- `src/tools/bash.rs:180-207` + +问题: + +`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。 + +影响: + +长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。 + +建议: + +- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。 +- 超时分支显式 kill child 并 wait。 +- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。 +- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。 + +### 中优先级:文件读取对大二进制文件没有输出上限 + +位置: + +- `src/tools/file_read.rs:121-131` +- `src/tools/file_read.rs:214-229` + +问题: + +`file_read` 先 `std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。 + +影响: + +读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。 + +建议: + +- 先检查 metadata size,超过阈值直接返回提示。 +- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。 +- 对文本读取也改成流式按行读取,而不是整文件读入。 + +### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验 + +位置: + +- `src/tools/http_request.rs:31-59` + +问题: + +`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。 + +影响: + +如果该工具暴露给非完全可信输入,存在 SSRF 风险。 + +建议: + +- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。 +- 禁止或限制重定向,重定向后的每个 URL 重新校验。 +- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。 + +### 中优先级:后台任务和主循环缺少监督与优雅关闭 + +位置: + +- `src/bus/mod.rs:51-99` +- `src/gateway/mod.rs:187-244` +- `src/gateway/mod.rs:247-266` + +问题: + +Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle,也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。 + +影响: + +- 某个后台 loop 异常退出后,Gateway 不一定能发现。 +- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。 +- bus channel 关闭时更像崩溃,而不是可恢复状态。 + +建议: + +- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。 +- 用 `CancellationToken` 贯穿 Gateway 子任务。 +- `consume_*()` 返回 `Result>`,由调用方决定退出或重启。 + +### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间 + +位置: + +- `src/scheduler/mod.rs:18-40` + +问题: + +`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)` 或 `upcoming(tz)` 的当前时间。 + +影响: + +单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。 + +建议: + +- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。 +- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。 +- 增加固定时间输入的单元测试,避免受系统时间影响。 + +### 中低优先级:存在未接入或半接入代码,增加维护噪音 + +位置: + +- `src/tools/pty.rs` +- `src/tools/mod.rs:1-20` +- `src/tools/mod.rs:49-88` + +问题: + +仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty`,`create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。 + +影响: + +维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。 + +建议: + +- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。 +- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。 + +## 架构评价 + +### 做得好的地方 + +- 模块分层方向清楚:Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。 +- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。 +- Provider 抽象简单直接,OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。 +- Storage 集中初始化 schema,便于部署单二进制应用。 +- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。 + +### 主要架构债务 + +- SessionManager 承担过多职责:会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。 +- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。 +- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。 +- 后台任务生命周期分散:gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn,缺少统一管理。 + +## 模块级分析 + +### gateway + +`GatewayState::new()` 是清晰的装配中心:配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。 + +### channels + +Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel,核心问题是会话身份混用和广播投递。 + +### bus + +MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。 + +### session + +这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成: + +- `manager.rs`:SessionManager 状态和 dialog 生命周期 +- `worker.rs`:per-session agent worker 和 cancellation +- `commands.rs`:slash command 执行 +- `outbound.rs`:OutboundMessenger 实现 +- `restore.rs`:storage 恢复与 tool call chain repair + +拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。 + +### agent + +AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义:`read_only()` 目前是工具自己声明,副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard,不应替代工具层的资源限制。 + +### providers + +Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response,可能包含用户隐私、API 返回内容和文件内容。 + +### tools + +工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里,默认值容易失控。 + +### storage + +Storage schema 初始化实用,但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”,适合早期迭代,不适合长期演进。建议引入 schema version 表或 sqlx migrations,至少把每次迁移记录下来。 + +### skills + +Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。 + +## 建议修复路线 + +### P0:先修会话正确性 + +1. 修正 CLI `chat_id/current_session_id` 数据模型。 +2. 修正 CLI 出站按 client/chat_id 投递。 +3. 实现 `get_current_dialog()`、`list_dialogs()` current 返回。 +4. 修正 `/delete`、`clear_history`、`archive` 的真实行为或从协议移除。 +5. 增加 WebSocket session lifecycle 测试。 + +### P1:收紧工具和资源边界 + +1. 文件工具默认限制 workspace,路径 canonicalize。 +2. bash 超时杀进程,必要时引入进程组。 +3. file_read 增加文件大小上限和二进制输出上限。 +4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。 +5. 明确高危工具的配置开关。 + +### P2:降低架构复杂度 + +1. 拆分 `session.rs`、`feishu.rs`、`storage/mod.rs`、`browser.rs`。 +2. 引入任务 supervisor 和统一 shutdown token。 +3. 引入正式数据库迁移。 +4. 增加工具注册快照测试,避免死代码和文档漂移。 + +## 建议测试补充 + +- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。 +- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。 +- Load/switch/list/delete/clear 的完整 WebSocket 流程。 +- `/delete` 后旧 session 软删除、新 session 成为 current。 +- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。 +- bash timeout 后检查子进程不存在。 +- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。 +- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。 diff --git a/src/channels/cli_chat.rs b/src/channels/cli_chat.rs index 7f30ae6..6e88027 100644 --- a/src/channels/cli_chat.rs +++ b/src/channels/cli_chat.rs @@ -1,10 +1,10 @@ -use std::sync::Arc; use async_trait::async_trait; -use tokio::sync::{mpsc, Mutex}; +use std::sync::Arc; +use tokio::sync::{Mutex, mpsc}; use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage}; +use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound}; use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId}; -use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo}; use super::base::{Channel, ChannelError}; @@ -14,6 +14,7 @@ use super::base::{Channel, ChannelError}; pub(crate) struct Client { sender: mpsc::Sender, + chat_id: String, current_session_id: Mutex>, } @@ -41,23 +42,28 @@ impl CliChatChannel { } /// Register a new client connection, returns (session_id, client) - pub(crate) async fn register_client(&self, sender: mpsc::Sender) -> (String, Arc) { - // Generate connection ID (used as chat_id) - use short ID - let connection_id = crate::util::short_id(); + pub(crate) async fn register_client( + &self, + sender: mpsc::Sender, + ) -> (String, Arc) { + // Each WebSocket connection gets a stable chat scope. All user input and + // dialog controls for this client stay inside that scope unless the + // protocol explicitly carries a full session id. + let chat_id = crate::util::short_id(); let client = Arc::new(Client { sender, + chat_id: chat_id.clone(), current_session_id: Mutex::new(None), }); self.clients.lock().await.push(client.clone()); // Create initial session via control message - let session_id = match self.create_session_via_control(&connection_id, None).await { - Ok(id) => id, + let session_id = match self.create_session_via_control(&chat_id, None).await { + Ok((id, _title)) => id, Err(e) => { tracing::error!(error = %e, "Failed to create initial session"); - // Fall back to old format for backward compatibility - connection_id.clone() + UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string() } }; @@ -73,21 +79,19 @@ impl CliChatChannel { /// Handle an inbound message from a client pub(crate) async fn handle_inbound(&self, client: Arc, raw_msg: &str) { match parse_inbound(raw_msg) { - Ok(inbound) => { - match self.handle_ws_inbound(client.clone(), inbound).await { - Ok(()) => {} - Err(e) => { - tracing::warn!(error = %e, "Failed to handle inbound message"); - let _ = client - .sender - .send(WsOutbound::Error { - code: "INTERNAL_ERROR".to_string(), - message: e.to_string(), - }) - .await; - } + Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await { + Ok(()) => {} + Err(e) => { + tracing::warn!(error = %e, "Failed to handle inbound message"); + let _ = client + .sender + .send(WsOutbound::Error { + code: "INTERNAL_ERROR".to_string(), + message: e.to_string(), + }) + .await; } - } + }, Err(e) => { tracing::warn!(error = %e, "Failed to parse inbound message"); let _ = client @@ -101,22 +105,30 @@ impl CliChatChannel { } } - async fn handle_ws_inbound(&self, client: Arc, inbound: WsInbound) -> Result<(), ChannelError> { + async fn handle_ws_inbound( + &self, + client: Arc, + inbound: WsInbound, + ) -> Result<(), ChannelError> { let bus = { let guard = self.bus.lock().unwrap(); - guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? + guard + .clone() + .ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? }; let mut current_session_guard = client.current_session_id.lock().await; match inbound { - WsInbound::UserInput { content, chat_id, .. } => { + WsInbound::UserInput { + content, chat_id, .. + } => { // All messages (including slash commands) go through the normal inbound flow // SessionManager handles session creation/reuse internally let msg = InboundMessage { channel: self.name().to_string(), sender_id: "cli".to_string(), - chat_id: chat_id.unwrap_or_else(crate::util::short_id), + chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()), content, timestamp: crate::bus::message::current_timestamp(), media: Vec::new(), @@ -125,19 +137,56 @@ impl CliChatChannel { }; bus.publish_inbound(msg).await?; } - WsInbound::ClearHistory { chat_id, session_id } => { - let target = session_id - .or(chat_id) - .or(current_session_guard.clone()) - .ok_or_else(|| ChannelError::Other("No active session".to_string()))?; - + WsInbound::ClearHistory { + chat_id, + session_id, + } => { let (reply_tx, mut reply_rx) = mpsc::channel(1); - let session_id = UnifiedSessionId::parse(&target) - .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; + let session_id = if let Some(session_id) = session_id { + UnifiedSessionId::parse(&session_id).ok_or_else(|| { + ChannelError::Other("Invalid session ID format".to_string()) + })? + } else if let Some(chat_id) = chat_id { + let (current_tx, mut current_rx) = mpsc::channel(1); + bus.publish_control(ControlMessage { + op: SessionCommand::GetCurrentDialog { + channel: "cli_chat".to_string(), + chat_id, + }, + reply_tx: current_tx, + }) + .await?; + match current_rx.recv().await { + Some(Ok(SessionEvent::CurrentDialog { + session_id: Some(session_id), + })) => session_id, + Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => { + return Err(ChannelError::Other("No active session".to_string())); + } + Some(Ok(_)) => { + return Err(ChannelError::Other( + "Unexpected response type".to_string(), + )); + } + Some(Err(e)) => return Err(e), + None => { + return Err(ChannelError::Other("Control channel closed".to_string())); + } + } + } else { + let target = current_session_guard + .clone() + .ok_or_else(|| ChannelError::Other("No active session".to_string()))?; + UnifiedSessionId::parse(&target).ok_or_else(|| { + ChannelError::Other("Invalid session ID format".to_string()) + })? + }; + let target = session_id.to_string(); bus.publish_control(ControlMessage { op: SessionCommand::ClearHistory { session_id }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { Some(Ok(SessionEvent::HistoryCleared { .. })) => { @@ -158,24 +207,21 @@ impl CliChatChannel { } } WsInbound::CreateSession { title } => { - // Use current session's chat_id if available, otherwise generate new one - let chat_id = current_session_guard.clone() - .unwrap_or_else(crate::util::short_id); - let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?; + let (new_id, created_title) = self + .create_session_via_control(&client.chat_id, title.as_deref()) + .await?; *current_session_guard = Some(new_id.clone()); let _ = client .sender .send(WsOutbound::SessionCreated { session_id: new_id, - title: title.unwrap_or_default(), + title: created_title, }) .await; } WsInbound::ListSessions { include_archived } => { // List dialogs for the current chat - let chat_id = current_session_guard.clone() - .unwrap_or_else(|| "".to_string()); - let chat_id_for_response = chat_id.clone(); + let chat_id = client.chat_id.clone(); let (reply_tx, mut reply_rx) = mpsc::channel(1); bus.publish_control(ControlMessage { op: SessionCommand::ListDialogs { @@ -184,13 +230,18 @@ impl CliChatChannel { include_archived, }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { - Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => { + Some(Ok(SessionEvent::DialogList { + dialogs, + current_dialog_id, + })) => { // Convert DialogInfo to SessionSummary for backward compatibility - let sessions: Vec = dialogs.into_iter().map(|d| { - crate::protocol::SessionSummary { + let sessions: Vec = dialogs + .into_iter() + .map(|d| crate::protocol::SessionSummary { session_id: d.session_id.to_string(), title: d.title, channel_name: d.session_id.channel.clone(), @@ -198,11 +249,14 @@ impl CliChatChannel { message_count: d.message_count, last_active_at: d.last_active_at, archived_at: d.archived_at, - } - }).collect(); + }) + .collect(); let current_session_id = current_dialog_id.map(|did| { - UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string() + UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string() }); + if let Some(ref session_id) = current_session_id { + *current_session_guard = Some(session_id.clone()); + } let _ = client .sender .send(WsOutbound::SessionList { @@ -223,39 +277,35 @@ impl CliChatChannel { } } WsInbound::LoadSession { session_id } => { - // LoadSession: parse the session_id and get current dialog info let (reply_tx, mut reply_rx) = mpsc::channel(1); let unified_id = UnifiedSessionId::parse(&session_id) .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; + if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id { + return Err(ChannelError::Other( + "Session does not belong to this client".to_string(), + )); + } bus.publish_control(ControlMessage { - op: SessionCommand::GetCurrentDialog { + op: SessionCommand::SwitchDialog { channel: unified_id.channel.clone(), chat_id: unified_id.chat_id.clone(), + dialog_id: unified_id.dialog_id.clone(), }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { - Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => { - if let Some(current_session_id) = current_session_id_opt { - *current_session_guard = Some(current_session_id.to_string()); - let _ = client - .sender - .send(WsOutbound::SessionLoaded { - session_id: current_session_id.to_string(), - title: "Session".to_string(), // TODO: get actual title - message_count: 0, // TODO: get actual count - }) - .await; - } else { - let _ = client - .sender - .send(WsOutbound::Error { - code: "NO_CURRENT_DIALOG".to_string(), - message: "No current dialog".to_string(), - }) - .await; - } + Some(Ok(SessionEvent::DialogSwitched { session_id })) => { + *current_session_guard = Some(session_id.to_string()); + let _ = client + .sender + .send(WsOutbound::SessionLoaded { + session_id: session_id.to_string(), + title: "Session".to_string(), + message_count: 0, + }) + .await; } Some(Ok(_)) => { // Unexpected response type @@ -275,23 +325,30 @@ impl CliChatChannel { } } WsInbound::RenameSession { session_id, title } => { - let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { - ChannelError::Other("No active session".to_string()) - })?; + let target = session_id + .or(current_session_guard.clone()) + .ok_or_else(|| ChannelError::Other("No active session".to_string()))?; let (reply_tx, mut reply_rx) = mpsc::channel(1); let unified_id = UnifiedSessionId::parse(&target) .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() }, + op: SessionCommand::RenameDialog { + session_id: unified_id, + title: title.clone(), + }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => { let _ = client .sender - .send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title }) + .send(WsOutbound::SessionRenamed { + session_id: session_id.to_string(), + title, + }) .await; } Some(Ok(_)) => { @@ -306,24 +363,43 @@ impl CliChatChannel { } } WsInbound::ArchiveSession { session_id } => { - let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { - ChannelError::Other("No active session".to_string()) - })?; + let target = session_id + .or(current_session_guard.clone()) + .ok_or_else(|| ChannelError::Other("No active session".to_string()))?; + let was_current = current_session_guard.as_deref() == Some(&target); let (reply_tx, mut reply_rx) = mpsc::channel(1); let unified_id = UnifiedSessionId::parse(&target) .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: SessionCommand::ArchiveDialog { session_id: unified_id }, + op: SessionCommand::ArchiveDialog { + session_id: unified_id, + }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { Some(Ok(SessionEvent::DialogArchived { session_id })) => { let _ = client .sender - .send(WsOutbound::SessionArchived { session_id: session_id.to_string() }) + .send(WsOutbound::SessionArchived { + session_id: session_id.to_string(), + }) .await; + if was_current { + let (new_id, title) = self + .create_session_via_control(&client.chat_id, None) + .await?; + *current_session_guard = Some(new_id.clone()); + let _ = client + .sender + .send(WsOutbound::SessionCreated { + session_id: new_id, + title, + }) + .await; + } } Some(Ok(_)) => { // Unexpected response type @@ -337,35 +413,42 @@ impl CliChatChannel { } } WsInbound::DeleteSession { session_id } => { - let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { - ChannelError::Other("No active session".to_string()) - })?; + let target = session_id + .or(current_session_guard.clone()) + .ok_or_else(|| ChannelError::Other("No active session".to_string()))?; let (reply_tx, mut reply_rx) = mpsc::channel(1); let unified_id = UnifiedSessionId::parse(&target) .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: SessionCommand::DeleteDialog { session_id: unified_id }, + op: SessionCommand::DeleteDialog { + session_id: unified_id, + }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { Some(Ok(SessionEvent::DialogDeleted { session_id })) => { let _ = client .sender - .send(WsOutbound::SessionDeleted { session_id: session_id.to_string() }) + .send(WsOutbound::SessionDeleted { + session_id: session_id.to_string(), + }) .await; // If deleting current session, create a new one if current_session_guard.as_deref() == Some(&target) { drop(reply_rx); - if let Ok(new_id) = self.create_session_via_control(&target, None).await { + if let Ok((new_id, title)) = + self.create_session_via_control(&client.chat_id, None).await + { *current_session_guard = Some(new_id.clone()); let _ = client .sender .send(WsOutbound::SessionCreated { session_id: new_id, - title: String::new(), + title, }) .await; } @@ -388,32 +471,45 @@ impl CliChatChannel { bus.publish_control(ControlMessage { op: SessionCommand::GetSlashCommands { channel: "cli_chat".to_string(), - chat_id: "".to_string(), + chat_id: client.chat_id.clone(), }, reply_tx, - }).await?; + }) + .await?; if let Some(result) = reply_rx.recv().await { match result { Ok(SessionEvent::SlashCommandsList { commands }) => { // Convert to SlashCommand to SlashCommandInfo - let command_infos: Vec = commands.into_iter().map(|cmd| { - SlashCommandInfo { + let command_infos: Vec = commands + .into_iter() + .map(|cmd| SlashCommandInfo { name: cmd.name.to_string(), description: cmd.description.to_string(), aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(), - } - }).collect(); - let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await; + }) + .collect(); + let _ = client + .sender + .send(WsOutbound::SlashCommandsList { + commands: command_infos, + }) + .await; } Ok(SessionEvent::Error { code, message }) => { - let _ = client.sender.send(WsOutbound::Error { code, message }).await; + let _ = client + .sender + .send(WsOutbound::Error { code, message }) + .await; } Err(e) => { - let _ = client.sender.send(WsOutbound::Error { - code: "GET_COMMANDS_ERROR".to_string(), - message: e.to_string() - }).await; + let _ = client + .sender + .send(WsOutbound::Error { + code: "GET_COMMANDS_ERROR".to_string(), + message: e.to_string(), + }) + .await; } _ => {} } @@ -427,29 +523,34 @@ impl CliChatChannel { } /// Create a session via control message and return the session_id - async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result { + async fn create_session_via_control( + &self, + chat_id: &str, + title: Option<&str>, + ) -> Result<(String, String), ChannelError> { let bus = { let guard = self.bus.lock().unwrap(); - guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? + guard + .clone() + .ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? }; let (reply_tx, mut reply_rx) = mpsc::channel(1); bus.publish_control(ControlMessage { op: SessionCommand::CreateDialog { channel: "cli_chat".to_string(), - chat_id: connection_id.to_string(), + chat_id: chat_id.to_string(), title: title.map(String::from), }, reply_tx, - }).await?; + }) + .await?; match reply_rx.recv().await { - Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => { - Ok(session_id.to_string()) - } - Some(Ok(_)) => { - Err(ChannelError::Other("Unexpected response type".to_string())) + Some(Ok(SessionEvent::DialogCreated { session_id, title })) => { + Ok((session_id.to_string(), title)) } + Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())), Some(Err(e)) => Err(e), None => Err(ChannelError::Other("Control channel closed".to_string())), } @@ -479,7 +580,11 @@ impl Channel for CliChatChannel { async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { let clients = self.clients.lock().await.clone(); for client in clients { - let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") { + if client.chat_id != msg.chat_id { + continue; + } + let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") + { WsOutbound::SystemNotification { content: msg.content.clone(), } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index a641827..65ad775 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,19 +1,19 @@ pub mod http; pub mod ws; +use axum::{Router, routing}; use std::sync::Arc; -use axum::{routing, Router}; use tokio::net::TcpListener; use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher}; -use crate::channels::{ChannelManager, CliChatChannel}; use crate::channels::base::{Channel, ChannelError}; -use crate::config::{Config, expand_path, ensure_workspace_dir}; +use crate::channels::{ChannelManager, CliChatChannel}; +use crate::config::{Config, ensure_workspace_dir, expand_path}; use crate::logging; use crate::mcp; use crate::memory::MemoryManager; -use crate::session::SessionManager; use crate::scheduler::Scheduler; +use crate::session::SessionManager; pub struct GatewayState { pub config: Config, @@ -32,8 +32,13 @@ impl GatewayState { let workspace_path = ensure_workspace_dir(&workspace_path)?; // Switch current working directory to workspace - std::env::set_current_dir(&workspace_path) - .map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?; + std::env::set_current_dir(&workspace_path).map_err(|e| { + format!( + "Failed to switch to workspace directory {}: {}", + workspace_path.display(), + e + ) + })?; tracing::info!("Using workspace directory: {}", workspace_path.display()); @@ -52,8 +57,9 @@ impl GatewayState { workspace_path.join("picobot.db") }; let storage = Arc::new( - crate::storage::Storage::new(&db_path).await - .map_err(|e| format!("failed to initialize session storage: {}", e))? + crate::storage::Storage::new(&db_path) + .await + .map_err(|e| format!("failed to initialize session storage: {}", e))?, ); tracing::info!("Session storage: {}", db_path.display()); @@ -98,7 +104,9 @@ impl GatewayState { // Create ChannelManager and init channels let cli_chat_channel = Arc::new(CliChatChannel::new()); let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus); - channel_manager.init(&config, workspace_path.clone()).await + channel_manager + .init(&config, workspace_path.clone()) + .await .map_err(|e| format!("Failed to init channels: {}", e))?; // Register send_message tool with available channel names @@ -107,9 +115,12 @@ impl GatewayState { session_manager.register_outbound_tool(available_channels); // Register chat_manager tool - session_manager.tools().register( - crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()), - ); + session_manager + .tools() + .register(crate::tools::ChatManagerTool::new( + storage.clone(), + valid_channels.clone(), + )); // Initialize MCP servers — connect and register discovered tools if !config.mcp.servers.is_empty() { @@ -130,24 +141,27 @@ impl GatewayState { let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default(); if scheduler_config.enabled { // Register cron tools - session_manager.tools().register( - crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels), - ); - session_manager.tools().register( - crate::tools::cron::CronListTool::new(storage.clone()), - ); - session_manager.tools().register( - crate::tools::cron::CronRemoveTool::new(storage.clone()), - ); - session_manager.tools().register( - crate::tools::cron::CronEnableTool::new(storage.clone()), - ); - session_manager.tools().register( - crate::tools::cron::CronDisableTool::new(storage.clone()), - ); - session_manager.tools().register( - crate::tools::cron::CronUpdateTool::new(storage.clone()), - ); + session_manager + .tools() + .register(crate::tools::cron::CronAddTool::new( + storage.clone(), + valid_channels, + )); + session_manager + .tools() + .register(crate::tools::cron::CronListTool::new(storage.clone())); + session_manager + .tools() + .register(crate::tools::cron::CronRemoveTool::new(storage.clone())); + session_manager + .tools() + .register(crate::tools::cron::CronEnableTool::new(storage.clone())); + session_manager + .tools() + .register(crate::tools::cron::CronDisableTool::new(storage.clone())); + session_manager + .tools() + .register(crate::tools::cron::CronUpdateTool::new(storage.clone())); tracing::info!("Cron tools registered"); } @@ -268,71 +282,103 @@ impl GatewayState { } /// Handle control messages (session management operations) - async fn handle_control_message( - session_manager: &SessionManager, - msg: ControlMessage, - ) { + async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) { use crate::session::{SessionCommand::*, SessionEvent}; let reply_tx = msg.reply_tx; let result: Result = match msg.op { - CreateDialog { channel, chat_id, title } => { - session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await - .map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - ListDialogs { channel, chat_id, include_archived } => { - session_manager.list_dialogs(&channel, &chat_id, include_archived).await - .map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - GetCurrentDialog { channel, chat_id } => { - session_manager.get_current_dialog(&channel, &chat_id).await - .map(|session_id| SessionEvent::CurrentDialog { session_id }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - SwitchDialog { channel, chat_id, dialog_id } => { - session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await - .map(|session_id| SessionEvent::DialogSwitched { session_id }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - RenameDialog { session_id, title } => { - session_manager.rename_dialog(&session_id, &title).await - .map(|()| SessionEvent::DialogRenamed { session_id, title }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - ArchiveDialog { session_id } => { - session_manager.archive_dialog(&session_id) - .map(|()| SessionEvent::DialogArchived { session_id }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - DeleteDialog { session_id } => { - session_manager.delete_dialog(&session_id).await - .map(|()| SessionEvent::DialogDeleted { session_id }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - ClearHistory { session_id } => { - session_manager.clear_dialog_history(&session_id) - .map(|()| SessionEvent::HistoryCleared { session_id }) - .map_err(|e| ChannelError::Other(e.to_string())) - } - GetSlashCommands { channel: _, chat_id: _ } => { + CreateDialog { + channel, + chat_id, + title, + } => session_manager + .create_dialog(&channel, &chat_id, title.as_deref()) + .await + .map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title }) + .map_err(|e| ChannelError::Other(e.to_string())), + ListDialogs { + channel, + chat_id, + include_archived, + } => session_manager + .list_dialogs(&channel, &chat_id, include_archived) + .await + .map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { + dialogs, + current_dialog_id, + }) + .map_err(|e| ChannelError::Other(e.to_string())), + GetCurrentDialog { channel, chat_id } => session_manager + .get_current_dialog(&channel, &chat_id) + .await + .map(|session_id| SessionEvent::CurrentDialog { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())), + SwitchDialog { + channel, + chat_id, + dialog_id, + } => session_manager + .switch_dialog(&channel, &chat_id, &dialog_id) + .await + .map(|session_id| SessionEvent::DialogSwitched { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())), + RenameDialog { session_id, title } => session_manager + .rename_dialog(&session_id, &title) + .await + .map(|()| SessionEvent::DialogRenamed { session_id, title }) + .map_err(|e| ChannelError::Other(e.to_string())), + ArchiveDialog { session_id } => session_manager + .archive_dialog(&session_id) + .await + .map(|()| SessionEvent::DialogArchived { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())), + DeleteDialog { session_id } => session_manager + .delete_dialog(&session_id) + .await + .map(|()| SessionEvent::DialogDeleted { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())), + ClearHistory { session_id } => session_manager + .clear_dialog_history(&session_id) + .await + .map(|()| SessionEvent::HistoryCleared { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())), + GetSlashCommands { + channel: _, + chat_id: _, + } => { let commands = session_manager.get_slash_commands().to_vec(); Ok(SessionEvent::SlashCommandsList { commands }) } - ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => { - session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref()) - .await - .map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg }) - .map_err(|e| ChannelError::Other(e.to_string())) - } + ExecuteSlashCommand { + command, + args, + channel, + chat_id, + current_session_id, + } => session_manager + .execute_slash_command( + &command, + args.as_deref(), + &channel, + &chat_id, + current_session_id.as_ref(), + ) + .await + .map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { + new_session_id: new_id, + message: msg, + }) + .map_err(|e| ChannelError::Other(e.to_string())), }; let _ = reply_tx.send(result).await; } } -pub async fn run(host: Option, port: Option) -> Result<(), Box> { +pub async fn run( + host: Option, + port: Option, +) -> Result<(), Box> { // Initialize logging logging::init_logging(); tracing::info!("Starting PicoBot Gateway"); diff --git a/src/session/session.rs b/src/session/session.rs index 3b863ec..5d90819 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -21,12 +21,12 @@ pub enum HandleResult { /// Agent processing spawned in background; response will be sent via bus AgentProcessing, } -use crate::channels::slash_command::parse_slash_command; -use crate::config::LLMProviderConfig; -use crate::config::BrowserConfig; -use crate::agent::{AgentLoop, AgentError, ContextCompressor}; -use crate::agent::system_prompt::build_system_prompt; use crate::agent::context_compressor::ContextCompressionConfig; +use crate::agent::system_prompt::build_system_prompt; +use crate::agent::{AgentError, AgentLoop, ContextCompressor}; +use crate::channels::slash_command::parse_slash_command; +use crate::config::BrowserConfig; +use crate::config::LLMProviderConfig; /// Check if an LLM error message indicates a context window overflow. fn is_context_overflow_error(msg: &str) -> bool { @@ -39,14 +39,14 @@ fn is_context_overflow_error(msg: &str) -> bool { || lower.contains("prompt is too long") || lower.contains("input is too long") } -use crate::providers::{create_provider, LLMProvider}; -use crate::session::session_id::UnifiedSessionId; -use crate::session::events::DialogInfo; -use crate::skills::SkillsLoader; -use crate::tools::{ToolRegistry, create_default_tools}; use crate::bus::MessageBus; +use crate::providers::{LLMProvider, create_provider}; +use crate::session::events::DialogInfo; +use crate::session::session_id::UnifiedSessionId; +use crate::skills::SkillsLoader; use crate::tools::OutboundMessenger; use crate::tools::SendMessageTool; +use crate::tools::{ToolRegistry, create_default_tools}; /// Session = 一个 dialog /// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history @@ -68,6 +68,7 @@ pub struct Session { storage: Option>, routing_info: String, + archived_at: Option, /// Timestamp (Unix ms) of the last consolidation. /// Messages before this time have been compressed into memory. pub last_consolidated_at: Option, @@ -113,7 +114,12 @@ impl Session { ..Default::default() }; - let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone()); + let mut compressor = ContextCompressor::with_config( + provider.clone(), + provider_config.token_limit, + compressor_config, + memory_manager.clone(), + ); compressor.set_session_id(Some(id.to_string())); let now = chrono::Utc::now().timestamp_millis(); @@ -133,6 +139,7 @@ impl Session { compressor, storage, routing_info, + archived_at: None, last_consolidated_at: None, last_compressed_message_at: None, memory_manager, @@ -150,8 +157,9 @@ impl Session { storage: StdArc, memory_manager: Arc, ) -> Result { - let session_meta = storage.get_session(&id.to_string()).await - .map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?; + let session_meta = storage.get_session(&id.to_string()).await.map_err(|e| { + AgentError::Other(format!("failed to load session from storage: {}", e)) + })?; let mut provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; @@ -163,7 +171,12 @@ impl Session { ..Default::default() }; - let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone()); + let mut compressor = ContextCompressor::with_config( + provider.clone(), + provider_config.token_limit, + compressor_config, + memory_manager.clone(), + ); compressor.set_session_id(Some(id.to_string())); let mut chat_messages: Vec = Vec::new(); @@ -184,14 +197,15 @@ impl Session { if has_more_timelines { chat_messages.push(ChatMessage::user( "[Earlier conversation summaries exist. \ - Use `timeline_recall` to search if needed.]" + Use `timeline_recall` to search if needed.]", )); } // Insert latest 3 timelines as context (reversed: oldest first) for tl in timelines.iter().take(3).rev() { chat_messages.push(ChatMessage::user(format!( - "[Previous Context]\n{}", tl.content + "[Previous Context]\n{}", + tl.content ))); } @@ -204,53 +218,73 @@ impl Session { Vec::new() }); - let mut tail_msgs: Vec = tail.into_iter().map(|m| { - ChatMessage { + let mut tail_msgs: Vec = tail + .into_iter() + .map(|m| ChatMessage { id: m.id, role: m.role, content: m.content, reasoning_content: m.reasoning_content, - media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(), + media_refs: m + .media_refs + .map(|refs| serde_json::from_str(&refs).unwrap_or_default()) + .unwrap_or_default(), timestamp: m.created_at, tool_call_id: m.tool_call_id, tool_name: m.tool_name, - tool_calls: m.tool_calls - .and_then(|tc| serde_json::from_str::>(&tc).ok()) + tool_calls: m + .tool_calls + .and_then(|tc| { + serde_json::from_str::>(&tc).ok() + }) .filter(|v| !v.is_empty()), source: m.source.and_then(|s| serde_json::from_str(&s).ok()), - } - }).collect(); + }) + .collect(); repair_tool_call_chains(&mut tail_msgs); chat_messages.extend(tail_msgs); } else { // No prior compression — load all messages - let messages = storage.load_messages(&id.to_string(), 0).await - .map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?; + let messages = storage + .load_messages(&id.to_string(), 0) + .await + .map_err(|e| { + AgentError::Other(format!("failed to load messages from storage: {}", e)) + })?; - chat_messages = messages.into_iter().map(|m| { - ChatMessage { + chat_messages = messages + .into_iter() + .map(|m| ChatMessage { id: m.id, role: m.role, content: m.content, reasoning_content: m.reasoning_content, - media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(), + media_refs: m + .media_refs + .map(|refs| serde_json::from_str(&refs).unwrap_or_default()) + .unwrap_or_default(), timestamp: m.created_at, tool_call_id: m.tool_call_id, tool_name: m.tool_name, - tool_calls: m.tool_calls - .and_then(|tc| serde_json::from_str::>(&tc).ok()) + tool_calls: m + .tool_calls + .and_then(|tc| { + serde_json::from_str::>(&tc).ok() + }) .filter(|v| !v.is_empty()), source: m.source.and_then(|s| serde_json::from_str(&s).ok()), - } - }).collect(); + }) + .collect(); repair_tool_call_chains(&mut chat_messages); } // Compress loaded history if it exceeds budget if !chat_messages.is_empty() { - let result = compressor.compress_if_needed(chat_messages).await + let result = compressor + .compress_if_needed(chat_messages) + .await .map_err(|e| AgentError::Other(format!("compression during restore: {}", e)))?; if result.created_timelines { restored_compressed_at = Some(chrono::Utc::now().timestamp_millis()); @@ -281,6 +315,7 @@ impl Session { compressor, storage: Some(storage), routing_info: session_meta.routing_info.unwrap_or_default(), + archived_at: session_meta.archived_at, last_consolidated_at: session_meta.last_consolidated_at, last_compressed_message_at: restored_compressed_at, memory_manager, @@ -297,7 +332,11 @@ impl Session { /// 添加消息到历史并持久化到 Storage /// 如果 `persist` 为 false,只更新内存(用于 compaction 场景) - pub async fn add_message(&mut self, message: ChatMessage, persist: bool) -> Result<(), StorageError> { + pub async fn add_message( + &mut self, + message: ChatMessage, + persist: bool, + ) -> Result<(), StorageError> { let is_user = message.role == "user"; let now = chrono::Utc::now().timestamp_millis(); @@ -306,28 +345,35 @@ impl Session { self.seq_counter += 1; // Persist to Storage - if persist - && let Some(ref storage) = self.storage { - let msg_meta = crate::storage::message::MessageMeta { - id: message.id.clone(), - session_id: self.id.to_string(), - seq, - role: message.role.clone(), - content: message.content.clone(), - reasoning_content: message.reasoning_content.clone(), - media_refs: if message.media_refs.is_empty() { - None - } else { - Some(serde_json::to_string(&message.media_refs).unwrap_or_default()) - }, - tool_call_id: message.tool_call_id.clone(), - tool_name: message.tool_name.clone(), - tool_calls: message.tool_calls.as_ref().and_then(|tc| serde_json::to_string(tc).ok()), - source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()), - created_at: now, - }; - storage.append_message_with_retry(&self.id.to_string(), &msg_meta).await?; - } + if persist && let Some(ref storage) = self.storage { + let msg_meta = crate::storage::message::MessageMeta { + id: message.id.clone(), + session_id: self.id.to_string(), + seq, + role: message.role.clone(), + content: message.content.clone(), + reasoning_content: message.reasoning_content.clone(), + media_refs: if message.media_refs.is_empty() { + None + } else { + Some(serde_json::to_string(&message.media_refs).unwrap_or_default()) + }, + tool_call_id: message.tool_call_id.clone(), + tool_name: message.tool_name.clone(), + tool_calls: message + .tool_calls + .as_ref() + .and_then(|tc| serde_json::to_string(tc).ok()), + source: message + .source + .as_ref() + .map(|s| serde_json::to_string(s).unwrap_or_default()), + created_at: now, + }; + storage + .append_message_with_retry(&self.id.to_string(), &msg_meta) + .await?; + } // Update in-memory state self.messages.push(message); @@ -413,6 +459,7 @@ impl Session { } else { Some(self.routing_info.clone()) }, + archived_at: self.archived_at, deleted_at: None, last_consolidated_at: self.last_consolidated_at, last_compressed_message_at: self.last_compressed_message_at, @@ -438,7 +485,8 @@ impl Session { 历史: {}"#, - self.messages.iter() + self.messages + .iter() .filter(|m| m.role == "user" || m.role == "assistant") .take(20) .map(|m| format!("[{}]: {}", m.role, m.content)) @@ -463,15 +511,16 @@ impl Session { use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message}; let request = ChatCompletionRequest { - messages: vec![ - Message::user(prompt.to_string()) - ], + messages: vec![Message::user(prompt.to_string())], temperature: Some(0.3), max_tokens: Some(20), tools: None, }; - let response: ChatCompletionResponse = self.provider.chat(request).await + let response: ChatCompletionResponse = self + .provider + .chat(request) + .await .map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?; Ok(response.content.trim().to_string()) @@ -501,7 +550,8 @@ impl Session { self.provider_config.model_id.clone(), self.provider_config.workspace_dir.clone(), self.provider_config.input_types.clone(), - ).with_context_window(self.provider_config.token_limit)) + ) + .with_context_window(self.provider_config.token_limit)) } /// 创建一个附通知通道的 AgentLoop 实例 @@ -567,7 +617,10 @@ impl Session { md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id)); md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id)); md.push_str(&format!("- **Message Count**: {}\n", self.messages.len())); - md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id)); + md.push_str(&format!( + "- **Model**: `{}`\n", + self.provider_config.model_id + )); md.push_str(&format!("- **Exported At**: {}\n", now)); md.push_str("\n---\n\n"); @@ -584,7 +637,11 @@ impl Session { let timestamp = if msg.timestamp > 0 { DateTime::from_timestamp_millis(msg.timestamp) - .map(|dt| dt.with_timezone(&Local).format("%Y-%m-%d %H:%M:%S").to_string()) + .map(|dt| { + dt.with_timezone(&Local) + .format("%Y-%m-%d %H:%M:%S") + .to_string() + }) .unwrap_or_default() } else { String::new() @@ -632,7 +689,10 @@ impl Session { md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id)); md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id)); md.push_str(&format!("- **Message Count**: {}\n", self.messages.len())); - md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id)); + md.push_str(&format!( + "- **Model**: `{}`\n", + self.provider_config.model_id + )); md.push_str(&format!("- **Exported At**: {}\n", now)); md.push_str("\n---\n\n"); @@ -656,7 +716,11 @@ impl Session { let timestamp = if msg.timestamp > 0 { DateTime::from_timestamp_millis(msg.timestamp) - .map(|dt| dt.with_timezone(&Local).format("%Y-%m-%d %H:%M:%S").to_string()) + .map(|dt| { + dt.with_timezone(&Local) + .format("%Y-%m-%d %H:%M:%S") + .to_string() + }) .unwrap_or_default() } else { String::new() @@ -712,7 +776,8 @@ fn repair_tool_call_chains(messages: &mut Vec) { } // Collect expected tool call IDs - let expected_ids: std::collections::HashSet<&str> = calls.iter().map(|c| c.id.as_str()).collect(); + let expected_ids: std::collections::HashSet<&str> = + calls.iter().map(|c| c.id.as_str()).collect(); let expected_count = expected_ids.len(); // Check following messages for matching tool results (same tool_call_id) @@ -721,9 +786,10 @@ fn repair_tool_call_chains(messages: &mut Vec) { while j < messages.len() && found < expected_count { if messages[j].role == "tool" { if let Some(ref tc_id) = messages[j].tool_call_id - && expected_ids.contains(tc_id.as_str()) { - found += 1; - } + && expected_ids.contains(tc_id.as_str()) + { + found += 1; + } } else if messages[j].role == "user" || messages[j].role == "assistant" { // Next user/assistant message — stop scanning, chain is broken break; @@ -743,7 +809,11 @@ fn repair_tool_call_chains(messages: &mut Vec) { "{}\n\n[Tool calls ({}): {} — execution interrupted by gateway restart]", old_content, expected_count, - calls.iter().map(|c| c.name.as_str()).collect::>().join(", ") + calls + .iter() + .map(|c| c.name.as_str()) + .collect::>() + .join(", ") ); messages[i].tool_calls = None; } @@ -772,8 +842,6 @@ struct SessionManagerInner { current_sessions: HashMap, } - - /// 斜杠命令定义 #[derive(Debug, Clone)] pub struct SlashCommand { @@ -789,7 +857,9 @@ impl SlashCommand { /// 检查给定内容是否匹配此命令 pub fn matches(&self, content: &str) -> bool { let trimmed = content.trim(); - self.aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) + self.aliases + .iter() + .any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) } } @@ -869,7 +939,7 @@ impl SessionManager { let tools = Arc::new(create_default_tools( skills_loader.clone(), memory_manager.clone(), - None, // SubAgentManager created below + None, // SubAgentManager created below browser_config.as_ref(), )); @@ -889,11 +959,8 @@ impl SessionManager { let sm_bus = bus.clone(); tokio::spawn(async move { while let Some(notif) = notify_rx.recv().await { - let content = format_task_notification( - ¬if.task_id, - ¬if.status, - ¬if.result_summary, - ); + let content = + format_task_notification(¬if.task_id, ¬if.status, ¬if.result_summary); let outbound = OutboundMessage { channel: notif.channel, chat_id: notif.chat_id, @@ -943,7 +1010,8 @@ impl SessionManager { /// Register the send_message tool (requires self in Arc) pub fn register_outbound_tool(self: &Arc, available_channels: Vec) { let messenger: Arc = self.clone(); - self.tools.register(SendMessageTool::new(messenger, available_channels)); + self.tools + .register(SendMessageTool::new(messenger, available_channels)); } pub fn tools(&self) -> Arc { @@ -961,7 +1029,8 @@ impl SessionManager { self.provider_config.model_id.clone(), self.provider_config.workspace_dir.clone(), self.provider_config.input_types.clone(), - ).with_context_window(self.provider_config.token_limit)) + ) + .with_context_window(self.provider_config.token_limit)) } /// 获取所有可用的斜杠命令 @@ -989,11 +1058,18 @@ impl SessionManager { match cmd.name { "new" => { let title = args.map(|s| s.to_string()); - let (new_id, title) = self.create_session(channel, chat_id, title.as_deref(), String::new()).await?; + let (new_id, title) = self + .create_session(channel, chat_id, title.as_deref(), String::new()) + .await?; Ok((Some(new_id), format!("新对话 '{}' 已创建。", title))) } "delete" => { - let (new_id, _title) = self.create_session(channel, chat_id, None, String::new()).await?; + if let Some(sid) = current_session_id { + self.delete_dialog(sid).await?; + } + let (new_id, _title) = self + .create_session(channel, chat_id, None, String::new()) + .await?; Ok((Some(new_id), "对话已删除。新对话已创建。".to_string())) } "compact" => { @@ -1002,25 +1078,29 @@ impl SessionManager { let mut session_guard = session.lock().await; let original_count = session_guard.get_history().len(); let history = session_guard.get_history().to_vec(); - let result = session_guard.compressor - .compress_if_needed(history) - .await?; + let result = session_guard.compressor.compress_if_needed(history).await?; let compressed_count = result.history.len(); if result.created_timelines { - session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis()); + session_guard.last_compressed_message_at = + Some(chrono::Utc::now().timestamp_millis()); if let Err(e) = session_guard.persist_session_meta().await { tracing::warn!(error = %e, "Failed to persist compression marker after /compact"); } } session_guard.clear_history(); for msg in result.history { - session_guard.add_message(msg, false).await + session_guard + .add_message(msg, false) + .await .map_err(|e| AgentError::Other(format!("persist error: {}", e)))?; } - Ok((None, format!( - "Context compressed: {} → {} messages.", - original_count, compressed_count - ))) + Ok(( + None, + format!( + "Context compressed: {} → {} messages.", + original_count, compressed_count + ), + )) } else { Ok((None, "No active conversation to compress.".to_string())) } @@ -1034,26 +1114,49 @@ impl SessionManager { let session_id_str = session_guard.session_id(); let title = &session_guard.title; let model_name = &session_guard.provider_config.name; - let created_at = chrono::DateTime::from_timestamp_millis(session_guard.created_at) - .map(|dt| dt.with_timezone(&chrono::Local).format("%Y-%m-%d %H:%M:%S").to_string()) - .unwrap_or_default(); - let last_active_at = chrono::DateTime::from_timestamp_millis(session_guard.last_active_at) - .map(|dt| dt.with_timezone(&chrono::Local).format("%Y-%m-%d %H:%M:%S").to_string()) - .unwrap_or_default(); + let created_at = + chrono::DateTime::from_timestamp_millis(session_guard.created_at) + .map(|dt| { + dt.with_timezone(&chrono::Local) + .format("%Y-%m-%d %H:%M:%S") + .to_string() + }) + .unwrap_or_default(); + let last_active_at = + chrono::DateTime::from_timestamp_millis(session_guard.last_active_at) + .map(|dt| { + dt.with_timezone(&chrono::Local) + .format("%Y-%m-%d %H:%M:%S") + .to_string() + }) + .unwrap_or_default(); let token_info = session_guard.compressor.token_info(history); let cache_info = if token_info.cache_active { - format!("API精确: {} tokens", token_info.last_api_tokens.unwrap_or(0)) + format!( + "API精确: {} tokens", + token_info.last_api_tokens.unwrap_or(0) + ) } else { "无API精确缓存".to_string() }; let threshold_pct = if token_info.context_window > 0 { - (token_info.threshold as f64 / token_info.context_window as f64 * 100.0) as usize - } else { 0 }; + (token_info.threshold as f64 / token_info.context_window as f64 * 100.0) + as usize + } else { + 0 + }; let usage_pct = if token_info.context_window > 0 { - (token_info.estimated_tokens as f64 / token_info.context_window as f64 * 100.0).min(100.0) as usize - } else { 0 }; + (token_info.estimated_tokens as f64 / token_info.context_window as f64 + * 100.0) + .min(100.0) as usize + } else { + 0 + }; let usage_bar = if token_info.context_window > 0 { - format!("{}/{} tokens ({}%)", token_info.estimated_tokens, token_info.context_window, usage_pct) + format!( + "{}/{} tokens ({}%)", + token_info.estimated_tokens, token_info.context_window, usage_pct + ) } else { "未设置".to_string() }; @@ -1071,11 +1174,20 @@ impl SessionManager { compression_status, cache_info, ); - Ok((None, format!( - "对话标题: {}\nSession ID: {}\n模型: {}\n用户消息: {} / 总消息: {}\n创建时间: {}\n最后活跃: {}\n\n上下文: {}", - title, session_id_str, model_name, session_guard.message_count, message_count, - created_at, last_active_at, ctx_info, - ))) + Ok(( + None, + format!( + "对话标题: {}\nSession ID: {}\n模型: {}\n用户消息: {} / 总消息: {}\n创建时间: {}\n最后活跃: {}\n\n上下文: {}", + title, + session_id_str, + model_name, + session_guard.message_count, + message_count, + created_at, + last_active_at, + ctx_info, + ), + )) } else { Ok((None, "No active session.".to_string())) } @@ -1089,7 +1201,8 @@ impl SessionManager { let skills_prompt = self.skills_loader.build_skills_prompt(); let system_prompt = session_guard.build_system_prompt(&skills_prompt, None); - let filepath = session_guard.dump_to_file(&system_prompt) + let filepath = session_guard + .dump_to_file(&system_prompt) .map_err(|e| AgentError::Other(format!("Failed to save dump: {}", e)))?; Ok((None, format!("Session dump saved to: {}", filepath))) } else { @@ -1101,24 +1214,43 @@ impl SessionManager { if dialogs.is_empty() { Ok((None, "暂无对话记录。".to_string())) } else { - let lines: Vec = dialogs.iter().map(|d| { - let current = if current_session_id.map(|s| s.dialog_id == d.session_id.dialog_id).unwrap_or(false) { - " [当前]" - } else { - "" - }; - format!("- {} ({}){} — {}", d.session_id.dialog_id, d.title, current, chrono::DateTime::from_timestamp_millis(d.last_active_at).map(|dt| dt.with_timezone(&chrono::Local).format("%m-%d %H:%M").to_string()).unwrap_or_default()) - }).collect(); + let lines: Vec = dialogs + .iter() + .map(|d| { + let current = if current_session_id + .map(|s| s.dialog_id == d.session_id.dialog_id) + .unwrap_or(false) + { + " [当前]" + } else { + "" + }; + format!( + "- {} ({}){} — {}", + d.session_id.dialog_id, + d.title, + current, + chrono::DateTime::from_timestamp_millis(d.last_active_at) + .map(|dt| dt + .with_timezone(&chrono::Local) + .format("%m-%d %H:%M") + .to_string()) + .unwrap_or_default() + ) + }) + .collect(); Ok((None, format!("最近对话:\n{}", lines.join("\n")))) } } "switch" => { - let dialog_id = args.ok_or_else(|| AgentError::Other("Usage: /switch ".to_string()))?; + let dialog_id = args + .ok_or_else(|| AgentError::Other("Usage: /switch ".to_string()))?; let new_id = self.switch_dialog(channel, chat_id, dialog_id).await?; Ok((None, format!("已切换到对话:{}", new_id.dialog_id))) } "rename" => { - let title = args.ok_or_else(|| AgentError::Other("Usage: /rename <新标题>".to_string()))?; + let title = + args.ok_or_else(|| AgentError::Other("Usage: /rename <新标题>".to_string()))?; if let Some(sid) = current_session_id { self.rename_dialog(sid, title).await?; Ok((None, format!("对话已重命名为:{}", title))) @@ -1127,9 +1259,10 @@ impl SessionManager { } } "?" | "help" => { - let lines: Vec = SLASH_COMMANDS.iter().map(|c| { - format!(" {} - {}", c.aliases.join(", "), c.description) - }).collect(); + let lines: Vec = SLASH_COMMANDS + .iter() + .map(|c| format!(" {} - {}", c.aliases.join(", "), c.description)) + .collect(); Ok((None, format!("可用命令:\n{}", lines.join("\n")))) } "mcp" => { @@ -1137,27 +1270,34 @@ impl SessionManager { if servers.is_empty() { return Ok((None, "未配置 MCP 服务。".to_string())); } - let lines: Vec = servers.iter().map(|s| { - let status = if s.connected { - format!("✅ 已连接 ({})", s.transport) - } else { - format!("❌ 连接失败: {}", s.error.as_deref().unwrap_or("未知错误")) - }; - let tool_lines: Vec = s.tools.iter().map(|t| { - let desc = if t.description.is_empty() { - "无描述".to_string() + let lines: Vec = servers + .iter() + .map(|s| { + let status = if s.connected { + format!("✅ 已连接 ({})", s.transport) } else { - t.description.chars().take(60).collect::() + format!("❌ 连接失败: {}", s.error.as_deref().unwrap_or("未知错误")) }; - format!(" - {}: {}", t.name, desc) - }).collect(); - let tools_section = if tool_lines.is_empty() { - String::new() - } else { - format!("\n{}", tool_lines.join("\n")) - }; - format!("{} {}{}", s.name, status, tools_section) - }).collect(); + let tool_lines: Vec = s + .tools + .iter() + .map(|t| { + let desc = if t.description.is_empty() { + "无描述".to_string() + } else { + t.description.chars().take(60).collect::() + }; + format!(" - {}: {}", t.name, desc) + }) + .collect(); + let tools_section = if tool_lines.is_empty() { + String::new() + } else { + format!("\n{}", tool_lines.join("\n")) + }; + format!("{} {}{}", s.name, status, tools_section) + }) + .collect(); Ok((None, format!("MCP 服务:\n\n{}", lines.join("\n\n")))) } "stop" => { @@ -1174,7 +1314,9 @@ impl SessionManager { } guard.worker_generation = guard.worker_generation.wrapping_add(1); // Cancel all running background sub-agent tasks for this session - self.sub_agent_manager.cancel_by_session(&sid.to_string()).await; + self.sub_agent_manager + .cancel_by_session(&sid.to_string()) + .await; let resp = if msgs.is_empty() { "没有正在执行的任务或队列。".to_string() } else { @@ -1182,7 +1324,10 @@ impl SessionManager { }; Ok((None, resp)) } - _ => Err(AgentError::Other(format!("未知命令:/{}。输入 /? 获取帮助。", cmd.name))), + _ => Err(AgentError::Other(format!( + "未知命令:/{}。输入 /? 获取帮助。", + cmd.name + ))), } } @@ -1214,13 +1359,19 @@ impl SessionManager { created_at: now, last_active_at: now, message_count: 0, - routing_info: if routing_info.is_empty() { None } else { Some(routing_info.clone()) }, + routing_info: if routing_info.is_empty() { + None + } else { + Some(routing_info.clone()) + }, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, }; - self.storage.upsert_session(&meta).await - .map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?; + self.storage.upsert_session(&meta).await.map_err(|e| { + AgentError::Other(format!("failed to create session in storage: {}", e)) + })?; let session = Session::new( unified_id.clone(), @@ -1230,7 +1381,8 @@ impl SessionManager { routing_info, title.clone(), self.memory_manager.clone(), - ).await?; + ) + .await?; let arc = Arc::new(Mutex::new(session)); let inner = &mut *self.inner.lock().await; @@ -1242,7 +1394,10 @@ impl SessionManager { Ok((unified_id, title)) } - pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result>, AgentError> { + pub async fn get_or_create_session( + &self, + unified_id: &UnifiedSessionId, + ) -> Result>, AgentError> { let session_id_str = unified_id.to_string(); let inner = &mut *self.inner.lock().await; @@ -1260,7 +1415,8 @@ impl SessionManager { self.tools.clone(), self.storage.clone(), self.memory_manager.clone(), - ).await?; + ) + .await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(session_id_str.clone(), arc.clone()); @@ -1283,7 +1439,8 @@ impl SessionManager { String::new(), "新对话".to_string(), self.memory_manager.clone(), - ).await?; + ) + .await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(session_id_str.clone(), arc.clone()); @@ -1299,15 +1456,34 @@ impl SessionManager { chat_id: &str, title: Option<&str>, ) -> Result<(UnifiedSessionId, String), AgentError> { - self.create_session(channel, chat_id, title, String::new()).await + self.create_session(channel, chat_id, title, String::new()) + .await } pub async fn get_current_dialog( &self, - _channel: &str, - _chat_id: &str, + channel: &str, + chat_id: &str, ) -> Result, AgentError> { - Ok(None) + let chat_scope = format!("{}:{}", channel, chat_id); + let current = { + self.inner + .lock() + .await + .current_sessions + .get(&chat_scope) + .cloned() + }; + + let Some(current) = current else { + return Ok(None); + }; + + match self.storage.get_session(¤t).await { + Ok(_) => Ok(UnifiedSessionId::parse(¤t)), + Err(StorageError::NotFound(_)) => Ok(None), + Err(e) => Err(AgentError::Other(format!("storage error: {}", e))), + } } pub async fn switch_dialog( @@ -1322,7 +1498,9 @@ impl SessionManager { // Update current session tracking let mut inner = self.inner.lock().await; let chat_scope = format!("{}:{}", channel, chat_id); - inner.current_sessions.insert(chat_scope, unified_id.to_string()); + inner + .current_sessions + .insert(chat_scope, unified_id.to_string()); Ok(unified_id) } @@ -1330,31 +1508,45 @@ impl SessionManager { &self, channel: &str, chat_id: &str, - _include_archived: bool, + include_archived: bool, ) -> Result<(Vec, Option), AgentError> { - let metas = self.storage.list_sessions(channel, chat_id, 10).await + let metas = self + .storage + .list_sessions(channel, chat_id, 10, include_archived) + .await .map_err(|e| AgentError::Other(format!("failed to list dialogs: {}", e)))?; + let current_dialog_id = self + .get_current_dialog(channel, chat_id) + .await? + .map(|sid| sid.dialog_id); - let dialogs: Vec = metas.into_iter().map(|meta| { - DialogInfo { + let dialogs: Vec = metas + .into_iter() + .map(|meta| DialogInfo { session_id: UnifiedSessionId::new(channel, chat_id, &meta.dialog_id), title: meta.title, created_at: meta.created_at, last_active_at: meta.last_active_at, message_count: meta.message_count, - archived_at: None, - } - }).collect(); + archived_at: meta.archived_at, + }) + .collect(); - Ok((dialogs, None)) + Ok((dialogs, current_dialog_id)) } - pub async fn rename_dialog(&self, session_id: &UnifiedSessionId, title: &str) -> Result<(), AgentError> { + pub async fn rename_dialog( + &self, + session_id: &UnifiedSessionId, + title: &str, + ) -> Result<(), AgentError> { // Update in-memory session let session = self.get_or_create_session(session_id).await?; let mut session_guard = session.lock().await; session_guard.title = title.to_string(); - session_guard.persist_session_meta().await + session_guard + .persist_session_meta() + .await .map_err(|e| AgentError::Other(format!("failed to rename dialog: {}", e)))?; Ok(()) } @@ -1363,7 +1555,9 @@ impl SessionManager { let session_id_str = session_id.to_string(); // Soft delete from Storage - self.storage.soft_delete_session(&session_id_str).await + self.storage + .soft_delete_session(&session_id_str) + .await .map_err(|e| AgentError::Other(format!("failed to delete dialog: {}", e)))?; // Remove from memory and current sessions @@ -1375,13 +1569,32 @@ impl SessionManager { Ok(()) } - pub fn archive_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> { - // Archive concept removed - just return OK + pub async fn archive_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { + let session_id_str = session_id.to_string(); + self.storage + .archive_session(&session_id_str) + .await + .map_err(|e| AgentError::Other(format!("failed to archive dialog: {}", e)))?; + + let mut inner = self.inner.lock().await; + inner.sessions.remove(&session_id_str); + let chat_scope = format!("{}:{}", session_id.channel, session_id.chat_id); + if inner + .current_sessions + .get(&chat_scope) + .is_some_and(|id| id == &session_id_str) + { + inner.current_sessions.remove(&chat_scope); + } + Ok(()) } - pub fn clear_dialog_history(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> { - Err(AgentError::Other("clear_dialog_history not available".to_string())) + pub async fn clear_dialog_history( + &self, + session_id: &UnifiedSessionId, + ) -> Result<(), AgentError> { + self.clear_session_history(session_id).await } /// Get or activate a specific session by its full UnifiedSessionId. @@ -1395,9 +1608,10 @@ impl SessionManager { let session_id_str = unified_id.to_string(); match self.storage.get_session(&session_id_str).await { Ok(_) => self.get_or_create_session(unified_id).await, - Err(StorageError::NotFound(_)) => { - Err(AgentError::Other(format!("session not found: {}", unified_id))) - } + Err(StorageError::NotFound(_)) => Err(AgentError::Other(format!( + "session not found: {}", + unified_id + ))), Err(e) => Err(AgentError::Other(format!("storage error: {}", e))), } } @@ -1409,21 +1623,32 @@ impl SessionManager { ) -> Result { let chat_scope = format!("{}:{}", channel, chat_id); let current_id = { - self.inner.lock().await.current_sessions.get(&chat_scope).cloned() + self.inner + .lock() + .await + .current_sessions + .get(&chat_scope) + .cloned() }; if let Some(ref current_id) = current_id - && let Ok(_) = self.storage.get_session(current_id).await { - let parts: Vec<&str> = current_id.split(':').collect(); - if parts.len() == 3 { - return Ok(UnifiedSessionId::new(channel, chat_id, parts[2])); - } + && let Ok(_) = self.storage.get_session(current_id).await + { + if let Some(parsed) = UnifiedSessionId::parse(current_id) { + return Ok(parsed); } + } - match self.storage.find_most_recent_session(channel, chat_id).await { + match self + .storage + .find_most_recent_session(channel, chat_id) + .await + { Ok(Some(meta)) => Ok(UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)), _ => { - let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?; + let (new_id, _) = self + .create_session(channel, chat_id, None, String::new()) + .await?; Ok(new_id) } } @@ -1456,7 +1681,9 @@ impl SessionManager { task_id: task_id.map(|s| s.to_string()), }; let msg = ChatMessage::assistant_with_source(content, source); - guard.add_message(msg, true).await + guard + .add_message(msg, true) + .await .map_err(|e| AgentError::Other(format!("persist error: {}", e)))?; } @@ -1468,7 +1695,9 @@ impl SessionManager { media: vec![], metadata: HashMap::new(), }; - self.bus.publish_outbound(outbound).await + self.bus + .publish_outbound(outbound) + .await .map_err(|e| AgentError::Other(format!("bus publish error: {}", e)))?; Ok(()) @@ -1486,78 +1715,93 @@ impl SessionManager { tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id"); let session = self.get_or_create_session(&unified_id).await?; - CURRENT_SOURCE_SESSION.scope(Some(unified_id.to_string()), async { - // Check for slash command - if let Some((cmd_name, cmd_args)) = parse_slash_command(content) { - let result = self.execute_slash_command( - cmd_name, - if cmd_args.is_empty() { None } else { Some(cmd_args) }, - channel, - chat_id, - Some(&unified_id), - ).await; + CURRENT_SOURCE_SESSION + .scope(Some(unified_id.to_string()), async { + // Check for slash command + if let Some((cmd_name, cmd_args)) = parse_slash_command(content) { + let result = self + .execute_slash_command( + cmd_name, + if cmd_args.is_empty() { + None + } else { + Some(cmd_args) + }, + channel, + chat_id, + Some(&unified_id), + ) + .await; - return match result { - Ok((_new_session_id, response)) => Ok(HandleResult::CommandOutput(response)), - Err(e) => Ok(HandleResult::CommandOutput(e.to_string())), + return match result { + Ok((_new_session_id, response)) => { + Ok(HandleResult::CommandOutput(response)) + } + Err(e) => Ok(HandleResult::CommandOutput(e.to_string())), + }; + } + + // Normal message: enqueue to per-session worker for serial processing. + let task = AgentTask { + channel: channel.to_string(), + chat_id: chat_id.to_string(), + content: content.to_string(), + media, }; - } - - // Normal message: enqueue to per-session worker for serial processing. - let task = AgentTask { - channel: channel.to_string(), - chat_id: chat_id.to_string(), - content: content.to_string(), - media, - }; - let session_clone = session.clone(); - let unified_str = unified_id.to_string(); - { - let mut guard = session_clone.lock().await; - let needs_spawn = - guard.agent_tx.is_none() || guard.agent_tx.as_ref().is_some_and(|tx| tx.is_closed()); - if needs_spawn { - guard.agent_tx = None; - guard.current_cancel = None; - guard.worker_generation = guard.worker_generation.wrapping_add(1); - let generation = guard.worker_generation; - let (tx, rx) = mpsc::unbounded_channel(); - guard.agent_tx = Some(tx); - spawn_agent_worker( - rx, - session_clone.clone(), - self.bus.clone(), - self.memory_manager.clone(), - self.skills_loader.clone(), - generation, - unified_str.clone(), - ); + let session_clone = session.clone(); + let unified_str = unified_id.to_string(); + { + let mut guard = session_clone.lock().await; + let needs_spawn = guard.agent_tx.is_none() + || guard.agent_tx.as_ref().is_some_and(|tx| tx.is_closed()); + if needs_spawn { + guard.agent_tx = None; + guard.current_cancel = None; + guard.worker_generation = guard.worker_generation.wrapping_add(1); + let generation = guard.worker_generation; + let (tx, rx) = mpsc::unbounded_channel(); + guard.agent_tx = Some(tx); + spawn_agent_worker( + rx, + session_clone.clone(), + self.bus.clone(), + self.memory_manager.clone(), + self.skills_loader.clone(), + generation, + unified_str.clone(), + ); + } + if let Err(e) = guard.agent_tx.as_ref().unwrap().send(task) { + // Worker died after we just spawned it — respawn with the recovered task + let task = e.0; + guard.agent_tx = None; + guard.current_cancel = None; + guard.worker_generation = guard.worker_generation.wrapping_add(1); + let generation = guard.worker_generation; + let (tx, rx) = mpsc::unbounded_channel(); + guard.agent_tx = Some(tx); + spawn_agent_worker( + rx, + session_clone.clone(), + self.bus.clone(), + self.memory_manager.clone(), + self.skills_loader.clone(), + generation, + unified_str.clone(), + ); + guard + .agent_tx + .as_ref() + .unwrap() + .send(task) + .unwrap_or_else(|_| { + tracing::error!("Agent worker spawn+send failed irrecoverably"); + }); + } } - if let Err(e) = guard.agent_tx.as_ref().unwrap().send(task) { - // Worker died after we just spawned it — respawn with the recovered task - let task = e.0; - guard.agent_tx = None; - guard.current_cancel = None; - guard.worker_generation = guard.worker_generation.wrapping_add(1); - let generation = guard.worker_generation; - let (tx, rx) = mpsc::unbounded_channel(); - guard.agent_tx = Some(tx); - spawn_agent_worker( - rx, - session_clone.clone(), - self.bus.clone(), - self.memory_manager.clone(), - self.skills_loader.clone(), - generation, - unified_str.clone(), - ); - guard.agent_tx.as_ref().unwrap().send(task).unwrap_or_else(|_| { - tracing::error!("Agent worker spawn+send failed irrecoverably"); - }); - } - } - Ok(HandleResult::AgentProcessing) - }).await + Ok(HandleResult::AgentProcessing) + }) + .await } } @@ -1889,7 +2133,8 @@ impl SessionManager { - 可以调用其他工具收集信息、处理任务,但最终消息必须通过 send_message 发送\n\ - 只输出最终消息内容,不要输出中间思考过程或分析!" ); - let full_system_prompt = format!("{}\n\n{}\n\n{}", base_prompt, skills_prompt, cron_context); + let full_system_prompt = + format!("{}\n\n{}\n\n{}", base_prompt, skills_prompt, cron_context); let history = vec![ ChatMessage::system(full_system_prompt), @@ -1898,18 +2143,20 @@ impl SessionManager { let agent = self.create_cron_agent()?; let source_session = format!("cron:{}", job_name); - let result = CURRENT_SOURCE_SESSION.scope(Some(source_session), async { - agent.process(history).await - }) - .await - .inspect_err(|e| { - tracing::error!(error = %e, job_id = %job_id, "Cron agent processing error"); - })?; + let result = CURRENT_SOURCE_SESSION + .scope(Some(source_session), async { agent.process(history).await }) + .await + .inspect_err(|e| { + tracing::error!(error = %e, job_id = %job_id, "Cron agent processing error"); + })?; Ok(HandleResult::AgentResponse(result.final_response.content)) } - pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { + pub async fn clear_session_history( + &self, + unified_id: &UnifiedSessionId, + ) -> Result<(), AgentError> { let session = self.get_or_create_session(unified_id).await?; let mut session_guard = session.lock().await; // Clear in-memory @@ -1917,11 +2164,19 @@ impl SessionManager { session_guard.seq_counter = 1; session_guard.total_message_count = 0; session_guard.message_count = 0; + session_guard.last_consolidated_at = None; + session_guard.last_compressed_message_at = None; // Clear Storage if let Some(ref storage) = session_guard.storage { - storage.clear_messages(&session_guard.id.to_string()).await + storage + .clear_messages(&session_guard.id.to_string()) + .await .map_err(|e| AgentError::Other(format!("failed to clear messages: {}", e)))?; } + session_guard + .persist_session_meta() + .await + .map_err(|e| AgentError::Other(format!("failed to persist cleared session: {}", e)))?; Ok(()) } } @@ -1939,18 +2194,27 @@ impl OutboundMessenger for SessionManager { ) -> Result<(), String> { // Fill origin from current source session if not provided if source.from_session.is_none() { - source.from_session = CURRENT_SOURCE_SESSION.try_with(|v| v.clone()).ok().flatten(); + source.from_session = CURRENT_SOURCE_SESSION + .try_with(|v| v.clone()) + .ok() + .flatten(); } let (target_sid, session) = if let Some(did) = dialog_id { let sid = UnifiedSessionId::new(channel, chat_id, did); - let session = self.get_or_activate_session(&sid).await + let session = self + .get_or_activate_session(&sid) + .await .map_err(|e| e.to_string())?; (sid, session) } else { - let sid = self.resolve_dialog_id(channel, chat_id).await + let sid = self + .resolve_dialog_id(channel, chat_id) + .await .map_err(|e| e.to_string())?; - let session = self.get_or_create_session(&sid).await + let session = self + .get_or_create_session(&sid) + .await .map_err(|e| e.to_string())?; (sid, session) }; @@ -1959,7 +2223,9 @@ impl OutboundMessenger for SessionManager { // Skip prefix for pure file messages to the same session (no cross-session redirect). let origin = source.from_session.as_deref().unwrap_or("unknown"); let origin_id = source.from_session.clone(); - let same_session = source.from_session.as_deref() + let same_session = source + .from_session + .as_deref() .map(|src| src == target_sid.to_string().as_str()) .unwrap_or(false); let marked_content = if content.trim().is_empty() && !media.is_empty() && same_session { @@ -1972,16 +2238,26 @@ impl OutboundMessenger for SessionManager { { let mut guard = session.lock().await; let msg = ChatMessage::assistant_with_source(marked_content.clone(), source); - guard.add_message(msg, true).await + guard + .add_message(msg, true) + .await .map_err(|e| e.to_string())?; } // Restore active dialog if source and target share channel:chat_id but differ in dialog_id if let Some(ref origin_id) = origin_id { let parts: Vec<&str> = origin_id.split(':').collect(); - if parts.len() == 3 && parts[0] == channel && parts[1] == chat_id && parts[2] != target_sid.dialog_id { + if parts.len() == 3 + && parts[0] == channel + && parts[1] == chat_id + && parts[2] != target_sid.dialog_id + { let scope = format!("{}:{}", channel, chat_id); - self.inner.lock().await.current_sessions.insert(scope, origin_id.clone()); + self.inner + .lock() + .await + .current_sessions + .insert(scope, origin_id.clone()); } } @@ -1994,7 +2270,9 @@ impl OutboundMessenger for SessionManager { media, metadata: HashMap::new(), }; - self.bus.publish_outbound(outbound).await + self.bus + .publish_outbound(outbound) + .await .map_err(|e| e.to_string())?; Ok(()) @@ -2026,23 +2304,20 @@ mod tests { } } -fn format_task_notification(task_id: &str, status: &crate::agent::TaskStatus, summary: &str) -> String { +fn format_task_notification( + task_id: &str, + status: &crate::agent::TaskStatus, + summary: &str, +) -> String { match status { crate::agent::TaskStatus::Completed => format!( "📋 后台任务完成\n\n任务 ID: {}\n\n结果:\n{}", task_id, summary ), - crate::agent::TaskStatus::Failed(err) => format!( - "📋 后台任务失败\n\n任务 ID: {}\n错误: {}", - task_id, err - ), - crate::agent::TaskStatus::Cancelled => format!( - "📋 后台任务已取消\n\n任务 ID: {}", - task_id - ), - crate::agent::TaskStatus::TimedOut => format!( - "📋 后台任务超时\n\n任务 ID: {}", - task_id - ), + crate::agent::TaskStatus::Failed(err) => { + format!("📋 后台任务失败\n\n任务 ID: {}\n错误: {}", task_id, err) + } + crate::agent::TaskStatus::Cancelled => format!("📋 后台任务已取消\n\n任务 ID: {}", task_id), + crate::agent::TaskStatus::TimedOut => format!("📋 后台任务超时\n\n任务 ID: {}", task_id), } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 1a08a71..10f4a5a 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,17 +1,17 @@ +pub mod background_task; pub mod error; pub mod memory; pub mod message; -pub mod background_task; pub mod scheduler; pub mod session; -pub use error::StorageError; pub use background_task::BackgroundTask; +pub use error::StorageError; pub use scheduler::{JobRun, ScheduledJob}; use sqlx::{Pool, Row, Sqlite, SqlitePool}; -use tokio::time::{sleep, Duration}; use std::path::Path; +use tokio::time::{Duration, sleep}; pub struct Storage { pub(crate) pool: Pool, @@ -42,6 +42,7 @@ impl Storage { last_active_at INTEGER NOT NULL, message_count INTEGER DEFAULT 0, routing_info TEXT, + archived_at INTEGER, deleted_at INTEGER, last_consolidated_at INTEGER, last_compressed_message_at INTEGER, @@ -92,20 +93,16 @@ impl Storage { .await?; // Migration: add source column if upgrading from older schema - sqlx::query( - r#"ALTER TABLE messages ADD COLUMN source TEXT"#, - ) - .execute(&self.pool) - .await - .ok(); + sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#) + .execute(&self.pool) + .await + .ok(); // Migration: add reasoning_content column if upgrading from older schema - sqlx::query( - r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#, - ) - .execute(&self.pool) - .await - .ok(); + sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#) + .execute(&self.pool) + .await + .ok(); // Background tasks table — for async sub-agent tasks. // Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL). @@ -216,11 +213,19 @@ impl Storage { .await?; // Rebuild FTS5 index for any existing records + sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')") + .execute(&self.pool) + .await?; + + // Migration: add last_consolidated_at column if not exists sqlx::query( - "INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')", + r#" + ALTER TABLE sessions ADD COLUMN archived_at INTEGER + "#, ) .execute(&self.pool) - .await?; + .await + .ok(); // Migration: add last_consolidated_at column if not exists sqlx::query( @@ -260,7 +265,10 @@ impl Storage { .await?; if let Err(e) = Self::init_scheduler_schema(&self.pool).await { - tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e); + tracing::warn!( + "Failed to init scheduler schema (tables may already exist): {}", + e + ); } Ok(()) @@ -374,16 +382,20 @@ impl Storage { &self.pool } - pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> { + pub async fn upsert_session( + &self, + meta: &crate::storage::session::SessionMeta, + ) -> Result<(), StorageError> { sqlx::query( r#" - INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET title = excluded.title, last_active_at = excluded.last_active_at, message_count = excluded.message_count, routing_info = excluded.routing_info, + archived_at = excluded.archived_at, deleted_at = excluded.deleted_at, last_consolidated_at = excluded.last_consolidated_at, last_compressed_message_at = excluded.last_compressed_message_at @@ -398,6 +410,7 @@ impl Storage { .bind(meta.last_active_at) .bind(meta.message_count) .bind(&meta.routing_info) + .bind(meta.archived_at) .bind(meta.deleted_at) .bind(meta.last_consolidated_at) .bind(meta.last_compressed_message_at) @@ -407,10 +420,13 @@ impl Storage { Ok(()) } - pub async fn get_session(&self, id: &str) -> Result { + pub async fn get_session( + &self, + id: &str, + ) -> Result { let row = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at FROM sessions WHERE id = ? AND deleted_at IS NULL "#, ) @@ -429,6 +445,7 @@ impl Storage { last_active_at: row.get("last_active_at"), message_count: row.get("message_count"), routing_info: row.get("routing_info"), + archived_at: row.get("archived_at"), deleted_at: row.get("deleted_at"), last_consolidated_at: row.get("last_consolidated_at"), last_compressed_message_at: row.get("last_compressed_message_at"), @@ -440,18 +457,21 @@ impl Storage { channel: &str, chat_id: &str, limit: i64, + include_archived: bool, ) -> Result, StorageError> { let rows = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at FROM sessions WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL + AND (? OR archived_at IS NULL) ORDER BY last_active_at DESC LIMIT ? "#, ) .bind(channel) .bind(chat_id) + .bind(include_archived) .bind(limit) .fetch_all(self.pool()) .await?; @@ -468,6 +488,7 @@ impl Storage { last_active_at: row.get("last_active_at"), message_count: row.get("message_count"), routing_info: row.get("routing_info"), + archived_at: row.get("archived_at"), deleted_at: row.get("deleted_at"), last_consolidated_at: row.get("last_consolidated_at"), last_compressed_message_at: row.get("last_compressed_message_at"), @@ -498,13 +519,22 @@ impl Storage { pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> { let now = chrono::Utc::now().timestamp_millis(); - sqlx::query( - r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#, - ) - .bind(now) - .bind(id) - .execute(self.pool()) - .await?; + sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#) + .bind(now) + .bind(id) + .execute(self.pool()) + .await?; + + Ok(()) + } + + pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> { + let now = chrono::Utc::now().timestamp_millis(); + sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#) + .bind(now) + .bind(id) + .execute(self.pool()) + .await?; Ok(()) } @@ -516,9 +546,9 @@ impl Storage { ) -> Result, StorageError> { let row = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at FROM sessions - WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL + WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL ORDER BY last_active_at DESC LIMIT 1 "#, @@ -539,6 +569,7 @@ impl Storage { last_active_at: row.get("last_active_at"), message_count: row.get("message_count"), routing_info: row.get("routing_info"), + archived_at: row.get("archived_at"), deleted_at: row.get("deleted_at"), last_consolidated_at: row.get("last_consolidated_at"), last_compressed_message_at: row.get("last_compressed_message_at"), @@ -547,7 +578,11 @@ impl Storage { } } - pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result { + pub async fn append_message( + &self, + session_id: &str, + msg: &crate::storage::message::MessageMeta, + ) -> Result { sqlx::query( r#" INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at) @@ -674,16 +709,15 @@ impl Storage { offset: i64, limit: i64, ) -> Result<(Vec, i64), StorageError> { - let count_row = sqlx::query( - "SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL", - ) - .fetch_one(self.pool()) - .await?; + let count_row = + sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL") + .fetch_one(self.pool()) + .await?; let total: i64 = count_row.get("total"); let rows = sqlx::query( r#" - SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at + SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at FROM sessions WHERE deleted_at IS NULL ORDER BY last_active_at DESC @@ -707,6 +741,7 @@ impl Storage { last_active_at: row.get("last_active_at"), message_count: row.get("message_count"), routing_info: row.get("routing_info"), + archived_at: row.get("archived_at"), deleted_at: row.get("deleted_at"), last_consolidated_at: row.get("last_consolidated_at"), last_compressed_message_at: row.get("last_compressed_message_at"), @@ -772,7 +807,10 @@ impl Storage { where_extra.push_str(" AND created_at > ?"); } - let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra); + let count_sql = format!( + "SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", + where_extra + ); let select_sql = format!( r#" SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at @@ -1030,6 +1068,7 @@ mod tests { last_active_at: 1000, message_count: 0, routing_info: Some(r#"{"type":"cli"}"#.to_string()), + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, @@ -1066,14 +1105,18 @@ mod tests { last_active_at: i as i64 * 1000, message_count: i, routing_info: None, + archived_at: None, deleted_at: None, - last_consolidated_at: None, - last_compressed_message_at: None, + last_consolidated_at: None, + last_compressed_message_at: None, }; storage.upsert_session(&meta).await.unwrap(); } - let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap(); + let sessions = storage + .list_sessions("cli_chat", "sid123", 10, false) + .await + .unwrap(); assert_eq!(sessions.len(), 5); // 按 last_active_at DESC 排序 assert_eq!(sessions[0].dialog_id, "dialog4"); @@ -1093,6 +1136,7 @@ mod tests { last_active_at: 1000, message_count: 0, routing_info: None, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, @@ -1120,6 +1164,7 @@ mod tests { last_active_at: 1000, message_count: 0, routing_info: None, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, @@ -1141,7 +1186,10 @@ mod tests { created_at: 1000, }; - let seq = storage.append_message(&session_meta.id, &msg).await.unwrap(); + let seq = storage + .append_message(&session_meta.id, &msg) + .await + .unwrap(); assert_eq!(seq, 1); let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap(); @@ -1163,6 +1211,7 @@ mod tests { last_active_at: 1000, message_count: 0, routing_info: None, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, diff --git a/src/storage/session.rs b/src/storage/session.rs index b9a1408..25e05d7 100644 --- a/src/storage/session.rs +++ b/src/storage/session.rs @@ -11,6 +11,7 @@ pub struct SessionMeta { pub last_active_at: i64, pub message_count: i64, pub routing_info: Option, + pub archived_at: Option, pub deleted_at: Option, pub last_consolidated_at: Option, pub last_compressed_message_at: Option, diff --git a/src/tools/chat_manager.rs b/src/tools/chat_manager.rs index 6737115..ff465d3 100644 --- a/src/tools/chat_manager.rs +++ b/src/tools/chat_manager.rs @@ -126,7 +126,10 @@ impl ChatManagerTool { let start_num = offset + 1; let end_num = offset + sessions.len() as i64; - let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num); + let mut output = format!( + "全部会话 (共 {} 个,第 {}-{} 个):\n", + total, start_num, end_num + ); for s in &sessions { let ago = format_duration_ago(now_ms - s.last_active_at); @@ -300,9 +303,10 @@ mod tests { last_active_at: now - i * 3600_000, message_count: i * 5, routing_info: None, + archived_at: None, deleted_at: None, - last_consolidated_at: None, - last_compressed_message_at: None, + last_consolidated_at: None, + last_compressed_message_at: None, }; storage.upsert_session(&meta).await.unwrap(); } @@ -335,6 +339,7 @@ mod tests { last_active_at: now, message_count: 3, routing_info: None, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, @@ -346,7 +351,11 @@ mod tests { id: format!("msg{}", i), session_id: session_id.to_string(), seq: i as i64 + 1, - role: if i == 0 { "user".to_string() } else { "assistant".to_string() }, + role: if i == 0 { + "user".to_string() + } else { + "assistant".to_string() + }, content: format!("消息内容 {}", i), reasoning_content: None, media_refs: None, @@ -392,6 +401,7 @@ mod tests { last_active_at: now, message_count: 5, routing_info: None, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, @@ -403,7 +413,11 @@ mod tests { id: format!("msg{}", i), session_id: session_id.to_string(), seq: i as i64 + 1, - role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() }, + role: if i % 2 == 0 { + "user".to_string() + } else { + "assistant".to_string() + }, content: format!("消息内容 {}", i), reasoning_content: None, media_refs: None, @@ -447,6 +461,7 @@ mod tests { last_active_at: now, message_count: 5, routing_info: None, + archived_at: None, deleted_at: None, last_consolidated_at: None, last_compressed_message_at: None, @@ -492,10 +507,7 @@ mod tests { let (storage, _dir) = create_test_storage().await; let tool = ChatManagerTool::new(storage, vec![]); - let result = tool - .execute(json!({ "action": "unknown" })) - .await - .unwrap(); + let result = tool.execute(json!({ "action": "unknown" })).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("Unknown action")); } diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index d73ce37..750fc4a 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -1,5 +1,5 @@ -use picobot::providers::{ChatCompletionRequest, Message}; use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; +use picobot::providers::{ChatCompletionRequest, Message}; /// Test that message with special characters is properly escaped #[test] @@ -19,7 +19,9 @@ fn test_message_special_characters() { #[test] fn test_multiline_system_prompt() { let messages = vec![ - Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"), + Message::system( + "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate", + ), Message::user("Hi"), ]; @@ -33,10 +35,7 @@ fn test_multiline_system_prompt() { #[test] fn test_chat_request_serialization() { let request = ChatCompletionRequest { - messages: vec![ - Message::system("You are helpful"), - Message::user("Hello"), - ], + messages: vec![Message::system("You are helpful"), Message::user("Hello")], temperature: Some(0.7), max_tokens: Some(100), tools: None,