PicoBot/src/channels/wechat.rs

423 lines
15 KiB
Rust

use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use std::time::UNIX_EPOCH;
use async_trait::async_trait;
use futures_util::FutureExt;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use wechatbot::{BotOptions, SendContent, WeChatBot};
use crate::bus::{InboundMessage, MediaItem, MessageBus, OutboundMessage};
use crate::bus::message::OutboundEventKind;
use crate::channels::base::{Channel, ChannelError};
use crate::config::{LLMProviderConfig, WechatChannelConfig};
#[derive(Clone)]
pub struct WechatChannel {
name: String,
config: WechatChannelConfig,
bot: Arc<WeChatBot>,
running: Arc<AtomicBool>,
task: Arc<RwLock<Option<JoinHandle<()>>>>,
}
impl WechatChannel {
pub fn new(
name: String,
config: WechatChannelConfig,
_provider_config: LLMProviderConfig,
) -> Result<Self, ChannelError> {
let channel_name = name.clone();
let bot = WeChatBot::new(BotOptions {
base_url: Some(config.base_url.clone()),
cred_path: Some(config.cred_path.clone()),
on_qr_url: Some(Box::new(move |url| {
tracing::info!(channel = %channel_name, qr_url = %url, "WeChat QR code ready");
})),
on_error: Some(Box::new(move |error| {
tracing::error!(error = %error, "WeChat SDK error");
})),
});
Ok(Self {
name,
config,
bot: Arc::new(bot),
running: Arc::new(AtomicBool::new(false)),
task: Arc::new(RwLock::new(None)),
})
}
fn sender_allowed(&self, sender_id: &str) -> bool {
self.config.allow_from.iter().any(|pattern| pattern == "*" || pattern == sender_id)
}
fn media_to_send_content(
media: &MediaItem,
caption: Option<String>,
) -> Result<SendContent, ChannelError> {
let data = std::fs::read(&media.path).map_err(|error| {
ChannelError::SendError(format!(
"WeChat media read failed for '{}': {}",
media.path, error
))
})?;
if data.is_empty() {
return Err(ChannelError::SendError(format!(
"WeChat media file is empty: {}",
media.path
)));
}
let file_name = Path::new(&media.path)
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("attachment.bin")
.to_string();
match media.media_type.as_str() {
"image" => Ok(SendContent::Image { data, caption }),
"video" => Ok(SendContent::Video { data, caption }),
_ => Ok(SendContent::File {
data,
file_name,
caption,
}),
}
}
fn default_media_dir() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("media").join("wechat")
}
fn build_download_filename(
media_type: &str,
file_name: Option<&str>,
format: Option<&str>,
) -> String {
if let Some(file_name) = file_name {
let sanitized: String = file_name
.chars()
.map(|ch| match ch {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
_ => ch,
})
.collect();
if !sanitized.trim().is_empty() {
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
}
}
let ext = match (media_type, format) {
("image", _) => "jpg",
("video", _) => "mp4",
("voice", Some(fmt)) if !fmt.trim().is_empty() => fmt,
("voice", _) => "bin",
_ => "bin",
};
format!("{}_{}.{}", media_type, uuid::Uuid::new_v4(), ext)
}
async fn download_inbound_media(
bot: Arc<WeChatBot>,
msg: wechatbot::IncomingMessage,
) -> Result<Vec<MediaItem>, ChannelError> {
let Some(downloaded) = bot.download(&msg).await.map_err(|error| {
ChannelError::Other(format!("WeChat media download failed: {}", error))
})? else {
return Ok(Vec::new());
};
let media_dir = Self::default_media_dir();
tokio::fs::create_dir_all(&media_dir)
.await
.map_err(|error| ChannelError::Other(format!("Failed to create WeChat media dir: {}", error)))?;
let filename = Self::build_download_filename(
&downloaded.media_type,
downloaded.file_name.as_deref(),
downloaded.format.as_deref(),
);
let file_path = media_dir.join(&filename);
tokio::fs::write(&file_path, downloaded.data)
.await
.map_err(|error| ChannelError::Other(format!("Failed to write WeChat media file: {}", error)))?;
tracing::info!(filename = %filename, media_type = %downloaded.media_type, "Downloaded WeChat media");
let mut media_item = MediaItem::new(
file_path.to_string_lossy().to_string(),
downloaded.media_type,
);
media_item.mime_type = mime_guess::from_path(&file_path)
.first_raw()
.map(ToOwned::to_owned);
Ok(vec![media_item])
}
async fn send_typing_indicator(bot: Arc<WeChatBot>, chat_id: &str) {
if let Err(error) = bot.send_typing(chat_id).await {
tracing::debug!(chat_id = %chat_id, error = %error, "Failed to send WeChat typing indicator");
}
}
}
#[async_trait]
impl Channel for WechatChannel {
fn name(&self) -> &str {
&self.name
}
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
if self.running.swap(true, Ordering::SeqCst) {
return Ok(());
}
let channel_name = self.name.clone();
let allow_from = self.config.allow_from.clone();
let bus_for_handler = bus.clone();
let bot_for_handler = self.bot.clone();
self.bot
.on_message(Box::new(move |msg| {
let sender_id = msg.user_id.clone();
let allowed = allow_from
.iter()
.any(|pattern| pattern == "*" || pattern == &sender_id);
if !allowed {
tracing::warn!(channel = %channel_name, sender = %sender_id, "Access denied");
return;
}
let msg = msg.clone();
let timestamp = msg
.timestamp
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
let bus = bus_for_handler.clone();
let bot = bot_for_handler.clone();
let channel_name_for_publish = channel_name.clone();
tokio::spawn(async move {
Self::send_typing_indicator(bot.clone(), &sender_id).await;
let media = match Self::download_inbound_media(bot, msg.clone()).await {
Ok(media) => media,
Err(error) => {
tracing::error!(error = %error, "Failed to download WeChat inbound media");
Vec::new()
}
};
let mut metadata = HashMap::new();
metadata.insert("context_token".to_string(), msg.context_token().to_string());
let inbound = InboundMessage {
channel: channel_name_for_publish,
sender_id: sender_id.clone(),
chat_id: sender_id,
content: msg.text.clone(),
timestamp,
media,
metadata,
forwarded_metadata: HashMap::new(),
};
if let Err(error) = bus.publish_inbound(inbound).await {
tracing::error!(error = %error, "Failed to publish WeChat inbound message");
}
});
}))
.await;
let bot = self.bot.clone();
let channel_name = self.name.clone();
let force_login = self.config.force_login;
let running = self.running.clone();
let handle = tokio::spawn(async move {
// Use catch_unwind to prevent a panic in the WeChat SDK (login or
// long-poll loop) from crashing the entire process. Any panic is
// logged and the channel is cleanly marked as stopped.
// AssertUnwindSafe is needed because WeChatBot contains internal
// locks (RwLock) that are not RefUnwindSafe.
let result = std::panic::AssertUnwindSafe(async {
match bot.login(force_login).await {
Ok(creds) => {
tracing::info!(
channel = %channel_name,
account_id = %creds.account_id,
user_id = %creds.user_id,
"WeChat login succeeded"
);
}
Err(error) => {
tracing::error!(channel = %channel_name, error = %error, "WeChat login failed");
return;
}
}
if let Err(error) = bot.run().await {
tracing::error!(channel = %channel_name, error = %error, "WeChat channel stopped with error");
}
})
.catch_unwind()
.await;
if let Err(_panic) = result {
tracing::error!(
channel = %channel_name,
"WeChat bot task panicked — marking channel as stopped"
);
}
running.store(false, Ordering::SeqCst);
});
*self.task.write().await = Some(handle);
tracing::info!(channel = %self.name, "WeChat channel started");
Ok(())
}
async fn stop(&self) -> Result<(), ChannelError> {
self.running.store(false, Ordering::SeqCst);
self.bot.stop().await;
if let Some(handle) = self.task.write().await.take() {
handle.abort();
}
tracing::info!(channel = %self.name, "WeChat channel stopped");
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
// WeChat iLink Bot has a ~10-message burst limit per context_token.
// Filter non-essential message types to conserve budget:
// - ToolCall: internal tool invocation details, not useful to WeChat users
// - ToolResult / ToolPending: raw tool output, not user-facing
// - Subagent events: internal agent orchestration
if matches!(
msg.event_kind,
OutboundEventKind::ToolResult
| OutboundEventKind::ToolPending
| OutboundEventKind::ToolCall
) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
{
return Ok(());
}
let text = msg.content.trim().to_string();
let mut text_sent = false;
if !text.is_empty() {
self.bot.send(&msg.chat_id, &text).await.map_err(|error| {
ChannelError::SendError(format!("WeChat text send failed: {}", error))
})?;
tracing::info!(
channel = %self.name,
chat_id = %msg.chat_id,
content_len = text.len(),
"WeChat text message sent"
);
text_sent = true;
}
for (index, media) in msg.media.iter().enumerate() {
let caption = if !text.is_empty() && !text_sent && index == 0 {
Some(text.clone())
} else {
None
};
let content = Self::media_to_send_content(media, caption)?;
self.bot.send_media(&msg.chat_id, content).await.map_err(|error| {
ChannelError::SendError(format!("WeChat media send failed: {}", error))
})?;
tracing::info!(
channel = %self.name,
chat_id = %msg.chat_id,
media_type = %media.media_type,
media_path = %media.path,
"WeChat media message sent"
);
}
if text.is_empty() && msg.media.is_empty() {
return Ok(());
}
Ok(())
}
fn is_allowed(&self, sender_id: &str) -> bool {
self.sender_allowed(sender_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn build_download_filename_preserves_file_name() {
let filename = WechatChannel::build_download_filename("file", Some("README.md"), None);
assert!(filename.ends_with("_README.md"));
}
#[test]
fn build_download_filename_adds_voice_extension_when_missing_name() {
let filename = WechatChannel::build_download_filename("voice", None, Some("silk"));
assert!(filename.starts_with("voice_"));
assert!(filename.ends_with(".silk"));
}
#[test]
fn media_to_send_content_maps_image() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), b"demo-image").unwrap();
let image_path = file.path().with_extension("png");
std::fs::rename(file.path(), &image_path).unwrap();
let media = MediaItem::new(image_path.to_string_lossy().to_string(), "image");
let content = WechatChannel::media_to_send_content(&media, None).unwrap();
assert!(matches!(content, SendContent::Image { .. }));
}
#[test]
fn media_to_send_content_maps_generic_file() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), b"hello").unwrap();
let doc_path = file.path().with_extension("md");
std::fs::rename(file.path(), &doc_path).unwrap();
let media = MediaItem::new(doc_path.to_string_lossy().to_string(), "file");
let content = WechatChannel::media_to_send_content(&media, Some("note".to_string())).unwrap();
match content {
SendContent::File {
file_name,
caption,
..
} => {
assert_eq!(file_name, doc_path.file_name().unwrap().to_string_lossy());
assert_eq!(caption.as_deref(), Some("note"));
}
_ => panic!("expected file send content"),
}
}
}