feat(agent_profile): 实现代理配置文件的注入与周期性重注入机制

This commit is contained in:
ooodc 2026-04-22 09:45:19 +08:00
parent 4725b5406e
commit 0dfa615ca9
7 changed files with 363 additions and 17 deletions

View File

@ -2,6 +2,23 @@ PicoBot
Skills (initial implementation) Skills (initial implementation)
Agent profile injection
PicoBot maintains a persistent agent profile file at ~/.picobot/agent/AGENT.md.
Behavior:
- The directory ~/.picobot/agent is created automatically when needed.
- If AGENT.md does not exist yet, PicoBot creates it with a default profile.
- When the active conversation is empty or has just been reset, AGENT.md is loaded as the first system message in the active history.
- After every configured number of user turns, PicoBot injects the latest AGENT.md content again before the next user message is appended.
Config:
- Set gateway.agent_prompt_reinject_every in ~/.picobot/config.json.
- Default value is 100.
- Set it to 0 to disable periodic reinjection.
This profile is persisted in session history, while the skills index system prompt is still injected dynamically by AgentLoop.
PicoBot now supports filesystem skills. PicoBot now supports filesystem skills.
Skill discovery locations: Skill discovery locations:

View File

@ -48,9 +48,11 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
2. `SessionManager` 定位到对应 channel 的运行时 `Session` 2. `SessionManager` 定位到对应 channel 的运行时 `Session`
3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。 3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。
4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。 4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。
5. 新的用户消息先写入 `messages`,再放入内存历史。 5. 如果当前活动段历史为空,系统会从 `~/.picobot/agent/AGENT.md` 读取 Agent 基本设定,并先写入一条 `system` 消息。
6. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages` 6. 在新的用户消息进入前,系统会检查当前活动段的 `user_turn_count` 是否刚跨过配置项 `gateway.agent_prompt_reinject_every` 指定的下一轮阈值;如果跨过,就再次把 `AGENT.md` 写入一条新的 `system` 消息。
7. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。 7. 新的用户消息先写入 `messages`,再放入内存历史。
8. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`
9. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
## 3. 会话标识规则 ## 3. 会话标识规则
@ -86,6 +88,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 | | `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 | | `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 | | `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 |
| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史和 `/reset` 时归零 |
| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次“达到配置阈值后的下一轮前注入”就递增,清空历史和 `/reset` 时归零 |
索引: 索引:
@ -171,6 +175,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。 其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。
另外,只有 `role = user` 的消息会递增 `user_turn_count``system``assistant``tool` 消息不会影响周期注入阈值的判定。
### 6.3 读取历史 ### 6.3 读取历史
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是: `load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是:
@ -236,6 +242,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 删除该会话在 `messages` 中的所有记录 - 删除该会话在 `messages` 中的所有记录
- 将 `sessions.message_count` 重置为 0 - 将 `sessions.message_count` 重置为 0
- 将 `sessions.reset_cutoff_seq` 重置为 0 - 将 `sessions.reset_cutoff_seq` 重置为 0
- 将 `sessions.user_turn_count` 重置为 0
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
- 更新 `updated_at``last_active_at` - 更新 `updated_at``last_active_at`
- 保留会话本身 - 保留会话本身
@ -247,11 +255,15 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 不删除 `messages` 中的任何记录 - 不删除 `messages` 中的任何记录
- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq` - 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq`
- 将 `sessions.user_turn_count` 重置为 0
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
- 更新 `updated_at``last_active_at` - 更新 `updated_at``last_active_at`
- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息 - 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息
这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。 这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。
由于 AGENT.md 注入消息也会持久化,`/reset` 前的 Agent 设定消息仍会保留在完整历史中,但不会继续出现在新的活动段。下一次活动段首次加载时,系统会重新读取当前版本的 `~/.picobot/agent/AGENT.md`,并把它作为新的首条系统消息写入活动段。
### 8.5 删除会话 ### 8.5 删除会话
`delete_session(session_id)` `delete_session(session_id)`

View File

@ -136,6 +136,8 @@ pub struct GatewayConfig {
pub port: u16, pub port: u16,
#[serde(default, rename = "session_ttl_hours")] #[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>, pub session_ttl_hours: Option<u64>,
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
pub agent_prompt_reinject_every: u64,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
@ -156,12 +158,17 @@ fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string() "ws://127.0.0.1:19876/ws".to_string()
} }
fn default_agent_prompt_reinject_every() -> u64 {
100
}
impl Default for GatewayConfig { impl Default for GatewayConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
host: default_gateway_host(), host: default_gateway_host(),
port: default_gateway_port(), port: default_gateway_port(),
session_ttl_hours: None, session_ttl_hours: None,
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
} }
} }
} }
@ -344,7 +351,8 @@ mod tests {
}, },
"gateway": { "gateway": {
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 19876 "port": 19876,
"agent_prompt_reinject_every": 120
} }
}"#, }"#,
) )
@ -387,5 +395,39 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap(); let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.host, "0.0.0.0");
assert_eq!(config.gateway.port, 19876); assert_eq!(config.gateway.port, 19876);
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
}
#[test]
fn test_gateway_config_defaults_agent_prompt_reinject_every() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
} }
} }

