146 lines
4.6 KiB
Rust
146 lines
4.6 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use crate::bus::{MessageBus, OutboundMessage};
|
|
use crate::tools::{SessionMessageSender, SessionSendOutcome, SessionSendRequest, ToolContext};
|
|
|
|
pub(crate) struct BusSessionMessageSender {
|
|
bus: Arc<MessageBus>,
|
|
}
|
|
|
|
impl BusSessionMessageSender {
|
|
pub(crate) fn new(bus: Arc<MessageBus>) -> Self {
|
|
Self { bus }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl SessionMessageSender for BusSessionMessageSender {
|
|
async fn send_to_current_session(
|
|
&self,
|
|
context: &ToolContext,
|
|
request: SessionSendRequest,
|
|
) -> anyhow::Result<SessionSendOutcome> {
|
|
let channel_name = context
|
|
.channel_name
|
|
.as_deref()
|
|
.ok_or_else(|| anyhow::anyhow!("missing channel_name in tool context"))?;
|
|
let chat_id = context
|
|
.chat_id
|
|
.as_deref()
|
|
.ok_or_else(|| anyhow::anyhow!("missing chat_id in tool context"))?;
|
|
|
|
let metadata = HashMap::new();
|
|
let attachment_count = request.attachments.len();
|
|
let mut published_messages = 0;
|
|
let text_sent = request
|
|
.text
|
|
.as_deref()
|
|
.map(str::trim)
|
|
.filter(|text| !text.is_empty())
|
|
.is_some();
|
|
|
|
if let Some(text) = request.text.filter(|value| !value.trim().is_empty()) {
|
|
let content_len = text.len();
|
|
let mut outbound = OutboundMessage::assistant(
|
|
channel_name.to_string(),
|
|
chat_id.to_string(),
|
|
None, // session_id
|
|
text,
|
|
None,
|
|
metadata.clone(),
|
|
);
|
|
if attachment_count > 0 {
|
|
outbound.media = request.attachments.clone();
|
|
}
|
|
self.bus.publish_outbound(outbound).await?;
|
|
published_messages += 1;
|
|
tracing::info!(
|
|
channel = %channel_name,
|
|
chat_id = %chat_id,
|
|
content_len = content_len,
|
|
attachment_count = attachment_count,
|
|
"Published session text message to outbound bus"
|
|
);
|
|
} else {
|
|
for attachment in request.attachments {
|
|
let media_path = attachment.path.clone();
|
|
let media_type = attachment.media_type.clone();
|
|
let mut outbound = OutboundMessage::assistant(
|
|
channel_name.to_string(),
|
|
chat_id.to_string(),
|
|
None, // session_id
|
|
String::new(),
|
|
None,
|
|
metadata.clone(),
|
|
);
|
|
outbound.media = vec![attachment];
|
|
self.bus.publish_outbound(outbound).await?;
|
|
published_messages += 1;
|
|
tracing::info!(
|
|
channel = %channel_name,
|
|
chat_id = %chat_id,
|
|
media_type = %media_type,
|
|
media_path = %media_path,
|
|
"Published session attachment to outbound bus"
|
|
);
|
|
}
|
|
}
|
|
|
|
Ok(SessionSendOutcome {
|
|
published_messages,
|
|
text_sent,
|
|
attachment_count,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::bus::MediaItem;
|
|
|
|
const TEST_CHANNEL: &str = "test-channel";
|
|
|
|
#[tokio::test]
|
|
async fn bus_sender_publishes_text_then_attachment() {
|
|
let bus = MessageBus::new(8);
|
|
let sender = BusSessionMessageSender::new(bus.clone());
|
|
let context = ToolContext {
|
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
|
chat_id: Some("chat-1".to_string()),
|
|
..ToolContext::default()
|
|
};
|
|
|
|
let outcome = sender
|
|
.send_to_current_session(
|
|
&context,
|
|
SessionSendRequest {
|
|
text: Some("hello".to_string()),
|
|
// 使用临时目录确保跨平台兼容
|
|
attachments: vec![MediaItem::new(
|
|
&std::env::temp_dir().join("demo.png").display().to_string(),
|
|
"image"
|
|
)],
|
|
},
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(
|
|
outcome,
|
|
SessionSendOutcome {
|
|
published_messages: 1,
|
|
text_sent: true,
|
|
attachment_count: 1,
|
|
}
|
|
);
|
|
|
|
let msg = bus.consume_outbound().await.expect("bus outbound closed");
|
|
assert_eq!(msg.content, "hello");
|
|
assert_eq!(msg.media.len(), 1);
|
|
assert_eq!(msg.media[0].media_type, "image");
|
|
}
|
|
} |