Compare commits
9 Commits
cb1140e9be
...
bafa7a606c
| Author | SHA1 | Date | |
|---|---|---|---|
| bafa7a606c | |||
| 11a8e93b77 | |||
| 8c0c76a232 | |||
| 709d70f828 | |||
| 25d37bcdc1 | |||
| 3d29854079 | |||
| e65130450e | |||
| 29543444da | |||
| d022e30943 |
@ -73,7 +73,7 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
||||
### Key Constraints
|
||||
|
||||
- Gateway **changes working directory** to workspace on startup (`src/gateway/mod.rs:31`)
|
||||
- Session/message persistence uses SQLite via `sqlx`; DB stored in workspace as `.picobot_sessions.db` by default
|
||||
- Session/message persistence uses SQLite via `sqlx`; DB stored in workspace as `picobot.db` by default
|
||||
- `ChannelManager` owns the `MessageBus` and all channel instances
|
||||
- `OutboundDispatcher` routes outbound messages to the correct channel via `ChannelManager`
|
||||
- Config `.env` loading uses `unsafe { env::set_var(...) }` — don't refactor to safer patterns without understanding side effects
|
||||
|
||||
@ -26,7 +26,7 @@ graph TB
|
||||
end
|
||||
|
||||
subgraph Storage
|
||||
SQLite[("SQLite<br/>.picobot_sessions.db")]
|
||||
SQLite[("SQLite<br/>picobot.db")]
|
||||
end
|
||||
|
||||
subgraph AI["AI Providers"]
|
||||
@ -236,7 +236,7 @@ The `.env` file in the working directory is loaded manually (not via dotenv crat
|
||||
| `port` | u16 | `19876` | Listen port |
|
||||
| `session_ttl_hours` | number | `4` | Inactive session expiration (hours) |
|
||||
| `cleanup_interval_minutes` | number | `60` | Session cleanup interval |
|
||||
| `session_db_path` | string | workspace `.picobot_sessions.db` | SQLite database path |
|
||||
| `session_db_path` | string | workspace `picobot.db` | SQLite database path |
|
||||
| `scheduler.enabled` | bool | `false` | Enable cron scheduler |
|
||||
|
||||
### Agent Config
|
||||
|
||||
90
docs/plans/2026-05-10-incremental-session-recovery-design.md
Normal file
90
docs/plans/2026-05-10-incremental-session-recovery-design.md
Normal file
@ -0,0 +1,90 @@
|
||||
# 启动增量恢复设计
|
||||
|
||||
## 问题
|
||||
|
||||
PicoBot 重启后,`Session::from_storage()` 全量加载 `messages` 表,恢复的 history 可能直接超出上下文窗口,首次 LLM 调用即触发压缩,浪费 token。
|
||||
|
||||
## 设计
|
||||
|
||||
### 核心思路
|
||||
|
||||
用 `last_compressed_message_at` 标记最后压缩时刻。恢复时:
|
||||
- 加载该标记之后的原始消息
|
||||
- 用该 session 的 Timeline 条目替代已压缩部分
|
||||
- `seq_counter` 统一从 SQLite 查 `MAX(seq) + 1`
|
||||
|
||||
```
|
||||
messages 表 memories(timeline)
|
||||
┌──────────────────────────┐ ┌───────────────────────────┐
|
||||
│ created_at = T1..T5 │ ← 跳过 │ session = feishu:oc:dialog │
|
||||
│ (压缩已覆盖,用Timeline替代)│ │ created_at 降序 │
|
||||
├──────────────────────────┤ ├───────────────────────────┤
|
||||
│ created_at > T6 │ ← 加载 │ 只取最近 3 条 │
|
||||
└──────────────────────────┘ └───────────────────────────┘
|
||||
```
|
||||
|
||||
### 数据变更
|
||||
|
||||
**`sessions` 表加列:**
|
||||
```sql
|
||||
last_compressed_message_at INTEGER
|
||||
```
|
||||
|
||||
**`SessionMeta` / `Session` 加字段:** `last_compressed_message_at: Option<i64>`
|
||||
|
||||
### Storage 层新增方法
|
||||
|
||||
| 方法 | SQL |
|
||||
|------|-----|
|
||||
| `get_max_message_seq(session_id)` | `SELECT MAX(seq) FROM messages WHERE session_id = ?` |
|
||||
| `load_messages_after_timestamp(session_id, after_ts)` | `WHERE created_at > ?` |
|
||||
| `load_session_timelines(session_id, limit)` | `WHERE session_id = ? AND category = 'timeline' ORDER BY created_at DESC LIMIT ?` |
|
||||
|
||||
### 压缩跟踪
|
||||
|
||||
`compress_if_needed()` 返回值改为 `CompressionResult { history, created_timelines: bool }`。
|
||||
`compress_once()` 中 LLM 摘要路径才置 `true`(Tier 2),Tier 1/3 不产生 Timeline。
|
||||
|
||||
**记录时机**(`handle_message` 正常流、溢出重试流、`/compact` 统一):
|
||||
```rust
|
||||
if result.created_timelines {
|
||||
session.last_compressed_message_at = Some(now());
|
||||
session.persist_session_meta().await;
|
||||
}
|
||||
```
|
||||
|
||||
### Session::from_storage() 恢复逻辑
|
||||
|
||||
有压缩标记时:
|
||||
1. `load_session_timelines(limit=4)` → 取 3 条给 LLM,第 4 条判"有更多"
|
||||
2. 有更多 → 插入提示 user 消息
|
||||
3. 逐条插入 Timeline 为 `[Previous Context]` user 消息
|
||||
4. `load_messages_after_timestamp(after_ts)` → 原始尾消息
|
||||
5. `repair_tool_call_chains`
|
||||
|
||||
无压缩标记 → 全量加载(现有行为)。
|
||||
|
||||
统一:`seq_counter = MAX(seq) + 1`
|
||||
|
||||
### 系统提示词
|
||||
|
||||
`Session.last_compressed_message_at` 非空时追加:
|
||||
```
|
||||
## 历史会话
|
||||
之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。
|
||||
```
|
||||
|
||||
## 改动清单
|
||||
|
||||
| # | 文件 | 改动 |
|
||||
|---|------|------|
|
||||
| 1 | `storage/session.rs` | `SessionMeta` 加 `last_compressed_message_at` |
|
||||
| 2 | `storage/mod.rs` | DDL migration + upsert/get_session 加列 |
|
||||
| 3 | `storage/mod.rs` | 新增 `get_max_message_seq`, `load_messages_after_timestamp` |
|
||||
| 4 | `storage/memory.rs` | 新增 `load_session_timelines` |
|
||||
| 5 | `agent/context_compressor.rs` | 返回值改为 `CompressionResult` 含 `created_timelines` |
|
||||
| 6 | `session/session.rs` | `Session` 加字段,`persist_session_meta` 加字段 |
|
||||
| 7 | `session/session.rs` | `from_storage()` 重写恢复逻辑 |
|
||||
| 8 | `session/session.rs` | `handle_message()` 压缩后记录标记 |
|
||||
| 9 | `session/session.rs` | `/compact` 命令压缩后记录标记 |
|
||||
| 10 | `session/session.rs` | `build_system_prompt()` 注入 `last_compressed_message_at` |
|
||||
674
docs/plans/2026-05-10-incremental-session-recovery.md
Normal file
674
docs/plans/2026-05-10-incremental-session-recovery.md
Normal file
@ -0,0 +1,674 @@
|
||||
# 启动增量恢复 Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** PicoBot 重启后不再全量加载 messages 表,而是基于 `last_compressed_message_at` 标记增量恢复,用 Timeline 替代已压缩部分。
|
||||
|
||||
**Architecture:** 在 `sessions` 表加 `last_compressed_message_at` 列,`compress_if_needed` 返回值增加 `created_timelines` 标志,恢复时按时间戳增量加载消息并用近 3 条 Timeline 前置,`seq_counter` 统一从 SQLite 查 MAX(seq)。
|
||||
|
||||
**Tech Stack:** Rust, sqlx (SQLite), tokio
|
||||
|
||||
---
|
||||
|
||||
### Task 1: SessionMeta 和数据库 DDL 加列
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/session.rs:15`
|
||||
- Modify: `src/storage/mod.rs:44-45` (DDL), `:172-180` (migration)
|
||||
- Modify: `src/storage/mod.rs:317-326` (upsert_session SQL + ON CONFLICT)
|
||||
- Modify: `src/storage/mod.rs:345-369` (get_session SELECT + struct)
|
||||
- Modify: `src/storage/mod.rs:380-406`, `:454-479`, `:564-588`, `:728`, `:754` (list_sessions 及测试 mock)
|
||||
|
||||
**Step 1: 在 `src/storage/session.rs` SessionMeta 加字段**
|
||||
|
||||
在 `last_consolidated_at: Option<i64>` 后加一行:
|
||||
```rust
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
```
|
||||
|
||||
**Step 2: DDL schema 加列**
|
||||
|
||||
在 `src/storage/mod.rs` 的 CREATE TABLE sessions 中 (line 44),`last_consolidated_at INTEGER` 后加逗号和:
|
||||
```sql
|
||||
last_compressed_message_at INTEGER
|
||||
```
|
||||
|
||||
**Step 3: migration 加列**
|
||||
|
||||
在 `src/storage/mod.rs` line 182 之后(现有 migration 的 `); .ok();` 之后),添加新 migration:
|
||||
```rust
|
||||
// Migration: add last_compressed_message_at column if not exists
|
||||
sqlx::query(
|
||||
r#"ALTER TABLE sessions ADD COLUMN last_compressed_message_at INTEGER"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
```
|
||||
|
||||
**Step 4: upsert_session SQL 加列**
|
||||
|
||||
`src/storage/mod.rs` line 317: INSERT 列列表加 `last_compressed_message_at`,VALUES 加 `?`,ON CONFLICT DO UPDATE SET 加 `last_compressed_message_at = excluded.last_compressed_message_at`。line 338 后加 `.bind(meta.last_compressed_message_at)`。
|
||||
|
||||
**Step 5: get_session SELECT 加列**
|
||||
|
||||
`src/storage/mod.rs` line 348: SELECT 列加 `last_compressed_message_at`。line 368 后加:
|
||||
```rust
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
```
|
||||
|
||||
**Step 6: 其他 SELECT sessions 的地方(list_sessions 多个变体)**
|
||||
|
||||
同样补 `last_compressed_message_at` 到 SELECT 列和 struct 构造。以及测试中的 mock SessionMeta 构造(line 728, 754 等)。
|
||||
|
||||
**Step 7: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 8: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/session.rs src/storage/mod.rs
|
||||
git commit -m "feat(storage): add last_compressed_message_at column to sessions"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Storage 新增加载方法
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/mod.rs` (在 load_messages 之后)
|
||||
- Modify: `src/storage/memory.rs` (在 cleanup_old_timelines 之后)
|
||||
|
||||
**Step 1: `get_max_message_seq`**
|
||||
|
||||
在 `src/storage/mod.rs` 中 `load_messages` 函数后面添加:
|
||||
```rust
|
||||
pub async fn get_max_message_seq(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||
let row = sqlx::query(
|
||||
"SELECT COALESCE(MAX(seq), 0) as max_seq FROM messages WHERE session_id = ?",
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_one(self.pool())
|
||||
.await?;
|
||||
Ok(row.get::<i64, _>("max_seq"))
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: `load_messages_after_timestamp`**
|
||||
|
||||
```rust
|
||||
pub async fn load_messages_after_timestamp(
|
||||
&self,
|
||||
session_id: &str,
|
||||
after_ts: i64,
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY seq ASC
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(after_ts)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|row| crate::storage::message::MessageMeta {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
tool_calls: row.get("tool_calls"),
|
||||
source: row.get("source"),
|
||||
created_at: row.get("created_at"),
|
||||
}).collect())
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: `load_session_timelines`**
|
||||
|
||||
在 `src/storage/memory.rs` 的 `cleanup_old_timelines` 之后(line 252 的 `}` 之前)添加:
|
||||
```rust
|
||||
pub async fn load_session_timelines(
|
||||
&self,
|
||||
session_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, key, content, category, importance,
|
||||
session_id, created_at, updated_at
|
||||
FROM memories
|
||||
WHERE session_id = ? AND category = 'timeline'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(limit as i64)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
parse_memory_rows(&rows)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/mod.rs src/storage/memory.rs
|
||||
git commit -m "feat(storage): add load_messages_after_timestamp, load_session_timelines, get_max_message_seq"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: ContextCompressor 引入 CompressionResult
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agent/context_compressor.rs:172-274` (compress_if_needed)
|
||||
- Modify: `src/agent/context_compressor.rs:320-402` (compress_once)
|
||||
|
||||
**Step 1: 定义 CompressionResult**
|
||||
|
||||
在 context_compressor.rs 中 `ContextCompressor` struct 定义之后添加:
|
||||
```rust
|
||||
pub struct CompressionResult {
|
||||
pub history: Vec<ChatMessage>,
|
||||
pub created_timelines: bool,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 修改 compress_if_needed 签名和返回**
|
||||
|
||||
将 `pub async fn compress_if_needed(&self, mut history: Vec<ChatMessage>) -> Result<Vec<ChatMessage>, AgentError>` 改为:
|
||||
```rust
|
||||
pub async fn compress_if_needed(
|
||||
&self,
|
||||
mut history: Vec<ChatMessage>,
|
||||
) -> Result<CompressionResult, AgentError> {
|
||||
```
|
||||
|
||||
内部的 `return Ok(history)` 改为 `return Ok(CompressionResult { history, created_timelines: false })`(Tier 1 fast trim 和不需要压缩时)。
|
||||
|
||||
**Step 3: 修改 LLM summarization pass 部分**
|
||||
|
||||
在压缩循环中维护一个 `created_timelines` 标志:
|
||||
```rust
|
||||
let mut created_timelines = false;
|
||||
for pass in 0..self.config.max_passes {
|
||||
// ...
|
||||
match self.compress_once(...).await {
|
||||
Ok(Some(compressed)) => {
|
||||
current_history = compressed;
|
||||
created_timelines = true;
|
||||
}
|
||||
// ...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
最后返回:
|
||||
```rust
|
||||
Ok(CompressionResult { history: current_history, created_timelines })
|
||||
```
|
||||
|
||||
**Step 4: 更新所有 compress_if_needed 调用方**
|
||||
|
||||
所有 `compress_if_needed(history)` 改为 `compress_if_needed(history).await?.history`。在 `handle_message` 和 `/compact` 中还需要用到 `created_timelines`。
|
||||
|
||||
**Step 5: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add src/agent/context_compressor.rs src/session/session.rs
|
||||
git commit -m "feat(compressor): return CompressionResult with created_timelines flag"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Session 结构体和持久化
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/session/session.rs:52-74` (Session struct)
|
||||
- Modify: `src/session/session.rs:76-121` (Session::new)
|
||||
- Modify: `src/session/session.rs:298-320` (persist_session_meta)
|
||||
|
||||
**Step 1: Session struct 加字段**
|
||||
|
||||
在 `pub last_consolidated_at: Option<i64>` 后加:
|
||||
```rust
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
```
|
||||
|
||||
**Step 2: Session::new 初始化**
|
||||
|
||||
在 `last_consolidated_at: None` 后加:
|
||||
```rust
|
||||
last_compressed_message_at: None,
|
||||
```
|
||||
|
||||
**Step 3: persist_session_meta 加字段**
|
||||
|
||||
在 `last_consolidated_at: self.last_consolidated_at` 后加:
|
||||
```rust
|
||||
last_compressed_message_at: self.last_compressed_message_at,
|
||||
```
|
||||
|
||||
**Step 4: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/session/session.rs
|
||||
git commit -m "feat(session): add last_compressed_message_at field to Session and persist"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Session::from_storage() 增量恢复
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/session/session.rs:124-189` (from_storage)
|
||||
|
||||
**Step 1: 重写 from_storage**
|
||||
|
||||
替换现有实现为:
|
||||
|
||||
```rust
|
||||
pub async fn from_storage(
|
||||
id: UnifiedSessionId,
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
storage: StdArc<Storage>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let session_meta = storage.get_session(&id.to_string()).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||
|
||||
let mut provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
provider_box.set_storage(storage.clone());
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||
|
||||
let compressor_config = ContextCompressionConfig {
|
||||
protect_first_n: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
|
||||
// Determine recovery strategy
|
||||
let mut chat_messages: Vec<ChatMessage> = Vec::new();
|
||||
|
||||
if let Some(after_ts) = session_meta.last_compressed_message_at {
|
||||
// Load last 4 timelines to determine if there are > 3
|
||||
let timelines = storage
|
||||
.load_session_timelines(&id.to_string(), 4)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let has_more_timelines = timelines.len() > 3;
|
||||
|
||||
// Insert hint if more timelines exist
|
||||
if has_more_timelines {
|
||||
chat_messages.push(ChatMessage::user(
|
||||
"[Earlier conversation summaries exist. \
|
||||
Use `timeline_recall` to search if needed.]"
|
||||
));
|
||||
}
|
||||
|
||||
// Insert latest 3 timelines as context (reversed: oldest first)
|
||||
for tl in timelines.iter().take(3).rev() {
|
||||
chat_messages.push(ChatMessage::user(format!(
|
||||
"[Previous Context]\n{}", tl.content
|
||||
)));
|
||||
}
|
||||
|
||||
// Load raw messages after compressed timestamp
|
||||
let tail = storage
|
||||
.load_messages_after_timestamp(&id.to_string(), after_ts)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut tail_msgs: Vec<ChatMessage> = tail.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut tail_msgs);
|
||||
chat_messages.extend(tail_msgs);
|
||||
} else {
|
||||
// No prior compression — load all messages (existing behavior)
|
||||
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||
|
||||
chat_messages = messages.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut chat_messages);
|
||||
}
|
||||
|
||||
// seq_counter from actual DB max
|
||||
let max_seq = storage
|
||||
.get_max_message_seq(&id.to_string())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let seq_counter = max_seq + 1;
|
||||
let total_message_count = session_meta.message_count;
|
||||
|
||||
Ok(Self {
|
||||
id: id.clone(),
|
||||
title: session_meta.title,
|
||||
created_at: session_meta.created_at,
|
||||
last_active_at: session_meta.last_active_at,
|
||||
message_count: session_meta.message_count,
|
||||
total_message_count,
|
||||
messages: chat_messages,
|
||||
seq_counter,
|
||||
provider_config: provider_config.clone(),
|
||||
provider: provider.clone(),
|
||||
tools,
|
||||
compressor,
|
||||
storage: Some(storage),
|
||||
routing_info: session_meta.routing_info.unwrap_or_default(),
|
||||
last_consolidated_at: session_meta.last_consolidated_at,
|
||||
last_compressed_message_at: session_meta.last_compressed_message_at,
|
||||
memory_manager,
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/session/session.rs
|
||||
git commit -m "feat(session): incremental recovery from storage using compressed timeline"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: 系统提示词加历史会话提示
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agent/system_prompt.rs:289-304` (MemorySection)
|
||||
- Modify: `src/agent/system_prompt.rs:16-23` (PromptContext)
|
||||
- Modify: `src/agent/system_prompt.rs:343-358` (build_system_prompt free function)
|
||||
- Modify: `src/session/session.rs:411-426` (build_system_prompt)
|
||||
|
||||
**Step 1: PromptContext 加 has_compressed_history 字段**
|
||||
|
||||
```rust
|
||||
pub struct PromptContext<'a> {
|
||||
pub workspace_dir: &'a Path,
|
||||
pub model_name: &'a str,
|
||||
pub tools: &'a ToolRegistry,
|
||||
pub session_id: Option<&'a str>,
|
||||
pub memory_context: Option<&'a str>,
|
||||
pub has_compressed_history: bool,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 加 HistorySection**
|
||||
|
||||
在 MemorySection 后面添加:
|
||||
```rust
|
||||
pub struct HistorySection;
|
||||
|
||||
impl PromptSection for HistorySection {
|
||||
fn name(&self) -> &str {
|
||||
"history"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||
if ctx.has_compressed_history {
|
||||
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。".to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: 注册到 SystemPromptBuilder::with_defaults**
|
||||
|
||||
在 `with_defaults()` 的 sections vec 中 `Box::new(MemorySection)` 后加 `Box::new(HistorySection)`。
|
||||
|
||||
**Step 4: 更新 build_system_prompt 签名和调用**
|
||||
|
||||
```rust
|
||||
pub fn build_system_prompt(
|
||||
workspace_dir: &Path,
|
||||
model_name: &str,
|
||||
tools: &ToolRegistry,
|
||||
session_id: Option<&str>,
|
||||
memory_context: Option<&str>,
|
||||
has_compressed_history: bool,
|
||||
) -> String {
|
||||
let ctx = PromptContext {
|
||||
workspace_dir,
|
||||
model_name,
|
||||
tools,
|
||||
session_id,
|
||||
memory_context,
|
||||
has_compressed_history,
|
||||
};
|
||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: 更新 Session::build_system_prompt**
|
||||
|
||||
```rust
|
||||
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
|
||||
let base_prompt = build_system_prompt(
|
||||
&self.provider_config.workspace_dir,
|
||||
&self.provider_config.model_id,
|
||||
&self.tools,
|
||||
Some(&self.id.to_string()),
|
||||
memory_context,
|
||||
self.last_compressed_message_at.is_some(),
|
||||
);
|
||||
|
||||
if skills_prompt.trim().is_empty() {
|
||||
base_prompt
|
||||
} else {
|
||||
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 6: 更新所有其他 build_system_prompt 调用方**
|
||||
|
||||
搜索 `build_system_prompt(` 的所有调用位置,每个都要加 `false` 参数。主要有 `agent/agent_loop.rs` 中的两个调用。
|
||||
|
||||
**Step 7: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 8: Commit**
|
||||
|
||||
```bash
|
||||
git add src/agent/system_prompt.rs src/session/session.rs src/agent/agent_loop.rs
|
||||
git commit -m "feat(system-prompt): add history section for archived conversation context"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: handle_message 和 /compact 记录压缩标记
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/session/session.rs:1339-1355` (handle_message 压缩后)
|
||||
- Modify: `src/session/session.rs:1372-1376` (handle_message 溢出重试)
|
||||
- Modify: `src/session/session.rs:851-872` (/compact 命令)
|
||||
|
||||
**Step 1: handle_message 正常流**
|
||||
|
||||
在 `compress_if_needed(history).await?` 之后(line 1346),改为:
|
||||
```rust
|
||||
let result = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
if let Err(e) = session_guard.persist_session_meta().await {
|
||||
tracing::warn!(error = %e, "Failed to persist compressed message marker");
|
||||
}
|
||||
}
|
||||
let mut history = result.history;
|
||||
```
|
||||
|
||||
同时删除后面(line 1350-1355)单独的 `persist_session_meta` 调用(现在已合入上面的逻辑)。
|
||||
|
||||
**Step 2: handle_message 溢出重试流**
|
||||
|
||||
```rust
|
||||
let raw = session_guard.get_history().to_vec();
|
||||
let result = session_guard.compressor.compress_if_needed(raw).await?;
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
let _ = session_guard.persist_session_meta().await;
|
||||
}
|
||||
let mut retry = result.history;
|
||||
retry.insert(0, ChatMessage::system(system_prompt));
|
||||
agent.process(retry).await?
|
||||
```
|
||||
|
||||
**Step 3: /compact 命令**
|
||||
|
||||
```rust
|
||||
let result = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
let compressed_count = result.history.len();
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
let _ = session_guard.persist_session_meta().await;
|
||||
}
|
||||
session_guard.clear_history();
|
||||
for msg in result.history {
|
||||
session_guard.add_message(msg, false).await
|
||||
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||
}
|
||||
```
|
||||
|
||||
同时确认 `compress_if_needed` 的 import 正常(已在 scope 中)。
|
||||
|
||||
**Step 4: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/session/session.rs
|
||||
git commit -m "feat(session): record last_compressed_message_at after compression"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 8: 全局编译和测试
|
||||
|
||||
**Step 1: 全局编译**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
修复所有编译错误,确保全部文件一致。
|
||||
|
||||
**Step 2: 运行单元测试**
|
||||
|
||||
```bash
|
||||
cargo test --lib 2>&1
|
||||
```
|
||||
|
||||
**Step 3: 测试通过后 commit**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "chore: fix remaining compilation and test issues for incremental recovery"
|
||||
```
|
||||
|
||||
**Step 4: 运行 lint**
|
||||
|
||||
```bash
|
||||
cargo clippy --lib 2>&1 | head -50
|
||||
```
|
||||
|
||||
修复任何 warning。
|
||||
|
||||
---
|
||||
|
||||
### Task 9: 验证 & 提交设计文档
|
||||
|
||||
**Step 1: 最终验证**
|
||||
|
||||
```bash
|
||||
cargo test --lib 2>&1
|
||||
```
|
||||
|
||||
**Step 2: Commit 设计文档**
|
||||
|
||||
```bash
|
||||
git add docs/plans/2026-05-10-incremental-session-recovery-design.md
|
||||
git commit -m "docs: add incremental session recovery design doc"
|
||||
```
|
||||
@ -383,7 +383,7 @@ impl AgentLoop {
|
||||
// Build and inject system prompt if not present
|
||||
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
||||
if !has_system {
|
||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None);
|
||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||
messages.insert(0, ChatMessage::system(system_prompt));
|
||||
|
||||
@ -59,7 +59,7 @@ impl Default for ContextCompressionConfig {
|
||||
pub struct ContextCompressor {
|
||||
config: ContextCompressionConfig,
|
||||
context_window: usize,
|
||||
/// Threshold ratio to trigger compression (50% of context window)
|
||||
/// Threshold ratio to trigger compression (70% of context window)
|
||||
threshold_ratio: f64,
|
||||
/// Shared LLM provider for summarization
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
@ -70,6 +70,12 @@ pub struct ContextCompressor {
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of context compression.
|
||||
pub struct CompressionResult {
|
||||
pub history: Vec<ChatMessage>,
|
||||
pub created_timelines: bool,
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
/// Create a new compressor with the given provider, context window size, and memory manager.
|
||||
pub fn new(
|
||||
@ -80,7 +86,7 @@ impl ContextCompressor {
|
||||
Self {
|
||||
config: ContextCompressionConfig::default(),
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
threshold_ratio: 0.7,
|
||||
provider,
|
||||
memory,
|
||||
session_id: None,
|
||||
@ -97,7 +103,7 @@ impl ContextCompressor {
|
||||
Self {
|
||||
config,
|
||||
context_window,
|
||||
threshold_ratio: 0.5,
|
||||
threshold_ratio: 0.7,
|
||||
provider,
|
||||
memory,
|
||||
session_id: None,
|
||||
@ -173,11 +179,11 @@ impl ContextCompressor {
|
||||
pub async fn compress_if_needed(
|
||||
&self,
|
||||
mut history: Vec<ChatMessage>,
|
||||
) -> Result<Vec<ChatMessage>, AgentError> {
|
||||
) -> Result<CompressionResult, AgentError> {
|
||||
// Check if compression is needed
|
||||
let tokens = estimate_tokens(&history);
|
||||
if tokens <= self.threshold() {
|
||||
return Ok(history);
|
||||
return Ok(CompressionResult { history, created_timelines: false });
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@ -200,11 +206,12 @@ impl ContextCompressor {
|
||||
);
|
||||
}
|
||||
if tokens_after <= self.threshold() {
|
||||
return Ok(history);
|
||||
return Ok(CompressionResult { history, created_timelines: false });
|
||||
}
|
||||
|
||||
// LLM summarization pass
|
||||
let mut current_history = history;
|
||||
let mut created_timelines = false;
|
||||
for pass in 0..self.config.max_passes {
|
||||
let tokens = estimate_tokens(¤t_history);
|
||||
if tokens <= self.threshold() {
|
||||
@ -221,6 +228,7 @@ impl ContextCompressor {
|
||||
match self.compress_once(¤t_history).await {
|
||||
Ok(Some(compressed)) => {
|
||||
current_history = compressed;
|
||||
created_timelines = true;
|
||||
}
|
||||
Ok(None) => {
|
||||
// No more compressible content
|
||||
@ -270,7 +278,7 @@ impl ContextCompressor {
|
||||
"Context compression completed"
|
||||
);
|
||||
|
||||
Ok(current_history)
|
||||
Ok(CompressionResult { history: current_history, created_timelines })
|
||||
}
|
||||
|
||||
/// Try to extract the actual context token limit from an LLM error message.
|
||||
@ -623,7 +631,7 @@ mod tests {
|
||||
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap();
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
|
||||
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
||||
assert!(
|
||||
@ -677,7 +685,7 @@ mod tests {
|
||||
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap();
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
|
||||
// B2A: "Q1" must appear exactly once
|
||||
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
|
||||
@ -721,7 +729,7 @@ mod tests {
|
||||
ChatMessage::tool("t3", "bash", &big),
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap();
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
|
||||
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
||||
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
|
||||
|
||||
@ -21,6 +21,8 @@ pub struct PromptContext<'a> {
|
||||
pub session_id: Option<&'a str>,
|
||||
/// Pre-fetched memory context string to inject.
|
||||
pub memory_context: Option<&'a str>,
|
||||
/// Whether this session has compressed history available via timeline_recall.
|
||||
pub has_compressed_history: bool,
|
||||
}
|
||||
|
||||
/// Trait for system prompt sections.
|
||||
@ -46,6 +48,7 @@ impl SystemPromptBuilder {
|
||||
Box::new(WorkspaceSection),
|
||||
Box::new(UserProfileSection),
|
||||
Box::new(MemorySection),
|
||||
Box::new(HistorySection),
|
||||
Box::new(DateTimeSection),
|
||||
Box::new(RuntimeSection),
|
||||
Box::new(CrossChannelSection),
|
||||
@ -256,9 +259,16 @@ impl PromptSection for CrossChannelSection {
|
||||
|
||||
### chat_manager 工具
|
||||
管理会话和查看消息。参数:
|
||||
- action = "list_sessions" — 列出最近活跃的会话
|
||||
- action = "list_sessions" — 列出全部会话,支持通过 offset/count 翻页
|
||||
- action = "list_channels" — 列出所有可用渠道
|
||||
- action = "list_messages" — 查看指定 session 的最近消息,需提供 session_id 和 count"#,
|
||||
- action = "list_messages" — 查看指定 session 的历史消息,支持以下参数:
|
||||
- session_id (必填): 会话 ID
|
||||
- count (可选): 返回数量,默认 20,最大 100
|
||||
- offset (可选): 跳过前 N 条,用于翻页查看更早历史,默认 0
|
||||
- before_time (可选): Unix 时间戳(秒),只返回该时间之前的消息
|
||||
- after_time (可选): Unix 时间戳(秒),只返回该时间之后的消息
|
||||
|
||||
当用户要求回顾历史、查找之前的消息、或你记不清之前的对话内容时,可以使用此工具的 list_messages 动作,通过调整 offset 或指定时间范围来查询具体的历史消息。"#,
|
||||
session_line
|
||||
)
|
||||
}
|
||||
@ -303,6 +313,23 @@ impl PromptSection for MemorySection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Prompt agent to use timeline_recall if compressed history exists.
|
||||
pub struct HistorySection;
|
||||
|
||||
impl PromptSection for HistorySection {
|
||||
fn name(&self) -> &str {
|
||||
"history"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||
if ctx.has_compressed_history {
|
||||
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。".to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Helper Functions ===
|
||||
|
||||
/// Get user config directory (~/.picobot/).
|
||||
@ -346,6 +373,7 @@ pub fn build_system_prompt(
|
||||
tools: &ToolRegistry,
|
||||
session_id: Option<&str>,
|
||||
memory_context: Option<&str>,
|
||||
has_compressed_history: bool,
|
||||
) -> String {
|
||||
let ctx = PromptContext {
|
||||
workspace_dir,
|
||||
@ -353,6 +381,7 @@ pub fn build_system_prompt(
|
||||
tools,
|
||||
session_id,
|
||||
memory_context,
|
||||
has_compressed_history,
|
||||
};
|
||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||
}
|
||||
@ -373,6 +402,7 @@ mod tests {
|
||||
tools: &tools,
|
||||
session_id: None,
|
||||
memory_context: None,
|
||||
has_compressed_history: false,
|
||||
};
|
||||
|
||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||
@ -402,7 +432,7 @@ mod tests {
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let tools = ToolRegistry::new();
|
||||
|
||||
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None, None);
|
||||
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None, None, false);
|
||||
|
||||
assert!(!prompt.is_empty());
|
||||
assert!(prompt.contains("test-model"));
|
||||
@ -419,6 +449,7 @@ mod tests {
|
||||
tools: &tools,
|
||||
session_id: None,
|
||||
memory_context: Some("- user_pref: Prefers Rust"),
|
||||
has_compressed_history: false,
|
||||
};
|
||||
|
||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||
@ -437,6 +468,7 @@ mod tests {
|
||||
tools: &tools,
|
||||
session_id: None,
|
||||
memory_context: None,
|
||||
has_compressed_history: false,
|
||||
};
|
||||
|
||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||
|
||||
@ -41,14 +41,11 @@ impl GatewayState {
|
||||
// Override workspace_dir with the ensured path
|
||||
provider_config.workspace_dir = workspace_path.clone();
|
||||
|
||||
// Session TTL from config (default 4 hours)
|
||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||
|
||||
// Initialize Storage
|
||||
let db_path = if let Some(ref path) = config.gateway.session_db_path {
|
||||
std::path::PathBuf::from(path)
|
||||
} else {
|
||||
workspace_path.join(".picobot_sessions.db")
|
||||
workspace_path.join("picobot.db")
|
||||
};
|
||||
let storage = Arc::new(
|
||||
crate::storage::Storage::new(&db_path).await
|
||||
@ -79,7 +76,6 @@ impl GatewayState {
|
||||
|
||||
// Create SessionManager with bus injection
|
||||
let session_manager = SessionManager::new(
|
||||
session_ttl_hours,
|
||||
provider_config.clone(),
|
||||
storage.clone(),
|
||||
bus.clone(),
|
||||
@ -87,11 +83,6 @@ impl GatewayState {
|
||||
)?;
|
||||
let session_manager = Arc::new(session_manager);
|
||||
|
||||
// Start background cleanup task (default 60 minutes)
|
||||
let cleanup_interval = config.gateway.cleanup_interval_minutes.unwrap_or(60);
|
||||
session_manager.clone().start_cleanup_task(cleanup_interval);
|
||||
tracing::info!("Session cleanup task started (interval: {} min)", cleanup_interval);
|
||||
|
||||
// Create ChannelManager and init channels
|
||||
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
||||
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
@ -70,6 +69,7 @@ pub struct Session {
|
||||
/// Timestamp (Unix ms) of the last consolidation.
|
||||
/// Messages before this time have been compressed into memory.
|
||||
pub last_consolidated_at: Option<i64>,
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
}
|
||||
|
||||
@ -116,6 +116,7 @@ impl Session {
|
||||
storage,
|
||||
routing_info,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
memory_manager,
|
||||
})
|
||||
}
|
||||
@ -131,9 +132,6 @@ impl Session {
|
||||
let session_meta = storage.get_session(&id.to_string()).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||
|
||||
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||
|
||||
let mut provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
provider_box.set_storage(storage.clone());
|
||||
@ -147,27 +145,92 @@ impl Session {
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
|
||||
// Convert MessageMeta to ChatMessage, then repair damaged tool call chains
|
||||
let mut chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
let mut chat_messages: Vec<ChatMessage> = Vec::new();
|
||||
|
||||
if let Some(after_ts) = session_meta.last_compressed_message_at {
|
||||
// Load last 4 timelines to detect if there are more than 3
|
||||
let timelines = storage
|
||||
.load_session_timelines(&id.to_string(), 4)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "Failed to load session timelines");
|
||||
Vec::new()
|
||||
});
|
||||
|
||||
let has_more_timelines = timelines.len() > 3;
|
||||
|
||||
if has_more_timelines {
|
||||
chat_messages.push(ChatMessage::user(
|
||||
"[Earlier conversation summaries exist. \
|
||||
Use `timeline_recall` to search if needed.]"
|
||||
));
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut chat_messages);
|
||||
// Insert latest 3 timelines as context (reversed: oldest first)
|
||||
for tl in timelines.iter().take(3).rev() {
|
||||
chat_messages.push(ChatMessage::user(format!(
|
||||
"[Previous Context]\n{}", tl.content
|
||||
)));
|
||||
}
|
||||
|
||||
let seq_counter = chat_messages.len() as i64 + 1;
|
||||
let total_message_count = chat_messages.len() as i64;
|
||||
// Load raw messages after compressed timestamp
|
||||
let tail = storage
|
||||
.load_messages_after_timestamp(&id.to_string(), after_ts)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "Failed to load messages after timestamp");
|
||||
Vec::new()
|
||||
});
|
||||
|
||||
let mut tail_msgs: Vec<ChatMessage> = tail.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut tail_msgs);
|
||||
chat_messages.extend(tail_msgs);
|
||||
} else {
|
||||
// No prior compression — load all messages (existing behavior)
|
||||
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||
|
||||
chat_messages = messages.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut chat_messages);
|
||||
}
|
||||
|
||||
// seq_counter from actual DB max
|
||||
let max_seq = storage
|
||||
.get_max_message_seq(&id.to_string())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let seq_counter = max_seq + 1;
|
||||
let total_message_count = session_meta.message_count;
|
||||
|
||||
Ok(Self {
|
||||
id: id.clone(),
|
||||
@ -185,6 +248,7 @@ impl Session {
|
||||
storage: Some(storage),
|
||||
routing_info: session_meta.routing_info.unwrap_or_default(),
|
||||
last_consolidated_at: session_meta.last_consolidated_at,
|
||||
last_compressed_message_at: session_meta.last_compressed_message_at,
|
||||
memory_manager,
|
||||
})
|
||||
}
|
||||
@ -313,6 +377,7 @@ impl Session {
|
||||
},
|
||||
deleted_at: None,
|
||||
last_consolidated_at: self.last_consolidated_at,
|
||||
last_compressed_message_at: self.last_compressed_message_at,
|
||||
};
|
||||
storage.upsert_session(&meta).await?;
|
||||
}
|
||||
@ -416,6 +481,7 @@ impl Session {
|
||||
&self.tools,
|
||||
Some(&self.id.to_string()),
|
||||
memory_context,
|
||||
self.last_compressed_message_at.is_some(),
|
||||
);
|
||||
|
||||
if skills_prompt.trim().is_empty() {
|
||||
@ -663,8 +729,6 @@ pub struct SessionManager {
|
||||
struct SessionManagerInner {
|
||||
/// Sessions keyed by UnifiedSessionId.to_string()
|
||||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||||
session_timestamps: HashMap<String, Instant>,
|
||||
session_ttl: Duration,
|
||||
/// Current active session per channel:chat_id
|
||||
current_sessions: HashMap<String, String>,
|
||||
}
|
||||
@ -741,7 +805,6 @@ pub static SLASH_COMMANDS: &[SlashCommand] = &[
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(
|
||||
session_ttl_hours: u64,
|
||||
provider_config: LLMProviderConfig,
|
||||
storage: Arc<Storage>,
|
||||
bus: Arc<MessageBus>,
|
||||
@ -756,8 +819,6 @@ impl SessionManager {
|
||||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||||
sessions: HashMap::new(),
|
||||
session_timestamps: HashMap::new(),
|
||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||||
current_sessions: HashMap::new(),
|
||||
})),
|
||||
provider_config,
|
||||
@ -780,42 +841,6 @@ impl SessionManager {
|
||||
self.tools.clone()
|
||||
}
|
||||
|
||||
/// 启动后台 TTL 清理任务
|
||||
pub fn start_cleanup_task(self: Arc<Self>, interval_mins: u64) {
|
||||
let cleanup_interval = Duration::from_secs(interval_mins * 60);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(cleanup_interval).await;
|
||||
self.run_cleanup().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// 执行一次 TTL 清理:释放内存中过期的 session,Storage 记录保留
|
||||
async fn run_cleanup(&self) {
|
||||
let inner = self.inner.lock().await;
|
||||
let now = Instant::now();
|
||||
let ttl = inner.session_ttl;
|
||||
|
||||
let expired: Vec<String> = inner
|
||||
.session_timestamps
|
||||
.iter()
|
||||
.filter(|(_, last_touch)| now.duration_since(**last_touch) > ttl)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
|
||||
drop(inner);
|
||||
|
||||
if !expired.is_empty() {
|
||||
let mut inner = self.inner.lock().await;
|
||||
for id in &expired {
|
||||
inner.sessions.remove(id);
|
||||
inner.session_timestamps.remove(id);
|
||||
}
|
||||
tracing::debug!(count = expired.len(), "Cleaned up expired sessions");
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取所有可用的斜杠命令
|
||||
pub fn get_slash_commands(&self) -> &[SlashCommand] {
|
||||
SLASH_COMMANDS
|
||||
@ -854,12 +879,18 @@ impl SessionManager {
|
||||
let mut session_guard = session.lock().await;
|
||||
let original_count = session_guard.get_history().len();
|
||||
let history = session_guard.get_history().to_vec();
|
||||
let compressed = session_guard.compressor
|
||||
let result = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
let compressed_count = compressed.len();
|
||||
let compressed_count = result.history.len();
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
if let Err(e) = session_guard.persist_session_meta().await {
|
||||
tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
|
||||
}
|
||||
}
|
||||
session_guard.clear_history();
|
||||
for msg in compressed {
|
||||
for msg in result.history {
|
||||
session_guard.add_message(msg, false).await
|
||||
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||
}
|
||||
@ -980,6 +1011,7 @@ impl SessionManager {
|
||||
routing_info: if routing_info.is_empty() { None } else { Some(routing_info.clone()) },
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
self.storage.upsert_session(&meta).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
|
||||
@ -997,7 +1029,6 @@ impl SessionManager {
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
let inner = &mut *self.inner.lock().await;
|
||||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||||
inner.session_timestamps.insert(session_id_str.clone(), Instant::now());
|
||||
// Set as current session for this channel:chat_id
|
||||
let chat_scope = format!("{}:{}", channel, chat_id);
|
||||
inner.current_sessions.insert(chat_scope, session_id_str);
|
||||
@ -1010,7 +1041,6 @@ impl SessionManager {
|
||||
let inner = &mut *self.inner.lock().await;
|
||||
|
||||
if let Some(session) = inner.sessions.get(&session_id_str) {
|
||||
inner.session_timestamps.insert(session_id_str, Instant::now());
|
||||
return Ok(session.clone());
|
||||
}
|
||||
|
||||
@ -1028,7 +1058,6 @@ impl SessionManager {
|
||||
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||||
inner.session_timestamps.insert(session_id_str.clone(), Instant::now());
|
||||
// Set as current session
|
||||
let chat_scope = format!("{}:{}", unified_id.channel, unified_id.chat_id);
|
||||
inner.current_sessions.insert(chat_scope, session_id_str);
|
||||
@ -1052,7 +1081,6 @@ impl SessionManager {
|
||||
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||||
inner.session_timestamps.insert(session_id_str.clone(), Instant::now());
|
||||
// Set as current session
|
||||
let chat_scope = format!("{}:{}", unified_id.channel, unified_id.chat_id);
|
||||
inner.current_sessions.insert(chat_scope, session_id_str);
|
||||
@ -1135,7 +1163,6 @@ impl SessionManager {
|
||||
// Remove from memory and current sessions
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner.sessions.remove(&session_id_str);
|
||||
inner.session_timestamps.remove(&session_id_str);
|
||||
let chat_scope = format!("{}:{}", session_id.channel, session_id.chat_id);
|
||||
inner.current_sessions.remove(&chat_scope);
|
||||
|
||||
@ -1188,8 +1215,7 @@ impl SessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
|
||||
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
|
||||
match self.storage.find_most_recent_session(channel, chat_id).await {
|
||||
Ok(Some(meta)) => Ok(UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)),
|
||||
_ => {
|
||||
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
|
||||
@ -1341,13 +1367,17 @@ impl SessionManager {
|
||||
// in context compression (system prompt is dynamic and should not be persisted).
|
||||
let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref());
|
||||
|
||||
let mut history = session_guard.compressor
|
||||
let result = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
}
|
||||
let mut history = result.history;
|
||||
|
||||
history.insert(0, ChatMessage::system(system_prompt.clone()));
|
||||
|
||||
// Advance consolidation pointer — future compressions skip already-processed messages
|
||||
// Persist consolidation state
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
session_guard.last_consolidated_at = Some(now);
|
||||
if let Err(e) = session_guard.persist_session_meta().await {
|
||||
@ -1371,7 +1401,14 @@ impl SessionManager {
|
||||
);
|
||||
session_guard.compressor.set_context_window(new_window);
|
||||
let raw = session_guard.get_history().to_vec();
|
||||
let mut retry = session_guard.compressor.compress_if_needed(raw).await?;
|
||||
let retry_result = session_guard.compressor.compress_if_needed(raw).await?;
|
||||
if retry_result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
if let Err(e) = session_guard.persist_session_meta().await {
|
||||
tracing::warn!(error = %e, "Failed to persist compression marker on retry");
|
||||
}
|
||||
}
|
||||
let mut retry = retry_result.history;
|
||||
retry.insert(0, ChatMessage::system(system_prompt));
|
||||
agent.process(retry).await?
|
||||
}
|
||||
@ -1488,7 +1525,8 @@ impl SessionManager {
|
||||
// in context compression (system prompt is dynamic and should not be persisted).
|
||||
let mut history = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
.await?
|
||||
.history;
|
||||
|
||||
history.insert(0, ChatMessage::system(full_system_prompt));
|
||||
|
||||
|
||||
@ -250,6 +250,30 @@ impl super::Storage {
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Load timeline entries for a specific session.
|
||||
pub async fn load_session_timelines(
|
||||
&self,
|
||||
session_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, key, content, category, importance,
|
||||
session_id, created_at, updated_at
|
||||
FROM memories
|
||||
WHERE session_id = ? AND category = 'timeline'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(limit as i64)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
parse_memory_rows(&rows)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_memory_rows(
|
||||
|
||||
@ -42,6 +42,7 @@ impl Storage {
|
||||
routing_info TEXT,
|
||||
deleted_at INTEGER,
|
||||
last_consolidated_at INTEGER,
|
||||
last_compressed_message_at INTEGER,
|
||||
UNIQUE(channel, chat_id, dialog_id)
|
||||
)
|
||||
"#,
|
||||
@ -179,6 +180,16 @@ impl Storage {
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Migration: add last_compressed_message_at column if not exists
|
||||
sqlx::query(
|
||||
r#"
|
||||
ALTER TABLE sessions ADD COLUMN last_compressed_message_at INTEGER
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS llm_calls (
|
||||
@ -314,15 +325,16 @@ impl Storage {
|
||||
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
title = excluded.title,
|
||||
last_active_at = excluded.last_active_at,
|
||||
message_count = excluded.message_count,
|
||||
routing_info = excluded.routing_info,
|
||||
deleted_at = excluded.deleted_at,
|
||||
last_consolidated_at = excluded.last_consolidated_at
|
||||
last_consolidated_at = excluded.last_consolidated_at,
|
||||
last_compressed_message_at = excluded.last_compressed_message_at
|
||||
"#,
|
||||
)
|
||||
.bind(&meta.id)
|
||||
@ -336,6 +348,7 @@ impl Storage {
|
||||
.bind(&meta.routing_info)
|
||||
.bind(meta.deleted_at)
|
||||
.bind(meta.last_consolidated_at)
|
||||
.bind(meta.last_compressed_message_at)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
@ -345,7 +358,7 @@ impl Storage {
|
||||
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||
"#,
|
||||
)
|
||||
@ -366,6 +379,7 @@ impl Storage {
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -377,7 +391,7 @@ impl Storage {
|
||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
@ -404,6 +418,7 @@ impl Storage {
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
@ -442,25 +457,22 @@ impl Storage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn find_active_session(
|
||||
pub async fn find_most_recent_session(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
ttl_millis: i64,
|
||||
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
|
||||
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(channel)
|
||||
.bind(chat_id)
|
||||
.bind(cutoff)
|
||||
.fetch_optional(self.pool())
|
||||
.await?;
|
||||
|
||||
@ -477,6 +489,7 @@ impl Storage {
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
})),
|
||||
None => Ok(None),
|
||||
}
|
||||
@ -555,24 +568,79 @@ impl Storage {
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn list_all_active_sessions(
|
||||
pub async fn get_max_message_seq(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||
let row = sqlx::query(
|
||||
"SELECT COALESCE(MAX(seq), 0) as max_seq FROM messages WHERE session_id = ?",
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_one(self.pool())
|
||||
.await?;
|
||||
Ok(row.get::<i64, _>("max_seq"))
|
||||
}
|
||||
|
||||
pub async fn load_messages_after_timestamp(
|
||||
&self,
|
||||
limit: i64,
|
||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||
session_id: &str,
|
||||
after_ts: i64,
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at
|
||||
FROM sessions
|
||||
WHERE deleted_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT ?
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY seq ASC
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(session_id)
|
||||
.bind(after_ts)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| crate::storage::message::MessageMeta {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
tool_calls: row.get("tool_calls"),
|
||||
source: row.get("source"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn query_sessions_range(
|
||||
&self,
|
||||
offset: i64,
|
||||
limit: i64,
|
||||
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
|
||||
let count_row = sqlx::query(
|
||||
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL",
|
||||
)
|
||||
.fetch_one(self.pool())
|
||||
.await?;
|
||||
let total: i64 = count_row.get("total");
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions
|
||||
WHERE deleted_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
let sessions: Vec<_> = rows
|
||||
.into_iter()
|
||||
.map(|row| crate::storage::session::SessionMeta {
|
||||
id: row.get("id"),
|
||||
@ -586,8 +654,11 @@ impl Storage {
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
})
|
||||
.collect())
|
||||
.collect();
|
||||
|
||||
Ok((sessions, total))
|
||||
}
|
||||
|
||||
pub async fn list_recent_messages(
|
||||
@ -629,6 +700,77 @@ impl Storage {
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub async fn query_messages_range(
|
||||
&self,
|
||||
session_id: &str,
|
||||
before_time: Option<i64>,
|
||||
after_time: Option<i64>,
|
||||
offset: i64,
|
||||
limit: i64,
|
||||
) -> Result<(Vec<crate::storage::message::MessageMeta>, i64), StorageError> {
|
||||
let mut where_extra = String::new();
|
||||
if before_time.is_some() {
|
||||
where_extra.push_str(" AND created_at < ?");
|
||||
}
|
||||
if after_time.is_some() {
|
||||
where_extra.push_str(" AND created_at > ?");
|
||||
}
|
||||
|
||||
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
|
||||
let select_sql = format!(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?{}
|
||||
ORDER BY seq ASC
|
||||
LIMIT ? OFFSET ?
|
||||
"#,
|
||||
where_extra
|
||||
);
|
||||
|
||||
let mut count_query = sqlx::query(&count_sql).bind(session_id);
|
||||
if let Some(bt) = before_time {
|
||||
count_query = count_query.bind(bt);
|
||||
}
|
||||
if let Some(at) = after_time {
|
||||
count_query = count_query.bind(at);
|
||||
}
|
||||
let count_row = count_query.fetch_one(self.pool()).await?;
|
||||
let total: i64 = count_row.get("total");
|
||||
|
||||
let mut select_query = sqlx::query(&select_sql).bind(session_id);
|
||||
if let Some(bt) = before_time {
|
||||
select_query = select_query.bind(bt);
|
||||
}
|
||||
if let Some(at) = after_time {
|
||||
select_query = select_query.bind(at);
|
||||
}
|
||||
let rows = select_query
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
let messages: Vec<_> = rows
|
||||
.into_iter()
|
||||
.map(|row| crate::storage::message::MessageMeta {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
tool_calls: row.get("tool_calls"),
|
||||
source: row.get("source"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok((messages, total))
|
||||
}
|
||||
|
||||
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
||||
.bind(session_id)
|
||||
@ -691,6 +833,7 @@ mod tests {
|
||||
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
@ -726,6 +869,7 @@ mod tests {
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
}
|
||||
@ -752,6 +896,7 @@ mod tests {
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
@ -778,6 +923,7 @@ mod tests {
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&session_meta).await.unwrap();
|
||||
|
||||
@ -819,6 +965,7 @@ mod tests {
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
|
||||
|
||||
@ -13,4 +13,5 @@ pub struct SessionMeta {
|
||||
pub routing_info: Option<String>,
|
||||
pub deleted_at: Option<i64>,
|
||||
pub last_consolidated_at: Option<i64>,
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
}
|
||||
|
||||
@ -27,8 +27,8 @@ impl Tool for ChatManagerTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"聊天管理工具。可以列出当前活跃的 session、可用的 channel、以及查看指定 session 的最近消息内容。\
|
||||
action 可选值: list_sessions (列出最近活跃会话), list_channels (列出可用渠道), list_messages (查看最近消息)"
|
||||
"聊天管理工具。可以列出全部 session、可用的 channel,以及查看指定 session 的消息内容,支持时间范围筛选和分页翻页。\
|
||||
action 可选值: list_sessions (列出全部会话), list_channels (列出可用渠道), list_messages (查看消息)"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -38,7 +38,7 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list_sessions", "list_channels", "list_messages"],
|
||||
"description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的最近消息"
|
||||
"description": "操作类型: list_sessions 列出全部会话, list_channels 列出可用渠道, list_messages 查看指定会话的消息"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
@ -46,7 +46,19 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "获取最近消息的数量,仅在 action 为 list_messages 时有效,默认 20"
|
||||
"description": "获取数量,在 action 为 list_sessions 或 list_messages 时有效,默认 20,最大 100"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "跳过前 N 条(用于翻页),在 action 为 list_sessions 或 list_messages 时有效,默认 0"
|
||||
},
|
||||
"before_time": {
|
||||
"type": "integer",
|
||||
"description": "Unix 时间戳(秒),仅返回此时间之前的消息,仅在 action 为 list_messages 时有效"
|
||||
},
|
||||
"after_time": {
|
||||
"type": "integer",
|
||||
"description": "Unix 时间戳(秒),仅返回此时间之后的消息,仅在 action 为 list_messages 时有效"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
@ -68,7 +80,7 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列
|
||||
|
||||
match action {
|
||||
"list_channels" => self.list_channels().await,
|
||||
"list_sessions" => self.list_sessions().await,
|
||||
"list_sessions" => self.list_sessions(&args).await,
|
||||
"list_messages" => self.list_messages(&args).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
@ -92,23 +104,29 @@ impl ChatManagerTool {
|
||||
})
|
||||
}
|
||||
|
||||
async fn list_sessions(&self) -> anyhow::Result<ToolResult> {
|
||||
let sessions = self
|
||||
async fn list_sessions(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100);
|
||||
let offset = args["offset"].as_i64().unwrap_or(0).max(0);
|
||||
|
||||
let (sessions, total) = self
|
||||
.storage
|
||||
.list_all_active_sessions(20)
|
||||
.query_sessions_range(offset, count)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to list sessions: {}", e))?;
|
||||
|
||||
if sessions.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "当前没有活跃的会话".to_string(),
|
||||
output: "当前没有会话".to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let now_ms = chrono::Utc::now().timestamp_millis();
|
||||
let mut output = format!("活跃会话 (共 {} 个):\n", sessions.len());
|
||||
let start_num = offset + 1;
|
||||
let end_num = offset + sessions.len() as i64;
|
||||
|
||||
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num);
|
||||
|
||||
for s in &sessions {
|
||||
let ago = format_duration_ago(now_ms - s.last_active_at);
|
||||
@ -131,6 +149,10 @@ impl ChatManagerTool {
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: session_id"))?;
|
||||
|
||||
let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100);
|
||||
let offset = args["offset"].as_i64().unwrap_or(0).max(0);
|
||||
|
||||
let before_time = args["before_time"].as_i64().map(|t| t * 1000);
|
||||
let after_time = args["after_time"].as_i64().map(|t| t * 1000);
|
||||
|
||||
let session = self
|
||||
.storage
|
||||
@ -138,15 +160,31 @@ impl ChatManagerTool {
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Session not found: {}", e))?;
|
||||
|
||||
let messages = self
|
||||
let (messages, total) = self
|
||||
.storage
|
||||
.list_recent_messages(session_id, count)
|
||||
.query_messages_range(session_id, before_time, after_time, offset, count)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load messages: {}", e))?;
|
||||
|
||||
let start_num = offset + 1;
|
||||
let end_num = offset + messages.len() as i64;
|
||||
|
||||
let range_desc = if messages.is_empty() {
|
||||
"无消息".to_string()
|
||||
} else {
|
||||
format!("第 {}-{} 条", start_num, end_num)
|
||||
};
|
||||
|
||||
let filter_desc = match (before_time, after_time) {
|
||||
(Some(_), Some(_)) => "(已按时间范围筛选)",
|
||||
(Some(_), None) => "(已按截止时间筛选)",
|
||||
(None, Some(_)) => "(已按起始时间筛选)",
|
||||
(None, None) => "",
|
||||
};
|
||||
|
||||
let mut output = format!(
|
||||
"会话: {} ({})\n--- 最近 {} 条消息 (共 {} 条) ---\n",
|
||||
session_id, session.title, messages.len(), session.message_count
|
||||
"会话: {} ({})\n--- 消息 {} / 共 {} 条 {} ---\n",
|
||||
session_id, session.title, range_desc, total, filter_desc
|
||||
);
|
||||
|
||||
if messages.is_empty() {
|
||||
@ -264,6 +302,7 @@ mod tests {
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
}
|
||||
@ -281,7 +320,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_messages() {
|
||||
async fn test_list_messages_default() {
|
||||
let (storage, _dir) = create_test_storage().await;
|
||||
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
@ -298,6 +337,7 @@ mod tests {
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
|
||||
@ -328,6 +368,120 @@ mod tests {
|
||||
assert!(result.output.contains("消息内容 0"));
|
||||
assert!(result.output.contains("消息内容 2"));
|
||||
assert!(result.output.contains("测试会话"));
|
||||
assert!(result.output.contains("共 3 条"));
|
||||
// Verify ascending order: seq 1 before seq 3
|
||||
let pos_0 = result.output.find("消息内容 0").unwrap();
|
||||
let pos_2 = result.output.find("消息内容 2").unwrap();
|
||||
assert!(pos_0 < pos_2, "Messages should be in ascending order");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_messages_with_pagination() {
|
||||
let (storage, _dir) = create_test_storage().await;
|
||||
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
let session_id = "cli_chat:sid0:dialog0";
|
||||
let meta = crate::storage::session::SessionMeta {
|
||||
id: session_id.to_string(),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid0".to_string(),
|
||||
dialog_id: "dialog0".to_string(),
|
||||
title: "分页测试".to_string(),
|
||||
created_at: now,
|
||||
last_active_at: now,
|
||||
message_count: 5,
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
|
||||
for i in 0..5 {
|
||||
let msg = crate::storage::message::MessageMeta {
|
||||
id: format!("msg{}", i),
|
||||
session_id: session_id.to_string(),
|
||||
seq: i as i64 + 1,
|
||||
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||
content: format!("消息内容 {}", i),
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
created_at: now + i * 1000,
|
||||
};
|
||||
storage.append_message(session_id, &msg).await.unwrap();
|
||||
}
|
||||
|
||||
let tool = ChatManagerTool::new(storage, vec![]);
|
||||
|
||||
// offset=2, count=2 => should return messages 2,3 (seq 3,4)
|
||||
let result = tool
|
||||
.execute(json!({ "action": "list_messages", "session_id": session_id, "offset": 2, "count": 2 }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("第 3-4 条"));
|
||||
assert!(result.output.contains("消息内容 2"));
|
||||
assert!(result.output.contains("消息内容 3"));
|
||||
assert!(!result.output.contains("消息内容 0"));
|
||||
assert!(result.output.contains("共 5 条"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_messages_with_time_range() {
|
||||
let (storage, _dir) = create_test_storage().await;
|
||||
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
let session_id = "cli_chat:sid0:dialog0";
|
||||
let meta = crate::storage::session::SessionMeta {
|
||||
id: session_id.to_string(),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid0".to_string(),
|
||||
dialog_id: "dialog0".to_string(),
|
||||
title: "时间范围测试".to_string(),
|
||||
created_at: now,
|
||||
last_active_at: now,
|
||||
message_count: 5,
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
|
||||
for i in 0..5 {
|
||||
let msg = crate::storage::message::MessageMeta {
|
||||
id: format!("msg{}", i),
|
||||
session_id: session_id.to_string(),
|
||||
seq: i as i64 + 1,
|
||||
role: "user".to_string(),
|
||||
content: format!("消息内容 {}", i),
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
created_at: now + i * 1000,
|
||||
};
|
||||
storage.append_message(session_id, &msg).await.unwrap();
|
||||
}
|
||||
|
||||
let tool = ChatManagerTool::new(storage, vec![]);
|
||||
|
||||
// after_time: filter to messages after msg1's second boundary
|
||||
let after_ts = now / 1000 + 2;
|
||||
let result = tool
|
||||
.execute(json!({ "action": "list_messages", "session_id": session_id, "after_time": after_ts }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("已按起始时间筛选"));
|
||||
assert!(!result.output.contains("消息内容 0"));
|
||||
assert!(!result.output.contains("消息内容 1"));
|
||||
assert!(result.output.contains("消息内容 3"));
|
||||
assert!(result.output.contains("消息内容 4"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user