PicoBot/src/gateway/session_message_sender.rs

147 lines
4.5 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use crate::bus::{MessageBus, OutboundMessage};
use crate::tools::{SessionMessageSender, SessionSendOutcome, SessionSendRequest, ToolContext};
pub(crate) struct BusSessionMessageSender {
bus: Arc<MessageBus>,
}
impl BusSessionMessageSender {
pub(crate) fn new(bus: Arc<MessageBus>) -> Self {
Self { bus }
}
}
#[async_trait]
impl SessionMessageSender for BusSessionMessageSender {
async fn send_to_current_session(
&self,
context: &ToolContext,
request: SessionSendRequest,
) -> anyhow::Result<SessionSendOutcome> {
let channel_name = context
.channel_name
.as_deref()
.ok_or_else(|| anyhow::anyhow!("missing channel_name in tool context"))?;
let chat_id = context
.chat_id
.as_deref()
.ok_or_else(|| anyhow::anyhow!("missing chat_id in tool context"))?;
let metadata = HashMap::new();
let mut published_messages = 0;
let text_sent = request
.text
.as_deref()
.map(str::trim)
.filter(|text| !text.is_empty())
.is_some();
if let Some(text) = request.text.filter(|value| !value.trim().is_empty()) {
let content_len = text.len();
self.bus
.publish_outbound(OutboundMessage::assistant(
channel_name.to_string(),
chat_id.to_string(),
None, // session_id
text,
None,
metadata.clone(),
))
.await?;
published_messages += 1;
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content_len,
"Published session text message to outbound bus"
);
}
let attachment_count = request.attachments.len();
for attachment in request.attachments {
let media_path = attachment.path.clone();
let media_type = attachment.media_type.clone();
let mut outbound = OutboundMessage::assistant(
channel_name.to_string(),
chat_id.to_string(),
None, // session_id
String::new(),
None,
metadata.clone(),
);
outbound.media = vec![attachment];
self.bus.publish_outbound(outbound).await?;
published_messages += 1;
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
media_type = %media_type,
media_path = %media_path,
"Published session attachment to outbound bus"
);
}
Ok(SessionSendOutcome {
published_messages,
text_sent,
attachment_count,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::MediaItem;
const TEST_CHANNEL: &str = "test-channel";
#[tokio::test]
async fn bus_sender_publishes_text_then_attachment() {
let bus = MessageBus::new(8);
let sender = BusSessionMessageSender::new(bus.clone());
let context = ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
chat_id: Some("chat-1".to_string()),
..ToolContext::default()
};
let outcome = sender
.send_to_current_session(
&context,
SessionSendRequest {
text: Some("hello".to_string()),
// 使用临时目录确保跨平台兼容
attachments: vec![MediaItem::new(
&std::env::temp_dir().join("demo.png").display().to_string(),
"image"
)],
},
)
.await
.unwrap();
assert_eq!(
outcome,
SessionSendOutcome {
published_messages: 2,
text_sent: true,
attachment_count: 1,
}
);
let first = bus.consume_outbound().await;
assert_eq!(first.content, "hello");
assert!(first.media.is_empty());
let second = bus.consume_outbound().await;
assert_eq!(second.content, "");
assert_eq!(second.media.len(), 1);
assert_eq!(second.media[0].media_type, "image");
}
}