fix(session): replace Mutex with task-local storage to prevent reentrant deadlock in send_message

This commit is contained in:
xiaoski 2026-05-13 08:48:31 +08:00
parent 1e69fa3bd1
commit 3d42f22f83

View File

@ -7,6 +7,10 @@ use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceK
use crate::storage::{Storage, StorageError}; use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc; use std::sync::Arc as StdArc;
tokio::task_local! {
static CURRENT_SOURCE_SESSION: Option<String>;
}
/// Result of handling a message - either an AI response or a command output /// Result of handling a message - either an AI response or a command output
pub enum HandleResult { pub enum HandleResult {
/// AI response to be sent as AssistantResponse /// AI response to be sent as AssistantResponse
@ -717,7 +721,6 @@ pub struct SessionManager {
skills_loader: Arc<SkillsLoader>, skills_loader: Arc<SkillsLoader>,
storage: Arc<Storage>, storage: Arc<Storage>,
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
current_source_session: Arc<Mutex<Option<String>>>,
memory_manager: Arc<crate::memory::MemoryManager>, memory_manager: Arc<crate::memory::MemoryManager>,
} }
@ -822,7 +825,6 @@ impl SessionManager {
skills_loader, skills_loader,
storage, storage,
bus, bus,
current_source_session: Arc::new(Mutex::new(None)),
memory_manager, memory_manager,
}) })
} }
@ -1273,10 +1275,10 @@ impl SessionManager {
media: Vec<MediaItem>, media: Vec<MediaItem>,
) -> Result<HandleResult, AgentError> { ) -> Result<HandleResult, AgentError> {
let unified_id = self.resolve_dialog_id(channel, chat_id).await?; let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
*self.current_source_session.lock().await = Some(unified_id.to_string());
tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id"); tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id");
let session = self.get_or_create_session(&unified_id).await?; let session = self.get_or_create_session(&unified_id).await?;
CURRENT_SOURCE_SESSION.scope(Some(unified_id.to_string()), async {
// Check for slash command // Check for slash command
if let Some((cmd_name, cmd_args)) = parse_slash_command(content) { if let Some((cmd_name, cmd_args)) = parse_slash_command(content) {
let result = self.execute_slash_command( let result = self.execute_slash_command(
@ -1287,16 +1289,10 @@ impl SessionManager {
Some(&unified_id), Some(&unified_id),
).await; ).await;
match result { return match result {
Ok((_new_session_id, response)) => { Ok((_new_session_id, response)) => Ok(HandleResult::CommandOutput(response)),
*self.current_source_session.lock().await = None; Err(e) => Ok(HandleResult::CommandOutput(e.to_string())),
return Ok(HandleResult::CommandOutput(response)); };
}
Err(e) => {
*self.current_source_session.lock().await = None;
return Ok(HandleResult::CommandOutput(e.to_string()));
}
}
} }
// Normal message handling through LLM // Normal message handling through LLM
@ -1432,9 +1428,8 @@ impl SessionManager {
"Agent response received" "Agent response received"
); );
*self.current_source_session.lock().await = None;
Ok(HandleResult::AgentResponse(response)) Ok(HandleResult::AgentResponse(response))
}).await
} }
/// Handle a message triggered by a scheduled cron job. /// Handle a message triggered by a scheduled cron job.
@ -1453,11 +1448,11 @@ impl SessionManager {
use crate::bus::{MessageSource, SourceKind}; use crate::bus::{MessageSource, SourceKind};
let unified_id = self.resolve_dialog_id(channel, chat_id).await?; let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
*self.current_source_session.lock().await = Some(unified_id.to_string());
tracing::debug!(unified_id = %unified_id, job_id = %job_id, "handle_cron_message resolved"); tracing::debug!(unified_id = %unified_id, job_id = %job_id, "handle_cron_message resolved");
let session = self.get_or_create_session(&unified_id).await?; let session = self.get_or_create_session(&unified_id).await?;
CURRENT_SOURCE_SESSION.scope(Some(unified_id.to_string()), async {
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel(); let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
{ {
@ -1569,9 +1564,8 @@ impl SessionManager {
"Cron agent response received" "Cron agent response received"
); );
*self.current_source_session.lock().await = None;
Ok(HandleResult::AgentResponse(response)) Ok(HandleResult::AgentResponse(response))
}).await
} }
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
@ -1603,7 +1597,7 @@ impl OutboundMessenger for SessionManager {
) -> Result<(), String> { ) -> Result<(), String> {
// Fill origin from current source session if not provided // Fill origin from current source session if not provided
if source.from_session.is_none() { if source.from_session.is_none() {
source.from_session = self.current_source_session.lock().await.clone(); source.from_session = CURRENT_SOURCE_SESSION.try_with(|v| v.clone()).ok().flatten();
} }
let (target_sid, session) = if let Some(did) = dialog_id { let (target_sid, session) = if let Some(did) = dialog_id {