feat(session): 更新聊天加载逻辑,确保初始代理提示并添加活跃用户消息计数功能

This commit is contained in:
ooodc 2026-04-22 10:09:39 +08:00
parent 09ccd71cc7
commit 9d15d50b09
2 changed files with 81 additions and 3 deletions

View File

@ -111,7 +111,7 @@ impl Session {
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
if self.chat_histories.contains_key(chat_id) { if self.chat_histories.contains_key(chat_id) {
return Ok(()); return self.ensure_initial_agent_prompt(chat_id);
} }
let history = self let history = self
@ -132,10 +132,14 @@ impl Session {
.get_session(&session_id) .get_session(&session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))? .map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?; .ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = self
.store
.count_active_user_messages(&session_id)
.map_err(|err| AgentError::Other(format!("count active user messages error: {}", err)))?;
if self.agent_prompt_reinject_every > 0 if self.agent_prompt_reinject_every > 0
&& session_record.user_turn_count > 0 && active_user_turns > 0
&& session_record.user_turn_count / self.agent_prompt_reinject_every && active_user_turns / self.agent_prompt_reinject_every
> session_record.agent_prompt_reinjection_count > session_record.agent_prompt_reinjection_count
{ {
if let Some(agent_prompt) = load_agent_prompt()? { if let Some(agent_prompt) = load_agent_prompt()? {
@ -745,6 +749,11 @@ mod tests {
.len(), .len(),
2, 2,
); );
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
} }
#[tokio::test] #[tokio::test]
@ -852,4 +861,36 @@ mod tests {
let system_messages = history.iter().filter(|message| message.role == "system").count(); let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 1); assert_eq!(system_messages, 1);
} }
#[tokio::test]
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
} }

View File

@ -784,6 +784,21 @@ impl SessionStore {
let conn = self.conn.lock().expect("session db mutex poisoned"); let conn = self.conn.lock().expect("session db mutex poisoned");
load_messages_after(&conn, session_id, 0) load_messages_after(&conn, session_id, 0)
} }
pub fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
conn.query_row(
"
SELECT COUNT(*)
FROM messages
WHERE session_id = ?1 AND seq > ?2 AND role = 'user'
",
params![session_id, cutoff_seq],
|row| row.get(0),
)
.map_err(StorageError::from)
}
} }
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String { pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
@ -1156,6 +1171,28 @@ mod tests {
assert_eq!(session.agent_prompt_reinjection_count, 0); assert_eq!(session.agent_prompt_reinjection_count, 0);
} }
#[test]
fn test_count_active_user_messages_respects_reset_cutoff_seq() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("count-users")).unwrap();
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
store.reset_session(&session.id).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
}
#[test] #[test]
fn test_mark_agent_prompt_reinjected_increments_counter() { fn test_mark_agent_prompt_reinjected_increments_counter() {
let store = SessionStore::in_memory().unwrap(); let store = SessionStore::in_memory().unwrap();