xiaoxixi 86dea0f874 Refactor session management to support dialog-based architecture
- Removed InputHandler and related input event handling code.
- Updated GatewayState to handle new session commands for dialogs.
- Introduced UnifiedSessionId for managing session identifiers across channels and chats.
- Refactored Session and SessionManager to manage dialogs instead of sessions.
- Added methods for creating, listing, switching, renaming, archiving, and deleting dialogs.
- Updated storage functions to accommodate dialog IDs in persistent session management.
- Enhanced tests to cover new dialog functionalities and ensure stability.
2026-04-26 20:59:54 +08:00

651 lines
23 KiB
Rust

use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use rusqlite::{Connection, OptionalExtension, params};
use serde::{Deserialize, Serialize};
use crate::bus::ChatMessage;
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRecord {
pub id: String,
pub title: String,
pub channel_name: String,
pub chat_id: String,
pub summary: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub last_active_at: i64,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
}
#[derive(Clone)]
pub struct SessionStore {
conn: Arc<Mutex<Connection>>,
}
impl SessionStore {
pub fn new() -> Result<Self, StorageError> {
let db_path = default_session_db_path()?;
Self::open_at_path(&db_path)
}
fn open_at_path(path: &Path) -> Result<Self, StorageError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(path)?;
Self::from_connection(conn)
}
fn from_connection(conn: Connection) -> Result<Self, StorageError> {
conn.execute_batch(
"
PRAGMA journal_mode = WAL;
PRAGMA foreign_keys = ON;
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
channel_name TEXT NOT NULL,
chat_id TEXT NOT NULL,
summary TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
archived_at INTEGER,
deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0,
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
ON sessions(channel_name, archived_at, last_active_at DESC);
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at
ON sessions(updated_at DESC);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs_json TEXT NOT NULL,
tool_call_id TEXT,
tool_name TEXT,
tool_calls_json TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
UNIQUE(session_id, seq)
);
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
ON messages(session_id, seq);
CREATE INDEX IF NOT EXISTS idx_messages_session_created
ON messages(session_id, created_at);
",
)?;
ensure_sessions_schema(&conn)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
#[cfg(test)]
pub(crate) fn in_memory() -> Result<Self, StorageError> {
Self::from_connection(Connection::open_in_memory()?)
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
let now = crate::bus::message::current_timestamp();
let id = uuid::Uuid::new_v4().to_string();
let title = title
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| format!("CLI Session {}", &id[..8]));
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, 'cli_chat', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0)
",
params![id, title, id, now],
)?;
drop(conn);
self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn ensure_channel_session(
&self,
channel_name: &str,
chat_id: &str,
dialog_id: &str,
) -> Result<SessionRecord, StorageError> {
let session_id = persistent_session_id(channel_name, chat_id, dialog_id);
if let Some(record) = self.get_session(&session_id)? {
return Ok(record);
}
let now = crate::bus::message::current_timestamp();
let title = format!("{}:{}", channel_name, chat_id);
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0)
",
params![session_id, title, channel_name, chat_id, now],
)?;
drop(conn);
self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut stmt = conn.prepare(
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq
FROM sessions
WHERE id = ?1 AND deleted_at IS NULL
",
)?;
stmt.query_row(params![session_id], map_session_record)
.optional()
.map_err(StorageError::from)
}
pub fn list_sessions(
&self,
channel_name: &str,
include_archived: bool,
) -> Result<Vec<SessionRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut sql = String::from(
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq
FROM sessions
WHERE channel_name = ?1
AND deleted_at IS NULL
",
);
if !include_archived {
sql.push_str(" AND archived_at IS NULL");
}
sql.push_str(" ORDER BY last_active_at DESC, created_at DESC");
let mut stmt = conn.prepare(&sql)?;
let rows = stmt.query_map(params![channel_name], map_session_record)?;
let mut sessions = Vec::new();
for row in rows {
sessions.push(row?);
}
Ok(sessions)
}
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), StorageError> {
let now = crate::bus::message::current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1 AND deleted_at IS NULL",
params![session_id, title.trim(), now],
)?;
Ok(())
}
pub fn archive_session(&self, session_id: &str) -> Result<(), StorageError> {
let now = crate::bus::message::current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"UPDATE sessions SET archived_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL",
params![session_id, now],
)?;
Ok(())
}
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
Ok(())
}
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
let now = crate::bus::message::current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
conn.execute(
"
UPDATE sessions
SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now],
)?;
Ok(())
}
pub fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
let now = crate::bus::message::current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
let tx = conn.unchecked_transaction()?;
let cutoff_seq: i64 = tx.query_row(
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
params![session_id],
|row| row.get(0),
)?;
tx.execute(
"
UPDATE sessions
SET reset_cutoff_seq = ?2,
updated_at = ?3,
last_active_at = ?3,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, cutoff_seq, now],
)?;
tx.commit()?;
Ok(())
}
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let tx = conn.unchecked_transaction()?;
let seq: i64 = tx.query_row(
"SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1",
params![session_id],
|row| row.get(0),
)?;
let media_refs_json = serde_json::to_string(&message.media_refs)?;
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
tx.execute(
"
INSERT INTO messages (
id, session_id, seq, role, content,
media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
",
params![
message.id,
session_id,
seq,
message.role,
message.content,
media_refs_json,
message.tool_call_id,
message.tool_name,
tool_calls_json,
message.timestamp,
],
)?;
let now = crate::bus::message::current_timestamp();
tx.execute(
"
UPDATE sessions
SET message_count = message_count + 1,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now],
)?;
tx.commit()?;
Ok(())
}
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
load_messages_after(&conn, session_id, cutoff_seq)
}
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
load_messages_after(&conn, session_id, 0)
}
}
pub fn persistent_session_id(channel_name: &str, chat_id: &str, dialog_id: &str) -> String {
format!("{}:{}:{}", channel_name, chat_id, dialog_id)
}
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
Ok(home.join(".picobot").join("storage").join("sessions.db"))
}
fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
Ok(SessionRecord {
id: row.get(0)?,
title: row.get(1)?,
channel_name: row.get(2)?,
chat_id: row.get(3)?,
summary: row.get(4)?,
created_at: row.get(5)?,
updated_at: row.get(6)?,
last_active_at: row.get(7)?,
archived_at: row.get(8)?,
deleted_at: row.get(9)?,
message_count: row.get(10)?,
reset_cutoff_seq: row.get(11)?,
})
}
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
if !has_column(conn, "sessions", "reset_cutoff_seq")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
Ok(())
}
fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<bool, StorageError> {
let pragma = format!("PRAGMA table_info({})", table_name);
let mut stmt = conn.prepare(&pragma)?;
let mut rows = stmt.query([])?;
while let Some(row) = rows.next()? {
let existing_name: String = row.get(1)?;
if existing_name == column_name {
return Ok(true);
}
}
Ok(false)
}
fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> {
let cutoff = conn
.query_row(
"SELECT reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL",
params![session_id],
|row| row.get(0),
)
.optional()?;
Ok(cutoff.unwrap_or(0))
}
fn load_messages_after(
conn: &Connection,
session_id: &str,
cutoff_seq: i64,
) -> Result<Vec<ChatMessage>, StorageError> {
let mut stmt = conn.prepare(
"
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
FROM messages
WHERE session_id = ?1 AND seq > ?2
ORDER BY seq ASC
",
)?;
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
let media_refs_json: String = row.get(3)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let tool_calls_json: Option<String> = row.get(7)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
7,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
media_refs,
timestamp: row.get(4)?,
tool_call_id: row.get(5)?,
tool_name: row.get(6)?,
tool_calls,
})
})?;
let mut messages = Vec::new();
for row in rows {
messages.push(row?);
}
Ok(messages)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::ToolCall;
#[test]
fn test_persistent_session_id_for_cli_and_channel() {
assert_eq!(persistent_session_id("cli", "abc", "default"), "cli:abc:default");
assert_eq!(persistent_session_id("cli_chat", "abc", "default"), "cli_chat:abc:default");
assert_eq!(persistent_session_id("feishu", "abc", "default"), "feishu:abc:default");
}
#[test]
fn test_session_store_roundtrip_and_lifecycle() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("demo")).unwrap();
assert_eq!(session.title, "demo");
assert_eq!(session.channel_name, "cli_chat");
assert_eq!(session.chat_id, session.id);
assert_eq!(session.message_count, 0);
assert_eq!(session.reset_cutoff_seq, 0);
let first = ChatMessage::user("hello");
let second = ChatMessage::assistant("world");
store.append_message(&session.id, &first).unwrap();
store.append_message(&session.id, &second).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.message_count, 2);
assert!(stored.archived_at.is_none());
assert_eq!(stored.reset_cutoff_seq, 0);
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "hello");
assert_eq!(messages[1].role, "assistant");
assert_eq!(messages[1].content, "world");
store.rename_session(&session.id, "renamed").unwrap();
let renamed = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(renamed.title, "renamed");
store.archive_session(&session.id).unwrap();
let archived = store.get_session(&session.id).unwrap().unwrap();
assert!(archived.archived_at.is_some());
let active_only = store.list_sessions("cli_chat", false).unwrap();
assert!(active_only.is_empty());
let including_archived = store.list_sessions("cli_chat", true).unwrap();
assert_eq!(including_archived.len(), 1);
store.clear_messages(&session.id).unwrap();
let cleared = store.load_messages(&session.id).unwrap();
assert!(cleared.is_empty());
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(cleared_session.message_count, 0);
store.delete_session(&session.id).unwrap();
assert!(store.get_session(&session.id).unwrap().is_none());
}
#[test]
fn test_ensure_channel_session_is_stable() {
let store = SessionStore::in_memory().unwrap();
let first = store.ensure_channel_session("feishu", "chat-1", "default").unwrap();
let second = store.ensure_channel_session("feishu", "chat-1", "default").unwrap();
assert_eq!(first.id, second.id);
assert_eq!(first.chat_id, "chat-1");
assert_eq!(second.channel_name, "feishu");
}
#[test]
fn test_assistant_tool_calls_roundtrip() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("tools")).unwrap();
let assistant = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "3*7" }),
}],
);
store.append_message(&session.id, &assistant).unwrap();
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "assistant");
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
}
#[test]
fn test_reset_session_preserves_full_history_and_hides_active_history() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("reset")).unwrap();
store.append_message(&session.id, &ChatMessage::user("before")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap();
store.reset_session(&session.id).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.reset_cutoff_seq, 2);
let active_messages = store.load_messages(&session.id).unwrap();
assert!(active_messages.is_empty());
let all_messages = store.load_all_messages(&session.id).unwrap();
assert_eq!(all_messages.len(), 2);
assert_eq!(all_messages[0].content, "before");
assert_eq!(all_messages[1].content, "context");
store.append_message(&session.id, &ChatMessage::user("after")).unwrap();
let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 1);
assert_eq!(active_messages[0].content, "after");
}
#[test]
fn test_schema_migration_adds_reset_cutoff_column() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
channel_name TEXT NOT NULL,
chat_id TEXT NOT NULL,
summary TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
archived_at INTEGER,
deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs_json TEXT NOT NULL,
tool_call_id TEXT,
tool_name TEXT,
tool_calls_json TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
UNIQUE(session_id, seq)
);
",
)
.unwrap();
let store = SessionStore::from_connection(conn).unwrap();
let session = store.create_cli_session(Some("migrated")).unwrap();
assert_eq!(session.reset_cutoff_seq, 0);
}
#[test]
fn test_tool_result_roundtrip() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("tool-result")).unwrap();
let tool_message = ChatMessage::tool("call_9", "file_write", "saved to /tmp/output.txt");
store.append_message(&session.id, &tool_message).unwrap();
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "tool");
assert_eq!(messages[0].content, "saved to /tmp/output.txt");
assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9"));
assert_eq!(messages[0].tool_name.as_deref(), Some("file_write"));
assert!(messages[0].tool_calls.is_none());
}
}