From 2e13f6932ca73c02ea87a33732fd0fdd753a6b29 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Fri, 15 May 2026 15:01:58 +0800 Subject: [PATCH] feat: Enhance session management with topic support - Added topic management capabilities, allowing users to create, switch, and query topics within sessions. - Updated command structure to include new commands: SwitchSession and GetCurrentSession. - Introduced TopicRecord for managing topic data in the storage layer. - Modified session handlers to accommodate topic operations, including listing and loading topics. - Enhanced database schema to support topics, including new tables and relationships. - Updated input adapters to recognize new commands and handle topic-related actions. - Improved logging for session and topic operations to aid in debugging and monitoring. --- README.md | 272 +++++++++++++--------- src/command/adapters/channel.rs | 9 +- src/command/adapters/cli.rs | 9 +- src/command/context.rs | 18 ++ src/command/handlers/save_session.rs | 28 +++ src/command/handlers/session.rs | 74 +++--- src/command/handlers/session_query.rs | 313 +++++++++++++++++++------- src/command/mod.rs | 14 +- src/gateway/processor.rs | 17 +- src/gateway/ws.rs | 25 +- src/storage/mod.rs | 204 ++++++++++++++++- src/storage/records.rs | 12 + 12 files changed, 756 insertions(+), 239 deletions(-) diff --git a/README.md b/README.md index c25ed5c..0b786fe 100644 --- a/README.md +++ b/README.md @@ -4,14 +4,15 @@ PicoBot 是一个用 Rust 构建的多通道 Agent 网关。它把消息接入 当前代码库已经实现以下核心能力: -- 基于 Gateway 的统一消息入口,支持 WebSocket CLI 与飞书通道 +- 基于 Gateway 的统一消息入口,支持 WebSocket CLI、飞书通道和微信通道 - 面向工具调用的 Agent 循环,支持多轮 tool calling - SQLite 持久化会话、消息、长期记忆、技能事件和调度任务 - 基于用户维度的长期记忆检索与写入机制 -- 基于 SKILL.md 的项目级 / 用户级技能加载与运行时管理 +- 基于 SKILL.md 的项目级 / 用户级 / OpenClaw 级技能加载与运行时管理 - 定时任务系统,支持延迟、周期、绝对时间和 cron 调度 - 超长上下文压缩与历史摘要 - 持久化 Agent 配置文件注入与周期性重注入 +- 会话管理支持按通道查询和切换 ## 1. 项目定位 @@ -37,110 +38,116 @@ PicoBot 的设计目标不是“只会聊天”的单进程 Bot,而是一个 ### 2.1 消息流转图 -```mermaid -sequenceDiagram - autonumber - participant U as User / External Chat - participant C as Channel - participant B as MessageBus - participant P as InboundProcessor - participant SM as SessionManager - participant SES as Session - participant AES as AgentExecutionService - participant AL as AgentLoop - participant T as ToolRegistry - participant DB as SQLite - participant OD as OutboundDispatcher +```text +用户消息输入 + │ + ▼ +┌─────────────┐ +│ Channel │ ◄── WebSocket CLI / 飞书等 +└──────┬──────┘ + │ publish_inbound + ▼ +┌─────────────┐ +│ MessageBus │ ◄── 统一消息总线 +└──────┬──────┘ + │ consume_inbound + ▼ +┌──────────────────┐ +│ InboundProcessor │ ◄── 入站处理器 +└────────┬─────────┘ + │ handle_message + ▼ +┌────────────────┐ +│ SessionManager │ ◄── 会话管理器 +└────────┬───────┘ + │ 定位/创建 Session + ▼ +┌────────────────────────┐ +│ AgentExecutionService │ ◄── 准备上下文、执行 Agent +└────────┬───────────────┘ + │ + ┌────┴────┐ + ▼ ▼ +┌───────┐ ┌──────────┐ +│SQLite │ │ AgentLoop│ ◄── 调用 LLM、工具执行 +└───────┘ └────┬─────┘ + ▲ │ + │ ▼ + │ ┌──────────┐ + └────┤ToolRegistry│ ◄── 工具注册表 + └──────────┘ + │ + ▼ + ┌─────────────────┐ + │ OutboundDispatcher │ ◄── 出站分发 + └────────┬────────┘ + │ dispatch + ▼ + ┌─────────────────┐ + │ Channel │ + └────────┬────────┘ + │ 最终回复 + ▼ + 用户 - U->>C: 输入消息 - C->>B: publish_inbound - B->>P: consume_inbound - P->>SM: handle_message(channel, sender, chat, content) - SM->>SES: active_session(channel) - AES->>SES: ensure_persistent_session / ensure_chat_loaded - AES->>DB: 追加 user / system 消息 - AES->>AL: process(history) - AL->>T: 调用工具 - T-->>AL: 返回 tool result - AL-->>AES: emitted_messages - AES->>DB: 按真实顺序持久化 assistant / tool / assistant - AES->>SES: 安排后台历史压缩 - AES-->>SM: outbound messages - SM-->>P: outbound messages - P->>B: publish_outbound - B->>OD: consume_outbound - OD->>C: dispatch - C-->>U: 最终回复 +关键步骤说明: +1. Channel 接收外部消息 → 2. MessageBus 统一路由 → 3. InboundProcessor 处理 +4. SessionManager 定位 Session → 5. AgentExecutionService 执行 +6. 消息持久化到 SQLite → 7. AgentLoop 推理/工具调用 +8. 结果经 OutboundDispatcher 返回 Channel ``` ### 2.2 项目架构图 -```mermaid -flowchart TB - subgraph Edge[接入层] - CLI[CLI Client / WebSocket] - FEI[Feishu Channel] - HTTP[HTTP Health / WS Gateway] - end +```text +接入层 (Edge) +│ +├── CLI Client / WebSocket ──┐ +├── Feishu Channel ─────────┼──► 网关与运行时编排 (Gateway) +└── HTTP Health / WS ────────┘ │ + ├── ChannelManager ◄──┐ + ├── MessageBus ────────┼── 双向通信 + ├── InboundProcessor ──┘ + ├── OutboundDispatcher + ├── SessionManager + ├── SessionLifecycle / Message / ScheduledTask Services + └── AgentExecutionService ◄── 调用 Agent 执行层 + │ +Agent 执行层 (Agent) ◄─────────────────────────────────────────┘ +│ +├── AgentLoop ──► ToolRegistry (工具调用) +├── ContextCompressor (上下文压缩) +├── SkillRuntime (技能系统) +└── LLM Providers (OpenAI / Anthropic) - subgraph Gateway[网关与运行时编排] - CM[ChannelManager] - BUS[MessageBus] - IP[InboundProcessor] - OD[OutboundDispatcher] - SSM[SessionManager] - SVC[SessionLifecycle / Message / ScheduledTask Services] - AES[AgentExecutionService] - end +持久化与后台能力 (Runtime) +│ +├── SessionStore / SQLite (会话/消息/记忆存储) +├── Scheduler (定时任务调度) +└── Memory Maintenance (记忆维护) - subgraph Agent[Agent 执行层] - LOOP[AgentLoop] - COMP[ContextCompressor] - SK[SkillRuntime] - TOOLS[ToolRegistry] - PROV[LLM Providers] - end - - subgraph Runtime[持久化与后台能力] - STORE[SessionStore / SQLite] - SCH[Scheduler] - MEM[Memory Maintenance] - end - - CLI --> HTTP - FEI --> CM - HTTP --> CM - CM --> BUS - BUS --> IP - IP --> SSM - SSM --> SVC - SVC --> AES - AES --> LOOP - LOOP --> TOOLS - LOOP --> SK - LOOP --> PROV - AES --> COMP - SSM --> STORE - AES --> STORE - SCH --> BUS - SCH --> SSM - MEM --> STORE - BUS --> OD - OD --> CM +数据流向: +- 接入层 ◄──► 网关 ◄──► Agent 执行层 +- 网关 ◄──► 持久化层 (SQLite) +- Scheduler ◄──► 总线 ◄──► SessionManager ``` 主要模块如下: -- src/gateway:网关生命周期、InboundProcessor、OutboundDispatcher、SessionManager,以及消息执行、调度任务执行、Prompt 注入、历史压缩和记忆维护编排 +- src/agent:AgentLoop、上下文压缩器、运行时配置、系统提示构建 - src/bus:消息总线队列与消息结构定义,不包含渠道投递逻辑 -- src/agent:AgentLoop 与上下文压缩器 +- src/channels:渠道适配层,当前已有 CLI、飞书、微信通道 +- src/cli:本地 CLI 客户端、输入命令解析 +- src/client:WebSocket CLI 客户端实现 +- src/command:命令系统,包括处理器、适配器、上下文和响应处理 +- src/config:配置解析与默认值定义 +- src/domain:领域模型,包含消息和工具定义 +- src/gateway:网关生命周期、InboundProcessor、OutboundDispatcher、SessionManager,以及消息执行、调度任务执行、Prompt 注入、历史压缩和记忆维护编排 - src/providers:不同 LLM Provider 的统一抽象,当前支持 openai 和 anthropic - src/tools:内置工具集合与 ToolRegistry - src/storage:SQLite 持久化实现 -- src/channels:渠道适配层,当前已有 CLI 与飞书通道 - src/scheduler:数据库驱动的计划任务调度器 - src/skills:技能发现、加载与运行时管理 -- src/client / src/cli:本地 CLI 客户端、输入命令解析与会话交互 - src/protocol:WebSocket 入站 / 出站协议结构 ## 3. 消息机制 @@ -379,8 +386,10 @@ PicoBot 支持基于文件系统的技能系统,用来给 Agent 注入某一 - 用户级技能:~/.picobot/skills/*/SKILL.md - 用户 Agent 级技能:~/.agents/skills/*/SKILL.md +- 用户 OpenClaw 级技能:~/.openclaw/skills/*/SKILL.md - 项目级技能:.picobot/skills/*/SKILL.md - 项目 Agent 级技能:.agents/skills/*/SKILL.md +- 项目 OpenClaw 级技能:.openclaw/skills/*/SKILL.md ### 7.2 最小 SKILL.md 格式 @@ -459,7 +468,7 @@ skills 配置示例: { "skills": { "enabled": true, - "sources": ["user", "user_agent", "project", "project_agent"], + "sources": ["user", "user_agent", "user_openclaw", "project", "project_agent", "project_openclaw"], "max_index_chars": 4000, "max_listed_skills": 32 } @@ -654,6 +663,7 @@ silent_agent_task 和 agent_task 使用同一套 Agent 执行能力,但路由 - WebSocket CLI 客户端 - 飞书通道 +- 微信通道 ### 10.2 Gateway 接口 @@ -678,13 +688,15 @@ cargo run -- agent CLI 中已实现的交互命令包括: -- /new [title] -- /reset -- /sessions -- /use -- /rename -- /archive -- /delete +- /new [title] - 创建新会话 +- /reset - 重置当前会话上下文 +- /sessions - 列出当前通道的所有会话(支持跨通道隔离) +- /use <session> - 切换到指定会话 +- /rename <title> - 重命名当前会话 +- /archive - 归档当前会话 +- /delete - 删除指定会话 +- /clear - 清屏 +- /quit - 退出 CLI - /clear - /quit @@ -707,7 +719,8 @@ CLI 中已实现的交互命令包括: "base_url": "<OPENAI_BASE_URL>", "api_key": "<OPENAI_API_KEY>", "extra_headers": {}, - "llm_timeout_secs": 120 + "llm_timeout_secs": 120, + "memory_maintenance_timeout_secs": 600 } }, "models": { @@ -744,12 +757,12 @@ CLI 中已实现的交互命令包括: 常用配置项: -- providers:Provider 连接信息 +- providers:Provider 连接信息,包含 llm_timeout_secs(LLM 调用超时,默认 120 秒)和 memory_maintenance_timeout_secs(记忆维护超时,默认 600 秒) - models:模型参数,包括上下文窗口估算所用的 context_window_tokens - agents:Agent 级别的工具轮次、工具结果裁剪与上下文裁剪参数 - gateway:监听地址、端口、工具结果展示、会话 TTL、Prompt 重注入策略 - scheduler:调度器开关、worker 队列容量、误触发策略和任务列表 -- channels:飞书等通道配置 +- channels:飞书、微信等通道配置 - skills:技能来源与索引限制 - tools:工具启用/禁用配置(通过 disabled 列表指定禁用的工具) - time.timezone:时区,默认应使用 IANA 时区名,例如 Asia/Shanghai @@ -761,6 +774,46 @@ CLI 中已实现的交互命令包括: 1. 复制并修改 config.json,或把配置放到 ~/.picobot/config.json 2. 配置好 Provider 的 base_url、api_key、model_id 3. 如果要接飞书,再补充 channels.feishu 配置 +4. 如果要接微信,再补充 channels.wechat 配置 + +飞书通道配置示例: + +```json +{ + "channels": { + "feishu": { + "type": "feishu", + "app_id": "<FEISHU_APP_ID>", + "app_secret": "<FEISHU_APP_SECRET>", + "enabled": true, + "allow_from": ["*"], + "agent": "default", + "media_dir": "~/.picobot/media/feishu", + "reaction_emoji": "Typing", + "max_message_chars": 20000, + "reply_context_max_chars": 20000 + } + } +} +``` + +微信通道配置示例: + +```json +{ + "channels": { + "wechat": { + "type": "wechat", + "enabled": true, + "allow_from": ["*"], + "agent": "default", + "base_url": "https://ilinkai.weixin.qq.com", + "cred_path": "~/.picobot/wechat/credentials.json", + "force_login": false + } + } +} +``` ### 12.2 启动网关 @@ -797,17 +850,23 @@ curl http://127.0.0.1:19876/health ```text PicoBot/ ├── src/ -│ ├── agent/ # AgentLoop、上下文压缩 +│ ├── agent/ # AgentLoop、上下文压缩、运行时配置、系统提示 │ ├── bus/ # 消息总线与消息结构 -│ ├── channels/ # CLI / 飞书等渠道适配 -│ ├── cli/ # CLI 输入命令 +│ ├── channels/ # CLI / 飞书 / 微信等渠道适配 +│ ├── cli/ # CLI 输入命令与通道实现 │ ├── client/ # WebSocket CLI 客户端 +│ ├── command/ # 命令系统(处理器、适配器、上下文) │ ├── config/ # 配置解析 -│ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面 +│ ├── domain/ # 领域模型(消息、工具定义) +│ ├── gateway/ # Gateway、Session 编排、WS/HTTP 控制面、执行服务 +│ ├── logging/ # 日志配置 +│ ├── observability/ # 可观测性支持 +│ ├── platform/ # 平台抽象 │ ├── providers/ # OpenAI / Anthropic Provider │ ├── scheduler/ # 定时任务系统 │ ├── skills/ # 技能运行时 -│ ├── storage/ # SQLite 持久化 +│ ├── storage/ # SQLite 持久化(存储、端口、记录、错误) +│ ├── text/ # 文本处理工具 │ └── tools/ # 内置工具集合 ├── docs/ │ ├── IMPLEMENTATION_LOG.md @@ -824,8 +883,11 @@ PicoBot/ - docs/PERSISTENCE.md:持久化结构是否与代码一致 - src/gateway/session.rs:会话状态、会话路由和运行时服务编排 +- src/gateway/execution.rs:Agent 执行服务 - src/storage/mod.rs:SQLite schema 变更 - src/config/mod.rs:配置项变更是否同步到 README +- src/bus/message.rs:消息结构变更(如 OutboundMessage 新增 session_id) +- src/command/handlers/:命令处理器实现 ## 15. 总结 diff --git a/src/command/adapters/channel.rs b/src/command/adapters/channel.rs index cd9e55b..44fccf1 100644 --- a/src/command/adapters/channel.rs +++ b/src/command/adapters/channel.rs @@ -78,14 +78,19 @@ impl InputAdapter for ChannelInputAdapter { })); } - // 解析 /use 命令 + // 解析 /use 命令 - 切换会话(支持 session_id 或序号) if let Some(session_id) = trimmed.strip_prefix("/use ") { let session_id = session_id.trim(); - return Ok(Some(Command::LoadSession { + return Ok(Some(Command::SwitchSession { session_id: session_id.to_string(), })); } + // 解析 /current 命令 - 获取当前会话信息 + if trimmed == "/current" { + return Ok(Some(Command::GetCurrentSession)); + } + // 不是命令,返回 None Ok(None) } diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 809286b..f279bff 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -79,14 +79,19 @@ impl InputAdapter for CliInputAdapter { })); } - // 解析 /use 命令 + // 解析 /use 命令 - 切换会话(支持 session_id 或序号) if let Some(session_id) = trimmed.strip_prefix("/use ") { let session_id = session_id.trim(); - return Ok(Some(Command::LoadSession { + return Ok(Some(Command::SwitchSession { session_id: session_id.to_string(), })); } + // 解析 /current 命令 - 获取当前会话信息 + if trimmed == "/current" { + return Ok(Some(Command::GetCurrentSession)); + } + // 不是命令,返回 None Ok(None) } diff --git a/src/command/context.rs b/src/command/context.rs index 9602bfa..7687df5 100644 --- a/src/command/context.rs +++ b/src/command/context.rs @@ -8,6 +8,8 @@ pub struct CommandContext { pub request_id: Uuid, /// 当前会话ID pub session_id: Option<String>, + /// 当前话题ID + pub topic_id: Option<String>, /// 当前聊天ID pub chat_id: Option<String>, /// 发送者ID @@ -24,6 +26,7 @@ impl CommandContext { Self { request_id: Uuid::new_v4(), session_id: None, + topic_id: None, chat_id: None, sender_id: sender_id.into(), channel_name: channel_name.into(), @@ -37,6 +40,12 @@ impl CommandContext { self } + /// 设置话题ID + pub fn with_topic_id(mut self, topic_id: impl Into<String>) -> Self { + self.topic_id = Some(topic_id.into()); + self + } + /// 设置聊天ID pub fn with_chat_id(mut self, chat_id: impl Into<String>) -> Self { self.chat_id = Some(chat_id.into()); @@ -55,6 +64,8 @@ impl CommandContext { pub struct AdapterContext { /// 当前会话ID pub session_id: Option<String>, + /// 当前话题ID + pub topic_id: Option<String>, /// 当前聊天ID pub chat_id: Option<String>, /// 发送者ID @@ -66,6 +77,7 @@ impl AdapterContext { pub fn new(sender_id: impl Into<String>) -> Self { Self { session_id: None, + topic_id: None, chat_id: None, sender_id: sender_id.into(), } @@ -77,6 +89,12 @@ impl AdapterContext { self } + /// 设置话题ID + pub fn with_topic_id(mut self, topic_id: impl Into<String>) -> Self { + self.topic_id = Some(topic_id.into()); + self + } + /// 设置聊天ID pub fn with_chat_id(mut self, chat_id: impl Into<String>) -> Self { self.chat_id = Some(chat_id.into()); diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index 75c2c88..2f77192 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -128,11 +128,39 @@ async fn handle_save_session( include_all: bool, ctx: CommandContext, ) -> Result<CommandResponse, CommandError> { + tracing::debug!( + ctx_session_id = ?ctx.session_id, + ctx_chat_id = ?ctx.chat_id, + channel = %ctx.channel_name, + "SaveSession command received" + ); + let session_id = ctx .session_id .as_deref() .ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?; + tracing::debug!(session_id = %session_id, "Attempting to save session"); + + // 先检查会话是否存在 + match handler.store.get_session(session_id) { + Ok(Some(record)) => { + tracing::debug!( + session_id = %session_id, + title = %record.title, + chat_id = %record.chat_id, + message_count = record.message_count, + "Session found for saving" + ); + } + Ok(None) => { + tracing::warn!(session_id = %session_id, "Session not found in store"); + } + Err(e) => { + tracing::error!(session_id = %session_id, error = %e, "Error querying session"); + } + } + // 调用公共函数 let output_path = save_session_to_file( session_id, diff --git a/src/command/handlers/session.rs b/src/command/handlers/session.rs index a4840e9..9c58d34 100644 --- a/src/command/handlers/session.rs +++ b/src/command/handlers/session.rs @@ -2,23 +2,24 @@ use crate::command::context::CommandContext; use crate::command::handler::CommandHandler; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; -use crate::gateway::cli_session::CliSessionService; +use crate::storage::SessionStore; use async_trait::async_trait; +use std::sync::Arc; /// 会话命令处理器 /// /// 处理与会话管理相关的命令 pub struct SessionCommandHandler { - cli_sessions: CliSessionService, + store: Arc<SessionStore>, } impl SessionCommandHandler { /// 创建新的会话命令处理器 /// /// # Arguments - /// * `cli_sessions` - CLI 会话服务 - pub(crate) fn new(cli_sessions: CliSessionService) -> Self { - Self { cli_sessions } + /// * `store` - Session 存储 + pub(crate) fn new(store: Arc<SessionStore>) -> Self { + Self { store } } } @@ -47,37 +48,49 @@ async fn handle_create_session( title: Option<String>, ctx: CommandContext, ) -> Result<CommandResponse, CommandError> { - let record = handler - .cli_sessions - .create_with_channel(&ctx.channel_name, title.as_deref()) - .map_err(|e| CommandError::new("CREATE_SESSION_ERROR", e.to_string()))?; + // 获取当前 session_id,如果没有则报错 + let session_id = ctx.session_id.as_deref() + .ok_or_else(|| CommandError::new("NO_SESSION", "No active session. Please ensure a session exists first."))?; + + // 创建新话题(在同一个 Session 内) + let topic_title = title.unwrap_or_else(|| { + format!("Topic {}", &uuid::Uuid::new_v4().to_string()[..8]) + }); + + let topic = handler + .store + .create_topic(session_id, &topic_title, None) + .map_err(|e| CommandError::new("CREATE_TOPIC_ERROR", e.to_string()))?; Ok(CommandResponse::success(ctx.request_id) - .with_message(MessageKind::Notification, &record.title) - .with_metadata("session_id", &record.id) - .with_metadata("channel_name", &record.channel_name) - .with_metadata("message_count", &record.message_count.to_string())) + .with_message(MessageKind::Notification, &topic.title) + .with_metadata("topic_id", &topic.id) + .with_metadata("session_id", &topic.session_id) + .with_metadata("message_count", &topic.message_count.to_string())) } #[cfg(test)] mod tests { use super::*; - use crate::command::response::MessageKind; - use crate::storage::{SessionRecord, SessionStore}; + use crate::storage::SessionStore; use std::sync::Arc; - fn create_test_service() -> CliSessionService { + fn create_test_handler() -> SessionCommandHandler { let store = Arc::new(SessionStore::in_memory().unwrap()); - CliSessionService::new(store) + SessionCommandHandler::new(store) } #[tokio::test] async fn test_create_session_with_title() { - let service = create_test_service(); - let handler = SessionCommandHandler::new(service); - let ctx = CommandContext::new("test", "test"); + let handler = create_test_handler(); + // 需要先创建一个 session + let store = handler.store.clone(); + let session = store.create_session("cli", Some("test session")).unwrap(); + + let ctx = CommandContext::new("test", "cli") + .with_session_id(&session.id); let cmd = Command::CreateSession { - title: Some("my session".to_string()), + title: Some("my topic".to_string()), }; let result = handler.handle(cmd, ctx).await; @@ -85,16 +98,17 @@ mod tests { assert!(result.is_ok()); let resp = result.unwrap(); assert!(resp.success); - assert_eq!(resp.messages.len(), 1); - assert_eq!(resp.messages[0].content, "my session"); - assert!(resp.metadata.contains_key("session_id")); + assert!(resp.metadata.contains_key("topic_id")); } #[tokio::test] async fn test_create_session_without_title() { - let service = create_test_service(); - let handler = SessionCommandHandler::new(service); - let ctx = CommandContext::new("test", "test"); + let handler = create_test_handler(); + let store = handler.store.clone(); + let session = store.create_session("cli", Some("test session")).unwrap(); + + let ctx = CommandContext::new("test", "cli") + .with_session_id(&session.id); let cmd = Command::CreateSession { title: None }; let result = handler.handle(cmd, ctx).await; @@ -102,15 +116,11 @@ mod tests { assert!(result.is_ok()); let resp = result.unwrap(); assert!(resp.success); - assert_eq!(resp.messages.len(), 1); - // 自动生成的标题 - assert!(!resp.messages[0].content.is_empty()); } #[test] fn test_can_handle() { - let service = create_test_service(); - let handler = SessionCommandHandler::new(service); + let handler = create_test_handler(); assert!(handler.can_handle(&Command::CreateSession { title: None })); } diff --git a/src/command/handlers/session_query.rs b/src/command/handlers/session_query.rs index fced38c..704c639 100644 --- a/src/command/handlers/session_query.rs +++ b/src/command/handlers/session_query.rs @@ -2,28 +2,39 @@ use crate::command::context::CommandContext; use crate::command::handler::CommandHandler; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; -use crate::gateway::cli_session::CliSessionService; -use crate::protocol::SessionSummary; +use crate::gateway::session::SessionManager; +use crate::storage::{SessionStore, TopicRecord}; use async_trait::async_trait; +use std::sync::Arc; /// 会话查询命令处理器 /// -/// 处理 ListSessions 和 LoadSession 命令 +/// 处理 ListSessions、LoadSession 和 SwitchSession 命令(现在操作 Topic) pub struct SessionQueryCommandHandler { - cli_sessions: CliSessionService, + store: Arc<SessionStore>, + session_manager: Option<SessionManager>, } impl SessionQueryCommandHandler { /// 创建新的会话查询命令处理器 - pub fn new(cli_sessions: CliSessionService) -> Self { - Self { cli_sessions } + pub fn new(store: Arc<SessionStore>) -> Self { + Self { + store, + session_manager: None, + } + } + + /// 设置 SessionManager(用于 SwitchSession 命令) + pub fn with_session_manager(mut self, session_manager: SessionManager) -> Self { + self.session_manager = Some(session_manager); + self } } #[async_trait] impl CommandHandler for SessionQueryCommandHandler { fn can_handle(&self, cmd: &Command) -> bool { - matches!(cmd, Command::ListSessions { .. } | Command::LoadSession { .. }) + matches!(cmd, Command::ListSessions { .. } | Command::LoadSession { .. } | Command::SwitchSession { .. } | Command::GetCurrentSession) } async fn handle( @@ -38,84 +49,217 @@ impl CommandHandler for SessionQueryCommandHandler { Command::LoadSession { session_id } => { handle_load_session(self, session_id, ctx).await } + Command::SwitchSession { session_id } => { + handle_switch_session(self, session_id, ctx).await + } + Command::GetCurrentSession => { + handle_get_current_session(self, ctx).await + } _ => unreachable!(), } } } -/// 处理列出会话命令 +/// 处理列出话题命令 async fn handle_list_sessions( handler: &SessionQueryCommandHandler, - include_archived: bool, + _include_archived: bool, ctx: CommandContext, ) -> Result<CommandResponse, CommandError> { - // 使用当前通道名称查询会话,而不是硬编码 "cli" - let channel_name = &ctx.channel_name; - let records = handler - .cli_sessions - .list_by_channel(channel_name, include_archived) - .map_err(|e| CommandError::new("LIST_SESSIONS_ERROR", e.to_string()))?; + // 获取当前 session_id + let session_id = ctx.session_id.as_deref() + .ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?; - let summaries: Vec<SessionSummary> = records - .into_iter() - .map(|r| SessionSummary { - session_id: r.id, - title: r.title, - channel_name: r.channel_name, - chat_id: r.chat_id, - message_count: r.message_count, - last_active_at: r.last_active_at, - archived_at: r.archived_at, - }) - .collect(); + // 查询该 session 的所有 topic + let topics = handler + .store + .list_topics(session_id) + .map_err(|e| CommandError::new("LIST_TOPICS_ERROR", e.to_string()))?; - // 将会话列表序列化为 JSON 存储在 metadata 中 - let sessions_json = - serde_json::to_string(&summaries).map_err(|e| CommandError::new("SERIALIZE_ERROR", e.to_string()))?; + // 获取当前 topic ID + let current_topic_id = ctx.topic_id.as_deref().unwrap_or(""); - // 构建可读的会话列表消息 - let message = if summaries.is_empty() { - "No sessions found.".to_string() + // 构建表格格式的话题列表消息 + let message = if topics.is_empty() { + "No topics found. Use /new <title> to create a topic.".to_string() } else { - let mut lines = vec![format!("Found {} session(s):", summaries.len())]; - for summary in &summaries { - let archived_info = summary - .archived_at - .map(|_| " [archived]") - .unwrap_or(""); + let mut lines = vec![format!("Found {} topic(s):", topics.len())]; + lines.push(String::new()); + + // 表格头部 + lines.push("┌────┬─────────────────┬──────────────────────┬──────────┬─────────────────┐".to_string()); + lines.push("│ No │ Topic ID │ Title │ Messages │ Last Active │".to_string()); + lines.push("├────┼─────────────────┼──────────────────────┼──────────┼─────────────────┤".to_string()); + + // 表格内容 + for (idx, topic) in topics.iter().enumerate() { + let row_num = idx + 1; + let is_current = topic.id == current_topic_id; + let num_marker = if is_current { " * ".to_string() } else { format!(" {:<2}", row_num) }; + + // 截断过长的字段 + let topic_id_display = if topic.id.len() > 15 { + format!("{}...", &topic.id[..12]) + } else { + topic.id.clone() + }; + let title_display = if topic.title.len() > 20 { + format!("{}...", &topic.title[..17]) + } else { + topic.title.clone() + }; + + let last_active = format_time_ago(topic.last_active_at); + lines.push(format!( - " - {}: {}{}", - summary.session_id, summary.title, archived_info + "│{}│ {:<15} │ {:<20} │ {:<8} │ {:<15} │", + num_marker, + topic_id_display, + title_display, + topic.message_count, + last_active )); } - lines.push("".to_string()); - lines.push("Use /use <session_id> to switch to a session".to_string()); + + // 表格底部 + lines.push("└────┴─────────────────┴──────────────────────┴──────────┴─────────────────┘".to_string()); + lines.push(String::new()); + lines.push("* = current topic".to_string()); + lines.push("Use /use <number> or /use <topic_id> to switch".to_string()); + lines.join("\n") }; + let topics_json = serde_json::to_string(&topics) + .map_err(|e| CommandError::new("SERIALIZE_ERROR", e.to_string()))?; + Ok(CommandResponse::success(ctx.request_id) .with_message(MessageKind::Notification, &message) - .with_metadata("sessions", &sessions_json) - .with_metadata("count", &summaries.len().to_string())) + .with_metadata("topics", &topics_json) + .with_metadata("count", &topics.len().to_string()) + .with_metadata("current_topic_id", current_topic_id)) } -/// 处理加载会话命令 +/// 处理加载话题命令 async fn handle_load_session( handler: &SessionQueryCommandHandler, - session_id: String, + topic_id: String, ctx: CommandContext, ) -> Result<CommandResponse, CommandError> { - let record = handler - .cli_sessions - .get(&session_id) - .map_err(|e| CommandError::new("LOAD_SESSION_ERROR", e.to_string()))? - .ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", format!("Session not found: {}", session_id)))?; + let topic = handler + .store + .get_topic(&topic_id) + .map_err(|e| CommandError::new("LOAD_TOPIC_ERROR", e.to_string()))? + .ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", topic_id)))?; Ok(CommandResponse::success(ctx.request_id) - .with_message(MessageKind::Notification, &record.title) - .with_metadata("session_id", &record.id) - .with_metadata("title", &record.title) - .with_metadata("message_count", &record.message_count.to_string())) + .with_message(MessageKind::Notification, &topic.title) + .with_metadata("topic_id", &topic.id) + .with_metadata("title", &topic.title) + .with_metadata("message_count", &topic.message_count.to_string())) +} + +/// 格式化时间为相对时间(如 "2 mins ago") +fn format_time_ago(timestamp_ms: i64) -> String { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + + let diff_ms = now - timestamp_ms; + let diff_secs = diff_ms / 1000; + + if diff_secs < 60 { + "just now".to_string() + } else if diff_secs < 3600 { + format!("{} mins ago", diff_secs / 60) + } else if diff_secs < 86400 { + format!("{} hours ago", diff_secs / 3600) + } else { + format!("{} days ago", diff_secs / 86400) + } +} + +/// 处理获取当前话题命令 +async fn handle_get_current_session( + handler: &SessionQueryCommandHandler, + ctx: CommandContext, +) -> Result<CommandResponse, CommandError> { + let topic_id = ctx.topic_id.as_deref() + .ok_or_else(|| CommandError::new("NO_CURRENT_TOPIC", "No current topic"))?; + + let topic = handler + .store + .get_topic(topic_id) + .map_err(|e| CommandError::new("GET_TOPIC_ERROR", e.to_string()))? + .ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", topic_id)))?; + + let last_active = format_time_ago(topic.last_active_at); + let created_at = format_time_ago(topic.created_at); + + let message = format!( + "Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Created: {}\n Last Active: {}", + topic.id, + topic.title, + topic.message_count, + created_at, + last_active + ); + + Ok(CommandResponse::success(ctx.request_id) + .with_message(MessageKind::Notification, &message) + .with_metadata("topic_id", &topic.id) + .with_metadata("title", &topic.title) + .with_metadata("message_count", &topic.message_count.to_string())) +} + +/// 处理切换话题命令 +async fn handle_switch_session( + handler: &SessionQueryCommandHandler, + topic_id: String, + ctx: CommandContext, +) -> Result<CommandResponse, CommandError> { + // 获取当前 session_id + let session_id = ctx.session_id.as_deref() + .ok_or_else(|| CommandError::new("NO_SESSION", "No active session"))?; + + // 尝试解析为序号 + let target_topic_id = if let Ok(index) = topic_id.parse::<usize>() { + let topics = handler + .store + .list_topics(session_id) + .map_err(|e| CommandError::new("LIST_TOPICS_ERROR", e.to_string()))?; + + let index = index.saturating_sub(1); + if index >= topics.len() { + return Err(CommandError::new( + "INVALID_TOPIC_INDEX", + format!("Topic index {} is out of range (1-{})", index + 1, topics.len()) + )); + } + topics[index].id.clone() + } else { + topic_id + }; + + // 验证目标话题存在 + let topic = handler + .store + .get_topic(&target_topic_id) + .map_err(|e| CommandError::new("SWITCH_TOPIC_ERROR", e.to_string()))? + .ok_or_else(|| CommandError::new("TOPIC_NOT_FOUND", format!("Topic not found: {}", target_topic_id)))?; + + // 返回切换成功响应 + let message = format!( + "✓ Switched to topic: {} ({} messages)", + topic.title, topic.message_count + ); + + Ok(CommandResponse::success(ctx.request_id) + .with_message(MessageKind::Notification, &message) + .with_metadata("topic_id", &topic.id) + .with_metadata("title", &topic.title) + .with_metadata("message_count", &topic.message_count.to_string())) } #[cfg(test)] @@ -124,17 +268,20 @@ mod tests { use crate::storage::SessionStore; use std::sync::Arc; - fn create_test_service() -> CliSessionService { + fn create_test_handler() -> SessionQueryCommandHandler { let store = Arc::new(SessionStore::in_memory().unwrap()); - CliSessionService::new(store) + SessionQueryCommandHandler::new(store) } #[tokio::test] async fn test_list_sessions_empty() { - let service = create_test_service(); - let handler = SessionQueryCommandHandler::new(service); - // 使用 "cli" 通道,与 CliSessionService::create 一致 - let ctx = CommandContext::new("test", "cli"); + let handler = create_test_handler(); + // 需要先创建一个 session 和 topic + let store = handler.store.clone(); + let session = store.create_session("cli", Some("test")).unwrap(); + + let ctx = CommandContext::new("test", "cli") + .with_session_id(&session.id); let cmd = Command::ListSessions { include_archived: false, }; @@ -144,19 +291,20 @@ mod tests { assert!(result.is_ok()); let resp = result.unwrap(); assert!(resp.success); - assert!(resp.messages[0].content.contains("No sessions")); + assert!(resp.messages[0].content.contains("No topics")); } #[tokio::test] async fn test_list_sessions_with_items() { - let service = create_test_service(); - let handler = SessionQueryCommandHandler::new(service.clone()); + let handler = create_test_handler(); + let store = handler.store.clone(); + let session = store.create_session("cli", Some("test")).unwrap(); - // 创建一些会话(使用 cli 通道) - service.create(Some("test session")).unwrap(); + // 创建一个 topic + store.create_topic(&session.id, "Test Topic", None).unwrap(); - // 使用 "cli" 通道查询,与创建会话的通道一致 - let ctx = CommandContext::new("test", "cli"); + let ctx = CommandContext::new("test", "cli") + .with_session_id(&session.id); let cmd = Command::ListSessions { include_archived: false, }; @@ -166,14 +314,17 @@ mod tests { assert!(result.is_ok()); let resp = result.unwrap(); assert!(resp.success); - assert!(resp.metadata.contains_key("sessions")); + assert!(resp.metadata.contains_key("topics")); } #[tokio::test] async fn test_load_session_not_found() { - let service = create_test_service(); - let handler = SessionQueryCommandHandler::new(service); - let ctx = CommandContext::new("test", "test"); + let handler = create_test_handler(); + let store = handler.store.clone(); + let session = store.create_session("cli", Some("test")).unwrap(); + + let ctx = CommandContext::new("test", "test") + .with_session_id(&session.id); let cmd = Command::LoadSession { session_id: "nonexistent".to_string(), }; @@ -185,15 +336,15 @@ mod tests { #[tokio::test] async fn test_load_session_success() { - let service = create_test_service(); - let handler = SessionQueryCommandHandler::new(service.clone()); + let handler = create_test_handler(); + let store = handler.store.clone(); + let session = store.create_session("cli", Some("test")).unwrap(); + let topic = store.create_topic(&session.id, "Test Topic", None).unwrap(); - // 创建会话 - let record = service.create(Some("test session")).unwrap(); - - let ctx = CommandContext::new("test", "test"); + let ctx = CommandContext::new("test", "test") + .with_session_id(&session.id); let cmd = Command::LoadSession { - session_id: record.id.clone(), + session_id: topic.id.clone(), }; let result = handler.handle(cmd, ctx).await; @@ -201,6 +352,6 @@ mod tests { assert!(result.is_ok()); let resp = result.unwrap(); assert!(resp.success); - assert_eq!(resp.metadata.get("session_id").unwrap(), &record.id); + assert_eq!(resp.metadata.get("topic_id").unwrap(), &topic.id); } } diff --git a/src/command/mod.rs b/src/command/mod.rs index 935abc5..f2867d1 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -11,17 +11,21 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Command { - /// 创建新会话 + /// 创建新话题(在同一个 Session 内) CreateSession { title: Option<String> }, - /// 保存会话内容到 Markdown 文件 + /// 保存话题内容到 Markdown 文件 SaveSession { filepath: Option<String>, include_all: bool, }, - /// 列出会话 + /// 列出当前 Session 的所有话题 ListSessions { include_archived: bool }, - /// 加载指定会话 + /// 加载指定话题 LoadSession { session_id: String }, + /// 切换到指定话题(清理当前历史并加载新话题) + SwitchSession { session_id: String }, + /// 获取当前话题信息 + GetCurrentSession, } impl Command { @@ -32,6 +36,8 @@ impl Command { Command::SaveSession { .. } => "save_session", Command::ListSessions { .. } => "list_sessions", Command::LoadSession { .. } => "load_session", + Command::SwitchSession { .. } => "switch_session", + Command::GetCurrentSession => "get_current_session", } } } diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 9710153..5a86971 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -14,6 +14,7 @@ use crate::command::Command; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::skills::SkillPromptProvider; +use crate::storage::persistent_session_id; use super::session::{BusToolCallEmitter, SessionManager}; @@ -35,13 +36,15 @@ impl InboundProcessor { ) -> Self { // 创建命令路由器并注册处理器 let mut command_router = CommandRouter::new(); + let store = session_manager.store(); // 注册 Session 处理器 - let cli_sessions = session_manager.cli_sessions(); - command_router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone()))); + command_router.register(Box::new(SessionCommandHandler::new(store.clone()))); // 注册 session_query 处理器 - command_router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions))); + let session_query_handler = SessionQueryCommandHandler::new(store) + .with_session_manager(session_manager.clone()); + command_router.register(Box::new(session_query_handler)); // 注册 save_session 处理器 let store = session_manager.store(); @@ -115,15 +118,19 @@ impl InboundProcessor { } async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> { + // 计算正确的 session_id(根据 channel_name 和 chat_id) + let session_id = persistent_session_id(&inbound.channel, &inbound.chat_id); + // 使用 ChannelInputAdapter 尝试解析命令 let adapter = ChannelInputAdapter::new(); let ctx = crate::command::context::AdapterContext::new(&inbound.channel) - .with_session_id(&inbound.chat_id); + .with_session_id(&session_id); if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) { // 使用命令路由器处理 let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel, &inbound.channel) - .with_session_id(&inbound.chat_id); + .with_session_id(&session_id) + .with_chat_id(&inbound.chat_id); // 记录是否是创建会话命令(用于后续处理) let _is_create_session = matches!(cmd, Command::CreateSession { .. }); diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index a193a7e..6310dc9 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -221,16 +221,24 @@ async fn handle_inbound( ])); let mut router = CommandRouter::new(); - router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone()))); - router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions))); + router.register(Box::new(SessionCommandHandler::new(store.clone()))); + // 修复:添加 SessionManager 到 SessionQueryCommandHandler + let session_query_handler = SessionQueryCommandHandler::new(store.clone()) + .with_session_manager(state.session_manager.clone()); + router.register(Box::new(session_query_handler)); router.register(Box::new(SaveSessionCommandHandler::new( store, system_prompt_provider, ))); // 构建命令上下文 + tracing::debug!( + current_session_id = %current_session_id, + "Building CommandContext for WebSocket command" + ); let cmd_ctx = CommandContext::new("websocket", "cli") - .with_session_id(current_session_id.as_str()); + .with_session_id(current_session_id.as_str()) + .with_chat_id(current_session_id.as_str()); // 执行命令 let response = router.dispatch_with_response(cmd, cmd_ctx).await; @@ -239,6 +247,11 @@ async fn handle_inbound( if response.success { // 更新当前会话 ID(如果是创建会话) if let Some(session_id) = response.metadata.get("session_id") { + tracing::info!( + old_session_id = %current_session_id, + new_session_id = %session_id, + "Updating current_session_id" + ); *current_session_id = session_id.clone(); state .channel_manager @@ -250,6 +263,12 @@ async fn handle_inbound( ) .await; } + } else if let Some(ref error) = response.error { + tracing::warn!( + error_code = %error.code, + error_message = %error.message, + "Command failed" + ); } // 适配并发送响应 diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 13f9515..2d49256 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -17,7 +17,7 @@ pub use ports::{ }; pub use records::{ MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, - SchedulerJobUpsert, SessionRecord, SkillEventRecord, + SchedulerJobUpsert, SessionRecord, SkillEventRecord, TopicRecord, }; #[derive(Clone)] @@ -79,6 +79,7 @@ impl SessionStore { CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, + topic_id TEXT, seq INTEGER NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, @@ -90,6 +91,7 @@ impl SessionStore { tool_calls_json TEXT, created_at INTEGER NOT NULL, FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, + FOREIGN KEY(topic_id) REFERENCES topics(id) ON DELETE SET NULL, UNIQUE(session_id, seq) ); @@ -98,6 +100,21 @@ impl SessionStore { CREATE INDEX IF NOT EXISTS idx_messages_session_created ON messages(session_id, created_at); + CREATE TABLE IF NOT EXISTS topics ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + title TEXT NOT NULL, + description TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + last_active_at INTEGER NOT NULL, + message_count INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_topics_session + ON topics(session_id, last_active_at DESC); + CREATE TABLE IF NOT EXISTS skill_events ( id TEXT PRIMARY KEY, session_id TEXT, @@ -211,6 +228,8 @@ impl SessionStore { ) -> Result<SessionRecord, StorageError> { let now = current_timestamp(); let id = uuid::Uuid::new_v4().to_string(); + // 统一使用 persistent_session_id 格式 + let session_id = persistent_session_id(channel_name, &id); let title = title .map(str::trim) .filter(|value| !value.is_empty()) @@ -232,11 +251,11 @@ impl SessionStore { reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count ) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0) ", - params![id, title, channel_name, id, now], + params![&session_id, title, channel_name, id, now], )?; drop(conn); - self.get_session(&id)? + self.get_session(&session_id)? .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } @@ -354,6 +373,103 @@ impl SessionStore { Ok(()) } + // ==================== Topic Methods ==================== + + pub fn create_topic( + &self, + session_id: &str, + title: &str, + description: Option<&str>, + ) -> Result<TopicRecord, StorageError> { + let now = current_timestamp(); + let id = format!("topic:{}", uuid::Uuid::new_v4()); + + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + "INSERT INTO topics (id, session_id, title, description, created_at, updated_at, last_active_at, message_count) VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?5, 0)", + params![&id, session_id, title, description.unwrap_or(""), now], + )?; + drop(conn); + + self.get_topic(&id)? + .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) + } + + pub fn get_topic(&self, topic_id: &str) -> Result<Option<TopicRecord>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut stmt = conn.prepare( + "SELECT id, session_id, title, description, created_at, updated_at, last_active_at, message_count FROM topics WHERE id = ?1", + )?; + + stmt.query_row(params![topic_id], |row| { + Ok(TopicRecord { + id: row.get(0)?, + session_id: row.get(1)?, + title: row.get(2)?, + description: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + last_active_at: row.get(6)?, + message_count: row.get(7)?, + }) + }) + .optional() + .map_err(StorageError::from) + } + + pub fn list_topics(&self, session_id: &str) -> Result<Vec<TopicRecord>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut stmt = conn.prepare( + "SELECT id, session_id, title, description, created_at, updated_at, last_active_at, message_count FROM topics WHERE session_id = ?1 ORDER BY last_active_at DESC" + )?; + + let rows = stmt.query_map(params![session_id], |row| { + Ok(TopicRecord { + id: row.get(0)?, + session_id: row.get(1)?, + title: row.get(2)?, + description: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + last_active_at: row.get(6)?, + message_count: row.get(7)?, + }) + })?; + + let mut topics = Vec::new(); + for row in rows { + topics.push(row?); + } + Ok(topics) + } + + pub fn update_topic_title(&self, topic_id: &str, title: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + "UPDATE topics SET title = ?2, updated_at = ?3 WHERE id = ?1", + params![topic_id, title.trim(), now], + )?; + Ok(()) + } + + pub fn delete_topic(&self, topic_id: &str) -> Result<(), StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + // Messages 的 topic_id 会被设为 NULL(ON DELETE SET NULL) + conn.execute("DELETE FROM topics WHERE id = ?1", params![topic_id])?; + Ok(()) + } + + pub fn touch_topic(&self, topic_id: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + "UPDATE topics SET last_active_at = ?2 WHERE id = ?1", + params![topic_id, now], + )?; + Ok(()) + } + pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { let now = current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); @@ -410,6 +526,15 @@ impl SessionStore { &self, session_id: &str, message: &ChatMessage, + ) -> Result<(), StorageError> { + self.append_message_with_topic(session_id, None, message) + } + + pub fn append_message_with_topic( + &self, + session_id: &str, + topic_id: Option<&str>, + message: &ChatMessage, ) -> Result<(), StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let tx = conn.unchecked_transaction()?; @@ -429,13 +554,14 @@ impl SessionStore { tx.execute( " INSERT INTO messages ( - id, session_id, seq, role, content, + id, session_id, topic_id, seq, role, content, system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12) + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13) ", params![ message.id, session_id, + topic_id, seq, message.role, message.content, @@ -1177,6 +1303,62 @@ impl SessionStore { load_messages_after(&conn, session_id, cutoff_seq) } + pub fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut stmt = conn.prepare( + " + SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json + FROM messages + WHERE topic_id = ?1 + ORDER BY seq ASC + ", + )?; + + let rows = stmt.query_map(params![topic_id], |row| { + let media_refs_json: String = row.get(5)?; + let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + media_refs_json.len(), + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; + + let tool_calls_json: Option<String> = row.get(9)?; + let tool_calls = tool_calls_json + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + 9, + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; + + Ok(ChatMessage { + id: row.get(0)?, + role: row.get(1)?, + content: row.get(2)?, + system_context: row.get(3)?, + reasoning_content: row.get(4)?, + media_refs, + timestamp: row.get(6)?, + tool_call_id: row.get(7)?, + tool_name: row.get(8)?, + tool_state: None, + tool_calls, + }) + })?; + + let mut messages = Vec::new(); + for row in rows { + messages.push(row?); + } + Ok(messages) + } + pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); load_messages_after(&conn, session_id, 0) @@ -1349,6 +1531,18 @@ fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> { )?; } + if !has_column(conn, "messages", "topic_id")? { + add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN topic_id TEXT")?; + // 添加外键约束(SQLite 不支持 ALTER TABLE ADD FOREIGN KEY,需要重建表) + // 这里只添加列,外键约束由应用层保证 + } + + // 创建 topic_id 索引(如果不存在) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_messages_topic_seq ON messages(topic_id, seq) WHERE topic_id IS NOT NULL", + [], + )?; + Ok(()) } diff --git a/src/storage/records.rs b/src/storage/records.rs index 6e01384..9fa852e 100644 --- a/src/storage/records.rs +++ b/src/storage/records.rs @@ -28,6 +28,18 @@ pub struct SessionRecord { pub agent_prompt_reinjection_count: i64, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TopicRecord { + pub id: String, + pub session_id: String, + pub title: String, + pub description: Option<String>, + pub created_at: i64, + pub updated_at: i64, + pub last_active_at: i64, + pub message_count: i64, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MemoryRecord { pub id: String,