feat: 更新 load_messages_for_topic 方法,支持按 session_id 过滤消息
This commit is contained in:
parent
24bbd5f8c9
commit
6f8c4a7ce8
@ -323,7 +323,7 @@ impl InboundProcessor {
|
||||
tokio::spawn(async move {
|
||||
// 从 DB 查询该 topic 的第一条用户消息作为描述生成的依据
|
||||
let first_user_message = store_clone
|
||||
.load_messages_for_topic(&topic_id_clone)
|
||||
.load_messages_for_topic(&topic_id_clone, None)
|
||||
.ok()
|
||||
.and_then(|msgs| msgs.into_iter().find(|m| m.role == "user"))
|
||||
.map(|m| m.content);
|
||||
|
||||
@ -274,10 +274,11 @@ impl Session {
|
||||
// 先设置当前话题(set_history 需要这个)
|
||||
self.history.set_chat_topic(chat_id, topic_id.to_string());
|
||||
|
||||
// 加载新话题的历史
|
||||
// 加载新话题的历史(按 session_id 过滤,排除子智能体消息)
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
let messages = self
|
||||
.store
|
||||
.load_messages_for_topic(topic_id)
|
||||
.load_messages_for_topic(topic_id, Some(&session_id))
|
||||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||
|
||||
self.history.set_history(chat_id, messages);
|
||||
@ -464,11 +465,12 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
// 如果当前有 topic,加载该 topic 的消息
|
||||
// 如果当前有 topic,加载该 topic 的消息(按 session_id 过滤,排除子智能体消息)
|
||||
if let Some(topic_id) = self.history.chat_topic(chat_id) {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
let messages = self
|
||||
.store
|
||||
.load_messages_for_topic(topic_id)
|
||||
.load_messages_for_topic(topic_id, Some(&session_id))
|
||||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||
self.history.set_history(chat_id, messages);
|
||||
} else {
|
||||
|
||||
@ -66,8 +66,9 @@ impl SessionHistory {
|
||||
|
||||
// 如果提供了 topic_id,按 topic 加载;否则按 session 加载
|
||||
let history = if let Some(tid) = topic_id {
|
||||
let sid = self.persistent_session_id(chat_id);
|
||||
self.conversations
|
||||
.load_messages_for_topic(tid)
|
||||
.load_messages_for_topic(tid, Some(&sid))
|
||||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?
|
||||
} else {
|
||||
self.conversations
|
||||
|
||||
@ -637,12 +637,12 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
|
||||
/// 加载并发送话题历史消息
|
||||
async fn send_topic_history(
|
||||
store: &Arc<crate::storage::SessionStore>,
|
||||
_session_id: &str,
|
||||
session_id: &str,
|
||||
topic_id: &str,
|
||||
sender: &mpsc::Sender<WsOutbound>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// 加载话题消息
|
||||
let messages = store.load_messages_for_topic(topic_id)?;
|
||||
// 加载话题消息,按 session_id 过滤,避免混入子智能体消息
|
||||
let messages = store.load_messages_for_topic(topic_id, Some(session_id))?;
|
||||
|
||||
tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history");
|
||||
|
||||
|
||||
@ -1412,66 +1412,49 @@ impl SessionStore {
|
||||
load_messages_after(&conn, session_id, 0)
|
||||
}
|
||||
|
||||
pub fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
pub fn load_messages_for_topic(
|
||||
&self,
|
||||
topic_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
let conn = self.pool.get()?;
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||||
FROM messages
|
||||
WHERE topic_id = ?1
|
||||
ORDER BY seq ASC
|
||||
",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![topic_id], |row| {
|
||||
let media_refs_json: String = row.get(5)?;
|
||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
media_refs_json.len(),
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls_json: Option<String> = row.get(9)?;
|
||||
let tool_calls = tool_calls_json
|
||||
.as_deref()
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
9,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(ChatMessage {
|
||||
id: row.get(0)?,
|
||||
role: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
system_context: row.get(3)?,
|
||||
reasoning_content: row.get(4)?,
|
||||
media_refs,
|
||||
timestamp: row.get(6)?,
|
||||
tool_call_id: row.get(7)?,
|
||||
tool_name: row.get(8)?,
|
||||
tool_state: None,
|
||||
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
||||
tool_calls,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for row in rows {
|
||||
messages.push(row?);
|
||||
if let Some(sid) = session_id {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||||
FROM messages
|
||||
WHERE topic_id = ?1 AND session_id = ?2
|
||||
ORDER BY seq ASC
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![topic_id, sid], map_chat_message_row)?;
|
||||
let mut messages = Vec::new();
|
||||
for row in rows {
|
||||
messages.push(row?);
|
||||
}
|
||||
Ok(messages)
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||||
FROM messages
|
||||
WHERE topic_id = ?1
|
||||
ORDER BY seq ASC
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![topic_id], map_chat_message_row)?;
|
||||
let mut messages = Vec::new();
|
||||
for row in rows {
|
||||
messages.push(row?);
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// 获取指定话题的消息数量(动态计算,确保准确)
|
||||
pub fn get_topic_message_count(&self, topic_id: &str) -> Result<usize, StorageError> {
|
||||
self.load_messages_for_topic(topic_id).map(|msgs| msgs.len())
|
||||
self.load_messages_for_topic(topic_id, None).map(|msgs| msgs.len())
|
||||
}
|
||||
|
||||
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
@ -1673,6 +1656,45 @@ fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEven
|
||||
})
|
||||
}
|
||||
|
||||
fn map_chat_message_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<ChatMessage> {
|
||||
let media_refs_json: String = row.get(5)?;
|
||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
media_refs_json.len(),
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls_json: Option<String> = row.get(9)?;
|
||||
let tool_calls = tool_calls_json
|
||||
.as_deref()
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
9,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(ChatMessage {
|
||||
id: row.get(0)?,
|
||||
role: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
system_context: row.get(3)?,
|
||||
reasoning_content: row.get(4)?,
|
||||
media_refs,
|
||||
timestamp: row.get(6)?,
|
||||
tool_call_id: row.get(7)?,
|
||||
tool_name: row.get(8)?,
|
||||
tool_state: None,
|
||||
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
|
||||
Ok(MemoryRecord {
|
||||
id: row.get(0)?,
|
||||
|
||||
@ -22,7 +22,11 @@ pub trait ConversationRepository: Send + Sync + 'static {
|
||||
|
||||
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
|
||||
|
||||
fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
|
||||
fn load_messages_for_topic(
|
||||
&self,
|
||||
topic_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<ChatMessage>, StorageError>;
|
||||
|
||||
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
|
||||
|
||||
@ -180,8 +184,12 @@ impl ConversationRepository for super::SessionStore {
|
||||
super::SessionStore::load_messages(self, session_id)
|
||||
}
|
||||
|
||||
fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
super::SessionStore::load_messages_for_topic(self, topic_id)
|
||||
fn load_messages_for_topic(
|
||||
&self,
|
||||
topic_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
super::SessionStore::load_messages_for_topic(self, topic_id, session_id)
|
||||
}
|
||||
|
||||
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user