diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 684ec1b..4008e20 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -111,7 +111,7 @@ impl Session { pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { if self.chat_histories.contains_key(chat_id) { - return Ok(()); + return self.ensure_initial_agent_prompt(chat_id); } let history = self @@ -132,10 +132,14 @@ impl Session { .get_session(&session_id) .map_err(|err| AgentError::Other(format!("get session error: {}", err)))? .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; + let active_user_turns = self + .store + .count_active_user_messages(&session_id) + .map_err(|err| AgentError::Other(format!("count active user messages error: {}", err)))?; if self.agent_prompt_reinject_every > 0 - && session_record.user_turn_count > 0 - && session_record.user_turn_count / self.agent_prompt_reinject_every + && active_user_turns > 0 + && active_user_turns / self.agent_prompt_reinject_every > session_record.agent_prompt_reinjection_count { if let Some(agent_prompt) = load_agent_prompt()? { @@ -745,6 +749,11 @@ mod tests { .len(), 2, ); + + session.ensure_chat_loaded("chat-1").unwrap(); + let history = session.get_history("chat-1").unwrap(); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); } #[tokio::test] @@ -852,4 +861,36 @@ mod tests { let system_messages = history.iter().filter(|message| message.role == "system").count(); assert_eq!(system_messages, 1); } + + #[tokio::test] + async fn test_reset_reinjects_agent_prompt_before_next_user_message() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store.clone(), + 100, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + session + .append_persisted_message("chat-1", ChatMessage::user("hello")) + .unwrap(); + + handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap(); + session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + + let history = session.get_history("chat-1").unwrap(); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 38979fa..dab0816 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -784,6 +784,21 @@ impl SessionStore { let conn = self.conn.lock().expect("session db mutex poisoned"); load_messages_after(&conn, session_id, 0) } + + pub fn count_active_user_messages(&self, session_id: &str) -> Result { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let cutoff_seq = active_reset_cutoff(&conn, session_id)?; + conn.query_row( + " + SELECT COUNT(*) + FROM messages + WHERE session_id = ?1 AND seq > ?2 AND role = 'user' + ", + params![session_id, cutoff_seq], + |row| row.get(0), + ) + .map_err(StorageError::from) + } } pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String { @@ -1156,6 +1171,28 @@ mod tests { assert_eq!(session.agent_prompt_reinjection_count, 0); } + #[test] + fn test_count_active_user_messages_respects_reset_cutoff_seq() { + let store = SessionStore::in_memory().unwrap(); + let session = store.create_cli_session(Some("count-users")).unwrap(); + + store.append_message(&session.id, &ChatMessage::system("agent")).unwrap(); + store.append_message(&session.id, &ChatMessage::user("u1")).unwrap(); + store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap(); + store.append_message(&session.id, &ChatMessage::user("u2")).unwrap(); + + assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2); + + store.reset_session(&session.id).unwrap(); + assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0); + + store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap(); + store.append_message(&session.id, &ChatMessage::user("u3")).unwrap(); + store.append_message(&session.id, &ChatMessage::user("u4")).unwrap(); + + assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2); + } + #[test] fn test_mark_agent_prompt_reinjected_increments_counter() { let store = SessionStore::in_memory().unwrap();