feat(session): 更新聊天加载逻辑,确保初始代理提示并添加活跃用户消息计数功能
This commit is contained in:
parent
09ccd71cc7
commit
9d15d50b09
@ -111,7 +111,7 @@ impl Session {
|
|||||||
|
|
||||||
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
if self.chat_histories.contains_key(chat_id) {
|
if self.chat_histories.contains_key(chat_id) {
|
||||||
return Ok(());
|
return self.ensure_initial_agent_prompt(chat_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
let history = self
|
let history = self
|
||||||
@ -132,10 +132,14 @@ impl Session {
|
|||||||
.get_session(&session_id)
|
.get_session(&session_id)
|
||||||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
|
||||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||||
|
let active_user_turns = self
|
||||||
|
.store
|
||||||
|
.count_active_user_messages(&session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("count active user messages error: {}", err)))?;
|
||||||
|
|
||||||
if self.agent_prompt_reinject_every > 0
|
if self.agent_prompt_reinject_every > 0
|
||||||
&& session_record.user_turn_count > 0
|
&& active_user_turns > 0
|
||||||
&& session_record.user_turn_count / self.agent_prompt_reinject_every
|
&& active_user_turns / self.agent_prompt_reinject_every
|
||||||
> session_record.agent_prompt_reinjection_count
|
> session_record.agent_prompt_reinjection_count
|
||||||
{
|
{
|
||||||
if let Some(agent_prompt) = load_agent_prompt()? {
|
if let Some(agent_prompt) = load_agent_prompt()? {
|
||||||
@ -745,6 +749,11 @@ mod tests {
|
|||||||
.len(),
|
.len(),
|
||||||
2,
|
2,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
let history = session.get_history("chat-1").unwrap();
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -852,4 +861,36 @@ mod tests {
|
|||||||
let system_messages = history.iter().filter(|message| message.role == "system").count();
|
let system_messages = history.iter().filter(|message| message.role == "system").count();
|
||||||
assert_eq!(system_messages, 1);
|
assert_eq!(system_messages, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
|
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||||
|
let mut session = Session::new(
|
||||||
|
"feishu".to_string(),
|
||||||
|
test_provider_config(),
|
||||||
|
user_tx,
|
||||||
|
tools,
|
||||||
|
skills,
|
||||||
|
store.clone(),
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
session.ensure_persistent_session("chat-1").unwrap();
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
session
|
||||||
|
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
||||||
|
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
|
||||||
|
|
||||||
|
let history = session.get_history("chat-1").unwrap();
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -784,6 +784,21 @@ impl SessionStore {
|
|||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
load_messages_after(&conn, session_id, 0)
|
load_messages_after(&conn, session_id, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
|
||||||
|
conn.query_row(
|
||||||
|
"
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ?1 AND seq > ?2 AND role = 'user'
|
||||||
|
",
|
||||||
|
params![session_id, cutoff_seq],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.map_err(StorageError::from)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
||||||
@ -1156,6 +1171,28 @@ mod tests {
|
|||||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_count_active_user_messages_respects_reset_cutoff_seq() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
let session = store.create_cli_session(Some("count-users")).unwrap();
|
||||||
|
|
||||||
|
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
|
||||||
|
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
|
||||||
|
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
|
||||||
|
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||||
|
|
||||||
|
store.reset_session(&session.id).unwrap();
|
||||||
|
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
|
||||||
|
|
||||||
|
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
|
||||||
|
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
|
||||||
|
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mark_agent_prompt_reinjected_increments_counter() {
|
fn test_mark_agent_prompt_reinjected_increments_counter() {
|
||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user