diff --git a/src/gateway/session_history.rs b/src/gateway/session_history.rs index f6c4dc5..1ca4dae 100644 --- a/src/gateway/session_history.rs +++ b/src/gateway/session_history.rs @@ -145,23 +145,6 @@ impl SessionHistory { .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) } - pub(crate) fn append_persisted_message( - &mut self, - chat_id: &str, - message: ChatMessage, - ) -> Result<(), AgentError> { - let session_id = self.persistent_session_id(chat_id); - // 获取当前话题 ID,用于关联消息 - let topic_id = self.chat_topic_ids.get(chat_id).map(|s| s.as_str()); - self.conversations - .append_message_with_topic(&session_id, topic_id, &message) - .map_err(|err| { - AgentError::Other(format!("append message persistence error: {}", err)) - })?; - self.add_message(chat_id, message); - Ok(()) - } - pub(crate) fn append_persisted_messages( &mut self, chat_id: &str, @@ -170,8 +153,21 @@ impl SessionHistory { where I: IntoIterator, { + let messages: Vec = messages.into_iter().collect(); + if messages.is_empty() { + return Ok(()); + } + + let session_id = self.persistent_session_id(chat_id); + let topic_id = self.chat_topic_ids.get(chat_id).map(|s| s.as_str()); + self.conversations + .append_messages_batch(&session_id, topic_id, &messages) + .map_err(|err| { + AgentError::Other(format!("batch append messages error: {}", err)) + })?; + for message in messages { - self.append_persisted_message(chat_id, message)?; + self.add_message(chat_id, message); } Ok(()) } @@ -184,16 +180,17 @@ impl SessionHistory { topic_id: &str, messages: &[ChatMessage], ) -> Result<(), AgentError> { - let session_id = self.persistent_session_id(chat_id); - - for message in messages { - self.conversations - .append_message_with_topic(&session_id, Some(topic_id), message) - .map_err(|err| { - AgentError::Other(format!("append message to topic error: {}", err)) - })?; + if messages.is_empty() { + return Ok(()); } + let session_id = self.persistent_session_id(chat_id); + self.conversations + .append_messages_batch(&session_id, Some(topic_id), messages) + .map_err(|err| { + AgentError::Other(format!("batch append to topic error: {}", err)) + })?; + Ok(()) } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 0009b8c..32e85bb 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -617,6 +617,92 @@ impl SessionStore { Ok(()) } + pub fn append_messages_batch( + &self, + session_id: &str, + topic_id: Option<&str>, + messages: &[ChatMessage], + ) -> Result<(), StorageError> { + if messages.is_empty() { + return Ok(()); + } + + let conn = self.conn.lock().expect("session db mutex poisoned"); + let tx = conn.unchecked_transaction()?; + + let mut seq: i64 = tx.query_row( + "SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1", + params![session_id], + |row| row.get(0), + )?; + + for message in messages { + let media_refs_json = serde_json::to_string(&message.media_refs)?; + let tool_calls_json = message + .tool_calls + .as_ref() + .map(serde_json::to_string) + .transpose()?; + tx.execute( + " + INSERT INTO messages ( + id, session_id, topic_id, seq, role, content, + system_context, reasoning_content, media_refs_json, + tool_call_id, tool_name, tool_calls_json, created_at + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13) + ", + params![ + message.id, + session_id, + topic_id, + seq, + message.role, + message.content, + message.system_context, + message.reasoning_content, + media_refs_json, + message.tool_call_id, + message.tool_name, + tool_calls_json, + message.timestamp, + ], + )?; + seq += 1; + } + + let now = current_timestamp(); + let user_msg_count: i64 = messages + .iter() + .filter(|m| m.role == "user") + .count() + .try_into() + .unwrap_or(0); + let msg_count: i64 = messages.len() as i64; + + tx.execute( + " + UPDATE sessions + SET message_count = message_count + ?2, + user_turn_count = user_turn_count + ?3, + updated_at = ?4, + last_active_at = ?4, + archived_at = NULL + WHERE id = ?1 AND deleted_at IS NULL + ", + params![session_id, msg_count, user_msg_count, now], + )?; + + if let Some(tid) = topic_id { + tx.execute( + "UPDATE topics SET message_count = message_count + ?2, last_active_at = ?3 WHERE id = ?1", + params![tid, msg_count, now], + )?; + } + + tx.commit()?; + Ok(()) + } + pub fn compact_active_history( &self, session_id: &str, diff --git a/src/storage/ports.rs b/src/storage/ports.rs index acced9a..35fbb87 100644 --- a/src/storage/ports.rs +++ b/src/storage/ports.rs @@ -33,6 +33,13 @@ pub trait ConversationRepository: Send + Sync + 'static { message: &ChatMessage, ) -> Result<(), StorageError>; + fn append_messages_batch( + &self, + session_id: &str, + topic_id: Option<&str>, + messages: &[ChatMessage], + ) -> Result<(), StorageError>; + fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>; fn compact_active_history( @@ -178,6 +185,15 @@ impl ConversationRepository for super::SessionStore { super::SessionStore::append_message_with_topic(self, session_id, topic_id, message) } + fn append_messages_batch( + &self, + session_id: &str, + topic_id: Option<&str>, + messages: &[ChatMessage], + ) -> Result<(), StorageError> { + super::SessionStore::append_messages_batch(self, session_id, topic_id, messages) + } + fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { super::SessionStore::clear_messages(self, session_id) } diff --git a/src/tools/task/runtime.rs b/src/tools/task/runtime.rs index bc52531..3c2ccc5 100644 --- a/src/tools/task/runtime.rs +++ b/src/tools/task/runtime.rs @@ -274,12 +274,10 @@ impl DefaultSubAgentRuntime { match result { Ok(Ok(process_result)) => { - // 保存子智能体产生的所有消息到数据库 - for message in &process_result.emitted_messages { - if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) { - tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message"); - } - } + // 保存子智能体产生的所有消息到数据库(批量单事务) + self.conversation_repository + .append_messages_batch(&session.session_id, None, &process_result.emitted_messages) + .map_err(TaskError::RepositoryError)?; let final_message = process_result.final_response; Ok(TaskToolResult { @@ -326,12 +324,10 @@ impl DefaultSubAgentRuntime { match result { Ok(Ok(process_result)) => { - // 保存子智能体产生的所有消息到数据库 - for message in &process_result.emitted_messages { - if let Err(e) = self.conversation_repository.append_message(&session.session_id, message) { - tracing::warn!(error = %e, session_id = %session.session_id, "Failed to append subagent message"); - } - } + // 保存子智能体产生的所有消息到数据库(批量单事务) + self.conversation_repository + .append_messages_batch(&session.session_id, None, &process_result.emitted_messages) + .map_err(TaskError::RepositoryError)?; let final_message = process_result.final_response; Ok(TaskToolResult { diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index fa0cdf1..0311948 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -121,6 +121,7 @@ fn test_tool_call_outbound_serialization() { arguments: serde_json::json!({"expression": "1 + 1"}), content: "调用工具: calculator".to_string(), role: "assistant".to_string(), + subagent_task_id: None, }; let json = serde_json::to_string(&msg).unwrap(); @@ -152,6 +153,7 @@ fn test_tool_result_outbound_serialization() { tool_name: "calculator".to_string(), content: "工具结果: calculator\n\n2".to_string(), role: "tool".to_string(), + subagent_task_id: None, }; let json = serde_json::to_string(&msg).unwrap();