use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; use std::sync::{ Arc, atomic::{AtomicBool, Ordering}, }; use std::time::UNIX_EPOCH; use async_trait::async_trait; use tokio::sync::RwLock; use tokio::task::JoinHandle; use wechatbot::{BotOptions, SendContent, WeChatBot}; use crate::bus::{InboundMessage, MediaItem, MessageBus, OutboundMessage}; use crate::bus::message::OutboundEventKind; use crate::channels::base::{Channel, ChannelError}; use crate::config::{LLMProviderConfig, WechatChannelConfig}; #[derive(Clone)] pub struct WechatChannel { name: String, config: WechatChannelConfig, bot: Arc, running: Arc, task: Arc>>>, } impl WechatChannel { pub fn new( name: String, config: WechatChannelConfig, _provider_config: LLMProviderConfig, ) -> Result { let channel_name = name.clone(); let bot = WeChatBot::new(BotOptions { base_url: Some(config.base_url.clone()), cred_path: Some(config.cred_path.clone()), on_qr_url: Some(Box::new(move |url| { tracing::info!(channel = %channel_name, qr_url = %url, "WeChat QR code ready"); })), on_error: Some(Box::new(move |error| { tracing::error!(error = %error, "WeChat SDK error"); })), }); Ok(Self { name, config, bot: Arc::new(bot), running: Arc::new(AtomicBool::new(false)), task: Arc::new(RwLock::new(None)), }) } fn sender_allowed(&self, sender_id: &str) -> bool { self.config.allow_from.iter().any(|pattern| pattern == "*" || pattern == sender_id) } fn media_to_send_content( media: &MediaItem, caption: Option, ) -> Result { let data = std::fs::read(&media.path).map_err(|error| { ChannelError::SendError(format!( "WeChat media read failed for '{}': {}", media.path, error )) })?; if data.is_empty() { return Err(ChannelError::SendError(format!( "WeChat media file is empty: {}", media.path ))); } let file_name = Path::new(&media.path) .file_name() .and_then(|name| name.to_str()) .unwrap_or("attachment.bin") .to_string(); match media.media_type.as_str() { "image" => Ok(SendContent::Image { data, caption }), "video" => Ok(SendContent::Video { data, caption }), _ => Ok(SendContent::File { data, file_name, caption, }), } } fn default_media_dir() -> PathBuf { let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); home.join(".picobot").join("media").join("wechat") } fn build_download_filename( media_type: &str, file_name: Option<&str>, format: Option<&str>, ) -> String { if let Some(file_name) = file_name { let sanitized: String = file_name .chars() .map(|ch| match ch { '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_', _ => ch, }) .collect(); if !sanitized.trim().is_empty() { return format!("{}_{}", uuid::Uuid::new_v4(), sanitized); } } let ext = match (media_type, format) { ("image", _) => "jpg", ("video", _) => "mp4", ("voice", Some(fmt)) if !fmt.trim().is_empty() => fmt, ("voice", _) => "bin", _ => "bin", }; format!("{}_{}.{}", media_type, uuid::Uuid::new_v4(), ext) } async fn download_inbound_media( bot: Arc, msg: wechatbot::IncomingMessage, ) -> Result, ChannelError> { let Some(downloaded) = bot.download(&msg).await.map_err(|error| { ChannelError::Other(format!("WeChat media download failed: {}", error)) })? else { return Ok(Vec::new()); }; let media_dir = Self::default_media_dir(); tokio::fs::create_dir_all(&media_dir) .await .map_err(|error| ChannelError::Other(format!("Failed to create WeChat media dir: {}", error)))?; let filename = Self::build_download_filename( &downloaded.media_type, downloaded.file_name.as_deref(), downloaded.format.as_deref(), ); let file_path = media_dir.join(&filename); tokio::fs::write(&file_path, downloaded.data) .await .map_err(|error| ChannelError::Other(format!("Failed to write WeChat media file: {}", error)))?; tracing::info!(filename = %filename, media_type = %downloaded.media_type, "Downloaded WeChat media"); let mut media_item = MediaItem::new( file_path.to_string_lossy().to_string(), downloaded.media_type, ); media_item.mime_type = mime_guess::from_path(&file_path) .first_raw() .map(ToOwned::to_owned); Ok(vec![media_item]) } async fn send_typing_indicator(bot: Arc, chat_id: &str) { if let Err(error) = bot.send_typing(chat_id).await { tracing::debug!(chat_id = %chat_id, error = %error, "Failed to send WeChat typing indicator"); } } } #[async_trait] impl Channel for WechatChannel { fn name(&self) -> &str { &self.name } fn is_running(&self) -> bool { self.running.load(Ordering::SeqCst) } async fn start(&self, bus: Arc) -> Result<(), ChannelError> { if self.running.swap(true, Ordering::SeqCst) { return Ok(()); } let channel_name = self.name.clone(); let allow_from = self.config.allow_from.clone(); let bus_for_handler = bus.clone(); let bot_for_handler = self.bot.clone(); self.bot .on_message(Box::new(move |msg| { let sender_id = msg.user_id.clone(); let allowed = allow_from .iter() .any(|pattern| pattern == "*" || pattern == &sender_id); if !allowed { tracing::warn!(channel = %channel_name, sender = %sender_id, "Access denied"); return; } let msg = msg.clone(); let timestamp = msg .timestamp .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; let bus = bus_for_handler.clone(); let bot = bot_for_handler.clone(); let channel_name_for_publish = channel_name.clone(); tokio::spawn(async move { Self::send_typing_indicator(bot.clone(), &sender_id).await; let media = match Self::download_inbound_media(bot, msg.clone()).await { Ok(media) => media, Err(error) => { tracing::error!(error = %error, "Failed to download WeChat inbound media"); Vec::new() } }; let mut metadata = HashMap::new(); metadata.insert("context_token".to_string(), msg.context_token().to_string()); let inbound = InboundMessage { channel: channel_name_for_publish, sender_id: sender_id.clone(), chat_id: sender_id, content: msg.text.clone(), timestamp, media, metadata, forwarded_metadata: HashMap::new(), }; if let Err(error) = bus.publish_inbound(inbound).await { tracing::error!(error = %error, "Failed to publish WeChat inbound message"); } }); })) .await; let bot = self.bot.clone(); let channel_name = self.name.clone(); let force_login = self.config.force_login; let running = self.running.clone(); let handle = tokio::spawn(async move { match bot.login(force_login).await { Ok(creds) => { tracing::info!( channel = %channel_name, account_id = %creds.account_id, user_id = %creds.user_id, "WeChat login succeeded" ); } Err(error) => { running.store(false, Ordering::SeqCst); tracing::error!(channel = %channel_name, error = %error, "WeChat login failed"); return; } } if let Err(error) = bot.run().await { tracing::error!(channel = %channel_name, error = %error, "WeChat channel stopped with error"); } running.store(false, Ordering::SeqCst); }); *self.task.write().await = Some(handle); tracing::info!(channel = %self.name, "WeChat channel started"); Ok(()) } async fn stop(&self) -> Result<(), ChannelError> { self.running.store(false, Ordering::SeqCst); self.bot.stop().await; if let Some(handle) = self.task.write().await.take() { handle.abort(); } tracing::info!(channel = %self.name, "WeChat channel stopped"); Ok(()) } async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false) { return Ok(()); } let text = msg.content.trim().to_string(); let mut text_sent = false; if !text.is_empty() { self.bot.send(&msg.chat_id, &text).await.map_err(|error| { ChannelError::SendError(format!("WeChat text send failed: {}", error)) })?; tracing::info!( channel = %self.name, chat_id = %msg.chat_id, content_len = text.len(), "WeChat text message sent" ); text_sent = true; } for (index, media) in msg.media.iter().enumerate() { let caption = if !text.is_empty() && !text_sent && index == 0 { Some(text.clone()) } else { None }; let content = Self::media_to_send_content(media, caption)?; self.bot.send_media(&msg.chat_id, content).await.map_err(|error| { ChannelError::SendError(format!("WeChat media send failed: {}", error)) })?; tracing::info!( channel = %self.name, chat_id = %msg.chat_id, media_type = %media.media_type, media_path = %media.path, "WeChat media message sent" ); } if text.is_empty() && msg.media.is_empty() { return Ok(()); } Ok(()) } fn is_allowed(&self, sender_id: &str) -> bool { self.sender_allowed(sender_id) } } #[cfg(test)] mod tests { use super::*; use tempfile::NamedTempFile; #[test] fn build_download_filename_preserves_file_name() { let filename = WechatChannel::build_download_filename("file", Some("README.md"), None); assert!(filename.ends_with("_README.md")); } #[test] fn build_download_filename_adds_voice_extension_when_missing_name() { let filename = WechatChannel::build_download_filename("voice", None, Some("silk")); assert!(filename.starts_with("voice_")); assert!(filename.ends_with(".silk")); } #[test] fn media_to_send_content_maps_image() { let file = NamedTempFile::new().unwrap(); std::fs::write(file.path(), b"demo-image").unwrap(); let image_path = file.path().with_extension("png"); std::fs::rename(file.path(), &image_path).unwrap(); let media = MediaItem::new(image_path.to_string_lossy().to_string(), "image"); let content = WechatChannel::media_to_send_content(&media, None).unwrap(); assert!(matches!(content, SendContent::Image { .. })); } #[test] fn media_to_send_content_maps_generic_file() { let file = NamedTempFile::new().unwrap(); std::fs::write(file.path(), b"hello").unwrap(); let doc_path = file.path().with_extension("md"); std::fs::rename(file.path(), &doc_path).unwrap(); let media = MediaItem::new(doc_path.to_string_lossy().to_string(), "file"); let content = WechatChannel::media_to_send_content(&media, Some("note".to_string())).unwrap(); match content { SendContent::File { file_name, caption, .. } => { assert_eq!(file_name, doc_path.file_name().unwrap().to_string_lossy()); assert_eq!(caption.as_deref(), Some("note")); } _ => panic!("expected file send content"), } } }