feat(agent_profile): 实现代理配置文件的注入与周期性重注入机制

This commit is contained in:
ooodc 2026-04-22 09:45:19 +08:00
parent 4725b5406e
commit 0dfa615ca9
7 changed files with 363 additions and 17 deletions

View File

@ -2,6 +2,23 @@ PicoBot
Skills (initial implementation)
Agent profile injection
PicoBot maintains a persistent agent profile file at ~/.picobot/agent/AGENT.md.
Behavior:
- The directory ~/.picobot/agent is created automatically when needed.
- If AGENT.md does not exist yet, PicoBot creates it with a default profile.
- When the active conversation is empty or has just been reset, AGENT.md is loaded as the first system message in the active history.
- After every configured number of user turns, PicoBot injects the latest AGENT.md content again before the next user message is appended.
Config:
- Set gateway.agent_prompt_reinject_every in ~/.picobot/config.json.
- Default value is 100.
- Set it to 0 to disable periodic reinjection.
This profile is persisted in session history, while the skills index system prompt is still injected dynamically by AgentLoop.
PicoBot now supports filesystem skills.
Skill discovery locations:

View File

@ -48,9 +48,11 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
2. `SessionManager` 定位到对应 channel 的运行时 `Session`
3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。
4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。
5. 新的用户消息先写入 `messages`,再放入内存历史。
6. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`
7. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
5. 如果当前活动段历史为空,系统会从 `~/.picobot/agent/AGENT.md` 读取 Agent 基本设定,并先写入一条 `system` 消息。
6. 在新的用户消息进入前,系统会检查当前活动段的 `user_turn_count` 是否刚跨过配置项 `gateway.agent_prompt_reinject_every` 指定的下一轮阈值;如果跨过,就再次把 `AGENT.md` 写入一条新的 `system` 消息。
7. 新的用户消息先写入 `messages`,再放入内存历史。
8. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`
9. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
## 3. 会话标识规则
@ -86,6 +88,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 |
| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史和 `/reset` 时归零 |
| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次“达到配置阈值后的下一轮前注入”就递增,清空历史和 `/reset` 时归零 |
索引:
@ -171,6 +175,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。
另外,只有 `role = user` 的消息会递增 `user_turn_count``system``assistant``tool` 消息不会影响周期注入阈值的判定。
### 6.3 读取历史
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是:
@ -236,6 +242,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 删除该会话在 `messages` 中的所有记录
- 将 `sessions.message_count` 重置为 0
- 将 `sessions.reset_cutoff_seq` 重置为 0
- 将 `sessions.user_turn_count` 重置为 0
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
- 更新 `updated_at``last_active_at`
- 保留会话本身
@ -247,11 +255,15 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 不删除 `messages` 中的任何记录
- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq`
- 将 `sessions.user_turn_count` 重置为 0
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
- 更新 `updated_at``last_active_at`
- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息
这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。
由于 AGENT.md 注入消息也会持久化,`/reset` 前的 Agent 设定消息仍会保留在完整历史中,但不会继续出现在新的活动段。下一次活动段首次加载时,系统会重新读取当前版本的 `~/.picobot/agent/AGENT.md`,并把它作为新的首条系统消息写入活动段。
### 8.5 删除会话
`delete_session(session_id)`

View File

@ -136,6 +136,8 @@ pub struct GatewayConfig {
pub port: u16,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
pub agent_prompt_reinject_every: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -156,12 +158,17 @@ fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string()
}
fn default_agent_prompt_reinject_every() -> u64 {
100
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
}
}
}
@ -344,7 +351,8 @@ mod tests {
},
"gateway": {
"host": "0.0.0.0",
"port": 19876
"port": 19876,
"agent_prompt_reinject_every": 120
}
}"#,
)
@ -387,5 +395,39 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.host, "0.0.0.0");
assert_eq!(config.gateway.port, 19876);
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
}
#[test]
fn test_gateway_config_defaults_agent_prompt_reinject_every() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
}
}

View File

@ -29,10 +29,16 @@ impl GatewayState {
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
let session_manager = SessionManager::new(session_ttl_hours, provider_config, skills)?;
let session_manager = SessionManager::new(
session_ttl_hours,
agent_prompt_reinject_every,
provider_config,
skills,
)?;
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();

View File

@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::fs;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
@ -16,6 +17,8 @@ use crate::tools::{
WebFetchTool,
};
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
/// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理
pub struct Session {
@ -29,6 +32,7 @@ pub struct Session {
skills: Arc<SkillRuntime>,
compressor: ContextCompressor,
store: Arc<SessionStore>,
agent_prompt_reinject_every: i64,
}
pub struct BusToolCallEmitter {
@ -79,6 +83,7 @@ impl Session {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> {
Ok(Self {
id: Uuid::new_v4(),
@ -90,6 +95,7 @@ impl Session {
skills,
compressor: ContextCompressor::new(provider_config.token_limit),
store,
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
})
}
@ -113,6 +119,33 @@ impl Session {
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
self.ensure_initial_agent_prompt(chat_id)?;
Ok(())
}
pub fn ensure_agent_prompt_before_user_message(&mut self, chat_id: &str) -> Result<(), AgentError> {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let session_record = self
.store
.get_session(&session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
if self.agent_prompt_reinject_every > 0
&& session_record.user_turn_count > 0
&& session_record.user_turn_count / self.agent_prompt_reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
self.store
.mark_agent_prompt_reinjected(&session_id)
.map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?;
}
}
Ok(())
}
@ -264,6 +297,51 @@ impl Session {
})
})
}
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history_is_empty = self
.get_history(chat_id)
.map(|history| history.is_empty())
.unwrap_or(true);
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
}
Ok(())
}
}
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
if !path.exists() {
fs::write(&path, DEFAULT_AGENT_PROMPT)
.map_err(|err| AgentError::Other(format!("create agent prompt file error: {}", err)))?;
}
let content = fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
let trimmed = content.trim();
if trimmed.is_empty() {
return Ok(None);
}
Ok(Some(trimmed.to_string()))
}
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
}
/// SessionManager 管理所有 Session按 channel_name 路由
@ -274,6 +352,7 @@ pub struct SessionManager {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
}
struct SessionManagerInner {
@ -331,6 +410,7 @@ pub(crate) fn handle_in_chat_command(
impl SessionManager {
pub fn new(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
) -> Result<Self, AgentError> {
@ -353,6 +433,7 @@ impl SessionManager {
tools: Arc::new(default_tools(skills.clone(), store.clone())),
skills,
store,
agent_prompt_reinject_every,
})
}
@ -447,6 +528,7 @@ impl SessionManager {
self.tools.clone(),
self.skills.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await?;
let arc = Arc::new(Mutex::new(session));
@ -523,6 +605,8 @@ impl SessionManager {
)]);
}
session_guard.ensure_agent_prompt_before_user_message(chat_id)?;
// 添加用户消息到历史
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
@ -633,6 +717,7 @@ mod tests {
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
@ -658,7 +743,113 @@ mod tests {
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
1,
2,
);
}
#[tokio::test]
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
assert!(history[0].content.contains("PicoBot 代理配置"));
}
#[tokio::test]
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 2);
let stored = store
.get_session(&session.persistent_session_id("chat-1"))
.unwrap()
.unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 1);
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 2);
}
#[tokio::test]
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
0,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 1);
}
}

View File

@ -59,6 +59,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
state.session_manager.tools(),
state.session_manager.skills(),
state.session_manager.store(),
state.config.gateway.agent_prompt_reinject_every,
)
.await
{
@ -210,6 +211,8 @@ async fn handle_inbound(
return Ok(());
}
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;

View File

@ -40,6 +40,8 @@ pub struct SessionRecord {
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -113,7 +115,9 @@ impl SessionStore {
archived_at INTEGER,
deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0,
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0,
user_turn_count INTEGER NOT NULL DEFAULT 0,
agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
@ -234,8 +238,9 @@ impl SessionStore {
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0)
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0)
",
params![id, title, id, now],
)?;
@ -261,8 +266,9 @@ impl SessionStore {
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0)
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
",
params![session_id, title, channel_name, chat_id, now],
)?;
@ -277,7 +283,8 @@ impl SessionStore {
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq
archived_at, deleted_at, message_count, reset_cutoff_seq,
user_turn_count, agent_prompt_reinjection_count
FROM sessions
WHERE id = ?1 AND deleted_at IS NULL
",
@ -298,7 +305,8 @@ impl SessionStore {
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq
archived_at, deleted_at, message_count, reset_cutoff_seq,
user_turn_count, agent_prompt_reinjection_count
FROM sessions
WHERE channel_name = ?1
AND deleted_at IS NULL
@ -354,7 +362,12 @@ impl SessionStore {
conn.execute(
"
UPDATE sessions
SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0
SET message_count = 0,
updated_at = ?2,
last_active_at = ?2,
reset_cutoff_seq = 0,
user_turn_count = 0,
agent_prompt_reinjection_count = 0
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now],
@ -379,7 +392,9 @@ impl SessionStore {
SET reset_cutoff_seq = ?2,
updated_at = ?3,
last_active_at = ?3,
archived_at = NULL
archived_at = NULL,
user_turn_count = 0,
agent_prompt_reinjection_count = 0
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, cutoff_seq, now],
@ -423,10 +438,31 @@ impl SessionStore {
)?;
let now = current_timestamp();
let is_user_message = message.role == "user";
tx.execute(
"
UPDATE sessions
SET message_count = message_count + 1,
user_turn_count = user_turn_count + ?3,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now, if is_user_message { 1 } else { 0 }],
)?;
tx.commit()?;
Ok(())
}
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
UPDATE sessions
SET agent_prompt_reinjection_count = agent_prompt_reinjection_count + 1,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
@ -434,8 +470,6 @@ impl SessionStore {
",
params![session_id, now],
)?;
tx.commit()?;
Ok(())
}
@ -779,6 +813,8 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
deleted_at: row.get(9)?,
message_count: row.get(10)?,
reset_cutoff_seq: row.get(11)?,
user_turn_count: row.get(12)?,
agent_prompt_reinjection_count: row.get(13)?,
})
}
@ -829,6 +865,20 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
)?;
}
if !has_column(conn, "sessions", "user_turn_count")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
Ok(())
}
@ -947,6 +997,8 @@ mod tests {
assert_eq!(session.chat_id, session.id);
assert_eq!(session.message_count, 0);
assert_eq!(session.reset_cutoff_seq, 0);
assert_eq!(session.user_turn_count, 0);
assert_eq!(session.agent_prompt_reinjection_count, 0);
let first = ChatMessage::user("hello");
let second = ChatMessage::assistant("world");
@ -957,6 +1009,8 @@ mod tests {
assert_eq!(stored.message_count, 2);
assert!(stored.archived_at.is_none());
assert_eq!(stored.reset_cutoff_seq, 0);
assert_eq!(stored.user_turn_count, 1);
assert_eq!(stored.agent_prompt_reinjection_count, 0);
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 2);
@ -984,6 +1038,8 @@ mod tests {
assert!(cleared.is_empty());
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(cleared_session.message_count, 0);
assert_eq!(cleared_session.user_turn_count, 0);
assert_eq!(cleared_session.agent_prompt_reinjection_count, 0);
store.delete_session(&session.id).unwrap();
assert!(store.get_session(&session.id).unwrap().is_none());
@ -1036,6 +1092,8 @@ mod tests {
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.reset_cutoff_seq, 2);
assert_eq!(stored.user_turn_count, 0);
assert_eq!(stored.agent_prompt_reinjection_count, 0);
let active_messages = store.load_messages(&session.id).unwrap();
assert!(active_messages.is_empty());
@ -1049,6 +1107,9 @@ mod tests {
let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 1);
assert_eq!(active_messages[0].content, "after");
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.user_turn_count, 1);
}
#[test]
@ -1091,6 +1152,20 @@ mod tests {
let store = SessionStore::from_connection(conn).unwrap();
let session = store.create_cli_session(Some("migrated")).unwrap();
assert_eq!(session.reset_cutoff_seq, 0);
assert_eq!(session.user_turn_count, 0);
assert_eq!(session.agent_prompt_reinjection_count, 0);
}
#[test]
fn test_mark_agent_prompt_reinjected_increments_counter() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("prompt")).unwrap();
store.mark_agent_prompt_reinjected(&session.id).unwrap();
store.mark_agent_prompt_reinjected(&session.id).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 2);
}
#[test]