feat(agent_profile): 实现代理配置文件的注入与周期性重注入机制
This commit is contained in:
parent
4725b5406e
commit
0dfa615ca9
17
README.md
17
README.md
@ -2,6 +2,23 @@ PicoBot
|
||||
|
||||
Skills (initial implementation)
|
||||
|
||||
Agent profile injection
|
||||
|
||||
PicoBot maintains a persistent agent profile file at ~/.picobot/agent/AGENT.md.
|
||||
|
||||
Behavior:
|
||||
- The directory ~/.picobot/agent is created automatically when needed.
|
||||
- If AGENT.md does not exist yet, PicoBot creates it with a default profile.
|
||||
- When the active conversation is empty or has just been reset, AGENT.md is loaded as the first system message in the active history.
|
||||
- After every configured number of user turns, PicoBot injects the latest AGENT.md content again before the next user message is appended.
|
||||
|
||||
Config:
|
||||
- Set gateway.agent_prompt_reinject_every in ~/.picobot/config.json.
|
||||
- Default value is 100.
|
||||
- Set it to 0 to disable periodic reinjection.
|
||||
|
||||
This profile is persisted in session history, while the skills index system prompt is still injected dynamically by AgentLoop.
|
||||
|
||||
PicoBot now supports filesystem skills.
|
||||
|
||||
Skill discovery locations:
|
||||
|
||||
@ -48,9 +48,11 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
2. `SessionManager` 定位到对应 channel 的运行时 `Session`。
|
||||
3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。
|
||||
4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。
|
||||
5. 新的用户消息先写入 `messages`,再放入内存历史。
|
||||
6. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`。
|
||||
7. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
|
||||
5. 如果当前活动段历史为空,系统会从 `~/.picobot/agent/AGENT.md` 读取 Agent 基本设定,并先写入一条 `system` 消息。
|
||||
6. 在新的用户消息进入前,系统会检查当前活动段的 `user_turn_count` 是否刚跨过配置项 `gateway.agent_prompt_reinject_every` 指定的下一轮阈值;如果跨过,就再次把 `AGENT.md` 写入一条新的 `system` 消息。
|
||||
7. 新的用户消息先写入 `messages`,再放入内存历史。
|
||||
8. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`。
|
||||
9. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
|
||||
|
||||
## 3. 会话标识规则
|
||||
|
||||
@ -86,6 +88,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
|
||||
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
|
||||
| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 |
|
||||
| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史和 `/reset` 时归零 |
|
||||
| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次“达到配置阈值后的下一轮前注入”就递增,清空历史和 `/reset` 时归零 |
|
||||
|
||||
索引:
|
||||
|
||||
@ -171,6 +175,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
|
||||
其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。
|
||||
|
||||
另外,只有 `role = user` 的消息会递增 `user_turn_count`;`system`、`assistant`、`tool` 消息不会影响周期注入阈值的判定。
|
||||
|
||||
### 6.3 读取历史
|
||||
|
||||
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是:
|
||||
@ -236,6 +242,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
- 删除该会话在 `messages` 中的所有记录
|
||||
- 将 `sessions.message_count` 重置为 0
|
||||
- 将 `sessions.reset_cutoff_seq` 重置为 0
|
||||
- 将 `sessions.user_turn_count` 重置为 0
|
||||
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
|
||||
- 更新 `updated_at` 和 `last_active_at`
|
||||
- 保留会话本身
|
||||
|
||||
@ -247,11 +255,15 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
|
||||
- 不删除 `messages` 中的任何记录
|
||||
- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq`
|
||||
- 将 `sessions.user_turn_count` 重置为 0
|
||||
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
|
||||
- 更新 `updated_at` 和 `last_active_at`
|
||||
- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息
|
||||
|
||||
这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。
|
||||
|
||||
由于 AGENT.md 注入消息也会持久化,`/reset` 前的 Agent 设定消息仍会保留在完整历史中,但不会继续出现在新的活动段。下一次活动段首次加载时,系统会重新读取当前版本的 `~/.picobot/agent/AGENT.md`,并把它作为新的首条系统消息写入活动段。
|
||||
|
||||
### 8.5 删除会话
|
||||
|
||||
`delete_session(session_id)`:
|
||||
|
||||
@ -136,6 +136,8 @@ pub struct GatewayConfig {
|
||||
pub port: u16,
|
||||
#[serde(default, rename = "session_ttl_hours")]
|
||||
pub session_ttl_hours: Option<u64>,
|
||||
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
||||
pub agent_prompt_reinject_every: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -156,12 +158,17 @@ fn default_gateway_url() -> String {
|
||||
"ws://127.0.0.1:19876/ws".to_string()
|
||||
}
|
||||
|
||||
fn default_agent_prompt_reinject_every() -> u64 {
|
||||
100
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
session_ttl_hours: None,
|
||||
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -344,7 +351,8 @@ mod tests {
|
||||
},
|
||||
"gateway": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 19876
|
||||
"port": 19876,
|
||||
"agent_prompt_reinject_every": 120
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
@ -387,5 +395,39 @@ mod tests {
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
assert_eq!(config.gateway.host, "0.0.0.0");
|
||||
assert_eq!(config.gateway.port, 19876);
|
||||
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gateway_config_defaults_agent_prompt_reinject_every() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
"providers": {
|
||||
"aliyun": {
|
||||
"type": "openai",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"api_key": "test-key",
|
||||
"extra_headers": {}
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"qwen-plus": {
|
||||
"model_id": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "aliyun",
|
||||
"model": "qwen-plus"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,10 +29,16 @@ impl GatewayState {
|
||||
|
||||
// Session TTL from config (default 4 hours)
|
||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
||||
|
||||
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
||||
|
||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config, skills)?;
|
||||
let session_manager = SessionManager::new(
|
||||
session_ttl_hours,
|
||||
agent_prompt_reinject_every,
|
||||
provider_config,
|
||||
skills,
|
||||
)?;
|
||||
let channel_manager = ChannelManager::new();
|
||||
let bus = channel_manager.bus();
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use async_trait::async_trait;
|
||||
@ -16,6 +17,8 @@ use crate::tools::{
|
||||
WebFetchTool,
|
||||
};
|
||||
|
||||
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot,一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||
pub struct Session {
|
||||
@ -29,6 +32,7 @@ pub struct Session {
|
||||
skills: Arc<SkillRuntime>,
|
||||
compressor: ContextCompressor,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: i64,
|
||||
}
|
||||
|
||||
pub struct BusToolCallEmitter {
|
||||
@ -79,6 +83,7 @@ impl Session {
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: u64,
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
@ -90,6 +95,7 @@ impl Session {
|
||||
skills,
|
||||
compressor: ContextCompressor::new(provider_config.token_limit),
|
||||
store,
|
||||
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
|
||||
})
|
||||
}
|
||||
|
||||
@ -113,6 +119,33 @@ impl Session {
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
self.ensure_initial_agent_prompt(chat_id)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn ensure_agent_prompt_before_user_message(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
self.ensure_chat_loaded(chat_id)?;
|
||||
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
let session_record = self
|
||||
.store
|
||||
.get_session(&session_id)
|
||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
|
||||
if self.agent_prompt_reinject_every > 0
|
||||
&& session_record.user_turn_count > 0
|
||||
&& session_record.user_turn_count / self.agent_prompt_reinject_every
|
||||
> session_record.agent_prompt_reinjection_count
|
||||
{
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
|
||||
self.store
|
||||
.mark_agent_prompt_reinjected(&session_id)
|
||||
.map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -264,6 +297,51 @@ impl Session {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history_is_empty = self
|
||||
.get_history(chat_id)
|
||||
.map(|history| history.is_empty())
|
||||
.unwrap_or(true);
|
||||
|
||||
if !history_is_empty {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
if !path.exists() {
|
||||
fs::write(&path, DEFAULT_AGENT_PROMPT)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt file error: {}", err)))?;
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&path)
|
||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(trimmed.to_string()))
|
||||
}
|
||||
|
||||
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
|
||||
let home = dirs::home_dir()
|
||||
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
||||
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
||||
}
|
||||
|
||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||
@ -274,6 +352,7 @@ pub struct SessionManager {
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: u64,
|
||||
}
|
||||
|
||||
struct SessionManagerInner {
|
||||
@ -331,6 +410,7 @@ pub(crate) fn handle_in_chat_command(
|
||||
impl SessionManager {
|
||||
pub fn new(
|
||||
session_ttl_hours: u64,
|
||||
agent_prompt_reinject_every: u64,
|
||||
provider_config: LLMProviderConfig,
|
||||
skills: Arc<SkillRuntime>,
|
||||
) -> Result<Self, AgentError> {
|
||||
@ -353,6 +433,7 @@ impl SessionManager {
|
||||
tools: Arc::new(default_tools(skills.clone(), store.clone())),
|
||||
skills,
|
||||
store,
|
||||
agent_prompt_reinject_every,
|
||||
})
|
||||
}
|
||||
|
||||
@ -447,6 +528,7 @@ impl SessionManager {
|
||||
self.tools.clone(),
|
||||
self.skills.clone(),
|
||||
self.store.clone(),
|
||||
self.agent_prompt_reinject_every,
|
||||
)
|
||||
.await?;
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
@ -523,6 +605,8 @@ impl SessionManager {
|
||||
)]);
|
||||
}
|
||||
|
||||
session_guard.ensure_agent_prompt_before_user_message(chat_id)?;
|
||||
|
||||
// 添加用户消息到历史
|
||||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||||
#[cfg(debug_assertions)]
|
||||
@ -633,6 +717,7 @@ mod tests {
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@ -658,7 +743,113 @@ mod tests {
|
||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.len(),
|
||||
1,
|
||||
2,
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
assert!(history[0].content.contains("PicoBot 代理配置"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
|
||||
for turn in 0..100 {
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
|
||||
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
let system_messages = history.iter().filter(|message| message.role == "system").count();
|
||||
assert_eq!(system_messages, 2);
|
||||
|
||||
let stored = store
|
||||
.get_session(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(stored.agent_prompt_reinjection_count, 1);
|
||||
|
||||
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
let system_messages = history.iter().filter(|message| message.role == "system").count();
|
||||
assert_eq!(system_messages, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
|
||||
for turn in 0..100 {
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
|
||||
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
let system_messages = history.iter().filter(|message| message.role == "system").count();
|
||||
assert_eq!(system_messages, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,6 +59,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
state.session_manager.tools(),
|
||||
state.session_manager.skills(),
|
||||
state.session_manager.store(),
|
||||
state.config.gateway.agent_prompt_reinject_every,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@ -210,6 +211,8 @@ async fn handle_inbound(
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
|
||||
|
||||
let user_message = session_guard.create_user_message(&content, Vec::new());
|
||||
let user_message_id = user_message.id.clone();
|
||||
session_guard.append_persisted_message(&chat_id, user_message)?;
|
||||
|
||||
@ -40,6 +40,8 @@ pub struct SessionRecord {
|
||||
pub deleted_at: Option<i64>,
|
||||
pub message_count: i64,
|
||||
pub reset_cutoff_seq: i64,
|
||||
pub user_turn_count: i64,
|
||||
pub agent_prompt_reinjection_count: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -113,7 +115,9 @@ impl SessionStore {
|
||||
archived_at INTEGER,
|
||||
deleted_at INTEGER,
|
||||
message_count INTEGER NOT NULL DEFAULT 0,
|
||||
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0
|
||||
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0,
|
||||
user_turn_count INTEGER NOT NULL DEFAULT 0,
|
||||
agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
|
||||
@ -234,8 +238,9 @@ impl SessionStore {
|
||||
"
|
||||
INSERT INTO sessions (
|
||||
id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
|
||||
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0)
|
||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||||
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
|
||||
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0)
|
||||
",
|
||||
params![id, title, id, now],
|
||||
)?;
|
||||
@ -261,8 +266,9 @@ impl SessionStore {
|
||||
"
|
||||
INSERT INTO sessions (
|
||||
id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
|
||||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0)
|
||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||||
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
|
||||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
|
||||
",
|
||||
params![session_id, title, channel_name, chat_id, now],
|
||||
)?;
|
||||
@ -277,7 +283,8 @@ impl SessionStore {
|
||||
"
|
||||
SELECT id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at,
|
||||
archived_at, deleted_at, message_count, reset_cutoff_seq
|
||||
archived_at, deleted_at, message_count, reset_cutoff_seq,
|
||||
user_turn_count, agent_prompt_reinjection_count
|
||||
FROM sessions
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
@ -298,7 +305,8 @@ impl SessionStore {
|
||||
"
|
||||
SELECT id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at,
|
||||
archived_at, deleted_at, message_count, reset_cutoff_seq
|
||||
archived_at, deleted_at, message_count, reset_cutoff_seq,
|
||||
user_turn_count, agent_prompt_reinjection_count
|
||||
FROM sessions
|
||||
WHERE channel_name = ?1
|
||||
AND deleted_at IS NULL
|
||||
@ -354,7 +362,12 @@ impl SessionStore {
|
||||
conn.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0
|
||||
SET message_count = 0,
|
||||
updated_at = ?2,
|
||||
last_active_at = ?2,
|
||||
reset_cutoff_seq = 0,
|
||||
user_turn_count = 0,
|
||||
agent_prompt_reinjection_count = 0
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
params![session_id, now],
|
||||
@ -379,7 +392,9 @@ impl SessionStore {
|
||||
SET reset_cutoff_seq = ?2,
|
||||
updated_at = ?3,
|
||||
last_active_at = ?3,
|
||||
archived_at = NULL
|
||||
archived_at = NULL,
|
||||
user_turn_count = 0,
|
||||
agent_prompt_reinjection_count = 0
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
params![session_id, cutoff_seq, now],
|
||||
@ -423,10 +438,31 @@ impl SessionStore {
|
||||
)?;
|
||||
|
||||
let now = current_timestamp();
|
||||
let is_user_message = message.role == "user";
|
||||
tx.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
SET message_count = message_count + 1,
|
||||
user_turn_count = user_turn_count + ?3,
|
||||
updated_at = ?2,
|
||||
last_active_at = ?2,
|
||||
archived_at = NULL
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
params![session_id, now, if is_user_message { 1 } else { 0 }],
|
||||
)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
conn.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
SET agent_prompt_reinjection_count = agent_prompt_reinjection_count + 1,
|
||||
updated_at = ?2,
|
||||
last_active_at = ?2,
|
||||
archived_at = NULL
|
||||
@ -434,8 +470,6 @@ impl SessionStore {
|
||||
",
|
||||
params![session_id, now],
|
||||
)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -779,6 +813,8 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
|
||||
deleted_at: row.get(9)?,
|
||||
message_count: row.get(10)?,
|
||||
reset_cutoff_seq: row.get(11)?,
|
||||
user_turn_count: row.get(12)?,
|
||||
agent_prompt_reinjection_count: row.get(13)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -829,6 +865,20 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "sessions", "user_turn_count")? {
|
||||
conn.execute(
|
||||
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
|
||||
conn.execute(
|
||||
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -947,6 +997,8 @@ mod tests {
|
||||
assert_eq!(session.chat_id, session.id);
|
||||
assert_eq!(session.message_count, 0);
|
||||
assert_eq!(session.reset_cutoff_seq, 0);
|
||||
assert_eq!(session.user_turn_count, 0);
|
||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||
|
||||
let first = ChatMessage::user("hello");
|
||||
let second = ChatMessage::assistant("world");
|
||||
@ -957,6 +1009,8 @@ mod tests {
|
||||
assert_eq!(stored.message_count, 2);
|
||||
assert!(stored.archived_at.is_none());
|
||||
assert_eq!(stored.reset_cutoff_seq, 0);
|
||||
assert_eq!(stored.user_turn_count, 1);
|
||||
assert_eq!(stored.agent_prompt_reinjection_count, 0);
|
||||
|
||||
let messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(messages.len(), 2);
|
||||
@ -984,6 +1038,8 @@ mod tests {
|
||||
assert!(cleared.is_empty());
|
||||
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(cleared_session.message_count, 0);
|
||||
assert_eq!(cleared_session.user_turn_count, 0);
|
||||
assert_eq!(cleared_session.agent_prompt_reinjection_count, 0);
|
||||
|
||||
store.delete_session(&session.id).unwrap();
|
||||
assert!(store.get_session(&session.id).unwrap().is_none());
|
||||
@ -1036,6 +1092,8 @@ mod tests {
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.reset_cutoff_seq, 2);
|
||||
assert_eq!(stored.user_turn_count, 0);
|
||||
assert_eq!(stored.agent_prompt_reinjection_count, 0);
|
||||
|
||||
let active_messages = store.load_messages(&session.id).unwrap();
|
||||
assert!(active_messages.is_empty());
|
||||
@ -1049,6 +1107,9 @@ mod tests {
|
||||
let active_messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(active_messages.len(), 1);
|
||||
assert_eq!(active_messages[0].content, "after");
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.user_turn_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1091,6 +1152,20 @@ mod tests {
|
||||
let store = SessionStore::from_connection(conn).unwrap();
|
||||
let session = store.create_cli_session(Some("migrated")).unwrap();
|
||||
assert_eq!(session.reset_cutoff_seq, 0);
|
||||
assert_eq!(session.user_turn_count, 0);
|
||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mark_agent_prompt_reinjected_increments_counter() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("prompt")).unwrap();
|
||||
|
||||
store.mark_agent_prompt_reinjected(&session.id).unwrap();
|
||||
store.mark_agent_prompt_reinjected(&session.id).unwrap();
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.agent_prompt_reinjection_count, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user