2651 lines
94 KiB
Rust
2651 lines
94 KiB
Rust
#[cfg(not(test))]
|
||
use std::path::{Path, PathBuf};
|
||
use std::sync::{Arc, Mutex};
|
||
|
||
use rusqlite::{Connection, OptionalExtension, params};
|
||
|
||
use crate::bus::ChatMessage;
|
||
|
||
pub mod error;
|
||
pub mod ports;
|
||
pub mod records;
|
||
|
||
pub use error::StorageError;
|
||
pub use ports::{
|
||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||
SkillEventRepository,
|
||
};
|
||
pub use records::{
|
||
allowed_namespace_names, get_namespace_description, is_valid_namespace,
|
||
ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord,
|
||
SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||
TopicRecord,
|
||
};
|
||
|
||
#[derive(Clone)]
|
||
pub struct SessionStore {
|
||
conn: Arc<Mutex<Connection>>,
|
||
}
|
||
|
||
impl SessionStore {
|
||
#[cfg(test)]
|
||
pub fn new() -> Result<Self, StorageError> {
|
||
Self::from_connection(Connection::open_in_memory()?)
|
||
}
|
||
|
||
#[cfg(not(test))]
|
||
pub fn new() -> Result<Self, StorageError> {
|
||
let db_path = default_session_db_path()?;
|
||
Self::open_at_path(&db_path)
|
||
}
|
||
|
||
#[cfg(not(test))]
|
||
fn open_at_path(path: &Path) -> Result<Self, StorageError> {
|
||
if let Some(parent) = path.parent() {
|
||
std::fs::create_dir_all(parent)?;
|
||
}
|
||
|
||
let conn = Connection::open(path)?;
|
||
Self::from_connection(conn)
|
||
}
|
||
|
||
fn from_connection(conn: Connection) -> Result<Self, StorageError> {
|
||
conn.busy_timeout(std::time::Duration::from_secs(5))?;
|
||
conn.execute_batch(
|
||
"
|
||
PRAGMA journal_mode = WAL;
|
||
PRAGMA foreign_keys = ON;
|
||
|
||
CREATE TABLE IF NOT EXISTS sessions (
|
||
id TEXT PRIMARY KEY,
|
||
title TEXT NOT NULL,
|
||
channel_name TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
summary TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
updated_at INTEGER NOT NULL,
|
||
last_active_at INTEGER NOT NULL,
|
||
archived_at INTEGER,
|
||
deleted_at INTEGER,
|
||
message_count INTEGER NOT NULL DEFAULT 0,
|
||
user_turn_count INTEGER NOT NULL DEFAULT 0,
|
||
agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
|
||
ON sessions(channel_name, archived_at, last_active_at DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at
|
||
ON sessions(updated_at DESC);
|
||
|
||
CREATE TABLE IF NOT EXISTS messages (
|
||
id TEXT PRIMARY KEY,
|
||
session_id TEXT NOT NULL,
|
||
topic_id TEXT,
|
||
seq INTEGER NOT NULL,
|
||
role TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
system_context TEXT,
|
||
reasoning_content TEXT,
|
||
media_refs_json TEXT NOT NULL,
|
||
tool_call_id TEXT,
|
||
tool_name TEXT,
|
||
tool_calls_json TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||
FOREIGN KEY(topic_id) REFERENCES topics(id) ON DELETE SET NULL,
|
||
UNIQUE(session_id, seq)
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
|
||
ON messages(session_id, seq);
|
||
CREATE INDEX IF NOT EXISTS idx_messages_session_created
|
||
ON messages(session_id, created_at);
|
||
|
||
CREATE TABLE IF NOT EXISTS topics (
|
||
id TEXT PRIMARY KEY,
|
||
session_id TEXT NOT NULL,
|
||
title TEXT NOT NULL,
|
||
description TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
updated_at INTEGER NOT NULL,
|
||
last_active_at INTEGER NOT NULL,
|
||
message_count INTEGER NOT NULL DEFAULT 0,
|
||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_topics_session
|
||
ON topics(session_id, last_active_at DESC);
|
||
|
||
CREATE TABLE IF NOT EXISTS skill_events (
|
||
id TEXT PRIMARY KEY,
|
||
session_id TEXT,
|
||
event_type TEXT NOT NULL,
|
||
skill_name TEXT,
|
||
payload_json TEXT NOT NULL,
|
||
created_at INTEGER NOT NULL,
|
||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_skill_events_session_created
|
||
ON skill_events(session_id, created_at DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_skill_events_type_created
|
||
ON skill_events(event_type, created_at DESC);
|
||
|
||
CREATE TABLE IF NOT EXISTS memories (
|
||
id TEXT PRIMARY KEY,
|
||
scope_kind TEXT NOT NULL,
|
||
scope_key TEXT NOT NULL,
|
||
namespace TEXT NOT NULL,
|
||
memory_key TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
source_type TEXT NOT NULL,
|
||
source_session_id TEXT,
|
||
source_message_id TEXT,
|
||
source_message_seq INTEGER,
|
||
source_channel_name TEXT,
|
||
source_chat_id TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
updated_at INTEGER NOT NULL,
|
||
UNIQUE(scope_kind, scope_key, namespace, memory_key)
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_memories_scope_updated
|
||
ON memories(scope_kind, scope_key, updated_at DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_memories_scope_namespace_updated
|
||
ON memories(scope_kind, scope_key, namespace, updated_at DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_memories_source_session
|
||
ON memories(source_session_id, updated_at DESC);
|
||
|
||
CREATE TABLE IF NOT EXISTS scheduler_jobs (
|
||
id TEXT PRIMARY KEY,
|
||
kind TEXT NOT NULL,
|
||
schedule_json TEXT NOT NULL DEFAULT '{}',
|
||
interval_secs INTEGER NOT NULL DEFAULT 0,
|
||
startup_delay_secs INTEGER NOT NULL DEFAULT 0,
|
||
target_json TEXT NOT NULL,
|
||
payload_json TEXT NOT NULL,
|
||
enabled INTEGER NOT NULL DEFAULT 1,
|
||
state TEXT NOT NULL DEFAULT 'scheduled',
|
||
last_status TEXT,
|
||
last_error TEXT,
|
||
run_count INTEGER NOT NULL DEFAULT 0,
|
||
max_runs INTEGER,
|
||
last_fired_at INTEGER,
|
||
next_fire_at INTEGER,
|
||
paused_at INTEGER,
|
||
completed_at INTEGER,
|
||
created_at INTEGER NOT NULL,
|
||
updated_at INTEGER NOT NULL
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_scheduler_jobs_enabled_next_fire
|
||
ON scheduler_jobs(enabled, state, next_fire_at ASC);
|
||
|
||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||
namespace,
|
||
memory_key,
|
||
content,
|
||
content='memories',
|
||
content_rowid='rowid'
|
||
);
|
||
|
||
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
||
INSERT INTO memories_fts(rowid, namespace, memory_key, content)
|
||
VALUES (new.rowid, new.namespace, new.memory_key, new.content);
|
||
END;
|
||
|
||
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
||
INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content)
|
||
VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content);
|
||
END;
|
||
|
||
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
||
INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content)
|
||
VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content);
|
||
INSERT INTO memories_fts(rowid, namespace, memory_key, content)
|
||
VALUES (new.rowid, new.namespace, new.memory_key, new.content);
|
||
END;
|
||
",
|
||
)?;
|
||
|
||
ensure_sessions_schema(&conn)?;
|
||
ensure_messages_schema(&conn)?;
|
||
ensure_scheduler_schema(&conn)?;
|
||
ensure_memory_scope_key_migration(&conn)?;
|
||
|
||
Ok(Self {
|
||
conn: Arc::new(Mutex::new(conn)),
|
||
})
|
||
}
|
||
|
||
#[cfg(test)]
|
||
pub(crate) fn in_memory() -> Result<Self, StorageError> {
|
||
Self::from_connection(Connection::open_in_memory()?)
|
||
}
|
||
|
||
pub fn create_session(
|
||
&self,
|
||
channel_name: &str,
|
||
title: Option<&str>,
|
||
) -> Result<SessionRecord, StorageError> {
|
||
let now = current_timestamp();
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
// 统一使用 persistent_session_id 格式
|
||
let session_id = persistent_session_id(channel_name, &id);
|
||
let title = title
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.map(ToOwned::to_owned)
|
||
.unwrap_or_else(|| {
|
||
if channel_name == "cli" {
|
||
format!("CLI Session {}", &id[..8])
|
||
} else {
|
||
format!("Session {}", &id[..8])
|
||
}
|
||
});
|
||
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"
|
||
INSERT INTO sessions (
|
||
id, title, channel_name, chat_id, summary,
|
||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||
user_turn_count, agent_prompt_reinjection_count
|
||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0)
|
||
",
|
||
params![&session_id, title, channel_name, id, now],
|
||
)?;
|
||
|
||
drop(conn);
|
||
self.get_session(&session_id)?
|
||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||
}
|
||
|
||
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
|
||
self.create_session("cli", title)
|
||
}
|
||
|
||
pub fn ensure_channel_session(
|
||
&self,
|
||
channel_name: &str,
|
||
chat_id: &str,
|
||
) -> Result<SessionRecord, StorageError> {
|
||
let session_id = persistent_session_id(channel_name, chat_id);
|
||
self.ensure_session(&session_id, channel_name, chat_id, &format!("{}:{}", channel_name, chat_id))
|
||
}
|
||
|
||
/// 确保指定 session_id 的会话存在(如果不存在则创建)
|
||
pub fn ensure_session(
|
||
&self,
|
||
session_id: &str,
|
||
channel_name: &str,
|
||
chat_id: &str,
|
||
title: &str,
|
||
) -> Result<SessionRecord, StorageError> {
|
||
if let Some(record) = self.get_session(session_id)? {
|
||
return Ok(record);
|
||
}
|
||
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"
|
||
INSERT INTO sessions (
|
||
id, title, channel_name, chat_id, summary,
|
||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||
user_turn_count, agent_prompt_reinjection_count
|
||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0)
|
||
",
|
||
params![session_id, title, channel_name, chat_id, now],
|
||
)?;
|
||
drop(conn);
|
||
|
||
self.get_session(session_id)?
|
||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||
}
|
||
|
||
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, title, channel_name, chat_id, summary,
|
||
created_at, updated_at, last_active_at,
|
||
archived_at, deleted_at, message_count,
|
||
user_turn_count, agent_prompt_reinjection_count
|
||
FROM sessions
|
||
WHERE id = ?1 AND deleted_at IS NULL
|
||
",
|
||
)?;
|
||
|
||
stmt.query_row(params![session_id], map_session_record)
|
||
.optional()
|
||
.map_err(StorageError::from)
|
||
}
|
||
|
||
/// Find sessions whose id ends with the given suffix (used for task session lookup)
|
||
pub fn find_sessions_by_id_suffix(
|
||
&self,
|
||
suffix: &str,
|
||
) -> Result<Vec<SessionRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let pattern = format!("%{}", suffix);
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, title, channel_name, chat_id, summary,
|
||
created_at, updated_at, last_active_at,
|
||
archived_at, deleted_at, message_count,
|
||
user_turn_count, agent_prompt_reinjection_count
|
||
FROM sessions
|
||
WHERE id LIKE ?1 AND deleted_at IS NULL
|
||
ORDER BY last_active_at DESC
|
||
",
|
||
)?;
|
||
|
||
let rows = stmt.query_map(params![pattern], map_session_record)?;
|
||
let mut sessions = Vec::new();
|
||
for row in rows {
|
||
sessions.push(row?);
|
||
}
|
||
Ok(sessions)
|
||
}
|
||
|
||
pub fn list_sessions(
|
||
&self,
|
||
channel_name: &str,
|
||
include_archived: bool,
|
||
) -> Result<Vec<SessionRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut sql = String::from(
|
||
"
|
||
SELECT id, title, channel_name, chat_id, summary,
|
||
created_at, updated_at, last_active_at,
|
||
archived_at, deleted_at, message_count,
|
||
user_turn_count, agent_prompt_reinjection_count
|
||
FROM sessions
|
||
WHERE channel_name = ?1
|
||
AND deleted_at IS NULL
|
||
",
|
||
);
|
||
|
||
if !include_archived {
|
||
sql.push_str(" AND archived_at IS NULL");
|
||
}
|
||
|
||
sql.push_str(" ORDER BY last_active_at DESC, created_at DESC");
|
||
|
||
let mut stmt = conn.prepare(&sql)?;
|
||
let rows = stmt.query_map(params![channel_name], map_session_record)?;
|
||
let mut sessions = Vec::new();
|
||
for row in rows {
|
||
sessions.push(row?);
|
||
}
|
||
Ok(sessions)
|
||
}
|
||
|
||
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1 AND deleted_at IS NULL",
|
||
params![session_id, title.trim(), now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn archive_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"UPDATE sessions SET archived_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL",
|
||
params![session_id, now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"DELETE FROM messages WHERE session_id = ?1",
|
||
params![session_id],
|
||
)?;
|
||
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
|
||
Ok(())
|
||
}
|
||
|
||
// ==================== Topic Methods ====================
|
||
|
||
pub fn create_topic(
|
||
&self,
|
||
session_id: &str,
|
||
title: &str,
|
||
description: Option<&str>,
|
||
) -> Result<TopicRecord, StorageError> {
|
||
let now = current_timestamp();
|
||
let id = format!("topic:{}", uuid::Uuid::new_v4());
|
||
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"INSERT INTO topics (id, session_id, title, description, created_at, updated_at, last_active_at, message_count) VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?5, 0)",
|
||
params![&id, session_id, title, description.unwrap_or(""), now],
|
||
)?;
|
||
drop(conn);
|
||
|
||
self.get_topic(&id)?
|
||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||
}
|
||
|
||
pub fn get_topic(&self, topic_id: &str) -> Result<Option<TopicRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"SELECT id, session_id, title, description, created_at, updated_at, last_active_at, message_count FROM topics WHERE id = ?1",
|
||
)?;
|
||
|
||
stmt.query_row(params![topic_id], |row| {
|
||
Ok(TopicRecord {
|
||
id: row.get(0)?,
|
||
session_id: row.get(1)?,
|
||
title: row.get(2)?,
|
||
description: row.get(3)?,
|
||
created_at: row.get(4)?,
|
||
updated_at: row.get(5)?,
|
||
last_active_at: row.get(6)?,
|
||
message_count: row.get(7)?,
|
||
})
|
||
})
|
||
.optional()
|
||
.map_err(StorageError::from)
|
||
}
|
||
|
||
pub fn list_topics(&self, session_id: &str) -> Result<Vec<TopicRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"SELECT id, session_id, title, description, created_at, updated_at, last_active_at, message_count FROM topics WHERE session_id = ?1 ORDER BY last_active_at DESC"
|
||
)?;
|
||
|
||
let rows = stmt.query_map(params![session_id], |row| {
|
||
Ok(TopicRecord {
|
||
id: row.get(0)?,
|
||
session_id: row.get(1)?,
|
||
title: row.get(2)?,
|
||
description: row.get(3)?,
|
||
created_at: row.get(4)?,
|
||
updated_at: row.get(5)?,
|
||
last_active_at: row.get(6)?,
|
||
message_count: row.get(7)?,
|
||
})
|
||
})?;
|
||
|
||
let mut topics = Vec::new();
|
||
for row in rows {
|
||
topics.push(row?);
|
||
}
|
||
Ok(topics)
|
||
}
|
||
|
||
pub fn update_topic_title(&self, topic_id: &str, title: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"UPDATE topics SET title = ?2, updated_at = ?3 WHERE id = ?1",
|
||
params![topic_id, title.trim(), now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn update_topic_description(&self, topic_id: &str, description: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"UPDATE topics SET description = ?2, updated_at = ?3 WHERE id = ?1",
|
||
params![topic_id, description, now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn delete_topic(&self, topic_id: &str) -> Result<(), StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
// Messages 的 topic_id 会被设为 NULL(ON DELETE SET NULL)
|
||
conn.execute("DELETE FROM topics WHERE id = ?1", params![topic_id])?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn touch_topic(&self, topic_id: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"UPDATE topics SET last_active_at = ?2 WHERE id = ?1",
|
||
params![topic_id, now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"DELETE FROM messages WHERE session_id = ?1",
|
||
params![session_id],
|
||
)?;
|
||
conn.execute(
|
||
"
|
||
UPDATE sessions
|
||
SET message_count = 0,
|
||
updated_at = ?2,
|
||
last_active_at = ?2,
|
||
user_turn_count = 0,
|
||
agent_prompt_reinjection_count = 0
|
||
WHERE id = ?1 AND deleted_at IS NULL
|
||
",
|
||
params![session_id, now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn append_message(
|
||
&self,
|
||
session_id: &str,
|
||
message: &ChatMessage,
|
||
) -> Result<(), StorageError> {
|
||
self.append_message_with_topic(session_id, None, message)
|
||
}
|
||
|
||
pub fn append_message_with_topic(
|
||
&self,
|
||
session_id: &str,
|
||
topic_id: Option<&str>,
|
||
message: &ChatMessage,
|
||
) -> Result<(), StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let tx = conn.unchecked_transaction()?;
|
||
|
||
let seq: i64 = tx.query_row(
|
||
"SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1",
|
||
params![session_id],
|
||
|row| row.get(0),
|
||
)?;
|
||
|
||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||
let tool_calls_json = message
|
||
.tool_calls
|
||
.as_ref()
|
||
.map(serde_json::to_string)
|
||
.transpose()?;
|
||
tx.execute(
|
||
"
|
||
INSERT INTO messages (
|
||
id, session_id, topic_id, seq, role, content,
|
||
system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, tool_duration_ms, created_at
|
||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
|
||
",
|
||
params![
|
||
message.id,
|
||
session_id,
|
||
topic_id,
|
||
seq,
|
||
message.role,
|
||
message.content,
|
||
message.system_context,
|
||
message.reasoning_content,
|
||
media_refs_json,
|
||
message.tool_call_id,
|
||
message.tool_name,
|
||
tool_calls_json,
|
||
message.tool_duration_ms.map(|v| v as i64),
|
||
message.timestamp,
|
||
],
|
||
)?;
|
||
|
||
let now = current_timestamp();
|
||
let is_user_message = message.role == "user";
|
||
tx.execute(
|
||
"
|
||
UPDATE sessions
|
||
SET message_count = message_count + 1,
|
||
user_turn_count = user_turn_count + ?3,
|
||
updated_at = ?2,
|
||
last_active_at = ?2,
|
||
archived_at = NULL
|
||
WHERE id = ?1 AND deleted_at IS NULL
|
||
",
|
||
params![session_id, now, if is_user_message { 1 } else { 0 }],
|
||
)?;
|
||
|
||
if let Some(tid) = topic_id {
|
||
tx.execute(
|
||
"UPDATE topics SET message_count = message_count + 1, last_active_at = ?2 WHERE id = ?1",
|
||
params![tid, now],
|
||
)?;
|
||
}
|
||
|
||
tx.commit()?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn append_messages_batch(
|
||
&self,
|
||
session_id: &str,
|
||
topic_id: Option<&str>,
|
||
messages: &[ChatMessage],
|
||
) -> Result<(), StorageError> {
|
||
if messages.is_empty() {
|
||
return Ok(());
|
||
}
|
||
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let tx = conn.unchecked_transaction()?;
|
||
|
||
let mut seq: i64 = tx.query_row(
|
||
"SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1",
|
||
params![session_id],
|
||
|row| row.get(0),
|
||
)?;
|
||
|
||
for message in messages {
|
||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||
let tool_calls_json = message
|
||
.tool_calls
|
||
.as_ref()
|
||
.map(serde_json::to_string)
|
||
.transpose()?;
|
||
tx.execute(
|
||
"
|
||
INSERT INTO messages (
|
||
id, session_id, topic_id, seq, role, content,
|
||
system_context, reasoning_content, media_refs_json,
|
||
tool_call_id, tool_name, tool_calls_json, tool_duration_ms, created_at
|
||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
|
||
",
|
||
params![
|
||
message.id,
|
||
session_id,
|
||
topic_id,
|
||
seq,
|
||
message.role,
|
||
message.content,
|
||
message.system_context,
|
||
message.reasoning_content,
|
||
media_refs_json,
|
||
message.tool_call_id,
|
||
message.tool_name,
|
||
tool_calls_json,
|
||
message.tool_duration_ms.map(|v| v as i64),
|
||
message.timestamp,
|
||
],
|
||
)?;
|
||
seq += 1;
|
||
}
|
||
|
||
let now = current_timestamp();
|
||
let user_msg_count: i64 = messages
|
||
.iter()
|
||
.filter(|m| m.role == "user")
|
||
.count()
|
||
.try_into()
|
||
.unwrap_or(0);
|
||
let msg_count: i64 = messages.len() as i64;
|
||
|
||
tx.execute(
|
||
"
|
||
UPDATE sessions
|
||
SET message_count = message_count + ?2,
|
||
user_turn_count = user_turn_count + ?3,
|
||
updated_at = ?4,
|
||
last_active_at = ?4,
|
||
archived_at = NULL
|
||
WHERE id = ?1 AND deleted_at IS NULL
|
||
",
|
||
params![session_id, msg_count, user_msg_count, now],
|
||
)?;
|
||
|
||
if let Some(tid) = topic_id {
|
||
tx.execute(
|
||
"UPDATE topics SET message_count = message_count + ?2, last_active_at = ?3 WHERE id = ?1",
|
||
params![tid, msg_count, now],
|
||
)?;
|
||
}
|
||
|
||
tx.commit()?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn compact_active_history(
|
||
&self,
|
||
session_id: &str,
|
||
snapshot_end_seq: i64,
|
||
preserved_system_messages: &[ChatMessage],
|
||
summary_message: &ChatMessage,
|
||
preserved_messages: &[ChatMessage],
|
||
) -> Result<bool, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let tx = conn.unchecked_transaction()?;
|
||
|
||
let current_max_seq: i64 = tx.query_row(
|
||
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
|
||
params![session_id],
|
||
|row| row.get(0),
|
||
)?;
|
||
|
||
if snapshot_end_seq > current_max_seq {
|
||
return Ok(false);
|
||
}
|
||
|
||
let delta_messages =
|
||
load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
|
||
let now = current_timestamp();
|
||
|
||
// Collect all new messages first, then sanitize incomplete tool call
|
||
// sequences before writing to DB. This prevents orphaned tool_calls
|
||
// (without corresponding tool results) from being persisted permanently
|
||
// when compaction preserves an incomplete sequence from the snapshot or
|
||
// captures a partial sequence from delta messages.
|
||
let mut new_messages: Vec<ChatMessage> = Vec::new();
|
||
|
||
for message in preserved_system_messages {
|
||
new_messages.push(clone_message_for_compaction(message, message.timestamp));
|
||
}
|
||
|
||
new_messages.push(clone_message_for_compaction(summary_message, now));
|
||
|
||
for message in preserved_messages.iter().chain(delta_messages.iter()) {
|
||
new_messages.push(clone_message_for_compaction(message, message.timestamp));
|
||
}
|
||
|
||
let removed =
|
||
crate::bus::message::sanitize_incomplete_tool_call_sequences(&mut new_messages);
|
||
if removed > 0 {
|
||
tracing::warn!(
|
||
removed_count = removed,
|
||
session_id = %session_id,
|
||
"Compaction removed incomplete tool call sequences from new history"
|
||
);
|
||
}
|
||
|
||
// Write sanitized messages to DB
|
||
let mut next_seq = current_max_seq + 1;
|
||
let mut inserted_count = 0_i64;
|
||
let mut active_user_turn_count = 0_i64;
|
||
|
||
for message in &new_messages {
|
||
if message.role == "user" {
|
||
active_user_turn_count += 1;
|
||
}
|
||
insert_message_with_seq(&tx, session_id, next_seq, message)?;
|
||
next_seq += 1;
|
||
inserted_count += 1;
|
||
}
|
||
|
||
// Delete all old messages (including delta messages that were just re-inserted)
|
||
tx.execute(
|
||
"DELETE FROM messages WHERE session_id = ?1 AND seq <= ?2",
|
||
params![session_id, current_max_seq],
|
||
)?;
|
||
|
||
tx.execute(
|
||
"
|
||
UPDATE sessions
|
||
SET message_count = ?2,
|
||
user_turn_count = ?3,
|
||
updated_at = ?4,
|
||
last_active_at = ?4,
|
||
archived_at = NULL
|
||
WHERE id = ?1 AND deleted_at IS NULL
|
||
",
|
||
params![
|
||
session_id,
|
||
inserted_count,
|
||
active_user_turn_count,
|
||
now,
|
||
],
|
||
)?;
|
||
|
||
tx.commit()?;
|
||
Ok(true)
|
||
}
|
||
|
||
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"
|
||
UPDATE sessions
|
||
SET agent_prompt_reinjection_count = agent_prompt_reinjection_count + 1,
|
||
updated_at = ?2,
|
||
last_active_at = ?2,
|
||
archived_at = NULL
|
||
WHERE id = ?1 AND deleted_at IS NULL
|
||
",
|
||
params![session_id, now],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn append_skill_event(
|
||
&self,
|
||
session_id: Option<&str>,
|
||
event_type: &str,
|
||
skill_name: Option<&str>,
|
||
payload: &serde_json::Value,
|
||
) -> Result<(), StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"
|
||
INSERT INTO skill_events (
|
||
id, session_id, event_type, skill_name, payload_json, created_at
|
||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||
",
|
||
params![
|
||
uuid::Uuid::new_v4().to_string(),
|
||
session_id,
|
||
event_type,
|
||
skill_name,
|
||
serde_json::to_string(payload)?,
|
||
current_timestamp(),
|
||
],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn list_skill_events(
|
||
&self,
|
||
session_id: Option<&str>,
|
||
) -> Result<Vec<SkillEventRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let sql = if session_id.is_some() {
|
||
"
|
||
SELECT id, session_id, event_type, skill_name, payload_json, created_at
|
||
FROM skill_events
|
||
WHERE session_id = ?1
|
||
ORDER BY created_at ASC
|
||
"
|
||
} else {
|
||
"
|
||
SELECT id, session_id, event_type, skill_name, payload_json, created_at
|
||
FROM skill_events
|
||
WHERE session_id IS NULL
|
||
ORDER BY created_at ASC
|
||
"
|
||
};
|
||
|
||
let mut stmt = conn.prepare(sql)?;
|
||
let rows = if let Some(session_id) = session_id {
|
||
stmt.query_map(params![session_id], map_skill_event_record)?
|
||
} else {
|
||
stmt.query_map([], map_skill_event_record)?
|
||
};
|
||
|
||
let mut events = Vec::new();
|
||
for row in rows {
|
||
events.push(row?);
|
||
}
|
||
Ok(events)
|
||
}
|
||
|
||
pub fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let tx = conn.unchecked_transaction()?;
|
||
|
||
let existing: Option<(String, i64)> = tx
|
||
.query_row(
|
||
"
|
||
SELECT id, created_at
|
||
FROM memories
|
||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
|
||
",
|
||
params![
|
||
input.scope_kind,
|
||
input.scope_key,
|
||
input.namespace,
|
||
input.memory_key,
|
||
],
|
||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||
)
|
||
.optional()?;
|
||
|
||
let (id, created_at) = existing.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
|
||
|
||
tx.execute(
|
||
"
|
||
INSERT INTO memories (
|
||
id, scope_kind, scope_key, namespace, memory_key, content,
|
||
source_type, source_session_id, source_message_id, source_message_seq,
|
||
source_channel_name, source_chat_id, created_at, updated_at
|
||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
|
||
ON CONFLICT(scope_kind, scope_key, namespace, memory_key) DO UPDATE SET
|
||
content = excluded.content,
|
||
source_type = excluded.source_type,
|
||
source_session_id = excluded.source_session_id,
|
||
source_message_id = excluded.source_message_id,
|
||
source_message_seq = excluded.source_message_seq,
|
||
source_channel_name = excluded.source_channel_name,
|
||
source_chat_id = excluded.source_chat_id,
|
||
updated_at = excluded.updated_at
|
||
",
|
||
params![
|
||
id,
|
||
input.scope_kind,
|
||
input.scope_key,
|
||
input.namespace,
|
||
input.memory_key,
|
||
input.content,
|
||
input.source_type,
|
||
input.source_session_id,
|
||
input.source_message_id,
|
||
input.source_message_seq,
|
||
input.source_channel_name,
|
||
input.source_chat_id,
|
||
created_at,
|
||
now,
|
||
],
|
||
)?;
|
||
|
||
tx.commit()?;
|
||
drop(conn);
|
||
|
||
self.get_memory(
|
||
&input.scope_kind,
|
||
&input.scope_key,
|
||
&input.namespace,
|
||
&input.memory_key,
|
||
)?
|
||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||
}
|
||
|
||
pub fn get_memory(
|
||
&self,
|
||
scope_kind: &str,
|
||
scope_key: &str,
|
||
namespace: &str,
|
||
memory_key: &str,
|
||
) -> Result<Option<MemoryRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||
source_type, source_session_id, source_message_id, source_message_seq,
|
||
source_channel_name, source_chat_id, created_at, updated_at
|
||
FROM memories
|
||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
|
||
",
|
||
)?;
|
||
|
||
stmt.query_row(
|
||
params![scope_kind, scope_key, namespace, memory_key],
|
||
map_memory_record,
|
||
)
|
||
.optional()
|
||
.map_err(StorageError::from)
|
||
}
|
||
|
||
pub fn list_memories(
|
||
&self,
|
||
scope_kind: &str,
|
||
scope_key: &str,
|
||
namespace: Option<&str>,
|
||
limit: usize,
|
||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let limit = limit.max(1) as i64;
|
||
let mut memories = Vec::new();
|
||
|
||
if let Some(namespace) = namespace {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||
source_type, source_session_id, source_message_id, source_message_seq,
|
||
source_channel_name, source_chat_id, created_at, updated_at
|
||
FROM memories
|
||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3
|
||
ORDER BY updated_at DESC, created_at DESC
|
||
LIMIT ?4
|
||
",
|
||
)?;
|
||
let rows = stmt.query_map(
|
||
params![scope_kind, scope_key, namespace, limit],
|
||
map_memory_record,
|
||
)?;
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
} else {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||
source_type, source_session_id, source_message_id, source_message_seq,
|
||
source_channel_name, source_chat_id, created_at, updated_at
|
||
FROM memories
|
||
WHERE scope_kind = ?1 AND scope_key = ?2
|
||
ORDER BY updated_at DESC, created_at DESC
|
||
LIMIT ?3
|
||
",
|
||
)?;
|
||
let rows = stmt.query_map(params![scope_kind, scope_key, limit], map_memory_record)?;
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
}
|
||
|
||
Ok(memories)
|
||
}
|
||
|
||
pub fn list_memory_scope_keys(&self, scope_kind: &str) -> Result<Vec<String>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT DISTINCT scope_key
|
||
FROM memories
|
||
WHERE scope_kind = ?1
|
||
ORDER BY scope_key ASC
|
||
",
|
||
)?;
|
||
|
||
let rows = stmt.query_map(params![scope_kind], |row| row.get::<_, String>(0))?;
|
||
let mut scope_keys = Vec::new();
|
||
for row in rows {
|
||
scope_keys.push(row?);
|
||
}
|
||
Ok(scope_keys)
|
||
}
|
||
|
||
pub fn list_memories_for_scope(
|
||
&self,
|
||
scope_kind: &str,
|
||
scope_key: &str,
|
||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||
source_type, source_session_id, source_message_id, source_message_seq,
|
||
source_channel_name, source_chat_id, created_at, updated_at
|
||
FROM memories
|
||
WHERE scope_kind = ?1 AND scope_key = ?2
|
||
ORDER BY updated_at DESC, namespace ASC, memory_key ASC
|
||
",
|
||
)?;
|
||
|
||
let rows = stmt.query_map(params![scope_kind, scope_key], map_memory_record)?;
|
||
let mut memories = Vec::new();
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
Ok(memories)
|
||
}
|
||
|
||
pub fn update_memory(
|
||
&self,
|
||
input: &MemoryUpsert,
|
||
) -> Result<Option<MemoryRecord>, StorageError> {
|
||
if self
|
||
.get_memory(
|
||
&input.scope_kind,
|
||
&input.scope_key,
|
||
&input.namespace,
|
||
&input.memory_key,
|
||
)?
|
||
.is_none()
|
||
{
|
||
return Ok(None);
|
||
}
|
||
|
||
self.put_memory(input).map(Some)
|
||
}
|
||
|
||
pub fn delete_memory(
|
||
&self,
|
||
scope_kind: &str,
|
||
scope_key: &str,
|
||
namespace: &str,
|
||
memory_key: &str,
|
||
) -> Result<bool, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let changed = conn.execute(
|
||
"
|
||
DELETE FROM memories
|
||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
|
||
",
|
||
params![scope_kind, scope_key, namespace, memory_key],
|
||
)?;
|
||
Ok(changed > 0)
|
||
}
|
||
|
||
pub fn upsert_scheduler_job(
|
||
&self,
|
||
input: &SchedulerJobUpsert,
|
||
) -> Result<SchedulerJobRecord, StorageError> {
|
||
let now = current_timestamp();
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"
|
||
INSERT INTO scheduler_jobs (
|
||
id, kind, schedule_json, interval_secs, startup_delay_secs,
|
||
target_json, payload_json, enabled, state, last_status, last_error,
|
||
run_count, max_runs, last_fired_at, next_fire_at, paused_at, completed_at,
|
||
created_at, updated_at
|
||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?18)
|
||
ON CONFLICT(id) DO UPDATE SET
|
||
kind = excluded.kind,
|
||
schedule_json = excluded.schedule_json,
|
||
interval_secs = excluded.interval_secs,
|
||
startup_delay_secs = excluded.startup_delay_secs,
|
||
target_json = excluded.target_json,
|
||
payload_json = excluded.payload_json,
|
||
enabled = excluded.enabled,
|
||
state = excluded.state,
|
||
last_status = excluded.last_status,
|
||
last_error = excluded.last_error,
|
||
run_count = excluded.run_count,
|
||
max_runs = excluded.max_runs,
|
||
last_fired_at = excluded.last_fired_at,
|
||
next_fire_at = excluded.next_fire_at,
|
||
paused_at = excluded.paused_at,
|
||
completed_at = excluded.completed_at,
|
||
updated_at = excluded.updated_at
|
||
",
|
||
params![
|
||
input.id,
|
||
input.kind,
|
||
serde_json::to_string(&input.schedule)?,
|
||
input.interval_secs,
|
||
input.startup_delay_secs,
|
||
serde_json::to_string(&input.target)?,
|
||
serde_json::to_string(&input.payload)?,
|
||
if input.enabled { 1 } else { 0 },
|
||
input.state.as_str(),
|
||
input.last_status.as_ref().map(SchedulerJobStatus::as_str),
|
||
input.last_error,
|
||
input.run_count,
|
||
input.max_runs,
|
||
input.last_fired_at,
|
||
input.next_fire_at,
|
||
input.paused_at,
|
||
input.completed_at,
|
||
now,
|
||
],
|
||
)?;
|
||
drop(conn);
|
||
|
||
self.get_scheduler_job(&input.id)?
|
||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||
}
|
||
|
||
pub fn get_scheduler_job(
|
||
&self,
|
||
job_id: &str,
|
||
) -> Result<Option<SchedulerJobRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, kind, schedule_json, interval_secs, startup_delay_secs,
|
||
target_json, payload_json, enabled, state, last_status, last_error,
|
||
run_count, max_runs, last_fired_at, next_fire_at, paused_at, completed_at,
|
||
created_at, updated_at
|
||
FROM scheduler_jobs
|
||
WHERE id = ?1
|
||
",
|
||
)?;
|
||
|
||
stmt.query_row(params![job_id], map_scheduler_job_record)
|
||
.optional()
|
||
.map_err(StorageError::from)
|
||
}
|
||
|
||
pub fn list_scheduler_jobs(
|
||
&self,
|
||
enabled_only: bool,
|
||
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let sql = if enabled_only {
|
||
"
|
||
SELECT id, kind, schedule_json, interval_secs, startup_delay_secs,
|
||
target_json, payload_json, enabled, state, last_status, last_error,
|
||
run_count, max_runs, last_fired_at, next_fire_at, paused_at, completed_at,
|
||
created_at, updated_at
|
||
FROM scheduler_jobs
|
||
WHERE enabled = 1
|
||
ORDER BY COALESCE(next_fire_at, created_at) ASC, id ASC
|
||
"
|
||
} else {
|
||
"
|
||
SELECT id, kind, schedule_json, interval_secs, startup_delay_secs,
|
||
target_json, payload_json, enabled, state, last_status, last_error,
|
||
run_count, max_runs, last_fired_at, next_fire_at, paused_at, completed_at,
|
||
created_at, updated_at
|
||
FROM scheduler_jobs
|
||
ORDER BY COALESCE(next_fire_at, created_at) ASC, id ASC
|
||
"
|
||
};
|
||
|
||
let mut stmt = conn.prepare(sql)?;
|
||
let rows = stmt.query_map([], map_scheduler_job_record)?;
|
||
let mut jobs = Vec::new();
|
||
for row in rows {
|
||
jobs.push(row?);
|
||
}
|
||
Ok(jobs)
|
||
}
|
||
|
||
pub fn list_running_scheduler_jobs(&self) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let sql = "
|
||
SELECT id, kind, schedule_json, interval_secs, startup_delay_secs,
|
||
target_json, payload_json, enabled, state, last_status, last_error,
|
||
run_count, max_runs, last_fired_at, next_fire_at, paused_at, completed_at,
|
||
created_at, updated_at
|
||
FROM scheduler_jobs
|
||
WHERE state = 'running'
|
||
ORDER BY COALESCE(next_fire_at, created_at) ASC, id ASC
|
||
";
|
||
|
||
let mut stmt = conn.prepare(sql)?;
|
||
let rows = stmt.query_map([], map_scheduler_job_record)?;
|
||
let mut jobs = Vec::new();
|
||
for row in rows {
|
||
jobs.push(row?);
|
||
}
|
||
Ok(jobs)
|
||
}
|
||
|
||
pub fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute("DELETE FROM scheduler_jobs WHERE id = ?1", params![job_id])?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn update_scheduler_job_runtime(
|
||
&self,
|
||
job_id: &str,
|
||
state: SchedulerJobState,
|
||
last_status: Option<SchedulerJobStatus>,
|
||
last_error: Option<&str>,
|
||
run_count: i64,
|
||
last_fired_at: Option<i64>,
|
||
next_fire_at: Option<i64>,
|
||
paused_at: Option<i64>,
|
||
completed_at: Option<i64>,
|
||
) -> Result<(), StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.execute(
|
||
"
|
||
UPDATE scheduler_jobs
|
||
SET state = ?2,
|
||
last_status = ?3,
|
||
last_error = ?4,
|
||
run_count = ?5,
|
||
last_fired_at = ?6,
|
||
next_fire_at = ?7,
|
||
paused_at = ?8,
|
||
completed_at = ?9,
|
||
updated_at = ?10
|
||
WHERE id = ?1
|
||
",
|
||
params![
|
||
job_id,
|
||
state.as_str(),
|
||
last_status.as_ref().map(SchedulerJobStatus::as_str),
|
||
last_error,
|
||
run_count,
|
||
last_fired_at,
|
||
next_fire_at,
|
||
paused_at,
|
||
completed_at,
|
||
current_timestamp(),
|
||
],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn search_memories(
|
||
&self,
|
||
scope_kind: &str,
|
||
scope_key: &str,
|
||
query: &str,
|
||
namespace: Option<&str>,
|
||
limit: usize,
|
||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let limit = limit.max(1) as i64;
|
||
let query = quote_fts_query(query);
|
||
let mut memories = Vec::new();
|
||
|
||
if let Some(namespace) = namespace {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
|
||
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
|
||
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
|
||
FROM memories_fts f
|
||
JOIN memories m ON m.rowid = f.rowid
|
||
WHERE memories_fts MATCH ?1
|
||
AND m.scope_kind = ?2
|
||
AND m.scope_key = ?3
|
||
AND m.namespace = ?4
|
||
ORDER BY bm25(memories_fts), m.updated_at DESC
|
||
LIMIT ?5
|
||
",
|
||
)?;
|
||
let rows = stmt.query_map(
|
||
params![query, scope_kind, scope_key, namespace, limit],
|
||
map_memory_record,
|
||
)?;
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
} else {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
|
||
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
|
||
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
|
||
FROM memories_fts f
|
||
JOIN memories m ON m.rowid = f.rowid
|
||
WHERE memories_fts MATCH ?1
|
||
AND m.scope_kind = ?2
|
||
AND m.scope_key = ?3
|
||
ORDER BY bm25(memories_fts), m.updated_at DESC
|
||
LIMIT ?4
|
||
",
|
||
)?;
|
||
let rows = stmt.query_map(
|
||
params![query, scope_kind, scope_key, limit],
|
||
map_memory_record,
|
||
)?;
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
}
|
||
|
||
Ok(memories)
|
||
}
|
||
|
||
pub fn search_memories_any(
|
||
&self,
|
||
scope_kind: &str,
|
||
scope_key: &str,
|
||
queries: &[String],
|
||
namespace: Option<&str>,
|
||
limit: usize,
|
||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let limit = limit.max(1) as i64;
|
||
let query = quote_fts_or_query(queries);
|
||
if query.is_empty() {
|
||
return Ok(Vec::new());
|
||
}
|
||
|
||
let mut memories = Vec::new();
|
||
|
||
if let Some(namespace) = namespace {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
|
||
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
|
||
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
|
||
FROM memories_fts f
|
||
JOIN memories m ON m.rowid = f.rowid
|
||
WHERE memories_fts MATCH ?1
|
||
AND m.scope_kind = ?2
|
||
AND m.scope_key = ?3
|
||
AND m.namespace = ?4
|
||
ORDER BY bm25(memories_fts), m.updated_at DESC
|
||
LIMIT ?5
|
||
",
|
||
)?;
|
||
let rows = stmt.query_map(
|
||
params![query, scope_kind, scope_key, namespace, limit],
|
||
map_memory_record,
|
||
)?;
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
} else {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
|
||
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
|
||
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
|
||
FROM memories_fts f
|
||
JOIN memories m ON m.rowid = f.rowid
|
||
WHERE memories_fts MATCH ?1
|
||
AND m.scope_kind = ?2
|
||
AND m.scope_key = ?3
|
||
ORDER BY bm25(memories_fts), m.updated_at DESC
|
||
LIMIT ?4
|
||
",
|
||
)?;
|
||
let rows = stmt.query_map(
|
||
params![query, scope_kind, scope_key, limit],
|
||
map_memory_record,
|
||
)?;
|
||
for row in rows {
|
||
memories.push(row?);
|
||
}
|
||
}
|
||
|
||
Ok(memories)
|
||
}
|
||
|
||
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
load_messages_after(&conn, session_id, 0)
|
||
}
|
||
|
||
pub fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||
FROM messages
|
||
WHERE topic_id = ?1
|
||
ORDER BY seq ASC
|
||
",
|
||
)?;
|
||
|
||
let rows = stmt.query_map(params![topic_id], |row| {
|
||
let media_refs_json: String = row.get(5)?;
|
||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
media_refs_json.len(),
|
||
rusqlite::types::Type::Text,
|
||
Box::new(err),
|
||
)
|
||
})?;
|
||
|
||
let tool_calls_json: Option<String> = row.get(9)?;
|
||
let tool_calls = tool_calls_json
|
||
.as_deref()
|
||
.map(serde_json::from_str)
|
||
.transpose()
|
||
.map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
9,
|
||
rusqlite::types::Type::Text,
|
||
Box::new(err),
|
||
)
|
||
})?;
|
||
|
||
Ok(ChatMessage {
|
||
id: row.get(0)?,
|
||
role: row.get(1)?,
|
||
content: row.get(2)?,
|
||
system_context: row.get(3)?,
|
||
reasoning_content: row.get(4)?,
|
||
media_refs,
|
||
timestamp: row.get(6)?,
|
||
tool_call_id: row.get(7)?,
|
||
tool_name: row.get(8)?,
|
||
tool_state: None,
|
||
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
||
tool_calls,
|
||
})
|
||
})?;
|
||
|
||
let mut messages = Vec::new();
|
||
for row in rows {
|
||
messages.push(row?);
|
||
}
|
||
Ok(messages)
|
||
}
|
||
|
||
/// 获取指定话题的消息数量(动态计算,确保准确)
|
||
pub fn get_topic_message_count(&self, topic_id: &str) -> Result<usize, StorageError> {
|
||
self.load_messages_for_topic(topic_id).map(|msgs| msgs.len())
|
||
}
|
||
|
||
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
load_messages_after(&conn, session_id, 0)
|
||
}
|
||
|
||
pub fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
|
||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||
conn.query_row(
|
||
"
|
||
SELECT COUNT(*)
|
||
FROM messages
|
||
WHERE session_id = ?1 AND role = 'user'
|
||
",
|
||
params![session_id],
|
||
|row| row.get(0),
|
||
)
|
||
.map_err(StorageError::from)
|
||
}
|
||
}
|
||
|
||
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
||
if channel_name == "cli" {
|
||
chat_id.to_string()
|
||
} else {
|
||
format!("{}:{}", channel_name, chat_id)
|
||
}
|
||
}
|
||
|
||
#[cfg(not(test))]
|
||
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
|
||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||
Ok(home.join(".picobot").join("storage").join("sessions.db"))
|
||
}
|
||
|
||
fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
|
||
Ok(SessionRecord {
|
||
id: row.get(0)?,
|
||
title: row.get(1)?,
|
||
channel_name: row.get(2)?,
|
||
chat_id: row.get(3)?,
|
||
summary: row.get(4)?,
|
||
created_at: row.get(5)?,
|
||
updated_at: row.get(6)?,
|
||
last_active_at: row.get(7)?,
|
||
archived_at: row.get(8)?,
|
||
deleted_at: row.get(9)?,
|
||
message_count: row.get(10)?,
|
||
user_turn_count: row.get(11)?,
|
||
agent_prompt_reinjection_count: row.get(12)?,
|
||
})
|
||
}
|
||
|
||
fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEventRecord> {
|
||
let payload_json: String = row.get(4)?;
|
||
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(err))
|
||
})?;
|
||
|
||
Ok(SkillEventRecord {
|
||
id: row.get(0)?,
|
||
session_id: row.get(1)?,
|
||
event_type: row.get(2)?,
|
||
skill_name: row.get(3)?,
|
||
payload,
|
||
created_at: row.get(5)?,
|
||
})
|
||
}
|
||
|
||
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
|
||
Ok(MemoryRecord {
|
||
id: row.get(0)?,
|
||
scope_kind: row.get(1)?,
|
||
scope_key: row.get(2)?,
|
||
namespace: row.get(3)?,
|
||
memory_key: row.get(4)?,
|
||
content: row.get(5)?,
|
||
source_type: row.get(6)?,
|
||
source_session_id: row.get(7)?,
|
||
source_message_id: row.get(8)?,
|
||
source_message_seq: row.get(9)?,
|
||
source_channel_name: row.get(10)?,
|
||
source_chat_id: row.get(11)?,
|
||
created_at: row.get(12)?,
|
||
updated_at: row.get(13)?,
|
||
})
|
||
}
|
||
|
||
fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SchedulerJobRecord> {
|
||
let schedule_json: String = row.get(2)?;
|
||
let target_json: String = row.get(5)?;
|
||
let payload_json: String = row.get(6)?;
|
||
let state: String = row.get(8)?;
|
||
let last_status: Option<String> = row.get(9)?;
|
||
|
||
let schedule = serde_json::from_str(&schedule_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(2, rusqlite::types::Type::Text, Box::new(err))
|
||
})?;
|
||
let target = serde_json::from_str(&target_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(err))
|
||
})?;
|
||
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(err))
|
||
})?;
|
||
|
||
Ok(SchedulerJobRecord {
|
||
id: row.get(0)?,
|
||
kind: row.get(1)?,
|
||
schedule,
|
||
interval_secs: row.get(3)?,
|
||
startup_delay_secs: row.get(4)?,
|
||
target,
|
||
payload,
|
||
enabled: row.get::<_, i64>(7)? != 0,
|
||
state: SchedulerJobState::from_str(&state).ok_or_else(|| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
8,
|
||
rusqlite::types::Type::Text,
|
||
format!("invalid scheduler job state: {}", state).into(),
|
||
)
|
||
})?,
|
||
last_status: last_status.and_then(|value| SchedulerJobStatus::from_str(&value)),
|
||
last_error: row.get(10)?,
|
||
run_count: row.get(11)?,
|
||
max_runs: row.get(12)?,
|
||
last_fired_at: row.get(13)?,
|
||
next_fire_at: row.get(14)?,
|
||
paused_at: row.get(15)?,
|
||
completed_at: row.get(16)?,
|
||
created_at: row.get(17)?,
|
||
updated_at: row.get(18)?,
|
||
})
|
||
}
|
||
|
||
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||
if !has_column(conn, "sessions", "user_turn_count")? {
|
||
add_column_if_missing(
|
||
conn,
|
||
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
|
||
)?;
|
||
}
|
||
|
||
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
|
||
add_column_if_missing(
|
||
conn,
|
||
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
|
||
)?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
|
||
if !has_column(conn, "messages", "system_context")? {
|
||
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN system_context TEXT")?;
|
||
}
|
||
|
||
if !has_column(conn, "messages", "reasoning_content")? {
|
||
add_column_if_missing(
|
||
conn,
|
||
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
|
||
)?;
|
||
}
|
||
|
||
if !has_column(conn, "messages", "topic_id")? {
|
||
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN topic_id TEXT")?;
|
||
// 添加外键约束(SQLite 不支持 ALTER TABLE ADD FOREIGN KEY,需要重建表)
|
||
// 这里只添加列,外键约束由应用层保证
|
||
}
|
||
|
||
if !has_column(conn, "messages", "tool_duration_ms")? {
|
||
add_column_if_missing(
|
||
conn,
|
||
"ALTER TABLE messages ADD COLUMN tool_duration_ms INTEGER",
|
||
)?;
|
||
}
|
||
|
||
// 创建 topic_id 索引(如果不存在)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_messages_topic_seq ON messages(topic_id, seq) WHERE topic_id IS NOT NULL",
|
||
[],
|
||
)?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||
if !has_column(conn, "scheduler_jobs", "schedule_json")? {
|
||
conn.execute(
|
||
"ALTER TABLE scheduler_jobs ADD COLUMN schedule_json TEXT NOT NULL DEFAULT '{}'",
|
||
[],
|
||
)?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "state")? {
|
||
conn.execute(
|
||
"ALTER TABLE scheduler_jobs ADD COLUMN state TEXT NOT NULL DEFAULT 'scheduled'",
|
||
[],
|
||
)?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "last_status")? {
|
||
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT", [])?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "last_error")? {
|
||
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT", [])?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "run_count")? {
|
||
conn.execute(
|
||
"ALTER TABLE scheduler_jobs ADD COLUMN run_count INTEGER NOT NULL DEFAULT 0",
|
||
[],
|
||
)?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "max_runs")? {
|
||
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER", [])?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "paused_at")? {
|
||
conn.execute(
|
||
"ALTER TABLE scheduler_jobs ADD COLUMN paused_at INTEGER",
|
||
[],
|
||
)?;
|
||
}
|
||
|
||
if !has_column(conn, "scheduler_jobs", "completed_at")? {
|
||
conn.execute(
|
||
"ALTER TABLE scheduler_jobs ADD COLUMN completed_at INTEGER",
|
||
[],
|
||
)?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageError> {
|
||
// 步骤1:去重。多条记录 scope_key 不同,改为 "default" 后会违反唯一约束。
|
||
// 对每个 (scope_kind, namespace, memory_key) 组合保留 updated_at 最新的一条。
|
||
conn.execute(
|
||
"
|
||
DELETE FROM memories
|
||
WHERE rowid NOT IN (
|
||
SELECT rowid FROM (
|
||
SELECT rowid, ROW_NUMBER() OVER (
|
||
PARTITION BY scope_kind, namespace, memory_key
|
||
ORDER BY updated_at DESC
|
||
) AS rn
|
||
FROM memories
|
||
)
|
||
WHERE rn = 1
|
||
)
|
||
",
|
||
[],
|
||
)?;
|
||
|
||
// 步骤2:统一 scope_key
|
||
conn.execute(
|
||
"UPDATE memories SET scope_key = 'default' WHERE scope_key != 'default'",
|
||
[],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
fn has_column(
|
||
conn: &Connection,
|
||
table_name: &str,
|
||
column_name: &str,
|
||
) -> Result<bool, StorageError> {
|
||
let pragma = format!("PRAGMA table_info({})", table_name);
|
||
let mut stmt = conn.prepare(&pragma)?;
|
||
let mut rows = stmt.query([])?;
|
||
|
||
while let Some(row) = rows.next()? {
|
||
let existing_name: String = row.get(1)?;
|
||
if existing_name == column_name {
|
||
return Ok(true);
|
||
}
|
||
}
|
||
|
||
Ok(false)
|
||
}
|
||
|
||
fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageError> {
|
||
match conn.execute(sql, []) {
|
||
Ok(_) => Ok(()),
|
||
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
|
||
if message.contains("duplicate column name") =>
|
||
{
|
||
Ok(())
|
||
}
|
||
Err(error) => Err(StorageError::Database(error)),
|
||
}
|
||
}
|
||
|
||
fn insert_message_with_seq(
|
||
conn: &rusqlite::Transaction<'_>,
|
||
session_id: &str,
|
||
seq: i64,
|
||
message: &ChatMessage,
|
||
) -> Result<(), StorageError> {
|
||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||
let tool_calls_json = message
|
||
.tool_calls
|
||
.as_ref()
|
||
.map(serde_json::to_string)
|
||
.transpose()?;
|
||
conn.execute(
|
||
"
|
||
INSERT INTO messages (
|
||
id, session_id, seq, role, content,
|
||
system_context, reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, tool_duration_ms, created_at
|
||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)
|
||
",
|
||
params![
|
||
message.id,
|
||
session_id,
|
||
seq,
|
||
message.role,
|
||
message.content,
|
||
message.system_context,
|
||
message.reasoning_content,
|
||
media_refs_json,
|
||
message.tool_call_id,
|
||
message.tool_name,
|
||
tool_calls_json,
|
||
message.tool_duration_ms.map(|v| v as i64),
|
||
message.timestamp,
|
||
],
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
fn clone_message_for_compaction(message: &ChatMessage, timestamp: i64) -> ChatMessage {
|
||
ChatMessage {
|
||
id: uuid::Uuid::new_v4().to_string(),
|
||
role: message.role.clone(),
|
||
content: message.content.clone(),
|
||
media_refs: message.media_refs.clone(),
|
||
timestamp,
|
||
system_context: message.system_context.clone(),
|
||
reasoning_content: message.reasoning_content.clone(),
|
||
tool_call_id: message.tool_call_id.clone(),
|
||
tool_name: message.tool_name.clone(),
|
||
tool_state: message.tool_state.clone(),
|
||
tool_duration_ms: message.tool_duration_ms,
|
||
tool_calls: message.tool_calls.clone(),
|
||
}
|
||
}
|
||
|
||
fn load_messages_between(
|
||
conn: &rusqlite::Transaction<'_>,
|
||
session_id: &str,
|
||
start_seq_exclusive: i64,
|
||
end_seq_inclusive: i64,
|
||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||
FROM messages
|
||
WHERE session_id = ?1 AND seq > ?2 AND seq <= ?3
|
||
ORDER BY seq ASC
|
||
",
|
||
)?;
|
||
|
||
let rows = stmt.query_map(
|
||
params![session_id, start_seq_exclusive, end_seq_inclusive],
|
||
|row| {
|
||
let media_refs_json: String = row.get(5)?;
|
||
let media_refs: Vec<String> =
|
||
serde_json::from_str(&media_refs_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
media_refs_json.len(),
|
||
rusqlite::types::Type::Text,
|
||
Box::new(err),
|
||
)
|
||
})?;
|
||
|
||
let tool_calls_json: Option<String> = row.get(9)?;
|
||
let tool_calls = tool_calls_json
|
||
.as_deref()
|
||
.map(serde_json::from_str)
|
||
.transpose()
|
||
.map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
9,
|
||
rusqlite::types::Type::Text,
|
||
Box::new(err),
|
||
)
|
||
})?;
|
||
|
||
Ok(ChatMessage {
|
||
id: row.get(0)?,
|
||
role: row.get(1)?,
|
||
content: row.get(2)?,
|
||
system_context: row.get(3)?,
|
||
reasoning_content: row.get(4)?,
|
||
media_refs,
|
||
timestamp: row.get(6)?,
|
||
tool_call_id: row.get(7)?,
|
||
tool_name: row.get(8)?,
|
||
tool_state: None,
|
||
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
||
tool_calls,
|
||
})
|
||
},
|
||
)?;
|
||
|
||
let mut messages = Vec::new();
|
||
for row in rows {
|
||
messages.push(row?);
|
||
}
|
||
Ok(messages)
|
||
}
|
||
|
||
fn load_messages_after(
|
||
conn: &Connection,
|
||
session_id: &str,
|
||
cutoff_seq: i64,
|
||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||
let mut stmt = conn.prepare(
|
||
"
|
||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||
FROM messages
|
||
WHERE session_id = ?1 AND seq > ?2
|
||
ORDER BY seq ASC
|
||
",
|
||
)?;
|
||
|
||
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
|
||
let media_refs_json: String = row.get(5)?;
|
||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
media_refs_json.len(),
|
||
rusqlite::types::Type::Text,
|
||
Box::new(err),
|
||
)
|
||
})?;
|
||
|
||
let tool_calls_json: Option<String> = row.get(9)?;
|
||
let tool_calls = tool_calls_json
|
||
.as_deref()
|
||
.map(serde_json::from_str)
|
||
.transpose()
|
||
.map_err(|err| {
|
||
rusqlite::Error::FromSqlConversionFailure(
|
||
9,
|
||
rusqlite::types::Type::Text,
|
||
Box::new(err),
|
||
)
|
||
})?;
|
||
|
||
Ok(ChatMessage {
|
||
id: row.get(0)?,
|
||
role: row.get(1)?,
|
||
content: row.get(2)?,
|
||
system_context: row.get(3)?,
|
||
reasoning_content: row.get(4)?,
|
||
media_refs,
|
||
timestamp: row.get(6)?,
|
||
tool_call_id: row.get(7)?,
|
||
tool_name: row.get(8)?,
|
||
tool_state: None,
|
||
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
||
tool_calls,
|
||
})
|
||
})?;
|
||
|
||
let mut messages = Vec::new();
|
||
for row in rows {
|
||
messages.push(row?);
|
||
}
|
||
Ok(messages)
|
||
}
|
||
|
||
fn current_timestamp() -> i64 {
|
||
std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.expect("system clock before unix epoch")
|
||
.as_millis() as i64
|
||
}
|
||
|
||
fn quote_fts_query(query: &str) -> String {
|
||
format!("\"{}\"", query.replace('"', "\"\""))
|
||
}
|
||
|
||
fn quote_fts_or_query(queries: &[String]) -> String {
|
||
queries
|
||
.iter()
|
||
.map(|query| query.trim())
|
||
.filter(|query| !query.is_empty())
|
||
.map(quote_fts_query)
|
||
.collect::<Vec<_>>()
|
||
.join(" OR ")
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
|
||
use crate::domain::messages::ToolCall;
|
||
|
||
const TEST_CHANNEL: &str = "test-channel";
|
||
|
||
#[test]
|
||
fn test_persistent_session_id_for_cli_and_channel() {
|
||
assert_eq!(persistent_session_id("cli", "abc"), "abc");
|
||
assert_eq!(persistent_session_id(TEST_CHANNEL, "abc"), "test-channel:abc");
|
||
}
|
||
|
||
#[test]
|
||
fn test_session_store_roundtrip_and_lifecycle() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
let session = store.create_cli_session(Some("demo")).unwrap();
|
||
assert_eq!(session.title, "demo");
|
||
assert_eq!(session.channel_name, "cli");
|
||
assert_eq!(session.chat_id, session.id);
|
||
assert_eq!(session.message_count, 0);
|
||
assert_eq!(session.user_turn_count, 0);
|
||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||
|
||
let first = ChatMessage::user("hello");
|
||
let second = ChatMessage::assistant("world");
|
||
store.append_message(&session.id, &first).unwrap();
|
||
store.append_message(&session.id, &second).unwrap();
|
||
|
||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||
assert_eq!(stored.message_count, 2);
|
||
assert!(stored.archived_at.is_none());
|
||
assert_eq!(stored.user_turn_count, 1);
|
||
assert_eq!(stored.agent_prompt_reinjection_count, 0);
|
||
|
||
let messages = store.load_messages(&session.id).unwrap();
|
||
assert_eq!(messages.len(), 2);
|
||
assert_eq!(messages[0].role, "user");
|
||
assert_eq!(messages[0].content, "hello");
|
||
assert_eq!(messages[1].role, "assistant");
|
||
assert_eq!(messages[1].content, "world");
|
||
|
||
store.rename_session(&session.id, "renamed").unwrap();
|
||
let renamed = store.get_session(&session.id).unwrap().unwrap();
|
||
assert_eq!(renamed.title, "renamed");
|
||
|
||
store.archive_session(&session.id).unwrap();
|
||
let archived = store.get_session(&session.id).unwrap().unwrap();
|
||
assert!(archived.archived_at.is_some());
|
||
|
||
let active_only = store.list_sessions("cli", false).unwrap();
|
||
assert!(active_only.is_empty());
|
||
|
||
let including_archived = store.list_sessions("cli", true).unwrap();
|
||
assert_eq!(including_archived.len(), 1);
|
||
|
||
store.clear_messages(&session.id).unwrap();
|
||
let cleared = store.load_messages(&session.id).unwrap();
|
||
assert!(cleared.is_empty());
|
||
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
|
||
assert_eq!(cleared_session.message_count, 0);
|
||
assert_eq!(cleared_session.user_turn_count, 0);
|
||
assert_eq!(cleared_session.agent_prompt_reinjection_count, 0);
|
||
|
||
store.delete_session(&session.id).unwrap();
|
||
assert!(store.get_session(&session.id).unwrap().is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_ensure_channel_session_is_stable() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
let first = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
|
||
let second = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
|
||
|
||
assert_eq!(first.id, second.id);
|
||
assert_eq!(first.chat_id, "chat-1");
|
||
assert_eq!(second.channel_name, TEST_CHANNEL);
|
||
}
|
||
|
||
#[test]
|
||
fn test_assistant_tool_calls_roundtrip() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let session = store.create_cli_session(Some("tools")).unwrap();
|
||
|
||
let assistant = ChatMessage::assistant_with_tool_calls(
|
||
"calling tool",
|
||
vec![ToolCall {
|
||
id: "call_1".to_string(),
|
||
name: "calculator".to_string(),
|
||
arguments: serde_json::json!({ "expression": "3*7" }),
|
||
}],
|
||
);
|
||
|
||
store.append_message(&session.id, &assistant).unwrap();
|
||
|
||
let messages = store.load_messages(&session.id).unwrap();
|
||
assert_eq!(messages.len(), 1);
|
||
assert_eq!(messages[0].role, "assistant");
|
||
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
|
||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||
assert_eq!(
|
||
messages[0].tool_calls.as_ref().unwrap()[0].name,
|
||
"calculator"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_assistant_reasoning_content_roundtrip() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let session = store.create_cli_session(Some("reasoning")).unwrap();
|
||
|
||
let assistant = ChatMessage::assistant_with_reasoning("final answer", "hidden reasoning");
|
||
|
||
store.append_message(&session.id, &assistant).unwrap();
|
||
|
||
let messages = store.load_messages(&session.id).unwrap();
|
||
assert_eq!(messages.len(), 1);
|
||
assert_eq!(messages[0].content, "final answer");
|
||
assert_eq!(
|
||
messages[0].reasoning_content.as_deref(),
|
||
Some("hidden reasoning")
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_schema_migration_adds_user_turn_and_reinjection_columns() {
|
||
let conn = Connection::open_in_memory().unwrap();
|
||
conn.execute_batch(
|
||
"
|
||
CREATE TABLE sessions (
|
||
id TEXT PRIMARY KEY,
|
||
title TEXT NOT NULL,
|
||
channel_name TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
summary TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
updated_at INTEGER NOT NULL,
|
||
last_active_at INTEGER NOT NULL,
|
||
archived_at INTEGER,
|
||
deleted_at INTEGER,
|
||
message_count INTEGER NOT NULL DEFAULT 0
|
||
);
|
||
|
||
CREATE TABLE messages (
|
||
id TEXT PRIMARY KEY,
|
||
session_id TEXT NOT NULL,
|
||
seq INTEGER NOT NULL,
|
||
role TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
media_refs_json TEXT NOT NULL,
|
||
tool_call_id TEXT,
|
||
tool_name TEXT,
|
||
tool_calls_json TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||
UNIQUE(session_id, seq)
|
||
);
|
||
",
|
||
)
|
||
.unwrap();
|
||
|
||
let store = SessionStore::from_connection(conn).unwrap();
|
||
let session = store.create_cli_session(Some("migrated")).unwrap();
|
||
assert_eq!(session.user_turn_count, 0);
|
||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_schema_migration_adds_reasoning_content_column_to_messages() {
|
||
let conn = Connection::open_in_memory().unwrap();
|
||
conn.execute_batch(
|
||
"
|
||
CREATE TABLE sessions (
|
||
id TEXT PRIMARY KEY,
|
||
title TEXT NOT NULL,
|
||
channel_name TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
summary TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
updated_at INTEGER NOT NULL,
|
||
last_active_at INTEGER NOT NULL,
|
||
archived_at INTEGER,
|
||
deleted_at INTEGER,
|
||
message_count INTEGER NOT NULL DEFAULT 0
|
||
);
|
||
|
||
CREATE TABLE messages (
|
||
id TEXT PRIMARY KEY,
|
||
session_id TEXT NOT NULL,
|
||
seq INTEGER NOT NULL,
|
||
role TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
media_refs_json TEXT NOT NULL,
|
||
tool_call_id TEXT,
|
||
tool_name TEXT,
|
||
tool_calls_json TEXT,
|
||
created_at INTEGER NOT NULL,
|
||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||
UNIQUE(session_id, seq)
|
||
);
|
||
",
|
||
)
|
||
.unwrap();
|
||
|
||
let _store = SessionStore::from_connection(conn).unwrap();
|
||
let conn = _store.conn.lock().unwrap();
|
||
|
||
assert!(has_column(&conn, "messages", "reasoning_content").unwrap());
|
||
}
|
||
|
||
#[test]
|
||
fn test_compact_active_history_rebuilds_active_segment_with_delta_messages() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let session = store.create_cli_session(Some("compact-history")).unwrap();
|
||
|
||
let agent_prompt = ChatMessage::system_with_context(
|
||
"agent",
|
||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||
);
|
||
let seed_messages = vec![
|
||
agent_prompt.clone(),
|
||
ChatMessage::user("u1"),
|
||
ChatMessage::assistant("a1"),
|
||
ChatMessage::user("u2"),
|
||
ChatMessage::assistant("a2"),
|
||
ChatMessage::user("u3"),
|
||
ChatMessage::assistant("a3"),
|
||
ChatMessage::user("u4"),
|
||
ChatMessage::assistant("a4"),
|
||
];
|
||
|
||
for message in &seed_messages {
|
||
store.append_message(&session.id, message).unwrap();
|
||
}
|
||
|
||
let snapshot_end_seq = store
|
||
.get_session(&session.id)
|
||
.unwrap()
|
||
.unwrap()
|
||
.message_count;
|
||
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
|
||
let preserved_system_messages = vec![agent_prompt];
|
||
|
||
store
|
||
.append_message(&session.id, &ChatMessage::user("u5"))
|
||
.unwrap();
|
||
store
|
||
.append_message(&session.id, &ChatMessage::assistant("a5"))
|
||
.unwrap();
|
||
|
||
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
|
||
let compacted = store
|
||
.compact_active_history(
|
||
&session.id,
|
||
snapshot_end_seq,
|
||
&preserved_system_messages,
|
||
&summary_message,
|
||
&preserved_messages,
|
||
)
|
||
.unwrap();
|
||
|
||
assert!(compacted);
|
||
|
||
let active_messages = store.load_messages(&session.id).unwrap();
|
||
assert_eq!(active_messages.len(), 10);
|
||
assert_eq!(active_messages[0].role, "system");
|
||
assert_eq!(active_messages[0].content, "agent");
|
||
assert_eq!(
|
||
active_messages[0].system_context.as_deref(),
|
||
Some(SYSTEM_CONTEXT_AGENT_PROMPT)
|
||
);
|
||
assert_eq!(active_messages[1].role, "system");
|
||
assert_eq!(
|
||
active_messages[1].content,
|
||
"[Compressed History]\n\nsummary"
|
||
);
|
||
assert_eq!(active_messages[2].content, "u2");
|
||
assert_eq!(active_messages[3].content, "a2");
|
||
assert_eq!(active_messages[8].content, "u5");
|
||
assert_eq!(active_messages[9].content, "a5");
|
||
|
||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||
assert_eq!(stored.user_turn_count, 4);
|
||
|
||
let all_messages = store.load_all_messages(&session.id).unwrap();
|
||
assert_eq!(all_messages.len(), 10);
|
||
}
|
||
|
||
#[test]
|
||
fn test_mark_agent_prompt_reinjected_increments_counter() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let session = store.create_cli_session(Some("prompt")).unwrap();
|
||
|
||
store.mark_agent_prompt_reinjected(&session.id).unwrap();
|
||
store.mark_agent_prompt_reinjected(&session.id).unwrap();
|
||
|
||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||
assert_eq!(stored.agent_prompt_reinjection_count, 2);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_result_roundtrip() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let session = store.create_cli_session(Some("tool-result")).unwrap();
|
||
|
||
let tool_message = ChatMessage::tool("call_9", "write", "saved to /tmp/output.txt");
|
||
store.append_message(&session.id, &tool_message).unwrap();
|
||
|
||
let messages = store.load_messages(&session.id).unwrap();
|
||
assert_eq!(messages.len(), 1);
|
||
assert_eq!(messages[0].role, "tool");
|
||
assert_eq!(messages[0].content, "saved to /tmp/output.txt");
|
||
assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9"));
|
||
assert_eq!(messages[0].tool_name.as_deref(), Some("write"));
|
||
assert!(messages[0].tool_calls.is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_skill_events_roundtrip() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
let session = store.create_cli_session(Some("skill-events")).unwrap();
|
||
|
||
store
|
||
.append_skill_event(None, "discovered", None, &serde_json::json!({"count": 2}))
|
||
.unwrap();
|
||
store
|
||
.append_skill_event(
|
||
Some(&session.id),
|
||
"activated",
|
||
Some("code-review"),
|
||
&serde_json::json!({"source": "project"}),
|
||
)
|
||
.unwrap();
|
||
|
||
let global_events = store.list_skill_events(None).unwrap();
|
||
assert_eq!(global_events.len(), 1);
|
||
assert_eq!(global_events[0].event_type, "discovered");
|
||
assert_eq!(global_events[0].payload["count"], 2);
|
||
|
||
let session_events = store.list_skill_events(Some(&session.id)).unwrap();
|
||
assert_eq!(session_events.len(), 1);
|
||
assert_eq!(session_events[0].event_type, "activated");
|
||
assert_eq!(session_events[0].skill_name.as_deref(), Some("code-review"));
|
||
assert_eq!(session_events[0].payload["source"], "project");
|
||
}
|
||
|
||
#[test]
|
||
fn test_memory_roundtrip_with_source_fields() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
let saved = store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "language".to_string(),
|
||
content: "Rust".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-1".to_string()),
|
||
source_message_seq: Some(7),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-1".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
assert_eq!(saved.content, "Rust");
|
||
assert_eq!(saved.source_type, "message");
|
||
assert_eq!(saved.source_session_id.as_deref(), Some("test-channel:chat-1"));
|
||
assert_eq!(saved.source_message_id.as_deref(), Some("msg-1"));
|
||
assert_eq!(saved.source_message_seq, Some(7));
|
||
|
||
let fetched = store
|
||
.get_memory("user", "test-channel:user-1", "user", "language")
|
||
.unwrap()
|
||
.unwrap();
|
||
assert_eq!(fetched.id, saved.id);
|
||
assert_eq!(fetched.source_chat_id.as_deref(), Some("chat-1"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_memory_fts_tracks_upsert_and_delete() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "editor".to_string(),
|
||
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-2".to_string()),
|
||
source_message_seq: Some(3),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-2".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
let hits = store
|
||
.search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10)
|
||
.unwrap();
|
||
assert_eq!(hits.len(), 1);
|
||
assert_eq!(hits[0].memory_key, "editor");
|
||
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "editor".to_string(),
|
||
content: "Prefers clippy diagnostics".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-3".to_string()),
|
||
source_message_seq: Some(4),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-3".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
let old_hits = store
|
||
.search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10)
|
||
.unwrap();
|
||
assert!(old_hits.is_empty());
|
||
|
||
let new_hits = store
|
||
.search_memories("user", "test-channel:user-1", "clippy", None, 10)
|
||
.unwrap();
|
||
assert_eq!(new_hits.len(), 1);
|
||
|
||
let deleted = store
|
||
.delete_memory("user", "test-channel:user-1", "user", "editor")
|
||
.unwrap();
|
||
assert!(deleted);
|
||
|
||
let hits_after_delete = store
|
||
.search_memories("user", "test-channel:user-1", "clippy", None, 10)
|
||
.unwrap();
|
||
assert!(hits_after_delete.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_memory_search_matches_memory_key_field() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "email_folder_preference".to_string(),
|
||
content: "用户提到邮件时默认查看代收邮箱。".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-8", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-8".to_string()),
|
||
source_message_seq: Some(8),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-8".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
let hits = store
|
||
.search_memories("user", "test-channel:user-1", "email_folder_preference", None, 10)
|
||
.unwrap();
|
||
|
||
assert_eq!(hits.len(), 1);
|
||
assert_eq!(hits[0].memory_key, "email_folder_preference");
|
||
}
|
||
|
||
#[test]
|
||
fn test_search_memories_any_matches_multiple_keywords_once() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "editor".to_string(),
|
||
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-2".to_string()),
|
||
source_message_seq: Some(3),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-2".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "episodic".to_string(),
|
||
memory_key: "quality".to_string(),
|
||
content: "Tracks clippy warnings before release".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-3".to_string()),
|
||
source_message_seq: Some(4),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-3".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
let hits = store
|
||
.search_memories_any(
|
||
"user",
|
||
"test-channel:user-1",
|
||
&["rust-analyzer".to_string(), "clippy".to_string()],
|
||
None,
|
||
10,
|
||
)
|
||
.unwrap();
|
||
|
||
assert_eq!(hits.len(), 2);
|
||
assert!(hits.iter().any(|memory| memory.memory_key == "editor"));
|
||
assert!(hits.iter().any(|memory| memory.memory_key == "quality"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_memory_scope_listing_and_full_scope_read() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-2", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "style".to_string(),
|
||
content: "偏好简洁表达".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-2".to_string()),
|
||
source_message_seq: Some(2),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-2".to_string()),
|
||
})
|
||
.unwrap();
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "user".to_string(),
|
||
memory_key: "work".to_string(),
|
||
content: "用户在做AI产品".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-1".to_string()),
|
||
source_message_seq: Some(1),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-1".to_string()),
|
||
})
|
||
.unwrap();
|
||
store
|
||
.put_memory(&MemoryUpsert {
|
||
scope_kind: "user".to_string(),
|
||
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||
namespace: "patterns".to_string(),
|
||
memory_key: "workflow".to_string(),
|
||
content: "习惯先问方案再要代码".to_string(),
|
||
source_type: "message".to_string(),
|
||
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||
source_message_id: Some("msg-3".to_string()),
|
||
source_message_seq: Some(3),
|
||
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||
source_chat_id: Some("chat-1".to_string()),
|
||
})
|
||
.unwrap();
|
||
|
||
let scope_keys = store.list_memory_scope_keys("user").unwrap();
|
||
assert_eq!(
|
||
scope_keys,
|
||
vec!["test-channel:user-1".to_string(), "test-channel:user-2".to_string()]
|
||
);
|
||
|
||
let full_scope = store
|
||
.list_memories_for_scope("user", "test-channel:user-1")
|
||
.unwrap();
|
||
assert_eq!(full_scope.len(), 2);
|
||
assert!(
|
||
full_scope
|
||
.iter()
|
||
.all(|memory| memory.scope_key == "test-channel:user-1")
|
||
);
|
||
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
|
||
assert!(
|
||
full_scope
|
||
.iter()
|
||
.any(|memory| memory.memory_key == "workflow")
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_scheduler_job_roundtrip_and_runtime_update() {
|
||
let store = SessionStore::in_memory().unwrap();
|
||
|
||
let saved = store
|
||
.upsert_scheduler_job(&SchedulerJobUpsert {
|
||
id: "heartbeat".to_string(),
|
||
kind: "outbound_message".to_string(),
|
||
schedule: serde_json::json!({
|
||
"type": "interval",
|
||
"seconds": 300,
|
||
"startup_delay_secs": 10,
|
||
}),
|
||
interval_secs: 300,
|
||
startup_delay_secs: 10,
|
||
target: serde_json::json!({
|
||
"channel": "test-channel",
|
||
"chat_id": "oc_demo",
|
||
}),
|
||
payload: serde_json::json!({
|
||
"content": "heartbeat",
|
||
}),
|
||
enabled: true,
|
||
state: SchedulerJobState::Scheduled,
|
||
last_status: None,
|
||
last_error: None,
|
||
run_count: 0,
|
||
max_runs: Some(3),
|
||
last_fired_at: None,
|
||
next_fire_at: Some(1_700_000_000_000),
|
||
paused_at: None,
|
||
completed_at: None,
|
||
})
|
||
.unwrap();
|
||
|
||
assert_eq!(saved.id, "heartbeat");
|
||
assert_eq!(saved.kind, "outbound_message");
|
||
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||
assert_eq!(saved.max_runs, Some(3));
|
||
|
||
store
|
||
.update_scheduler_job_runtime(
|
||
"heartbeat",
|
||
SchedulerJobState::Completed,
|
||
Some(SchedulerJobStatus::Ok),
|
||
None,
|
||
1,
|
||
Some(1_700_000_000_000),
|
||
None,
|
||
None,
|
||
Some(1_700_000_000_100),
|
||
)
|
||
.unwrap();
|
||
|
||
let fetched = store.get_scheduler_job("heartbeat").unwrap().unwrap();
|
||
assert_eq!(fetched.state, SchedulerJobState::Completed);
|
||
assert_eq!(fetched.last_status, Some(SchedulerJobStatus::Ok));
|
||
assert_eq!(fetched.run_count, 1);
|
||
assert_eq!(fetched.completed_at, Some(1_700_000_000_100));
|
||
}
|
||
|
||
}
|