View File

@ -29,10 +29,16 @@ impl GatewayState {
// Session TTL from config (default 4 hours) // Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone())); let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
let session_manager = SessionManager::new(session_ttl_hours, provider_config, skills)?; let session_manager = SessionManager::new(
session_ttl_hours,
agent_prompt_reinject_every,
provider_config,
skills,
)?;
let channel_manager = ChannelManager::new(); let channel_manager = ChannelManager::new();
let bus = channel_manager.bus(); let bus = channel_manager.bus();

View File

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fs;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
@ -16,6 +17,8 @@ use crate::tools::{
WebFetchTool, WebFetchTool,
}; };
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
/// Session 按 channel 隔离,每个 channel 一个 Session /// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理 /// History 按 chat_id 隔离,由 Session 统一管理
pub struct Session { pub struct Session {
@ -29,6 +32,7 @@ pub struct Session {
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
compressor: ContextCompressor, compressor: ContextCompressor,
store: Arc<SessionStore>, store: Arc<SessionStore>,
agent_prompt_reinject_every: i64,
} }
pub struct BusToolCallEmitter { pub struct BusToolCallEmitter {
@ -79,6 +83,7 @@ impl Session {
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
store: Arc<SessionStore>, store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
Ok(Self { Ok(Self {
id: Uuid::new_v4(), id: Uuid::new_v4(),
@ -90,6 +95,7 @@ impl Session {
skills, skills,
compressor: ContextCompressor::new(provider_config.token_limit), compressor: ContextCompressor::new(provider_config.token_limit),
store, store,
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
}) })
} }
@ -113,6 +119,33 @@ impl Session {
.load_messages(&self.persistent_session_id(chat_id)) .load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history); self.chat_histories.insert(chat_id.to_string(), history);
self.ensure_initial_agent_prompt(chat_id)?;
Ok(())
}
pub fn ensure_agent_prompt_before_user_message(&mut self, chat_id: &str) -> Result<(), AgentError> {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let session_record = self
.store
.get_session(&session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
if self.agent_prompt_reinject_every > 0
&& session_record.user_turn_count > 0
&& session_record.user_turn_count / self.agent_prompt_reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
self.store
.mark_agent_prompt_reinjected(&session_id)
.map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?;
}
}
Ok(()) Ok(())
} }
@ -264,6 +297,51 @@ impl Session {
}) })
}) })
} }
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history_is_empty = self
.get_history(chat_id)
.map(|history| history.is_empty())
.unwrap_or(true);
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
}
Ok(())
}
}
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
if !path.exists() {
fs::write(&path, DEFAULT_AGENT_PROMPT)
.map_err(|err| AgentError::Other(format!("create agent prompt file error: {}", err)))?;
}
let content = fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
let trimmed = content.trim();
if trimmed.is_empty() {
return Ok(None);
}
Ok(Some(trimmed.to_string()))
}
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
} }
/// SessionManager 管理所有 Session按 channel_name 路由 /// SessionManager 管理所有 Session按 channel_name 路由
@ -274,6 +352,7 @@ pub struct SessionManager {
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
store: Arc<SessionStore>, store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
} }
struct SessionManagerInner { struct SessionManagerInner {
@ -331,6 +410,7 @@ pub(crate) fn handle_in_chat_command(
impl SessionManager { impl SessionManager {
pub fn new( pub fn new(
session_ttl_hours: u64, session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>, skills: Arc<SkillRuntime>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
@ -353,6 +433,7 @@ impl SessionManager {
tools: Arc::new(default_tools(skills.clone(), store.clone())), tools: Arc::new(default_tools(skills.clone(), store.clone())),
skills, skills,
store, store,
agent_prompt_reinject_every,
}) })
} }
@ -447,6 +528,7 @@ impl SessionManager {
self.tools.clone(), self.tools.clone(),
self.skills.clone(), self.skills.clone(),
self.store.clone(), self.store.clone(),
self.agent_prompt_reinject_every,
) )
.await?; .await?;
let arc = Arc::new(Mutex::new(session)); let arc = Arc::new(Mutex::new(session));
@ -523,6 +605,8 @@ impl SessionManager {
)]); )]);
} }
session_guard.ensure_agent_prompt_before_user_message(chat_id)?;
// 添加用户消息到历史 // 添加用户消息到历史
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect(); let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -633,6 +717,7 @@ mod tests {
tools, tools,
skills, skills,
store.clone(), store.clone(),
100,
) )
.await .await
.unwrap(); .unwrap();
@ -658,7 +743,113 @@ mod tests {
.load_all_messages(&session.persistent_session_id("chat-1")) .load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap() .unwrap()
.len(), .len(),
1, 2,
); );
} }
#[tokio::test]
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
assert!(history[0].content.contains("PicoBot 代理配置"));
}
#[tokio::test]
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 2);
let stored = store
.get_session(&session.persistent_session_id("chat-1"))
.unwrap()
.unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 1);
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 2);
}
#[tokio::test]
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
0,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 1);
}
} }

