diff --git a/README.md b/README.md index d4c95f4..cfda533 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,23 @@ PicoBot Skills (initial implementation) +Agent profile injection + +PicoBot maintains a persistent agent profile file at ~/.picobot/agent/AGENT.md. + +Behavior: +- The directory ~/.picobot/agent is created automatically when needed. +- If AGENT.md does not exist yet, PicoBot creates it with a default profile. +- When the active conversation is empty or has just been reset, AGENT.md is loaded as the first system message in the active history. +- After every configured number of user turns, PicoBot injects the latest AGENT.md content again before the next user message is appended. + +Config: +- Set gateway.agent_prompt_reinject_every in ~/.picobot/config.json. +- Default value is 100. +- Set it to 0 to disable periodic reinjection. + +This profile is persisted in session history, while the skills index system prompt is still injected dynamically by AgentLoop. + PicoBot now supports filesystem skills. Skill discovery locations: diff --git a/docs/PERSISTENCE.md b/docs/PERSISTENCE.md index 5a15a8b..e615366 100644 --- a/docs/PERSISTENCE.md +++ b/docs/PERSISTENCE.md @@ -48,9 +48,11 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 2. `SessionManager` 定位到对应 channel 的运行时 `Session`。 3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。 4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。 -5. 新的用户消息先写入 `messages`,再放入内存历史。 -6. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`。 -7. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。 +5. 如果当前活动段历史为空,系统会从 `~/.picobot/agent/AGENT.md` 读取 Agent 基本设定,并先写入一条 `system` 消息。 +6. 在新的用户消息进入前,系统会检查当前活动段的 `user_turn_count` 是否刚跨过配置项 `gateway.agent_prompt_reinject_every` 指定的下一轮阈值;如果跨过,就再次把 `AGENT.md` 写入一条新的 `system` 消息。 +7. 新的用户消息先写入 `messages`,再放入内存历史。 +8. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`。 +9. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。 ## 3. 会话标识规则 @@ -86,6 +88,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 | `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 | | `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 | | `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 | +| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史和 `/reset` 时归零 | +| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次“达到配置阈值后的下一轮前注入”就递增,清空历史和 `/reset` 时归零 | 索引: @@ -171,6 +175,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。 +另外,只有 `role = user` 的消息会递增 `user_turn_count`;`system`、`assistant`、`tool` 消息不会影响周期注入阈值的判定。 + ### 6.3 读取历史 `load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是: @@ -236,6 +242,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 - 删除该会话在 `messages` 中的所有记录 - 将 `sessions.message_count` 重置为 0 - 将 `sessions.reset_cutoff_seq` 重置为 0 +- 将 `sessions.user_turn_count` 重置为 0 +- 将 `sessions.agent_prompt_reinjection_count` 重置为 0 - 更新 `updated_at` 和 `last_active_at` - 保留会话本身 @@ -247,11 +255,15 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 - 不删除 `messages` 中的任何记录 - 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq` +- 将 `sessions.user_turn_count` 重置为 0 +- 将 `sessions.agent_prompt_reinjection_count` 重置为 0 - 更新 `updated_at` 和 `last_active_at` - 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息 这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。 +由于 AGENT.md 注入消息也会持久化,`/reset` 前的 Agent 设定消息仍会保留在完整历史中,但不会继续出现在新的活动段。下一次活动段首次加载时,系统会重新读取当前版本的 `~/.picobot/agent/AGENT.md`,并把它作为新的首条系统消息写入活动段。 + ### 8.5 删除会话 `delete_session(session_id)`: diff --git a/src/config/mod.rs b/src/config/mod.rs index d4d3dd7..58104b0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -136,6 +136,8 @@ pub struct GatewayConfig { pub port: u16, #[serde(default, rename = "session_ttl_hours")] pub session_ttl_hours: Option, + #[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")] + pub agent_prompt_reinject_every: u64, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -156,12 +158,17 @@ fn default_gateway_url() -> String { "ws://127.0.0.1:19876/ws".to_string() } +fn default_agent_prompt_reinject_every() -> u64 { + 100 +} + impl Default for GatewayConfig { fn default() -> Self { Self { host: default_gateway_host(), port: default_gateway_port(), session_ttl_hours: None, + agent_prompt_reinject_every: default_agent_prompt_reinject_every(), } } } @@ -344,7 +351,8 @@ mod tests { }, "gateway": { "host": "0.0.0.0", - "port": 19876 + "port": 19876, + "agent_prompt_reinject_every": 120 } }"#, ) @@ -387,5 +395,39 @@ mod tests { let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.port, 19876); + assert_eq!(config.gateway.agent_prompt_reinject_every, 120); + } + + #[test] + fn test_gateway_config_defaults_agent_prompt_reinject_every() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + assert_eq!(config.gateway.agent_prompt_reinject_every, 100); } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 4f0bba0..14b9ba3 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -29,10 +29,16 @@ impl GatewayState { // Session TTL from config (default 4 hours) let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); + let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every; let skills = Arc::new(SkillRuntime::from_config(config.skills.clone())); - let session_manager = SessionManager::new(session_ttl_hours, provider_config, skills)?; + let session_manager = SessionManager::new( + session_ttl_hours, + agent_prompt_reinject_every, + provider_config, + skills, + )?; let channel_manager = ChannelManager::new(); let bus = channel_manager.bus(); diff --git a/src/gateway/session.rs b/src/gateway/session.rs index ee575a6..684ec1b 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::fs; use std::sync::Arc; use std::time::{Duration, Instant}; use async_trait::async_trait; @@ -16,6 +17,8 @@ use crate::tools::{ WebFetchTool, }; +const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot,一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n"; + /// Session 按 channel 隔离,每个 channel 一个 Session /// History 按 chat_id 隔离,由 Session 统一管理 pub struct Session { @@ -29,6 +32,7 @@ pub struct Session { skills: Arc, compressor: ContextCompressor, store: Arc, + agent_prompt_reinject_every: i64, } pub struct BusToolCallEmitter { @@ -79,6 +83,7 @@ impl Session { tools: Arc, skills: Arc, store: Arc, + agent_prompt_reinject_every: u64, ) -> Result { Ok(Self { id: Uuid::new_v4(), @@ -90,6 +95,7 @@ impl Session { skills, compressor: ContextCompressor::new(provider_config.token_limit), store, + agent_prompt_reinject_every: agent_prompt_reinject_every as i64, }) } @@ -113,6 +119,33 @@ impl Session { .load_messages(&self.persistent_session_id(chat_id)) .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; self.chat_histories.insert(chat_id.to_string(), history); + self.ensure_initial_agent_prompt(chat_id)?; + Ok(()) + } + + pub fn ensure_agent_prompt_before_user_message(&mut self, chat_id: &str) -> Result<(), AgentError> { + self.ensure_chat_loaded(chat_id)?; + + let session_id = self.persistent_session_id(chat_id); + let session_record = self + .store + .get_session(&session_id) + .map_err(|err| AgentError::Other(format!("get session error: {}", err)))? + .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; + + if self.agent_prompt_reinject_every > 0 + && session_record.user_turn_count > 0 + && session_record.user_turn_count / self.agent_prompt_reinject_every + > session_record.agent_prompt_reinjection_count + { + if let Some(agent_prompt) = load_agent_prompt()? { + self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?; + self.store + .mark_agent_prompt_reinjected(&session_id) + .map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?; + } + } + Ok(()) } @@ -264,6 +297,51 @@ impl Session { }) }) } + + fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> { + let history_is_empty = self + .get_history(chat_id) + .map(|history| history.is_empty()) + .unwrap_or(true); + + if !history_is_empty { + return Ok(()); + } + + if let Some(agent_prompt) = load_agent_prompt()? { + self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?; + } + + Ok(()) + } +} + +fn load_agent_prompt() -> Result, AgentError> { + let path = agent_prompt_path()?; + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?; + } + + if !path.exists() { + fs::write(&path, DEFAULT_AGENT_PROMPT) + .map_err(|err| AgentError::Other(format!("create agent prompt file error: {}", err)))?; + } + + let content = fs::read_to_string(&path) + .map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?; + let trimmed = content.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + Ok(Some(trimmed.to_string())) +} + +fn agent_prompt_path() -> Result { + let home = dirs::home_dir() + .ok_or_else(|| AgentError::Other("home directory not found".to_string()))?; + Ok(home.join(".picobot").join("agent").join("AGENT.md")) } /// SessionManager 管理所有 Session,按 channel_name 路由 @@ -274,6 +352,7 @@ pub struct SessionManager { tools: Arc, skills: Arc, store: Arc, + agent_prompt_reinject_every: u64, } struct SessionManagerInner { @@ -331,6 +410,7 @@ pub(crate) fn handle_in_chat_command( impl SessionManager { pub fn new( session_ttl_hours: u64, + agent_prompt_reinject_every: u64, provider_config: LLMProviderConfig, skills: Arc, ) -> Result { @@ -353,6 +433,7 @@ impl SessionManager { tools: Arc::new(default_tools(skills.clone(), store.clone())), skills, store, + agent_prompt_reinject_every, }) } @@ -447,6 +528,7 @@ impl SessionManager { self.tools.clone(), self.skills.clone(), self.store.clone(), + self.agent_prompt_reinject_every, ) .await?; let arc = Arc::new(Mutex::new(session)); @@ -523,6 +605,8 @@ impl SessionManager { )]); } + session_guard.ensure_agent_prompt_before_user_message(chat_id)?; + // 添加用户消息到历史 let media_refs: Vec = media.iter().map(|m| m.path.clone()).collect(); #[cfg(debug_assertions)] @@ -633,6 +717,7 @@ mod tests { tools, skills, store.clone(), + 100, ) .await .unwrap(); @@ -658,7 +743,113 @@ mod tests { .load_all_messages(&session.persistent_session_id("chat-1")) .unwrap() .len(), - 1, + 2, ); } + + #[tokio::test] + async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store.clone(), + 100, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + + let history = session.get_history("chat-1").unwrap(); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + assert!(history[0].content.contains("PicoBot 代理配置")); + } + + #[tokio::test] + async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store.clone(), + 100, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + + for turn in 0..100 { + session + .append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}"))) + .unwrap(); + } + + session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + + let history = session.get_history("chat-1").unwrap(); + let system_messages = history.iter().filter(|message| message.role == "system").count(); + assert_eq!(system_messages, 2); + + let stored = store + .get_session(&session.persistent_session_id("chat-1")) + .unwrap() + .unwrap(); + assert_eq!(stored.agent_prompt_reinjection_count, 1); + + session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + let history = session.get_history("chat-1").unwrap(); + let system_messages = history.iter().filter(|message| message.role == "system").count(); + assert_eq!(system_messages, 2); + } + + #[tokio::test] + async fn test_agent_prompt_reinjection_can_be_disabled_by_config() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store.clone(), + 0, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + + for turn in 0..100 { + session + .append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}"))) + .unwrap(); + } + + session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + + let history = session.get_history("chat-1").unwrap(); + let system_messages = history.iter().filter(|message| message.role == "system").count(); + assert_eq!(system_messages, 1); + } } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index b77cd2c..88dd308 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -59,6 +59,7 @@ async fn handle_socket(ws: WebSocket, state: Arc) { state.session_manager.tools(), state.session_manager.skills(), state.session_manager.store(), + state.config.gateway.agent_prompt_reinject_every, ) .await { @@ -210,6 +211,8 @@ async fn handle_inbound( return Ok(()); } + session_guard.ensure_agent_prompt_before_user_message(&chat_id)?; + let user_message = session_guard.create_user_message(&content, Vec::new()); let user_message_id = user_message.id.clone(); session_guard.append_persisted_message(&chat_id, user_message)?; diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 32b9802..38979fa 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -40,6 +40,8 @@ pub struct SessionRecord { pub deleted_at: Option, pub message_count: i64, pub reset_cutoff_seq: i64, + pub user_turn_count: i64, + pub agent_prompt_reinjection_count: i64, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -113,7 +115,9 @@ impl SessionStore { archived_at INTEGER, deleted_at INTEGER, message_count INTEGER NOT NULL DEFAULT 0, - reset_cutoff_seq INTEGER NOT NULL DEFAULT 0 + reset_cutoff_seq INTEGER NOT NULL DEFAULT 0, + user_turn_count INTEGER NOT NULL DEFAULT 0, + agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0 ); CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived @@ -234,8 +238,9 @@ impl SessionStore { " INSERT INTO sessions ( id, title, channel_name, chat_id, summary, - created_at, updated_at, last_active_at, archived_at, deleted_at, message_count - ) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0) + created_at, updated_at, last_active_at, archived_at, deleted_at, message_count, + reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count + ) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0) ", params![id, title, id, now], )?; @@ -261,8 +266,9 @@ impl SessionStore { " INSERT INTO sessions ( id, title, channel_name, chat_id, summary, - created_at, updated_at, last_active_at, archived_at, deleted_at, message_count - ) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0) + created_at, updated_at, last_active_at, archived_at, deleted_at, message_count, + reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count + ) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0) ", params![session_id, title, channel_name, chat_id, now], )?; @@ -277,7 +283,8 @@ impl SessionStore { " SELECT id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, - archived_at, deleted_at, message_count, reset_cutoff_seq + archived_at, deleted_at, message_count, reset_cutoff_seq, + user_turn_count, agent_prompt_reinjection_count FROM sessions WHERE id = ?1 AND deleted_at IS NULL ", @@ -298,7 +305,8 @@ impl SessionStore { " SELECT id, title, channel_name, chat_id, summary, created_at, updated_at, last_active_at, - archived_at, deleted_at, message_count, reset_cutoff_seq + archived_at, deleted_at, message_count, reset_cutoff_seq, + user_turn_count, agent_prompt_reinjection_count FROM sessions WHERE channel_name = ?1 AND deleted_at IS NULL @@ -354,7 +362,12 @@ impl SessionStore { conn.execute( " UPDATE sessions - SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0 + SET message_count = 0, + updated_at = ?2, + last_active_at = ?2, + reset_cutoff_seq = 0, + user_turn_count = 0, + agent_prompt_reinjection_count = 0 WHERE id = ?1 AND deleted_at IS NULL ", params![session_id, now], @@ -379,7 +392,9 @@ impl SessionStore { SET reset_cutoff_seq = ?2, updated_at = ?3, last_active_at = ?3, - archived_at = NULL + archived_at = NULL, + user_turn_count = 0, + agent_prompt_reinjection_count = 0 WHERE id = ?1 AND deleted_at IS NULL ", params![session_id, cutoff_seq, now], @@ -423,10 +438,31 @@ impl SessionStore { )?; let now = current_timestamp(); + let is_user_message = message.role == "user"; tx.execute( " UPDATE sessions SET message_count = message_count + 1, + user_turn_count = user_turn_count + ?3, + updated_at = ?2, + last_active_at = ?2, + archived_at = NULL + WHERE id = ?1 AND deleted_at IS NULL + ", + params![session_id, now, if is_user_message { 1 } else { 0 }], + )?; + + tx.commit()?; + Ok(()) + } + + pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + conn.execute( + " + UPDATE sessions + SET agent_prompt_reinjection_count = agent_prompt_reinjection_count + 1, updated_at = ?2, last_active_at = ?2, archived_at = NULL @@ -434,8 +470,6 @@ impl SessionStore { ", params![session_id, now], )?; - - tx.commit()?; Ok(()) } @@ -779,6 +813,8 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result Result<(), StorageError> { )?; } + if !has_column(conn, "sessions", "user_turn_count")? { + conn.execute( + "ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0", + [], + )?; + } + + if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? { + conn.execute( + "ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0", + [], + )?; + } + Ok(()) } @@ -947,6 +997,8 @@ mod tests { assert_eq!(session.chat_id, session.id); assert_eq!(session.message_count, 0); assert_eq!(session.reset_cutoff_seq, 0); + assert_eq!(session.user_turn_count, 0); + assert_eq!(session.agent_prompt_reinjection_count, 0); let first = ChatMessage::user("hello"); let second = ChatMessage::assistant("world"); @@ -957,6 +1009,8 @@ mod tests { assert_eq!(stored.message_count, 2); assert!(stored.archived_at.is_none()); assert_eq!(stored.reset_cutoff_seq, 0); + assert_eq!(stored.user_turn_count, 1); + assert_eq!(stored.agent_prompt_reinjection_count, 0); let messages = store.load_messages(&session.id).unwrap(); assert_eq!(messages.len(), 2); @@ -984,6 +1038,8 @@ mod tests { assert!(cleared.is_empty()); let cleared_session = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(cleared_session.message_count, 0); + assert_eq!(cleared_session.user_turn_count, 0); + assert_eq!(cleared_session.agent_prompt_reinjection_count, 0); store.delete_session(&session.id).unwrap(); assert!(store.get_session(&session.id).unwrap().is_none()); @@ -1036,6 +1092,8 @@ mod tests { let stored = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(stored.reset_cutoff_seq, 2); + assert_eq!(stored.user_turn_count, 0); + assert_eq!(stored.agent_prompt_reinjection_count, 0); let active_messages = store.load_messages(&session.id).unwrap(); assert!(active_messages.is_empty()); @@ -1049,6 +1107,9 @@ mod tests { let active_messages = store.load_messages(&session.id).unwrap(); assert_eq!(active_messages.len(), 1); assert_eq!(active_messages[0].content, "after"); + + let stored = store.get_session(&session.id).unwrap().unwrap(); + assert_eq!(stored.user_turn_count, 1); } #[test] @@ -1091,6 +1152,20 @@ mod tests { let store = SessionStore::from_connection(conn).unwrap(); let session = store.create_cli_session(Some("migrated")).unwrap(); assert_eq!(session.reset_cutoff_seq, 0); + assert_eq!(session.user_turn_count, 0); + assert_eq!(session.agent_prompt_reinjection_count, 0); + } + + #[test] + fn test_mark_agent_prompt_reinjected_increments_counter() { + let store = SessionStore::in_memory().unwrap(); + let session = store.create_cli_session(Some("prompt")).unwrap(); + + store.mark_agent_prompt_reinjected(&session.id).unwrap(); + store.mark_agent_prompt_reinjected(&session.id).unwrap(); + + let stored = store.get_session(&session.id).unwrap().unwrap(); + assert_eq!(stored.agent_prompt_reinjection_count, 2); } #[test]