fix(session): replace Mutex with task-local storage to prevent reentrant deadlock in send_message

This commit is contained in:
xiaoski 2026-05-13 08:48:31 +08:00
parent 1e69fa3bd1
commit 3d42f22f83

View File

@ -7,6 +7,10 @@ use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceK
use crate::storage::{Storage, StorageError}; use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc; use std::sync::Arc as StdArc;
tokio::task_local! {
static CURRENT_SOURCE_SESSION: Option<String>;
}
/// Result of handling a message - either an AI response or a command output /// Result of handling a message - either an AI response or a command output
pub enum HandleResult { pub enum HandleResult {
/// AI response to be sent as AssistantResponse /// AI response to be sent as AssistantResponse
@ -717,7 +721,6 @@ pub struct SessionManager {
skills_loader: Arc<SkillsLoader>, skills_loader: Arc<SkillsLoader>,
storage: Arc<Storage>, storage: Arc<Storage>,
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
current_source_session: Arc<Mutex<Option<String>>>,
memory_manager: Arc<crate::memory::MemoryManager>, memory_manager: Arc<crate::memory::MemoryManager>,
} }
@ -822,7 +825,6 @@ impl SessionManager {
skills_loader, skills_loader,
storage, storage,
bus, bus,
current_source_session: Arc::new(Mutex::new(None)),
memory_manager, memory_manager,
}) })
} }
@ -1273,168 +1275,161 @@ impl SessionManager {
media: Vec<MediaItem>, media: Vec<MediaItem>,
) -> Result<HandleResult, AgentError> { ) -> Result<HandleResult, AgentError> {
let unified_id = self.resolve_dialog_id(channel, chat_id).await?; let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
*self.current_source_session.lock().await = Some(unified_id.to_string());
tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id"); tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id");
let session = self.get_or_create_session(&unified_id).await?; let session = self.get_or_create_session(&unified_id).await?;
// Check for slash command CURRENT_SOURCE_SESSION.scope(Some(unified_id.to_string()), async {
if let Some((cmd_name, cmd_args)) = parse_slash_command(content) { // Check for slash command
let result = self.execute_slash_command( if let Some((cmd_name, cmd_args)) = parse_slash_command(content) {
cmd_name, let result = self.execute_slash_command(
if cmd_args.is_empty() { None } else { Some(cmd_args) }, cmd_name,
channel, if cmd_args.is_empty() { None } else { Some(cmd_args) },
chat_id, channel,
Some(&unified_id), chat_id,
).await; Some(&unified_id),
).await;
match result { return match result {
Ok((_new_session_id, response)) => { Ok((_new_session_id, response)) => Ok(HandleResult::CommandOutput(response)),
*self.current_source_session.lock().await = None; Err(e) => Ok(HandleResult::CommandOutput(e.to_string())),
return Ok(HandleResult::CommandOutput(response)); };
}
Err(e) => {
*self.current_source_session.lock().await = None;
return Ok(HandleResult::CommandOutput(e.to_string()));
}
}
}
// Normal message handling through LLM
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
// Spawn notification publisher — sends immediately when tools are detected
{
let bus = self.bus.clone();
let ch = channel.to_string();
let cid = chat_id.to_string();
tokio::spawn(async move {
while let Some(notif) = notify_rx.recv().await {
let mut metadata = HashMap::new();
metadata.insert("_type".to_string(), "notification".to_string());
let outbound = OutboundMessage {
channel: ch.clone(),
chat_id: cid.clone(),
content: notif,
reply_to: None,
media: vec![],
metadata,
};
let _ = bus.publish_outbound(outbound).await;
}
});
}
let response: String = {
let mut session_guard = session.lock().await;
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
} }
let user_message = session_guard.create_user_message(content, media_refs); // Normal message handling through LLM
session_guard.add_message(user_message, true).await let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
let history = session_guard.get_history().to_vec(); // Spawn notification publisher — sends immediately when tools are detected
{
// Build skills prompt let bus = self.bus.clone();
let skills_prompt = self.skills_loader.build_skills_prompt(); let ch = channel.to_string();
let cid = chat_id.to_string();
// Fetch memory context tokio::spawn(async move {
let memory_context = match self.memory_manager.recall(content, 5, Some(crate::memory::MemoryCategory::Knowledge), None).await { while let Some(notif) = notify_rx.recv().await {
Ok(entries) if !entries.is_empty() => { let mut metadata = HashMap::new();
Some(entries.iter() metadata.insert("_type".to_string(), "notification".to_string());
.map(|e| format!("- {}: {}", e.key, e.content)) let outbound = OutboundMessage {
.collect::<Vec<_>>() channel: ch.clone(),
.join("\n")) chat_id: cid.clone(),
} content: notif,
Err(e) => { reply_to: None,
tracing::warn!(error = %e, "Failed to fetch memory context"); media: vec![],
None metadata,
} };
_ => None, let _ = bus.publish_outbound(outbound).await;
};
// Build combined system prompt and inject at position 0 AFTER compression.
// This ensures AgentLoop.process() sees a system message without it participating
// in context compression (system prompt is dynamic and should not be persisted).
let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref());
let result = session_guard.compressor
.compress_if_needed(history)
.await?;
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
}
let mut history = result.history;
history.insert(0, ChatMessage::system(system_prompt.clone()));
// Persist consolidation state
let now = chrono::Utc::now().timestamp_millis();
session_guard.last_consolidated_at = Some(now);
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist consolidation timestamp");
}
let agent = session_guard.create_agent_with_notify(notify_tx)?;
// Try LLM call; on context overflow, re-compress with tighter window and retry once.
let result = match agent.process(history).await {
Ok(r) => r,
Err(AgentError::LlmError(ref msg))
if is_context_overflow_error(msg) =>
{
let new_window = crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
.unwrap_or(session_guard.compressor_threshold());
tracing::warn!(
new_window,
error = %msg,
"Context overflow in handle_message — retrying with tighter window"
);
session_guard.compressor.set_context_window(new_window);
let raw = session_guard.get_history().to_vec();
let retry_result = session_guard.compressor.compress_if_needed(raw).await?;
if retry_result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist compression marker on retry");
}
} }
let mut retry = retry_result.history; });
retry.insert(0, ChatMessage::system(system_prompt));
agent.process(retry).await?
}
Err(e) => return Err(e),
};
for msg in result.emitted_messages {
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
} }
// Check if we need to generate a title (after 10 user messages) let response: String = {
if session_guard.should_generate_title() let mut session_guard = session.lock().await;
&& let Err(e) = session_guard.generate_title().await {
tracing::warn!("failed to generate title: {}", e); let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
} }
result.final_response.content let user_message = session_guard.create_user_message(content, media_refs);
}; session_guard.add_message(user_message, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
#[cfg(debug_assertions)] let history = session_guard.get_history().to_vec();
tracing::debug!(
channel = %channel,
chat_id = %chat_id,
response_len = %response.len(),
"Agent response received"
);
*self.current_source_session.lock().await = None; // Build skills prompt
let skills_prompt = self.skills_loader.build_skills_prompt();
Ok(HandleResult::AgentResponse(response)) // Fetch memory context
let memory_context = match self.memory_manager.recall(content, 5, Some(crate::memory::MemoryCategory::Knowledge), None).await {
Ok(entries) if !entries.is_empty() => {
Some(entries.iter()
.map(|e| format!("- {}: {}", e.key, e.content))
.collect::<Vec<_>>()
.join("\n"))
}
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
};
// Build combined system prompt and inject at position 0 AFTER compression.
// This ensures AgentLoop.process() sees a system message without it participating
// in context compression (system prompt is dynamic and should not be persisted).
let system_prompt = session_guard.build_system_prompt(&skills_prompt, memory_context.as_deref());
let result = session_guard.compressor
.compress_if_needed(history)
.await?;
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
}
let mut history = result.history;
history.insert(0, ChatMessage::system(system_prompt.clone()));
// Persist consolidation state
let now = chrono::Utc::now().timestamp_millis();
session_guard.last_consolidated_at = Some(now);
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist consolidation timestamp");
}
let agent = session_guard.create_agent_with_notify(notify_tx)?;
// Try LLM call; on context overflow, re-compress with tighter window and retry once.
let result = match agent.process(history).await {
Ok(r) => r,
Err(AgentError::LlmError(ref msg))
if is_context_overflow_error(msg) =>
{
let new_window = crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
.unwrap_or(session_guard.compressor_threshold());
tracing::warn!(
new_window,
error = %msg,
"Context overflow in handle_message — retrying with tighter window"
);
session_guard.compressor.set_context_window(new_window);
let raw = session_guard.get_history().to_vec();
let retry_result = session_guard.compressor.compress_if_needed(raw).await?;
if retry_result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist compression marker on retry");
}
}
let mut retry = retry_result.history;
retry.insert(0, ChatMessage::system(system_prompt));
agent.process(retry).await?
}
Err(e) => return Err(e),
};
for msg in result.emitted_messages {
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
}
// Check if we need to generate a title (after 10 user messages)
if session_guard.should_generate_title()
&& let Err(e) = session_guard.generate_title().await {
tracing::warn!("failed to generate title: {}", e);
}
result.final_response.content
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel,
chat_id = %chat_id,
response_len = %response.len(),
"Agent response received"
);
Ok(HandleResult::AgentResponse(response))
}).await
} }
/// Handle a message triggered by a scheduled cron job. /// Handle a message triggered by a scheduled cron job.
@ -1453,125 +1448,124 @@ impl SessionManager {
use crate::bus::{MessageSource, SourceKind}; use crate::bus::{MessageSource, SourceKind};
let unified_id = self.resolve_dialog_id(channel, chat_id).await?; let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
*self.current_source_session.lock().await = Some(unified_id.to_string());
tracing::debug!(unified_id = %unified_id, job_id = %job_id, "handle_cron_message resolved"); tracing::debug!(unified_id = %unified_id, job_id = %job_id, "handle_cron_message resolved");
let session = self.get_or_create_session(&unified_id).await?; let session = self.get_or_create_session(&unified_id).await?;
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel(); CURRENT_SOURCE_SESSION.scope(Some(unified_id.to_string()), async {
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
{ {
use std::collections::HashMap; use std::collections::HashMap;
use crate::bus::OutboundMessage; use crate::bus::OutboundMessage;
let bus = self.bus.clone(); let bus = self.bus.clone();
let ch = channel.to_string(); let ch = channel.to_string();
let cid = chat_id.to_string(); let cid = chat_id.to_string();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(notif) = notify_rx.recv().await { while let Some(notif) = notify_rx.recv().await {
let mut metadata = HashMap::new(); let mut metadata = HashMap::new();
metadata.insert("_type".to_string(), "notification".to_string()); metadata.insert("_type".to_string(), "notification".to_string());
let outbound = OutboundMessage { let outbound = OutboundMessage {
channel: ch.clone(), channel: ch.clone(),
chat_id: cid.clone(), chat_id: cid.clone(),
content: notif, content: notif,
reply_to: None, reply_to: None,
media: vec![], media: vec![],
metadata, metadata,
}; };
let _ = bus.publish_outbound(outbound).await; let _ = bus.publish_outbound(outbound).await;
} }
}); });
}
let response: String = {
let mut session_guard = session.lock().await;
let source = MessageSource {
kind: SourceKind::ExternalTrigger,
from_channel: Some(channel.to_string()),
from_session: None,
from_user_id: None,
system_name: Some(job_name.to_string()),
task_id: Some(job_id.to_string()),
};
let user_message = session_guard.create_user_message_with_source(prompt, vec![], source);
session_guard.add_message(user_message, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
let history = session_guard.get_history().to_vec();
let skills_prompt = self.skills_loader.build_skills_prompt();
let system_prompt = session_guard.build_system_prompt(&skills_prompt, None);
let cron_context = format!(
"\n\n## 定时任务执行\n\n\
{}({})\n\
: {}:{}\n\n\
\n\
- \n\
- \n\
- \n\
- 使 send_message \n\
- ",
job_name, job_id, channel, chat_id
);
let full_system_prompt = format!("{}{}", system_prompt, cron_context);
// Inject system prompt AFTER compression so it doesn't participate
// in context compression (system prompt is dynamic and should not be persisted).
let mut history = session_guard.compressor
.compress_if_needed(history)
.await?
.history;
history.insert(0, ChatMessage::system(full_system_prompt));
let agent = session_guard.create_agent_with_notify(notify_tx)?;
let result = agent.process(history).await?;
for msg in result.emitted_messages {
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
} }
if session_guard.should_generate_title() let response: String = {
&& let Err(e) = session_guard.generate_title().await { let mut session_guard = session.lock().await;
tracing::warn!("failed to generate title: {}", e);
let source = MessageSource {
kind: SourceKind::ExternalTrigger,
from_channel: Some(channel.to_string()),
from_session: None,
from_user_id: None,
system_name: Some(job_name.to_string()),
task_id: Some(job_id.to_string()),
};
let user_message = session_guard.create_user_message_with_source(prompt, vec![], source);
session_guard.add_message(user_message, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
let history = session_guard.get_history().to_vec();
let skills_prompt = self.skills_loader.build_skills_prompt();
let system_prompt = session_guard.build_system_prompt(&skills_prompt, None);
let cron_context = format!(
"\n\n## 定时任务执行\n\n\
{}({})\n\
: {}:{}\n\n\
\n\
- \n\
- \n\
- \n\
- 使 send_message \n\
- ",
job_name, job_id, channel, chat_id
);
let full_system_prompt = format!("{}{}", system_prompt, cron_context);
// Inject system prompt AFTER compression so it doesn't participate
// in context compression (system prompt is dynamic and should not be persisted).
let mut history = session_guard.compressor
.compress_if_needed(history)
.await?
.history;
history.insert(0, ChatMessage::system(full_system_prompt));
let agent = session_guard.create_agent_with_notify(notify_tx)?;
let result = agent.process(history).await?;
for msg in result.emitted_messages {
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
} }
let raw_response = result.final_response.content; if session_guard.should_generate_title()
let prefix = format!( && let Err(e) = session_guard.generate_title().await {
"[message from cron:{}({})]\n", tracing::warn!("failed to generate title: {}", e);
job_name, job_id }
);
let prefixed_response = format!("{}{}", prefix, raw_response);
let source = MessageSource { let raw_response = result.final_response.content;
kind: SourceKind::CrossChannel, let prefix = format!(
from_channel: Some("cron".to_string()), "[message from cron:{}({})]\n",
from_session: Some(format!("{}:{}", job_name, job_id)), job_name, job_id
from_user_id: None, );
system_name: Some(job_name.to_string()), let prefixed_response = format!("{}{}", prefix, raw_response);
task_id: Some(job_id.to_string()),
let source = MessageSource {
kind: SourceKind::CrossChannel,
from_channel: Some("cron".to_string()),
from_session: Some(format!("{}:{}", job_name, job_id)),
from_user_id: None,
system_name: Some(job_name.to_string()),
task_id: Some(job_id.to_string()),
};
let msg = ChatMessage::assistant_with_source(prefixed_response.clone(), source);
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
prefixed_response
}; };
let msg = ChatMessage::assistant_with_source(prefixed_response.clone(), source);
session_guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
prefixed_response #[cfg(debug_assertions)]
}; tracing::debug!(
channel = %channel,
chat_id = %chat_id,
job_id = %job_id,
response_len = %response.len(),
"Cron agent response received"
);
#[cfg(debug_assertions)] Ok(HandleResult::AgentResponse(response))
tracing::debug!( }).await
channel = %channel,
chat_id = %chat_id,
job_id = %job_id,
response_len = %response.len(),
"Cron agent response received"
);
*self.current_source_session.lock().await = None;
Ok(HandleResult::AgentResponse(response))
} }
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
@ -1603,7 +1597,7 @@ impl OutboundMessenger for SessionManager {
) -> Result<(), String> { ) -> Result<(), String> {
// Fill origin from current source session if not provided // Fill origin from current source session if not provided
if source.from_session.is_none() { if source.from_session.is_none() {
source.from_session = self.current_source_session.lock().await.clone(); source.from_session = CURRENT_SOURCE_SESSION.try_with(|v| v.clone()).ok().flatten();
} }
let (target_sid, session) = if let Some(did) = dialog_id { let (target_sid, session) = if let Some(did) = dialog_id {