View File

@ -59,6 +59,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
state.session_manager.tools(), state.session_manager.tools(),
state.session_manager.skills(), state.session_manager.skills(),
state.session_manager.store(), state.session_manager.store(),
state.config.gateway.agent_prompt_reinject_every,
) )
.await .await
{ {
@ -210,6 +211,8 @@ async fn handle_inbound(
return Ok(()); return Ok(());
} }
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new()); let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone(); let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?; session_guard.append_persisted_message(&chat_id, user_message)?;

View File

@ -40,6 +40,8 @@ pub struct SessionRecord {
pub deleted_at: Option<i64>, pub deleted_at: Option<i64>,
pub message_count: i64, pub message_count: i64,
pub reset_cutoff_seq: i64, pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -113,7 +115,9 @@ impl SessionStore {
archived_at INTEGER, archived_at INTEGER,
deleted_at INTEGER, deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0, message_count INTEGER NOT NULL DEFAULT 0,
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0 reset_cutoff_seq INTEGER NOT NULL DEFAULT 0,
user_turn_count INTEGER NOT NULL DEFAULT 0,
agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0
); );
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
@ -234,8 +238,9 @@ impl SessionStore {
" "
INSERT INTO sessions ( INSERT INTO sessions (
id, title, channel_name, chat_id, summary, id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0) reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0)
", ",
params![id, title, id, now], params![id, title, id, now],
)?; )?;
@ -261,8 +266,9 @@ impl SessionStore {
" "
INSERT INTO sessions ( INSERT INTO sessions (
id, title, channel_name, chat_id, summary, id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0) reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
", ",
params![session_id, title, channel_name, chat_id, now], params![session_id, title, channel_name, chat_id, now],
)?; )?;
@ -277,7 +283,8 @@ impl SessionStore {
" "
SELECT id, title, channel_name, chat_id, summary, SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq archived_at, deleted_at, message_count, reset_cutoff_seq,
user_turn_count, agent_prompt_reinjection_count
FROM sessions FROM sessions
WHERE id = ?1 AND deleted_at IS NULL WHERE id = ?1 AND deleted_at IS NULL
", ",
@ -298,7 +305,8 @@ impl SessionStore {
" "
SELECT id, title, channel_name, chat_id, summary, SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq archived_at, deleted_at, message_count, reset_cutoff_seq,
user_turn_count, agent_prompt_reinjection_count
FROM sessions FROM sessions
WHERE channel_name = ?1 WHERE channel_name = ?1
AND deleted_at IS NULL AND deleted_at IS NULL
@ -354,7 +362,12 @@ impl SessionStore {
conn.execute( conn.execute(
" "
UPDATE sessions UPDATE sessions
SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0 SET message_count = 0,
updated_at = ?2,
last_active_at = ?2,
reset_cutoff_seq = 0,
user_turn_count = 0,
agent_prompt_reinjection_count = 0
WHERE id = ?1 AND deleted_at IS NULL WHERE id = ?1 AND deleted_at IS NULL
", ",
params![session_id, now], params![session_id, now],
@ -379,7 +392,9 @@ impl SessionStore {
SET reset_cutoff_seq = ?2, SET reset_cutoff_seq = ?2,
updated_at = ?3, updated_at = ?3,
last_active_at = ?3, last_active_at = ?3,
archived_at = NULL archived_at = NULL,
user_turn_count = 0,
agent_prompt_reinjection_count = 0
WHERE id = ?1 AND deleted_at IS NULL WHERE id = ?1 AND deleted_at IS NULL
", ",
params![session_id, cutoff_seq, now], params![session_id, cutoff_seq, now],
@ -423,10 +438,31 @@ impl SessionStore {
)?; )?;
let now = current_timestamp(); let now = current_timestamp();
let is_user_message = message.role == "user";
tx.execute( tx.execute(
" "
UPDATE sessions UPDATE sessions
SET message_count = message_count + 1, SET message_count = message_count + 1,
user_turn_count = user_turn_count + ?3,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now, if is_user_message { 1 } else { 0 }],
)?;
tx.commit()?;
Ok(())
}
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
UPDATE sessions
SET agent_prompt_reinjection_count = agent_prompt_reinjection_count + 1,
updated_at = ?2, updated_at = ?2,
last_active_at = ?2, last_active_at = ?2,
archived_at = NULL archived_at = NULL
@ -434,8 +470,6 @@ impl SessionStore {
", ",
params![session_id, now], params![session_id, now],
)?; )?;
tx.commit()?;
Ok(()) Ok(())
} }
@ -779,6 +813,8 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
deleted_at: row.get(9)?, deleted_at: row.get(9)?,
message_count: row.get(10)?, message_count: row.get(10)?,
reset_cutoff_seq: row.get(11)?, reset_cutoff_seq: row.get(11)?,
user_turn_count: row.get(12)?,
agent_prompt_reinjection_count: row.get(13)?,
}) })
} }
@ -829,6 +865,20 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
)?; )?;
} }
if !has_column(conn, "sessions", "user_turn_count")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
Ok(()) Ok(())
} }
@ -947,6 +997,8 @@ mod tests {
assert_eq!(session.chat_id, session.id); assert_eq!(session.chat_id, session.id);
assert_eq!(session.message_count, 0); assert_eq!(session.message_count, 0);
assert_eq!(session.reset_cutoff_seq, 0); assert_eq!(session.reset_cutoff_seq, 0);
assert_eq!(session.user_turn_count, 0);
assert_eq!(session.agent_prompt_reinjection_count, 0);
let first = ChatMessage::user("hello"); let first = ChatMessage::user("hello");
let second = ChatMessage::assistant("world"); let second = ChatMessage::assistant("world");
@ -957,6 +1009,8 @@ mod tests {
assert_eq!(stored.message_count, 2); assert_eq!(stored.message_count, 2);
assert!(stored.archived_at.is_none()); assert!(stored.archived_at.is_none());
assert_eq!(stored.reset_cutoff_seq, 0); assert_eq!(stored.reset_cutoff_seq, 0);
assert_eq!(stored.user_turn_count, 1);
assert_eq!(stored.agent_prompt_reinjection_count, 0);
let messages = store.load_messages(&session.id).unwrap(); let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 2); assert_eq!(messages.len(), 2);
@ -984,6 +1038,8 @@ mod tests {
assert!(cleared.is_empty()); assert!(cleared.is_empty());
let cleared_session = store.get_session(&session.id).unwrap().unwrap(); let cleared_session = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(cleared_session.message_count, 0); assert_eq!(cleared_session.message_count, 0);
assert_eq!(cleared_session.user_turn_count, 0);
assert_eq!(cleared_session.agent_prompt_reinjection_count, 0);
store.delete_session(&session.id).unwrap(); store.delete_session(&session.id).unwrap();
assert!(store.get_session(&session.id).unwrap().is_none()); assert!(store.get_session(&session.id).unwrap().is_none());
@ -1036,6 +1092,8 @@ mod tests {
let stored = store.get_session(&session.id).unwrap().unwrap(); let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.reset_cutoff_seq, 2); assert_eq!(stored.reset_cutoff_seq, 2);
assert_eq!(stored.user_turn_count, 0);
assert_eq!(stored.agent_prompt_reinjection_count, 0);
let active_messages = store.load_messages(&session.id).unwrap(); let active_messages = store.load_messages(&session.id).unwrap();
assert!(active_messages.is_empty()); assert!(active_messages.is_empty());
@ -1049,6 +1107,9 @@ mod tests {
let active_messages = store.load_messages(&session.id).unwrap(); let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 1); assert_eq!(active_messages.len(), 1);
assert_eq!(active_messages[0].content, "after"); assert_eq!(active_messages[0].content, "after");
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.user_turn_count, 1);
} }
#[test] #[test]
@ -1091,6 +1152,20 @@ mod tests {
let store = SessionStore::from_connection(conn).unwrap(); let store = SessionStore::from_connection(conn).unwrap();
let session = store.create_cli_session(Some("migrated")).unwrap(); let session = store.create_cli_session(Some("migrated")).unwrap();
assert_eq!(session.reset_cutoff_seq, 0); assert_eq!(session.reset_cutoff_seq, 0);
assert_eq!(session.user_turn_count, 0);
assert_eq!(session.agent_prompt_reinjection_count, 0);
}
#[test]
fn test_mark_agent_prompt_reinjected_increments_counter() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("prompt")).unwrap();
store.mark_agent_prompt_reinjected(&session.id).unwrap();
store.mark_agent_prompt_reinjected(&session.id).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 2);
} }
#[test] #[test]