feat: 更新 load_messages_for_topic 方法,支持按 session_id 过滤消息

This commit is contained in:
oudecheng 2026-06-12 19:05:06 +08:00
parent 24bbd5f8c9
commit 6f8c4a7ce8
6 changed files with 98 additions and 65 deletions

View File

@ -323,7 +323,7 @@ impl InboundProcessor {
tokio::spawn(async move { tokio::spawn(async move {
// 从 DB 查询该 topic 的第一条用户消息作为描述生成的依据 // 从 DB 查询该 topic 的第一条用户消息作为描述生成的依据
let first_user_message = store_clone let first_user_message = store_clone
.load_messages_for_topic(&topic_id_clone) .load_messages_for_topic(&topic_id_clone, None)
.ok() .ok()
.and_then(|msgs| msgs.into_iter().find(|m| m.role == "user")) .and_then(|msgs| msgs.into_iter().find(|m| m.role == "user"))
.map(|m| m.content); .map(|m| m.content);

View File

@ -274,10 +274,11 @@ impl Session {
// 先设置当前话题set_history 需要这个) // 先设置当前话题set_history 需要这个)
self.history.set_chat_topic(chat_id, topic_id.to_string()); self.history.set_chat_topic(chat_id, topic_id.to_string());
// 加载新话题的历史 // 加载新话题的历史(按 session_id 过滤,排除子智能体消息)
let session_id = self.persistent_session_id(chat_id);
let messages = self let messages = self
.store .store
.load_messages_for_topic(topic_id) .load_messages_for_topic(topic_id, Some(&session_id))
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?; .map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
self.history.set_history(chat_id, messages); self.history.set_history(chat_id, messages);
@ -464,11 +465,12 @@ impl Session {
} }
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
// 如果当前有 topic加载该 topic 的消息 // 如果当前有 topic加载该 topic 的消息(按 session_id 过滤,排除子智能体消息)
if let Some(topic_id) = self.history.chat_topic(chat_id) { if let Some(topic_id) = self.history.chat_topic(chat_id) {
let session_id = self.persistent_session_id(chat_id);
let messages = self let messages = self
.store .store
.load_messages_for_topic(topic_id) .load_messages_for_topic(topic_id, Some(&session_id))
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?; .map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
self.history.set_history(chat_id, messages); self.history.set_history(chat_id, messages);
} else { } else {

View File

@ -66,8 +66,9 @@ impl SessionHistory {
// 如果提供了 topic_id按 topic 加载;否则按 session 加载 // 如果提供了 topic_id按 topic 加载;否则按 session 加载
let history = if let Some(tid) = topic_id { let history = if let Some(tid) = topic_id {
let sid = self.persistent_session_id(chat_id);
self.conversations self.conversations
.load_messages_for_topic(tid) .load_messages_for_topic(tid, Some(&sid))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))? .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?
} else { } else {
self.conversations self.conversations

View File

@ -637,12 +637,12 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
/// 加载并发送话题历史消息 /// 加载并发送话题历史消息
async fn send_topic_history( async fn send_topic_history(
store: &Arc<crate::storage::SessionStore>, store: &Arc<crate::storage::SessionStore>,
_session_id: &str, session_id: &str,
topic_id: &str, topic_id: &str,
sender: &mpsc::Sender<WsOutbound>, sender: &mpsc::Sender<WsOutbound>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
// 加载话题消息 // 加载话题消息,按 session_id 过滤,避免混入子智能体消息
let messages = store.load_messages_for_topic(topic_id)?; let messages = store.load_messages_for_topic(topic_id, Some(session_id))?;
tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history"); tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history");

View File

@ -1412,8 +1412,29 @@ impl SessionStore {
load_messages_after(&conn, session_id, 0) load_messages_after(&conn, session_id, 0)
} }
pub fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> { pub fn load_messages_for_topic(
&self,
topic_id: &str,
session_id: Option<&str>,
) -> Result<Vec<ChatMessage>, StorageError> {
let conn = self.pool.get()?; let conn = self.pool.get()?;
if let Some(sid) = session_id {
let mut stmt = conn.prepare(
"
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
FROM messages
WHERE topic_id = ?1 AND session_id = ?2
ORDER BY seq ASC
",
)?;
let rows = stmt.query_map(params![topic_id, sid], map_chat_message_row)?;
let mut messages = Vec::new();
for row in rows {
messages.push(row?);
}
Ok(messages)
} else {
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
" "
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
@ -1422,56 +1443,18 @@ impl SessionStore {
ORDER BY seq ASC ORDER BY seq ASC
", ",
)?; )?;
let rows = stmt.query_map(params![topic_id], map_chat_message_row)?;
let rows = stmt.query_map(params![topic_id], |row| {
let media_refs_json: String = row.get(5)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let tool_calls_json: Option<String> = row.get(9)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
9,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
system_context: row.get(3)?,
reasoning_content: row.get(4)?,
media_refs,
timestamp: row.get(6)?,
tool_call_id: row.get(7)?,
tool_name: row.get(8)?,
tool_state: None,
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
tool_calls,
})
})?;
let mut messages = Vec::new(); let mut messages = Vec::new();
for row in rows { for row in rows {
messages.push(row?); messages.push(row?);
} }
Ok(messages) Ok(messages)
} }
}
/// 获取指定话题的消息数量(动态计算,确保准确) /// 获取指定话题的消息数量(动态计算,确保准确)
pub fn get_topic_message_count(&self, topic_id: &str) -> Result<usize, StorageError> { pub fn get_topic_message_count(&self, topic_id: &str) -> Result<usize, StorageError> {
self.load_messages_for_topic(topic_id).map(|msgs| msgs.len()) self.load_messages_for_topic(topic_id, None).map(|msgs| msgs.len())
} }
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> { pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
@ -1673,6 +1656,45 @@ fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEven
}) })
} }
fn map_chat_message_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<ChatMessage> {
let media_refs_json: String = row.get(5)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let tool_calls_json: Option<String> = row.get(9)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
9,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
system_context: row.get(3)?,
reasoning_content: row.get(4)?,
media_refs,
timestamp: row.get(6)?,
tool_call_id: row.get(7)?,
tool_name: row.get(8)?,
tool_state: None,
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
tool_calls,
})
}
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> { fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
Ok(MemoryRecord { Ok(MemoryRecord {
id: row.get(0)?, id: row.get(0)?,

View File

@ -22,7 +22,11 @@ pub trait ConversationRepository: Send + Sync + 'static {
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>; fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError>; fn load_messages_for_topic(
&self,
topic_id: &str,
session_id: Option<&str>,
) -> Result<Vec<ChatMessage>, StorageError>;
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>; fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
@ -180,8 +184,12 @@ impl ConversationRepository for super::SessionStore {
super::SessionStore::load_messages(self, session_id) super::SessionStore::load_messages(self, session_id)
} }
fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> { fn load_messages_for_topic(
super::SessionStore::load_messages_for_topic(self, topic_id) &self,
topic_id: &str,
session_id: Option<&str>,
) -> Result<Vec<ChatMessage>, StorageError> {
super::SessionStore::load_messages_for_topic(self, topic_id, session_id)
} }
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> { fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {