feat(session): 更新聊天加载逻辑,确保初始代理提示并添加活跃用户消息计数功能
This commit is contained in:
parent
09ccd71cc7
commit
9d15d50b09
@ -111,7 +111,7 @@ impl Session {
|
||||
|
||||
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if self.chat_histories.contains_key(chat_id) {
|
||||
return Ok(());
|
||||
return self.ensure_initial_agent_prompt(chat_id);
|
||||
}
|
||||
|
||||
let history = self
|
||||
@ -132,10 +132,14 @@ impl Session {
|
||||
.get_session(&session_id)
|
||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
let active_user_turns = self
|
||||
.store
|
||||
.count_active_user_messages(&session_id)
|
||||
.map_err(|err| AgentError::Other(format!("count active user messages error: {}", err)))?;
|
||||
|
||||
if self.agent_prompt_reinject_every > 0
|
||||
&& session_record.user_turn_count > 0
|
||||
&& session_record.user_turn_count / self.agent_prompt_reinject_every
|
||||
&& active_user_turns > 0
|
||||
&& active_user_turns / self.agent_prompt_reinject_every
|
||||
> session_record.agent_prompt_reinjection_count
|
||||
{
|
||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||
@ -745,6 +749,11 @@ mod tests {
|
||||
.len(),
|
||||
2,
|
||||
);
|
||||
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -852,4 +861,36 @@ mod tests {
|
||||
let system_messages = history.iter().filter(|message| message.role == "system").count();
|
||||
assert_eq!(system_messages, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
||||
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
|
||||
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
}
|
||||
}
|
||||
|
||||
@ -784,6 +784,21 @@ impl SessionStore {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
load_messages_after(&conn, session_id, 0)
|
||||
}
|
||||
|
||||
pub fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
|
||||
conn.query_row(
|
||||
"
|
||||
SELECT COUNT(*)
|
||||
FROM messages
|
||||
WHERE session_id = ?1 AND seq > ?2 AND role = 'user'
|
||||
",
|
||||
params![session_id, cutoff_seq],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(StorageError::from)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
||||
@ -1156,6 +1171,28 @@ mod tests {
|
||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_active_user_messages_respects_reset_cutoff_seq() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("count-users")).unwrap();
|
||||
|
||||
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
|
||||
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
|
||||
store.reset_session(&session.id).unwrap();
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
|
||||
|
||||
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
|
||||
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mark_agent_prompt_reinjected_increments_counter() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user