feat(session): 添加逻辑重置功能,优化会话历史管理
This commit is contained in:
parent
eb0f6c0bc7
commit
393d980742
@ -85,6 +85,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
|||||||
| `archived_at` | `INTEGER` | 归档时间 | 非空表示会话已归档 |
|
| `archived_at` | `INTEGER` | 归档时间 | 非空表示会话已归档 |
|
||||||
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
|
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
|
||||||
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
|
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
|
||||||
|
| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 |
|
||||||
|
|
||||||
索引:
|
索引:
|
||||||
|
|
||||||
@ -172,9 +173,14 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
|||||||
|
|
||||||
### 6.3 读取历史
|
### 6.3 读取历史
|
||||||
|
|
||||||
`load_messages(session_id)` 会按 `seq ASC` 读取整个消息历史,并把 JSON 字段反序列化回 `ChatMessage`。
|
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是:
|
||||||
|
|
||||||
因此它恢复的是“逻辑顺序”,而不是简单按创建时间排序。只要 `seq` 连续,重放顺序就稳定。
|
- 只返回 `seq > sessions.reset_cutoff_seq` 的消息
|
||||||
|
- 因此 `/reset` 之后,旧消息仍然保留在数据库中,但不会默认回灌到运行时上下文
|
||||||
|
|
||||||
|
如果需要审计、导出或查看完整历史,应使用全量读取接口 `load_all_messages(session_id)`。
|
||||||
|
|
||||||
|
因此运行态恢复的是“当前活动段的逻辑顺序”,而不是简单按创建时间排序。只要 `seq` 连续,重放顺序就稳定。
|
||||||
|
|
||||||
## 7. 典型时序
|
## 7. 典型时序
|
||||||
|
|
||||||
@ -229,12 +235,24 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
|||||||
|
|
||||||
- 删除该会话在 `messages` 中的所有记录
|
- 删除该会话在 `messages` 中的所有记录
|
||||||
- 将 `sessions.message_count` 重置为 0
|
- 将 `sessions.message_count` 重置为 0
|
||||||
|
- 将 `sessions.reset_cutoff_seq` 重置为 0
|
||||||
- 更新 `updated_at` 和 `last_active_at`
|
- 更新 `updated_at` 和 `last_active_at`
|
||||||
- 保留会话本身
|
- 保留会话本身
|
||||||
|
|
||||||
这适合“保留会话入口,但丢弃聊天内容”的场景。
|
这适合“保留会话入口,但丢弃聊天内容”的场景。
|
||||||
|
|
||||||
### 8.4 删除会话
|
### 8.4 逻辑重置
|
||||||
|
|
||||||
|
`reset_session(session_id)`:
|
||||||
|
|
||||||
|
- 不删除 `messages` 中的任何记录
|
||||||
|
- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq`
|
||||||
|
- 更新 `updated_at` 和 `last_active_at`
|
||||||
|
- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息
|
||||||
|
|
||||||
|
这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。
|
||||||
|
|
||||||
|
### 8.5 删除会话
|
||||||
|
|
||||||
`delete_session(session_id)`:
|
`delete_session(session_id)`:
|
||||||
|
|
||||||
@ -276,6 +294,9 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
|||||||
- `sessions.deleted_at`
|
- `sessions.deleted_at`
|
||||||
- 当前查询逻辑兼容软删除
|
- 当前查询逻辑兼容软删除
|
||||||
- 当前删除实现仍然是物理删除
|
- 当前删除实现仍然是物理删除
|
||||||
|
- `sessions.reset_cutoff_seq`
|
||||||
|
- 当前已用于实现 `/reset` 的非破坏性逻辑重置
|
||||||
|
- 只影响默认恢复的活动段,不影响数据库中的全量历史
|
||||||
|
|
||||||
这说明当前 schema 已经为“会话摘要”和“软删除”预留了演进空间,但并未完全落地。
|
这说明当前 schema 已经为“会话摘要”和“软删除”预留了演进空间,但并未完全落地。
|
||||||
|
|
||||||
@ -285,6 +306,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
|||||||
|
|
||||||
- 会话查不到:先看 `persistent_session_id` 是否和实际 `channel_name/chat_id` 一致
|
- 会话查不到:先看 `persistent_session_id` 是否和实际 `channel_name/chat_id` 一致
|
||||||
- 重启后没历史:检查 `ensure_chat_loaded()` 调用链,以及数据库文件路径是否正确
|
- 重启后没历史:检查 `ensure_chat_loaded()` 调用链,以及数据库文件路径是否正确
|
||||||
|
- `/reset` 后重启又带回旧上下文:检查 `sessions.reset_cutoff_seq` 是否已写入,以及恢复路径是否走了活动段读取而不是全量读取
|
||||||
- 消息顺序不对:检查 `messages.seq`
|
- 消息顺序不对:检查 `messages.seq`
|
||||||
- 工具调用上下文异常:同时检查 `tool_calls_json` 和 `tool_call_id`
|
- 工具调用上下文异常:同时检查 `tool_calls_json` 和 `tool_call_id`
|
||||||
- 会话列表里看不到记录:检查 `archived_at` 和 `include_archived` 参数
|
- 会话列表里看不到记录:检查 `archived_at` 和 `include_archived` 参数
|
||||||
|
|||||||
@ -48,7 +48,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let mut input = InputHandler::new();
|
let mut input = InputHandler::new();
|
||||||
let mut current_session_id: Option<String> = None;
|
let mut current_session_id: Option<String> = None;
|
||||||
input.write_output("picobot CLI - Commands: /new [title], /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
input.write_output("picobot CLI - Commands: /new [title], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
||||||
|
|
||||||
// Main loop: poll both stdin and WebSocket
|
// Main loop: poll both stdin and WebSocket
|
||||||
loop {
|
loop {
|
||||||
|
|||||||
@ -105,6 +105,19 @@ impl Session {
|
|||||||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
|
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||||
|
let len = history.len();
|
||||||
|
history.clear();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
self.store
|
||||||
|
.reset_session(&self.persistent_session_id(chat_id))
|
||||||
|
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
/// 将消息写入内存与持久化层
|
/// 将消息写入内存与持久化层
|
||||||
pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> {
|
pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> {
|
||||||
let session_id = self.persistent_session_id(chat_id);
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
@ -202,6 +215,32 @@ fn default_tools() -> ToolRegistry {
|
|||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum InChatCommand {
|
||||||
|
FreshConversation,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
||||||
|
match content.trim() {
|
||||||
|
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn handle_in_chat_command(
|
||||||
|
session: &mut Session,
|
||||||
|
chat_id: &str,
|
||||||
|
content: &str,
|
||||||
|
) -> Result<Option<String>, AgentError> {
|
||||||
|
match parse_in_chat_command(content) {
|
||||||
|
Some(InChatCommand::FreshConversation) => {
|
||||||
|
session.reset_chat_context(chat_id)?;
|
||||||
|
Ok(Some("Started a fresh conversation.".to_string()))
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let store = Arc::new(
|
let store = Arc::new(
|
||||||
@ -372,6 +411,10 @@ impl SessionManager {
|
|||||||
session_guard.ensure_persistent_session(chat_id)?;
|
session_guard.ensure_persistent_session(chat_id)?;
|
||||||
session_guard.ensure_chat_loaded(chat_id)?;
|
session_guard.ensure_chat_loaded(chat_id)?;
|
||||||
|
|
||||||
|
if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? {
|
||||||
|
return Ok(command_response);
|
||||||
|
}
|
||||||
|
|
||||||
// 添加用户消息到历史
|
// 添加用户消息到历史
|
||||||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -419,3 +462,74 @@ impl SessionManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
fn test_provider_config() -> LLMProviderConfig {
|
||||||
|
LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "test".to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
model_id: "test-model".to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
token_limit: 4096,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_in_chat_command_aliases() {
|
||||||
|
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
|
||||||
|
assert_eq!(parse_in_chat_command(" /reset \n"), Some(InChatCommand::FreshConversation));
|
||||||
|
assert_eq!(parse_in_chat_command("/new planning"), None);
|
||||||
|
assert_eq!(parse_in_chat_command("please /reset"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_in_chat_command_resets_active_history_only() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
|
let tools = Arc::new(default_tools());
|
||||||
|
let mut session = Session::new(
|
||||||
|
"feishu".to_string(),
|
||||||
|
test_provider_config(),
|
||||||
|
user_tx,
|
||||||
|
tools,
|
||||||
|
store.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
session.ensure_persistent_session("chat-1").unwrap();
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
session
|
||||||
|
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response, "Started a fresh conversation.");
|
||||||
|
assert!(session.get_history("chat-1").unwrap().is_empty());
|
||||||
|
assert!(store
|
||||||
|
.load_messages(&session.persistent_session_id("chat-1"))
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
|
assert_eq!(
|
||||||
|
store
|
||||||
|
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ use axum::response::Response;
|
|||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||||
use super::{GatewayState, session::Session};
|
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||||
ws.on_upgrade(|socket| async {
|
ws.on_upgrade(|socket| async {
|
||||||
@ -153,6 +153,17 @@ async fn handle_inbound(
|
|||||||
session_guard.ensure_persistent_session(&chat_id)?;
|
session_guard.ensure_persistent_session(&chat_id)?;
|
||||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||||
|
|
||||||
|
if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? {
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::AssistantResponse {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
content: command_response,
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
let user_message = session_guard.create_user_message(&content, Vec::new());
|
let user_message = session_guard.create_user_message(&content, Vec::new());
|
||||||
session_guard.append_persisted_message(&chat_id, user_message)?;
|
session_guard.append_persisted_message(&chat_id, user_message)?;
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ pub struct SessionRecord {
|
|||||||
pub archived_at: Option<i64>,
|
pub archived_at: Option<i64>,
|
||||||
pub deleted_at: Option<i64>,
|
pub deleted_at: Option<i64>,
|
||||||
pub message_count: i64,
|
pub message_count: i64,
|
||||||
|
pub reset_cutoff_seq: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -68,7 +69,8 @@ impl SessionStore {
|
|||||||
last_active_at INTEGER NOT NULL,
|
last_active_at INTEGER NOT NULL,
|
||||||
archived_at INTEGER,
|
archived_at INTEGER,
|
||||||
deleted_at INTEGER,
|
deleted_at INTEGER,
|
||||||
message_count INTEGER NOT NULL DEFAULT 0
|
message_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
|
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
|
||||||
@ -98,13 +100,15 @@ impl SessionStore {
|
|||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
ensure_sessions_schema(&conn)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Arc::new(Mutex::new(conn)),
|
conn: Arc::new(Mutex::new(conn)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn in_memory() -> Result<Self, StorageError> {
|
pub(crate) fn in_memory() -> Result<Self, StorageError> {
|
||||||
Self::from_connection(Connection::open_in_memory()?)
|
Self::from_connection(Connection::open_in_memory()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +169,7 @@ impl SessionStore {
|
|||||||
"
|
"
|
||||||
SELECT id, title, channel_name, chat_id, summary,
|
SELECT id, title, channel_name, chat_id, summary,
|
||||||
created_at, updated_at, last_active_at,
|
created_at, updated_at, last_active_at,
|
||||||
archived_at, deleted_at, message_count
|
archived_at, deleted_at, message_count, reset_cutoff_seq
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE id = ?1 AND deleted_at IS NULL
|
WHERE id = ?1 AND deleted_at IS NULL
|
||||||
",
|
",
|
||||||
@ -186,7 +190,7 @@ impl SessionStore {
|
|||||||
"
|
"
|
||||||
SELECT id, title, channel_name, chat_id, summary,
|
SELECT id, title, channel_name, chat_id, summary,
|
||||||
created_at, updated_at, last_active_at,
|
created_at, updated_at, last_active_at,
|
||||||
archived_at, deleted_at, message_count
|
archived_at, deleted_at, message_count, reset_cutoff_seq
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel_name = ?1
|
WHERE channel_name = ?1
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
@ -242,7 +246,7 @@ impl SessionStore {
|
|||||||
conn.execute(
|
conn.execute(
|
||||||
"
|
"
|
||||||
UPDATE sessions
|
UPDATE sessions
|
||||||
SET message_count = 0, updated_at = ?2, last_active_at = ?2
|
SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0
|
||||||
WHERE id = ?1 AND deleted_at IS NULL
|
WHERE id = ?1 AND deleted_at IS NULL
|
||||||
",
|
",
|
||||||
params![session_id, now],
|
params![session_id, now],
|
||||||
@ -250,6 +254,33 @@ impl SessionStore {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
|
let now = current_timestamp();
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let tx = conn.unchecked_transaction()?;
|
||||||
|
|
||||||
|
let cutoff_seq: i64 = tx.query_row(
|
||||||
|
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
|
||||||
|
params![session_id],
|
||||||
|
|row| row.get(0),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
tx.execute(
|
||||||
|
"
|
||||||
|
UPDATE sessions
|
||||||
|
SET reset_cutoff_seq = ?2,
|
||||||
|
updated_at = ?3,
|
||||||
|
last_active_at = ?3,
|
||||||
|
archived_at = NULL
|
||||||
|
WHERE id = ?1 AND deleted_at IS NULL
|
||||||
|
",
|
||||||
|
params![session_id, cutoff_seq, now],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
tx.commit()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
let tx = conn.unchecked_transaction()?;
|
let tx = conn.unchecked_transaction()?;
|
||||||
@ -302,16 +333,99 @@ impl SessionStore {
|
|||||||
|
|
||||||
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
|
||||||
|
load_messages_after(&conn, session_id, cutoff_seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
load_messages_after(&conn, session_id, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
||||||
|
if channel_name == "cli" {
|
||||||
|
chat_id.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}:{}", channel_name, chat_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
|
||||||
|
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||||
|
Ok(home.join(".picobot").join("storage").join("sessions.db"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
|
||||||
|
Ok(SessionRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
title: row.get(1)?,
|
||||||
|
channel_name: row.get(2)?,
|
||||||
|
chat_id: row.get(3)?,
|
||||||
|
summary: row.get(4)?,
|
||||||
|
created_at: row.get(5)?,
|
||||||
|
updated_at: row.get(6)?,
|
||||||
|
last_active_at: row.get(7)?,
|
||||||
|
archived_at: row.get(8)?,
|
||||||
|
deleted_at: row.get(9)?,
|
||||||
|
message_count: row.get(10)?,
|
||||||
|
reset_cutoff_seq: row.get(11)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||||
|
if !has_column(conn, "sessions", "reset_cutoff_seq")? {
|
||||||
|
conn.execute(
|
||||||
|
"ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0",
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<bool, StorageError> {
|
||||||
|
let pragma = format!("PRAGMA table_info({})", table_name);
|
||||||
|
let mut stmt = conn.prepare(&pragma)?;
|
||||||
|
let mut rows = stmt.query([])?;
|
||||||
|
|
||||||
|
while let Some(row) = rows.next()? {
|
||||||
|
let existing_name: String = row.get(1)?;
|
||||||
|
if existing_name == column_name {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> {
|
||||||
|
let cutoff = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL",
|
||||||
|
params![session_id],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.optional()?;
|
||||||
|
|
||||||
|
Ok(cutoff.unwrap_or(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_messages_after(
|
||||||
|
conn: &Connection,
|
||||||
|
session_id: &str,
|
||||||
|
cutoff_seq: i64,
|
||||||
|
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"
|
"
|
||||||
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE session_id = ?1
|
WHERE session_id = ?1 AND seq > ?2
|
||||||
ORDER BY seq ASC
|
ORDER BY seq ASC
|
||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let rows = stmt.query_map(params![session_id], |row| {
|
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
|
||||||
let media_refs_json: String = row.get(3)?;
|
let media_refs_json: String = row.get(3)?;
|
||||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
@ -352,36 +466,6 @@ impl SessionStore {
|
|||||||
}
|
}
|
||||||
Ok(messages)
|
Ok(messages)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
|
||||||
if channel_name == "cli" {
|
|
||||||
chat_id.to_string()
|
|
||||||
} else {
|
|
||||||
format!("{}:{}", channel_name, chat_id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
Ok(home.join(".picobot").join("storage").join("sessions.db"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
|
|
||||||
Ok(SessionRecord {
|
|
||||||
id: row.get(0)?,
|
|
||||||
title: row.get(1)?,
|
|
||||||
channel_name: row.get(2)?,
|
|
||||||
chat_id: row.get(3)?,
|
|
||||||
summary: row.get(4)?,
|
|
||||||
created_at: row.get(5)?,
|
|
||||||
updated_at: row.get(6)?,
|
|
||||||
last_active_at: row.get(7)?,
|
|
||||||
archived_at: row.get(8)?,
|
|
||||||
deleted_at: row.get(9)?,
|
|
||||||
message_count: row.get(10)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
fn current_timestamp() -> i64 {
|
||||||
std::time::SystemTime::now()
|
std::time::SystemTime::now()
|
||||||
@ -410,6 +494,7 @@ mod tests {
|
|||||||
assert_eq!(session.channel_name, "cli");
|
assert_eq!(session.channel_name, "cli");
|
||||||
assert_eq!(session.chat_id, session.id);
|
assert_eq!(session.chat_id, session.id);
|
||||||
assert_eq!(session.message_count, 0);
|
assert_eq!(session.message_count, 0);
|
||||||
|
assert_eq!(session.reset_cutoff_seq, 0);
|
||||||
|
|
||||||
let first = ChatMessage::user("hello");
|
let first = ChatMessage::user("hello");
|
||||||
let second = ChatMessage::assistant("world");
|
let second = ChatMessage::assistant("world");
|
||||||
@ -419,6 +504,7 @@ mod tests {
|
|||||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||||
assert_eq!(stored.message_count, 2);
|
assert_eq!(stored.message_count, 2);
|
||||||
assert!(stored.archived_at.is_none());
|
assert!(stored.archived_at.is_none());
|
||||||
|
assert_eq!(stored.reset_cutoff_seq, 0);
|
||||||
|
|
||||||
let messages = store.load_messages(&session.id).unwrap();
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
assert_eq!(messages.len(), 2);
|
assert_eq!(messages.len(), 2);
|
||||||
@ -487,6 +573,74 @@ mod tests {
|
|||||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reset_session_preserves_full_history_and_hides_active_history() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
let session = store.create_cli_session(Some("reset")).unwrap();
|
||||||
|
|
||||||
|
store.append_message(&session.id, &ChatMessage::user("before")).unwrap();
|
||||||
|
store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap();
|
||||||
|
store.reset_session(&session.id).unwrap();
|
||||||
|
|
||||||
|
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||||
|
assert_eq!(stored.reset_cutoff_seq, 2);
|
||||||
|
|
||||||
|
let active_messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert!(active_messages.is_empty());
|
||||||
|
|
||||||
|
let all_messages = store.load_all_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(all_messages.len(), 2);
|
||||||
|
assert_eq!(all_messages[0].content, "before");
|
||||||
|
assert_eq!(all_messages[1].content, "context");
|
||||||
|
|
||||||
|
store.append_message(&session.id, &ChatMessage::user("after")).unwrap();
|
||||||
|
let active_messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(active_messages.len(), 1);
|
||||||
|
assert_eq!(active_messages[0].content, "after");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_schema_migration_adds_reset_cutoff_column() {
|
||||||
|
let conn = Connection::open_in_memory().unwrap();
|
||||||
|
conn.execute_batch(
|
||||||
|
"
|
||||||
|
CREATE TABLE sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
channel_name TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
summary TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
updated_at INTEGER NOT NULL,
|
||||||
|
last_active_at INTEGER NOT NULL,
|
||||||
|
archived_at INTEGER,
|
||||||
|
deleted_at INTEGER,
|
||||||
|
message_count INTEGER NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE messages (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
seq INTEGER NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
media_refs_json TEXT NOT NULL,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tool_calls_json TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(session_id, seq)
|
||||||
|
);
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let store = SessionStore::from_connection(conn).unwrap();
|
||||||
|
let session = store.create_cli_session(Some("migrated")).unwrap();
|
||||||
|
assert_eq!(session.reset_cutoff_seq, 0);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_result_roundtrip() {
|
fn test_tool_result_roundtrip() {
|
||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user