feat: 更新 load_messages_for_topic 方法,支持按 session_id 过滤消息
This commit is contained in:
parent
24bbd5f8c9
commit
6f8c4a7ce8
@ -323,7 +323,7 @@ impl InboundProcessor {
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// 从 DB 查询该 topic 的第一条用户消息作为描述生成的依据
|
// 从 DB 查询该 topic 的第一条用户消息作为描述生成的依据
|
||||||
let first_user_message = store_clone
|
let first_user_message = store_clone
|
||||||
.load_messages_for_topic(&topic_id_clone)
|
.load_messages_for_topic(&topic_id_clone, None)
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|msgs| msgs.into_iter().find(|m| m.role == "user"))
|
.and_then(|msgs| msgs.into_iter().find(|m| m.role == "user"))
|
||||||
.map(|m| m.content);
|
.map(|m| m.content);
|
||||||
|
|||||||
@ -274,10 +274,11 @@ impl Session {
|
|||||||
// 先设置当前话题(set_history 需要这个)
|
// 先设置当前话题(set_history 需要这个)
|
||||||
self.history.set_chat_topic(chat_id, topic_id.to_string());
|
self.history.set_chat_topic(chat_id, topic_id.to_string());
|
||||||
|
|
||||||
// 加载新话题的历史
|
// 加载新话题的历史(按 session_id 过滤,排除子智能体消息)
|
||||||
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
let messages = self
|
let messages = self
|
||||||
.store
|
.store
|
||||||
.load_messages_for_topic(topic_id)
|
.load_messages_for_topic(topic_id, Some(&session_id))
|
||||||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||||
|
|
||||||
self.history.set_history(chat_id, messages);
|
self.history.set_history(chat_id, messages);
|
||||||
@ -464,11 +465,12 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
// 如果当前有 topic,加载该 topic 的消息
|
// 如果当前有 topic,加载该 topic 的消息(按 session_id 过滤,排除子智能体消息)
|
||||||
if let Some(topic_id) = self.history.chat_topic(chat_id) {
|
if let Some(topic_id) = self.history.chat_topic(chat_id) {
|
||||||
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
let messages = self
|
let messages = self
|
||||||
.store
|
.store
|
||||||
.load_messages_for_topic(topic_id)
|
.load_messages_for_topic(topic_id, Some(&session_id))
|
||||||
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("load topic messages error: {}", e)))?;
|
||||||
self.history.set_history(chat_id, messages);
|
self.history.set_history(chat_id, messages);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -66,8 +66,9 @@ impl SessionHistory {
|
|||||||
|
|
||||||
// 如果提供了 topic_id,按 topic 加载;否则按 session 加载
|
// 如果提供了 topic_id,按 topic 加载;否则按 session 加载
|
||||||
let history = if let Some(tid) = topic_id {
|
let history = if let Some(tid) = topic_id {
|
||||||
|
let sid = self.persistent_session_id(chat_id);
|
||||||
self.conversations
|
self.conversations
|
||||||
.load_messages_for_topic(tid)
|
.load_messages_for_topic(tid, Some(&sid))
|
||||||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?
|
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?
|
||||||
} else {
|
} else {
|
||||||
self.conversations
|
self.conversations
|
||||||
|
|||||||
@ -637,12 +637,12 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
|
|||||||
/// 加载并发送话题历史消息
|
/// 加载并发送话题历史消息
|
||||||
async fn send_topic_history(
|
async fn send_topic_history(
|
||||||
store: &Arc<crate::storage::SessionStore>,
|
store: &Arc<crate::storage::SessionStore>,
|
||||||
_session_id: &str,
|
session_id: &str,
|
||||||
topic_id: &str,
|
topic_id: &str,
|
||||||
sender: &mpsc::Sender<WsOutbound>,
|
sender: &mpsc::Sender<WsOutbound>,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// 加载话题消息
|
// 加载话题消息,按 session_id 过滤,避免混入子智能体消息
|
||||||
let messages = store.load_messages_for_topic(topic_id)?;
|
let messages = store.load_messages_for_topic(topic_id, Some(session_id))?;
|
||||||
|
|
||||||
tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history");
|
tracing::info!(topic_id = %topic_id, message_count = messages.len(), "Sending topic history");
|
||||||
|
|
||||||
|
|||||||
@ -1412,66 +1412,49 @@ impl SessionStore {
|
|||||||
load_messages_after(&conn, session_id, 0)
|
load_messages_after(&conn, session_id, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
pub fn load_messages_for_topic(
|
||||||
|
&self,
|
||||||
|
topic_id: &str,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
let conn = self.pool.get()?;
|
let conn = self.pool.get()?;
|
||||||
let mut stmt = conn.prepare(
|
|
||||||
"
|
|
||||||
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
|
||||||
FROM messages
|
|
||||||
WHERE topic_id = ?1
|
|
||||||
ORDER BY seq ASC
|
|
||||||
",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let rows = stmt.query_map(params![topic_id], |row| {
|
if let Some(sid) = session_id {
|
||||||
let media_refs_json: String = row.get(5)?;
|
let mut stmt = conn.prepare(
|
||||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
"
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||||||
media_refs_json.len(),
|
FROM messages
|
||||||
rusqlite::types::Type::Text,
|
WHERE topic_id = ?1 AND session_id = ?2
|
||||||
Box::new(err),
|
ORDER BY seq ASC
|
||||||
)
|
",
|
||||||
})?;
|
)?;
|
||||||
|
let rows = stmt.query_map(params![topic_id, sid], map_chat_message_row)?;
|
||||||
let tool_calls_json: Option<String> = row.get(9)?;
|
let mut messages = Vec::new();
|
||||||
let tool_calls = tool_calls_json
|
for row in rows {
|
||||||
.as_deref()
|
messages.push(row?);
|
||||||
.map(serde_json::from_str)
|
}
|
||||||
.transpose()
|
Ok(messages)
|
||||||
.map_err(|err| {
|
} else {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
let mut stmt = conn.prepare(
|
||||||
9,
|
"
|
||||||
rusqlite::types::Type::Text,
|
SELECT id, role, content, system_context, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json, tool_duration_ms
|
||||||
Box::new(err),
|
FROM messages
|
||||||
)
|
WHERE topic_id = ?1
|
||||||
})?;
|
ORDER BY seq ASC
|
||||||
|
",
|
||||||
Ok(ChatMessage {
|
)?;
|
||||||
id: row.get(0)?,
|
let rows = stmt.query_map(params![topic_id], map_chat_message_row)?;
|
||||||
role: row.get(1)?,
|
let mut messages = Vec::new();
|
||||||
content: row.get(2)?,
|
for row in rows {
|
||||||
system_context: row.get(3)?,
|
messages.push(row?);
|
||||||
reasoning_content: row.get(4)?,
|
}
|
||||||
media_refs,
|
Ok(messages)
|
||||||
timestamp: row.get(6)?,
|
|
||||||
tool_call_id: row.get(7)?,
|
|
||||||
tool_name: row.get(8)?,
|
|
||||||
tool_state: None,
|
|
||||||
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
|
||||||
tool_calls,
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
|
||||||
for row in rows {
|
|
||||||
messages.push(row?);
|
|
||||||
}
|
}
|
||||||
Ok(messages)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取指定话题的消息数量(动态计算,确保准确)
|
/// 获取指定话题的消息数量(动态计算,确保准确)
|
||||||
pub fn get_topic_message_count(&self, topic_id: &str) -> Result<usize, StorageError> {
|
pub fn get_topic_message_count(&self, topic_id: &str) -> Result<usize, StorageError> {
|
||||||
self.load_messages_for_topic(topic_id).map(|msgs| msgs.len())
|
self.load_messages_for_topic(topic_id, None).map(|msgs| msgs.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
pub fn load_all_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
@ -1673,6 +1656,45 @@ fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEven
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn map_chat_message_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<ChatMessage> {
|
||||||
|
let media_refs_json: String = row.get(5)?;
|
||||||
|
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||||
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
|
media_refs_json.len(),
|
||||||
|
rusqlite::types::Type::Text,
|
||||||
|
Box::new(err),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tool_calls_json: Option<String> = row.get(9)?;
|
||||||
|
let tool_calls = tool_calls_json
|
||||||
|
.as_deref()
|
||||||
|
.map(serde_json::from_str)
|
||||||
|
.transpose()
|
||||||
|
.map_err(|err| {
|
||||||
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
|
9,
|
||||||
|
rusqlite::types::Type::Text,
|
||||||
|
Box::new(err),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(ChatMessage {
|
||||||
|
id: row.get(0)?,
|
||||||
|
role: row.get(1)?,
|
||||||
|
content: row.get(2)?,
|
||||||
|
system_context: row.get(3)?,
|
||||||
|
reasoning_content: row.get(4)?,
|
||||||
|
media_refs,
|
||||||
|
timestamp: row.get(6)?,
|
||||||
|
tool_call_id: row.get(7)?,
|
||||||
|
tool_name: row.get(8)?,
|
||||||
|
tool_state: None,
|
||||||
|
tool_duration_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
|
||||||
|
tool_calls,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
|
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
|
||||||
Ok(MemoryRecord {
|
Ok(MemoryRecord {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
|
|||||||
@ -22,7 +22,11 @@ pub trait ConversationRepository: Send + Sync + 'static {
|
|||||||
|
|
||||||
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
|
fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
|
||||||
|
|
||||||
fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError>;
|
fn load_messages_for_topic(
|
||||||
|
&self,
|
||||||
|
topic_id: &str,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> Result<Vec<ChatMessage>, StorageError>;
|
||||||
|
|
||||||
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
|
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError>;
|
||||||
|
|
||||||
@ -180,8 +184,12 @@ impl ConversationRepository for super::SessionStore {
|
|||||||
super::SessionStore::load_messages(self, session_id)
|
super::SessionStore::load_messages(self, session_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
fn load_messages_for_topic(
|
||||||
super::SessionStore::load_messages_for_topic(self, topic_id)
|
&self,
|
||||||
|
topic_id: &str,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
|
super::SessionStore::load_messages_for_topic(self, topic_id, session_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user