From 597881f72e1684521d12e6b9f62efefd3d88dd7a Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Wed, 6 May 2026 14:18:47 +0800 Subject: [PATCH] feat: Implement WeChatBot SDK with error handling and message protocol - Add WeChatBotError enum for error handling with various error types. - Create a Result type alias for easier error management. - Implement ILinkClient for low-level API interactions including QR code generation, message sending, and updates retrieval. - Define message types and structures for handling incoming messages and media content. - Add tests for error handling and message parsing to ensure reliability. Co-authored-by: Copilot --- Cargo.toml | 2 + src/bootstrap.rs | 3 + src/bus/message.rs | 38 +- src/channels/feishu.rs | 7 +- src/channels/manager.rs | 192 +++++- src/channels/mod.rs | 2 + src/channels/wechat.rs | 272 ++++++++ src/config/mod.rs | 175 +++++- src/gateway/command.rs | 6 +- src/gateway/session.rs | 28 +- src/gateway/session_message_sender.rs | 4 +- src/lib.rs | 1 + src/main.rs | 2 + src/scheduler/mod.rs | 16 +- src/storage/mod.rs | 90 +-- src/tools/memory_manage.rs | 8 +- src/tools/memory_search.rs | 14 +- src/tools/scheduler_manage.rs | 18 +- src/tools/session_send.rs | 4 +- src/tools/skill_activate.rs | 10 +- vendor/wechatbot/.cargo-ok | 1 + vendor/wechatbot/.cargo_vcs_info.json | 6 + vendor/wechatbot/Cargo.toml | 91 +++ vendor/wechatbot/Cargo.toml.orig | 35 ++ vendor/wechatbot/README.md | 226 +++++++ vendor/wechatbot/examples/echo_bot.rs | 43 ++ vendor/wechatbot/src/bot.rs | 741 ++++++++++++++++++++++ vendor/wechatbot/src/cdn.rs | 138 +++++ vendor/wechatbot/src/crypto.rs | 148 +++++ vendor/wechatbot/src/error.rs | 93 +++ vendor/wechatbot/src/lib.rs | 38 ++ vendor/wechatbot/src/protocol.rs | 407 ++++++++++++ vendor/wechatbot/src/types.rs | 858 ++++++++++++++++++++++++++ 33 files changed, 3601 insertions(+), 116 deletions(-) create mode 100644 src/bootstrap.rs create mode 100644 src/channels/wechat.rs create mode 100644 vendor/wechatbot/.cargo-ok create mode 100644 vendor/wechatbot/.cargo_vcs_info.json create mode 100644 vendor/wechatbot/Cargo.toml create mode 100644 vendor/wechatbot/Cargo.toml.orig create mode 100644 vendor/wechatbot/README.md create mode 100644 vendor/wechatbot/examples/echo_bot.rs create mode 100644 vendor/wechatbot/src/bot.rs create mode 100644 vendor/wechatbot/src/cdn.rs create mode 100644 vendor/wechatbot/src/crypto.rs create mode 100644 vendor/wechatbot/src/error.rs create mode 100644 vendor/wechatbot/src/lib.rs create mode 100644 vendor/wechatbot/src/protocol.rs create mode 100644 vendor/wechatbot/src/types.rs diff --git a/Cargo.toml b/Cargo.toml index 0795f50..b0b7986 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,5 @@ image = { version = "0.25", default-features = false, features = ["jpeg", "png", tempfile = "3" meval = "0.2" rusqlite = { version = "0.32", features = ["bundled"] } +rustls = { version = "0.23", features = ["ring"] } +wechatbot = { path = "vendor/wechatbot" } diff --git a/src/bootstrap.rs b/src/bootstrap.rs new file mode 100644 index 0000000..937c325 --- /dev/null +++ b/src/bootstrap.rs @@ -0,0 +1,3 @@ +pub fn initialize_process_runtime() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} \ No newline at end of file diff --git a/src/bus/message.rs b/src/bus/message.rs index cdf7ac3..652b55f 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -538,6 +538,8 @@ mod tests { use serde_json::json; use std::collections::HashMap; + const TEST_CHANNEL: &str = "test-channel"; + #[test] fn test_from_chat_message_expands_tool_calls() { let message = ChatMessage::assistant_with_tool_calls( @@ -556,8 +558,13 @@ mod tests { ], ); - let outbound = - OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); + let outbound = OutboundMessage::from_chat_message( + TEST_CHANNEL, + "chat-1", + None, + &HashMap::new(), + &message, + ); assert_eq!(outbound.len(), 2); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall); @@ -588,8 +595,13 @@ mod tests { }], ); - let outbound = - OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); + let outbound = OutboundMessage::from_chat_message( + TEST_CHANNEL, + "chat-1", + None, + &HashMap::new(), + &message, + ); assert_eq!(outbound.len(), 2); assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse); @@ -602,8 +614,13 @@ mod tests { fn test_from_chat_message_includes_tool_result() { let message = ChatMessage::tool("call-9", "calculator", "2"); - let outbound = - OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); + let outbound = OutboundMessage::from_chat_message( + TEST_CHANNEL, + "chat-1", + None, + &HashMap::new(), + &message, + ); assert_eq!(outbound.len(), 1); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult); @@ -618,8 +635,13 @@ mod tests { ToolMessageState::PendingUserAction, ); - let outbound = - OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); + let outbound = OutboundMessage::from_chat_message( + TEST_CHANNEL, + "chat-1", + None, + &HashMap::new(), + &message, + ); assert_eq!(outbound.len(), 1); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending); diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index bb07dc1..9f68f9e 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -149,6 +149,7 @@ struct CachedTenantToken { #[derive(Clone)] pub struct FeishuChannel { + name: String, config: FeishuChannelConfig, http_client: reqwest::Client, running: Arc>, @@ -174,10 +175,12 @@ struct ParsedMessage { impl FeishuChannel { pub fn new( + name: String, config: FeishuChannelConfig, _provider_config: LLMProviderConfig, ) -> Result { Ok(Self { + name, config, http_client: reqwest::Client::new(), running: Arc::new(RwLock::new(false)), @@ -1251,7 +1254,7 @@ impl FeishuChannel { #[cfg(debug_assertions)] tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus"); let msg = crate::bus::InboundMessage { - channel: "feishu".to_string(), + channel: channel.name().to_string(), sender_id: parsed.open_id.clone(), chat_id: parsed.chat_id.clone(), content: parsed.content.clone(), @@ -2281,7 +2284,7 @@ mod tests { #[async_trait] impl Channel for FeishuChannel { fn name(&self) -> &str { - "feishu" + &self.name } async fn start(&self, bus: Arc) -> Result<(), ChannelError> { diff --git a/src/channels/manager.rs b/src/channels/manager.rs index 4a6ea36..ff4a2e2 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -6,7 +6,8 @@ use crate::bus::MessageBus; use crate::channels::base::{Channel, ChannelError}; use crate::channels::cli::CliChannel; use crate::channels::feishu::FeishuChannel; -use crate::config::Config; +use crate::channels::wechat::WechatChannel; +use crate::config::{Config, TaggedChannelConfig}; /// ChannelManager manages all Channel instances and the MessageBus #[derive(Clone)] @@ -42,23 +43,57 @@ impl ChannelManager { pub async fn init( &self, config: &Config, - _provider_config: crate::config::LLMProviderConfig, + provider_config: crate::config::LLMProviderConfig, ) -> Result<(), ChannelError> { - // Initialize Feishu channel if enabled - if let Some(feishu_config) = config.channels.get("feishu") { - if feishu_config.enabled { - let channel = - FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| { - ChannelError::Other(format!("Failed to create Feishu channel: {}", e)) - })?; + for (name, channel_config) in &config.channels { + match channel_config { + crate::config::ChannelConfig::Tagged(TaggedChannelConfig::Feishu(feishu_config)) + | crate::config::ChannelConfig::LegacyFeishu(feishu_config) => { + if feishu_config.enabled { + let channel = FeishuChannel::new( + name.clone(), + feishu_config.clone(), + provider_config.clone(), + ) + .map_err(|e| { + ChannelError::Other(format!( + "Failed to create Feishu channel '{}': {}", + name, e + )) + })?; - self.channels - .write() - .await - .insert("feishu".to_string(), Arc::new(channel)); - tracing::info!("Feishu channel registered"); - } else { - tracing::info!("Feishu channel disabled in config"); + self.channels + .write() + .await + .insert(name.clone(), Arc::new(channel)); + tracing::info!(channel = %name, kind = channel_config.kind(), "Channel registered"); + } else { + tracing::info!(channel = %name, kind = channel_config.kind(), "Channel disabled in config"); + } + } + crate::config::ChannelConfig::Tagged(TaggedChannelConfig::Wechat(wechat_config)) => { + if wechat_config.enabled { + let channel = WechatChannel::new( + name.clone(), + wechat_config.clone(), + provider_config.clone(), + ) + .map_err(|e| { + ChannelError::Other(format!( + "Failed to create WeChat channel '{}': {}", + name, e + )) + })?; + + self.channels + .write() + .await + .insert(name.clone(), Arc::new(channel)); + tracing::info!(channel = %name, kind = channel_config.kind(), "Channel registered"); + } else { + tracing::info!(channel = %name, kind = channel_config.kind(), "Channel disabled in config"); + } + } } } Ok(()) @@ -101,3 +136,128 @@ impl ChannelManager { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + + fn write_test_config() -> tempfile::NamedTempFile { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + }, + "channels": { + "primary": { + "type": "feishu", + "enabled": true, + "app_id": "app-id-1", + "app_secret": "secret-1" + }, + "backup": { + "type": "feishu", + "enabled": true, + "app_id": "app-id-2", + "app_secret": "secret-2" + } + } +}"#, + ) + .unwrap(); + file + } + + #[tokio::test] + async fn init_registers_all_configured_channels_by_instance_name() { + let file = write_test_config(); + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let provider_config = config.get_provider_config("default").unwrap(); + let manager = ChannelManager::new(); + + manager.init(&config, provider_config).await.unwrap(); + + let mut names = manager + .channels() + .await + .into_iter() + .map(|(name, _)| name) + .collect::>(); + names.sort(); + + assert_eq!(names, vec!["backup", "cli", "primary"]); + assert_eq!(manager.get_channel("primary").await.unwrap().name(), "primary"); + assert_eq!(manager.get_channel("backup").await.unwrap().name(), "backup"); + } + + #[tokio::test] + async fn init_registers_wechat_channel_by_instance_name() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + }, + "channels": { + "wechat_main": { + "type": "wechat", + "enabled": true, + "cred_path": "/tmp/wechat-creds.json" + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let provider_config = config.get_provider_config("default").unwrap(); + let manager = ChannelManager::new(); + + manager.init(&config, provider_config).await.unwrap(); + + let mut names = manager + .channels() + .await + .into_iter() + .map(|(name, _)| name) + .collect::>(); + names.sort(); + + assert_eq!(names, vec!["cli", "wechat_main"]); + assert_eq!(manager.get_channel("wechat_main").await.unwrap().name(), "wechat_main"); + } +} diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 53ce40d..5e1a7d0 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -2,8 +2,10 @@ pub mod base; pub mod cli; pub mod feishu; pub mod manager; +pub mod wechat; pub use base::{Channel, ChannelError}; pub use cli::CliChannel; pub use feishu::FeishuChannel; pub use manager::ChannelManager; +pub use wechat::WechatChannel; diff --git a/src/channels/wechat.rs b/src/channels/wechat.rs new file mode 100644 index 0000000..42a1c2c --- /dev/null +++ b/src/channels/wechat.rs @@ -0,0 +1,272 @@ +use std::collections::HashMap; +use std::path::Path; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; +use std::time::UNIX_EPOCH; + +use async_trait::async_trait; +use tokio::sync::RwLock; +use tokio::task::JoinHandle; +use wechatbot::{BotOptions, SendContent, WeChatBot}; + +use crate::bus::{InboundMessage, MediaItem, MessageBus, OutboundMessage}; +use crate::channels::base::{Channel, ChannelError}; +use crate::config::{LLMProviderConfig, WechatChannelConfig}; + +#[derive(Clone)] +pub struct WechatChannel { + name: String, + config: WechatChannelConfig, + bot: Arc, + running: Arc, + task: Arc>>>, +} + +impl WechatChannel { + pub fn new( + name: String, + config: WechatChannelConfig, + _provider_config: LLMProviderConfig, + ) -> Result { + let channel_name = name.clone(); + let bot = WeChatBot::new(BotOptions { + base_url: Some(config.base_url.clone()), + cred_path: Some(config.cred_path.clone()), + on_qr_url: Some(Box::new(move |url| { + tracing::info!(channel = %channel_name, qr_url = %url, "WeChat QR code ready"); + })), + on_error: Some(Box::new(move |error| { + tracing::error!(error = %error, "WeChat SDK error"); + })), + }); + + Ok(Self { + name, + config, + bot: Arc::new(bot), + running: Arc::new(AtomicBool::new(false)), + task: Arc::new(RwLock::new(None)), + }) + } + + fn sender_allowed(&self, sender_id: &str) -> bool { + self.config.allow_from.iter().any(|pattern| pattern == "*" || pattern == sender_id) + } + + fn media_to_send_content( + media: &MediaItem, + caption: Option, + ) -> Result { + let data = std::fs::read(&media.path).map_err(|error| { + ChannelError::SendError(format!( + "WeChat media read failed for '{}': {}", + media.path, error + )) + })?; + + if data.is_empty() { + return Err(ChannelError::SendError(format!( + "WeChat media file is empty: {}", + media.path + ))); + } + + let file_name = Path::new(&media.path) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or("attachment.bin") + .to_string(); + + match media.media_type.as_str() { + "image" => Ok(SendContent::Image { data, caption }), + "video" => Ok(SendContent::Video { data, caption }), + _ => Ok(SendContent::File { + data, + file_name, + caption, + }), + } + } +} + +#[async_trait] +impl Channel for WechatChannel { + fn name(&self) -> &str { + &self.name + } + + fn is_running(&self) -> bool { + self.running.load(Ordering::SeqCst) + } + + async fn start(&self, bus: Arc) -> Result<(), ChannelError> { + if self.running.swap(true, Ordering::SeqCst) { + return Ok(()); + } + + let channel_name = self.name.clone(); + let allow_from = self.config.allow_from.clone(); + let bus_for_handler = bus.clone(); + self.bot + .on_message(Box::new(move |msg| { + let sender_id = msg.user_id.clone(); + let allowed = allow_from + .iter() + .any(|pattern| pattern == "*" || pattern == &sender_id); + if !allowed { + tracing::warn!(channel = %channel_name, sender = %sender_id, "Access denied"); + return; + } + + let timestamp = msg + .timestamp + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64; + let mut metadata = HashMap::new(); + metadata.insert("context_token".to_string(), msg.context_token().to_string()); + + let inbound = InboundMessage { + channel: channel_name.clone(), + sender_id: sender_id.clone(), + chat_id: sender_id, + content: msg.text.clone(), + timestamp, + media: Vec::new(), + metadata, + forwarded_metadata: HashMap::new(), + }; + + let bus = bus_for_handler.clone(); + tokio::spawn(async move { + if let Err(error) = bus.publish_inbound(inbound).await { + tracing::error!(error = %error, "Failed to publish WeChat inbound message"); + } + }); + })) + .await; + + let bot = self.bot.clone(); + let channel_name = self.name.clone(); + let force_login = self.config.force_login; + let running = self.running.clone(); + + let handle = tokio::spawn(async move { + match bot.login(force_login).await { + Ok(creds) => { + tracing::info!( + channel = %channel_name, + account_id = %creds.account_id, + user_id = %creds.user_id, + "WeChat login succeeded" + ); + } + Err(error) => { + running.store(false, Ordering::SeqCst); + tracing::error!(channel = %channel_name, error = %error, "WeChat login failed"); + return; + } + } + + if let Err(error) = bot.run().await { + tracing::error!(channel = %channel_name, error = %error, "WeChat channel stopped with error"); + } + + running.store(false, Ordering::SeqCst); + }); + + *self.task.write().await = Some(handle); + tracing::info!(channel = %self.name, "WeChat channel started"); + Ok(()) + } + + async fn stop(&self) -> Result<(), ChannelError> { + self.running.store(false, Ordering::SeqCst); + self.bot.stop().await; + + if let Some(handle) = self.task.write().await.take() { + handle.abort(); + } + + tracing::info!(channel = %self.name, "WeChat channel stopped"); + Ok(()) + } + + async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { + let text = msg.content.trim().to_string(); + let mut text_sent = false; + + if !text.is_empty() { + self.bot.send(&msg.chat_id, &text).await.map_err(|error| { + ChannelError::SendError(format!("WeChat text send failed: {}", error)) + })?; + text_sent = true; + } + + for (index, media) in msg.media.iter().enumerate() { + let caption = if !text.is_empty() && !text_sent && index == 0 { + Some(text.clone()) + } else { + None + }; + let content = Self::media_to_send_content(media, caption)?; + self.bot.send_media(&msg.chat_id, content).await.map_err(|error| { + ChannelError::SendError(format!("WeChat media send failed: {}", error)) + })?; + } + + if text.is_empty() && msg.media.is_empty() { + return Ok(()); + } + + Ok(()) + } + + fn is_allowed(&self, sender_id: &str) -> bool { + self.sender_allowed(sender_id) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn media_to_send_content_maps_image() { + let file = NamedTempFile::new().unwrap(); + std::fs::write(file.path(), b"demo-image").unwrap(); + let image_path = file.path().with_extension("png"); + std::fs::rename(file.path(), &image_path).unwrap(); + + let media = MediaItem::new(image_path.to_string_lossy().to_string(), "image"); + let content = WechatChannel::media_to_send_content(&media, None).unwrap(); + + assert!(matches!(content, SendContent::Image { .. })); + } + + #[test] + fn media_to_send_content_maps_generic_file() { + let file = NamedTempFile::new().unwrap(); + std::fs::write(file.path(), b"hello").unwrap(); + let doc_path = file.path().with_extension("md"); + std::fs::rename(file.path(), &doc_path).unwrap(); + + let media = MediaItem::new(doc_path.to_string_lossy().to_string(), "file"); + let content = WechatChannel::media_to_send_content(&media, Some("note".to_string())).unwrap(); + + match content { + SendContent::File { + file_name, + caption, + .. + } => { + assert_eq!(file_name, doc_path.file_name().unwrap().to_string_lossy()); + assert_eq!(caption.as_deref(), Some("note")); + } + _ => panic!("expected file send content"), + } + } +} \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index 488f2b7..4bbd416 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -22,7 +22,7 @@ pub struct Config { #[serde(default)] pub client: ClientConfig, #[serde(default)] - pub channels: HashMap, + pub channels: HashMap, #[serde(default)] pub skills: SkillsConfig, } @@ -96,6 +96,54 @@ impl Default for SkillsConfig { } } +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ChannelConfig { + Tagged(TaggedChannelConfig), + LegacyFeishu(FeishuChannelConfig), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum TaggedChannelConfig { + Feishu(FeishuChannelConfig), + Wechat(WechatChannelConfig), +} + +impl ChannelConfig { + pub fn kind(&self) -> &'static str { + match self { + Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => "feishu", + Self::Tagged(TaggedChannelConfig::Wechat(_)) => "wechat", + } + } + + pub fn enabled(&self) -> bool { + match self { + Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => { + config.enabled + } + Self::Tagged(TaggedChannelConfig::Wechat(config)) => config.enabled, + } + } + + pub fn as_feishu(&self) -> Option<&FeishuChannelConfig> { + match self { + Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => { + Some(config) + } + Self::Tagged(TaggedChannelConfig::Wechat(_)) => None, + } + } + + pub fn as_wechat(&self) -> Option<&WechatChannelConfig> { + match self { + Self::Tagged(TaggedChannelConfig::Wechat(config)) => Some(config), + Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => None, + } + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeishuChannelConfig { #[serde(default)] @@ -117,6 +165,22 @@ pub struct FeishuChannelConfig { pub reply_context_max_chars: usize, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WechatChannelConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_allow_from")] + pub allow_from: Vec, + #[serde(default)] + pub agent: String, + #[serde(default = "default_wechat_base_url")] + pub base_url: String, + #[serde(default = "default_wechat_cred_path")] + pub cred_path: String, + #[serde(default)] + pub force_login: bool, +} + fn default_allow_from() -> Vec { vec!["*".to_string()] } @@ -128,6 +192,17 @@ fn default_media_dir() -> String { .to_string() } +fn default_wechat_base_url() -> String { + "https://ilinkai.weixin.qq.com".to_string() +} + +fn default_wechat_cred_path() -> String { + let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from(".")); + home.join(".picobot/wechat/credentials.json") + .to_string_lossy() + .to_string() +} + fn default_reaction_emoji() -> String { "Typing".to_string() } @@ -1171,11 +1246,105 @@ mod tests { .unwrap(); let config = Config::load(file.path().to_str().unwrap()).unwrap(); - let feishu = &config.channels["feishu"]; + let feishu = config.channels["feishu"].as_feishu().unwrap(); assert_eq!(feishu.max_message_chars, 20_000); assert_eq!(feishu.reply_context_max_chars, 20_000); } + #[test] + fn test_tagged_feishu_channel_config_loads() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + }, + "channels": { + "primary": { + "type": "feishu", + "enabled": true, + "app_id": "app-id", + "app_secret": "secret" + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let feishu = config.channels["primary"].as_feishu().unwrap(); + + assert_eq!(config.channels["primary"].kind(), "feishu"); + assert!(config.channels["primary"].enabled()); + assert_eq!(feishu.app_id, "app-id"); + } + + #[test] + fn test_tagged_wechat_channel_config_loads() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + }, + "channels": { + "wechat_main": { + "type": "wechat", + "enabled": true, + "base_url": "https://ilinkai.weixin.qq.com", + "cred_path": "/tmp/wechat-creds.json", + "force_login": true, + "allow_from": ["wxid_1"] + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let wechat = config.channels["wechat_main"].as_wechat().unwrap(); + + assert_eq!(config.channels["wechat_main"].kind(), "wechat"); + assert!(config.channels["wechat_main"].enabled()); + assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json"); + assert!(wechat.force_login); + assert_eq!(wechat.allow_from, vec!["wxid_1"]); + } + #[test] fn test_feishu_channel_config_loads_custom_truncation_limits() { let file = tempfile::NamedTempFile::new().unwrap(); @@ -1215,7 +1384,7 @@ mod tests { .unwrap(); let config = Config::load(file.path().to_str().unwrap()).unwrap(); - let feishu = &config.channels["feishu"]; + let feishu = config.channels["feishu"].as_feishu().unwrap(); assert_eq!(feishu.max_message_chars, 3456); assert_eq!(feishu.reply_context_max_chars, 4567); } diff --git a/src/gateway/command.rs b/src/gateway/command.rs index 3763c58..49e226b 100644 --- a/src/gateway/command.rs +++ b/src/gateway/command.rs @@ -40,6 +40,8 @@ mod tests { use std::sync::Arc; use tokio::sync::mpsc; + const TEST_CHANNEL: &str = "test-channel"; + fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { provider_type: "openai".to_string(), @@ -80,7 +82,7 @@ mod tests { let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new(ToolRegistry::new()); let mut session = Session::new( - "feishu".to_string(), + TEST_CHANNEL.to_string(), test_provider_config(), user_tx, tools, @@ -130,7 +132,7 @@ mod tests { let skills = Arc::new(SkillRuntime::default()); let tools = Arc::new(ToolRegistry::new()); let mut session = Session::new( - "feishu".to_string(), + TEST_CHANNEL.to_string(), test_provider_config(), user_tx, tools, diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 7a1babd..12663d0 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -542,7 +542,7 @@ mod tests { .build(), ); let mut session = Session::new( - "feishu".to_string(), + "test-channel".to_string(), test_provider_config(), user_tx, tools, @@ -587,7 +587,7 @@ mod tests { .build(), ); let mut session = Session::new( - "feishu".to_string(), + "test-channel".to_string(), test_provider_config(), user_tx, tools, @@ -791,7 +791,7 @@ mod tests { .unwrap(); let outbound = session_manager - .handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None) + .handle_message("test-channel", "user-1", "chat-1", "hello", Vec::new(), None) .await .unwrap(); @@ -840,7 +840,7 @@ mod tests { let planner_outbound = session_manager .run_scheduled_agent_task( - "feishu", + "test-channel", "chat-planner", "请规划今天工作", ScheduledAgentTaskOptions { @@ -856,7 +856,7 @@ mod tests { let default_outbound = session_manager .run_scheduled_agent_task( - "feishu", + "test-channel", "chat-default", "请规划今天工作", ScheduledAgentTaskOptions { @@ -904,7 +904,7 @@ mod tests { session_manager .run_scheduled_agent_task( - "feishu", + "test-channel", "chat-guard", "每小时执行以下流程:检查邮箱并同步待办", ScheduledAgentTaskOptions { @@ -916,7 +916,7 @@ mod tests { .await .unwrap(); - let session = session_manager.get("feishu").await.unwrap(); + let session = session_manager.get("test-channel").await.unwrap(); let session_guard = session.lock().await; let persisted_messages = session_guard .store() @@ -1477,7 +1477,13 @@ mod tests { async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() { let bus = MessageBus::new(4); let emitter = - BusToolCallEmitter::new(bus.clone(), "feishu", "chat-1", HashMap::new(), false); + BusToolCallEmitter::new( + bus.clone(), + "test-channel", + "chat-1", + HashMap::new(), + false, + ); emitter .handle(ChatMessage::tool("call-1", "calculator", "2")) @@ -1508,7 +1514,7 @@ mod tests { .build(), ); let mut session = Session::new( - "feishu".to_string(), + "test-channel".to_string(), test_provider_config(), user_tx, tools, @@ -1546,7 +1552,7 @@ mod tests { .build(), ); let mut session = Session::new( - "feishu".to_string(), + "test-channel".to_string(), test_provider_config(), user_tx, tools, @@ -1612,7 +1618,7 @@ mod tests { .build(), ); let mut session = Session::new( - "feishu".to_string(), + "test-channel".to_string(), test_provider_config(), user_tx, tools, diff --git a/src/gateway/session_message_sender.rs b/src/gateway/session_message_sender.rs index 029261d..fe55b79 100644 --- a/src/gateway/session_message_sender.rs +++ b/src/gateway/session_message_sender.rs @@ -81,12 +81,14 @@ mod tests { use super::*; use crate::bus::MediaItem; + const TEST_CHANNEL: &str = "test-channel"; + #[tokio::test] async fn bus_sender_publishes_text_then_attachment() { let bus = MessageBus::new(8); let sender = BusSessionMessageSender::new(bus.clone()); let context = ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), chat_id: Some("chat-1".to_string()), ..ToolContext::default() }; diff --git a/src/lib.rs b/src/lib.rs index 9e5e8c6..93113cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod agent; +pub mod bootstrap; pub mod bus; pub mod channels; pub mod cli; diff --git a/src/main.rs b/src/main.rs index 20df5e9..a40f2da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,8 @@ enum Command { #[tokio::main] async fn main() -> Result<(), Box> { + picobot::bootstrap::initialize_process_runtime(); + let mut cmd = Command::command(); // If no arguments, print help diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index ac4eabd..271f0f4 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -812,7 +812,7 @@ mod agent_task_tests { interval_secs: 0, startup_delay_secs: 0, target: serde_json::json!({ - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }), payload: serde_json::json!({ @@ -859,7 +859,7 @@ mod agent_task_tests { interval_secs: 0, startup_delay_secs: 0, target: serde_json::json!({ - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo", "session_chat_id": "scheduler/agent.daily_summary.background" }), @@ -905,7 +905,7 @@ mod agent_task_tests { startup_delay_secs: 0, }, target: SchedulerJobTarget { - channel: Some("feishu".to_string()), + channel: Some("test-channel".to_string()), chat_id: Some("oc_demo".to_string()), session_chat_id: None, reply_to: None, @@ -965,7 +965,7 @@ mod agent_task_tests { startup_delay_secs: 0, }, target: SchedulerJobTarget { - channel: Some("feishu".to_string()), + channel: Some("test-channel".to_string()), chat_id: Some("oc_demo".to_string()), session_chat_id: None, reply_to: None, @@ -1101,7 +1101,7 @@ mod tests { interval_secs: 0, startup_delay_secs: 0, target: serde_json::json!({ - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }), payload: serde_json::json!({"content": "hello"}), @@ -1151,7 +1151,7 @@ mod tests { interval_secs: 60, startup_delay_secs: 0, target: serde_json::json!({ - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }), payload: serde_json::json!({ @@ -1271,7 +1271,7 @@ mod tests { startup_delay_secs: 0, }, target: SchedulerJobTarget { - channel: Some("feishu".to_string()), + channel: Some("test-channel".to_string()), chat_id: Some("oc_demo".to_string()), session_chat_id: Some("scheduler/agent.daily_summary.background".to_string()), reply_to: None, @@ -1300,7 +1300,7 @@ mod tests { ) .await .unwrap(); - assert_eq!(outbound.channel, "feishu"); + assert_eq!(outbound.channel, "test-channel"); assert_eq!(outbound.chat_id, "oc_demo"); assert!(outbound.content.contains("定时任务执行失败")); assert!(outbound.content.contains("agent.daily_summary.background")); diff --git a/src/storage/mod.rs b/src/storage/mod.rs index ac7fb3b..3f3aae5 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1614,10 +1614,12 @@ mod tests { use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT; use crate::domain::messages::ToolCall; + const TEST_CHANNEL: &str = "test-channel"; + #[test] fn test_persistent_session_id_for_cli_and_channel() { assert_eq!(persistent_session_id("cli", "abc"), "abc"); - assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc"); + assert_eq!(persistent_session_id(TEST_CHANNEL, "abc"), "test-channel:abc"); } #[test] @@ -1682,12 +1684,12 @@ mod tests { fn test_ensure_channel_session_is_stable() { let store = SessionStore::in_memory().unwrap(); - let first = store.ensure_channel_session("feishu", "chat-1").unwrap(); - let second = store.ensure_channel_session("feishu", "chat-1").unwrap(); + let first = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap(); + let second = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap(); assert_eq!(first.id, second.id); assert_eq!(first.chat_id, "chat-1"); - assert_eq!(second.channel_name, "feishu"); + assert_eq!(second.channel_name, TEST_CHANNEL); } #[test] @@ -2040,27 +2042,27 @@ mod tests { let saved = store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "profile".to_string(), memory_key: "language".to_string(), content: "Rust".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-1".to_string()), + source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), source_message_id: Some("msg-1".to_string()), source_message_seq: Some(7), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-1".to_string()), }) .unwrap(); assert_eq!(saved.content, "Rust"); assert_eq!(saved.source_type, "message"); - assert_eq!(saved.source_session_id.as_deref(), Some("feishu:chat-1")); + assert_eq!(saved.source_session_id.as_deref(), Some("test-channel:chat-1")); assert_eq!(saved.source_message_id.as_deref(), Some("msg-1")); assert_eq!(saved.source_message_seq, Some(7)); let fetched = store - .get_memory("user", "feishu:user-1", "profile", "language") + .get_memory("user", "test-channel:user-1", "profile", "language") .unwrap() .unwrap(); assert_eq!(fetched.id, saved.id); @@ -2074,21 +2076,21 @@ mod tests { store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "preferences".to_string(), memory_key: "editor".to_string(), content: "Prefers rust-analyzer and cargo test output".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-2".to_string()), + source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)), source_message_id: Some("msg-2".to_string()), source_message_seq: Some(3), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-2".to_string()), }) .unwrap(); let hits = store - .search_memories("user", "feishu:user-1", "rust-analyzer", None, 10) + .search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10) .unwrap(); assert_eq!(hits.len(), 1); assert_eq!(hits[0].memory_key, "editor"); @@ -2096,36 +2098,36 @@ mod tests { store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "preferences".to_string(), memory_key: "editor".to_string(), content: "Prefers clippy diagnostics".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-3".to_string()), + source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)), source_message_id: Some("msg-3".to_string()), source_message_seq: Some(4), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-3".to_string()), }) .unwrap(); let old_hits = store - .search_memories("user", "feishu:user-1", "rust-analyzer", None, 10) + .search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10) .unwrap(); assert!(old_hits.is_empty()); let new_hits = store - .search_memories("user", "feishu:user-1", "clippy", None, 10) + .search_memories("user", "test-channel:user-1", "clippy", None, 10) .unwrap(); assert_eq!(new_hits.len(), 1); let deleted = store - .delete_memory("user", "feishu:user-1", "preferences", "editor") + .delete_memory("user", "test-channel:user-1", "preferences", "editor") .unwrap(); assert!(deleted); let hits_after_delete = store - .search_memories("user", "feishu:user-1", "clippy", None, 10) + .search_memories("user", "test-channel:user-1", "clippy", None, 10) .unwrap(); assert!(hits_after_delete.is_empty()); } @@ -2137,21 +2139,21 @@ mod tests { store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "preferences".to_string(), memory_key: "email_folder_preference".to_string(), content: "用户提到邮件时默认查看代收邮箱。".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-8".to_string()), + source_session_id: Some(format!("{}:chat-8", TEST_CHANNEL)), source_message_id: Some("msg-8".to_string()), source_message_seq: Some(8), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-8".to_string()), }) .unwrap(); let hits = store - .search_memories("user", "feishu:user-1", "email_folder_preference", None, 10) + .search_memories("user", "test-channel:user-1", "email_folder_preference", None, 10) .unwrap(); assert_eq!(hits.len(), 1); @@ -2165,15 +2167,15 @@ mod tests { store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "preferences".to_string(), memory_key: "editor".to_string(), content: "Prefers rust-analyzer and cargo test output".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-2".to_string()), + source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)), source_message_id: Some("msg-2".to_string()), source_message_seq: Some(3), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-2".to_string()), }) .unwrap(); @@ -2181,15 +2183,15 @@ mod tests { store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "tasks".to_string(), memory_key: "quality".to_string(), content: "Tracks clippy warnings before release".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-3".to_string()), + source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)), source_message_id: Some("msg-3".to_string()), source_message_seq: Some(4), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-3".to_string()), }) .unwrap(); @@ -2197,7 +2199,7 @@ mod tests { let hits = store .search_memories_any( "user", - "feishu:user-1", + "test-channel:user-1", &["rust-analyzer".to_string(), "clippy".to_string()], None, 10, @@ -2216,45 +2218,45 @@ mod tests { store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-2".to_string(), + scope_key: format!("{}:user-2", TEST_CHANNEL), namespace: "preferences".to_string(), memory_key: "style".to_string(), content: "偏好简洁表达".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-2".to_string()), + source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)), source_message_id: Some("msg-2".to_string()), source_message_seq: Some(2), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-2".to_string()), }) .unwrap(); store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "profile".to_string(), memory_key: "work".to_string(), content: "用户在做AI产品".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-1".to_string()), + source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), source_message_id: Some("msg-1".to_string()), source_message_seq: Some(1), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-1".to_string()), }) .unwrap(); store .put_memory(&MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "patterns".to_string(), memory_key: "workflow".to_string(), content: "习惯先问方案再要代码".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-1".to_string()), + source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), source_message_id: Some("msg-3".to_string()), source_message_seq: Some(3), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-1".to_string()), }) .unwrap(); @@ -2262,17 +2264,17 @@ mod tests { let scope_keys = store.list_memory_scope_keys("user").unwrap(); assert_eq!( scope_keys, - vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()] + vec!["test-channel:user-1".to_string(), "test-channel:user-2".to_string()] ); let full_scope = store - .list_memories_for_scope("user", "feishu:user-1") + .list_memories_for_scope("user", "test-channel:user-1") .unwrap(); assert_eq!(full_scope.len(), 2); assert!( full_scope .iter() - .all(|memory| memory.scope_key == "feishu:user-1") + .all(|memory| memory.scope_key == "test-channel:user-1") ); assert!(full_scope.iter().any(|memory| memory.memory_key == "work")); assert!( @@ -2298,7 +2300,7 @@ mod tests { interval_secs: 300, startup_delay_secs: 10, target: serde_json::json!({ - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo", }), payload: serde_json::json!({ diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs index 07534ab..0f8f2c1 100644 --- a/src/tools/memory_manage.rs +++ b/src/tools/memory_manage.rs @@ -221,15 +221,17 @@ mod tests { use super::*; use crate::storage::SessionStore; + const TEST_CHANNEL: &str = "test-channel"; + #[tokio::test] async fn test_memory_manage_put_returns_saved_memory() { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemoryManageTool::new(store); let context = ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), sender_id: Some("user-1".to_string()), chat_id: Some("chat-1".to_string()), - session_id: Some("feishu:chat-1".to_string()), + session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), message_id: Some("msg-1".to_string()), message_seq: Some(1), }; @@ -275,7 +277,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemoryManageTool::new(store); let context = ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), sender_id: Some("user-1".to_string()), ..ToolContext::default() }; diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs index 699e9d0..64f5078 100644 --- a/src/tools/memory_search.rs +++ b/src/tools/memory_search.rs @@ -207,31 +207,33 @@ mod tests { use super::*; use crate::storage::SessionStore; + const TEST_CHANNEL: &str = "test-channel"; + #[tokio::test] async fn test_memory_search_search_and_get() { let store = Arc::new(SessionStore::in_memory().unwrap()); store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), - scope_key: "feishu:user-1".to_string(), + scope_key: format!("{}:user-1", TEST_CHANNEL), namespace: "preferences".to_string(), memory_key: "language".to_string(), content: "User prefers Chinese responses".to_string(), source_type: "message".to_string(), - source_session_id: Some("feishu:chat-1".to_string()), + source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), source_message_id: Some("msg-1".to_string()), source_message_seq: Some(1), - source_channel_name: Some("feishu".to_string()), + source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-1".to_string()), }) .unwrap(); let tool = MemorySearchTool::new(store); let context = ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), sender_id: Some("user-1".to_string()), chat_id: Some("chat-1".to_string()), - session_id: Some("feishu:chat-1".to_string()), + session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), message_id: Some("msg-2".to_string()), message_seq: Some(2), }; @@ -285,7 +287,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemorySearchTool::new(store); let context = ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), sender_id: Some("user-1".to_string()), ..ToolContext::default() }; diff --git a/src/tools/scheduler_manage.rs b/src/tools/scheduler_manage.rs index 8d91e0a..5022228 100644 --- a/src/tools/scheduler_manage.rs +++ b/src/tools/scheduler_manage.rs @@ -435,6 +435,8 @@ mod tests { use super::*; use crate::storage::SessionStore; + const TEST_CHANNEL: &str = "test-channel"; + #[tokio::test] async fn test_scheduler_manage_put_and_get() { let store = Arc::new(SessionStore::in_memory().unwrap()); @@ -450,7 +452,7 @@ mod tests { "seconds": 60 }, "target": { - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }, "payload": { @@ -488,7 +490,7 @@ mod tests { "expression": "0 9 * * *" }, "target": { - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }, "payload": { @@ -518,7 +520,7 @@ mod tests { "expression": "0 9 * * *" }, "target": { - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo", "session_chat_id": "scheduler/agent.daily_summary.background" }, @@ -576,10 +578,10 @@ mod tests { let put_result = tool .execute_with_context( &crate::tools::ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), sender_id: Some("user-1".to_string()), chat_id: Some("oc_demo".to_string()), - session_id: Some("feishu:oc_demo".to_string()), + session_id: Some(format!("{}:oc_demo", TEST_CHANNEL)), message_id: Some("msg-1".to_string()), message_seq: Some(1), }, @@ -602,7 +604,7 @@ mod tests { assert!(put_result.success); let saved = store.get_scheduler_job("work_reminder").unwrap().unwrap(); - assert_eq!(saved.target["channel"], "feishu"); + assert_eq!(saved.target["channel"], "test-channel"); assert_eq!(saved.target["chat_id"], "oc_demo"); } @@ -621,7 +623,7 @@ mod tests { "expression": "0 9 * * *" }, "target": { - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }, "payload": { @@ -653,7 +655,7 @@ mod tests { "expression": "0 9 * * *" }, "target": { - "channel": "feishu", + "channel": "test-channel", "chat_id": "oc_demo" }, "payload": { diff --git a/src/tools/session_send.rs b/src/tools/session_send.rs index 8ce207e..d7ce9e8 100644 --- a/src/tools/session_send.rs +++ b/src/tools/session_send.rs @@ -240,6 +240,8 @@ mod tests { use super::*; use tempfile::NamedTempFile; + const TEST_CHANNEL: &str = "test-channel"; + struct MockSender { outcome: SessionSendOutcome, } @@ -257,7 +259,7 @@ mod tests { fn context() -> ToolContext { ToolContext { - channel_name: Some("feishu".to_string()), + channel_name: Some(TEST_CHANNEL.to_string()), chat_id: Some("chat-1".to_string()), ..ToolContext::default() } diff --git a/src/tools/skill_activate.rs b/src/tools/skill_activate.rs index 5de84aa..46b95d5 100644 --- a/src/tools/skill_activate.rs +++ b/src/tools/skill_activate.rs @@ -124,14 +124,16 @@ mod tests { use super::*; use crate::storage::SessionStore; + const TEST_CHANNEL: &str = "test-channel"; + #[tokio::test] async fn test_skill_activate_records_failed_activation_event() { let skills = Arc::new(SkillRuntime::default()); let store = Arc::new(SessionStore::in_memory().unwrap()); - store.ensure_channel_session("feishu", "chat-1").unwrap(); + store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap(); let tool = SkillActivateTool::new(skills, store.clone()); let context = ToolContext { - session_id: Some("feishu:chat-1".to_string()), + session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), ..ToolContext::default() }; @@ -143,7 +145,9 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("not found")); - let events = store.list_skill_events(Some("feishu:chat-1")).unwrap(); + let events = store + .list_skill_events(Some(&format!("{}:chat-1", TEST_CHANNEL))) + .unwrap(); assert_eq!(events.len(), 1); assert_eq!(events[0].event_type, "activation_failed"); assert_eq!(events[0].skill_name.as_deref(), Some("demo")); diff --git a/vendor/wechatbot/.cargo-ok b/vendor/wechatbot/.cargo-ok new file mode 100644 index 0000000..5f8b795 --- /dev/null +++ b/vendor/wechatbot/.cargo-ok @@ -0,0 +1 @@ +{"v":1} \ No newline at end of file diff --git a/vendor/wechatbot/.cargo_vcs_info.json b/vendor/wechatbot/.cargo_vcs_info.json new file mode 100644 index 0000000..6631308 --- /dev/null +++ b/vendor/wechatbot/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "70bc64cc8035de4677bbe01265570e7f157bb31d" + }, + "path_in_vcs": "rust" +} \ No newline at end of file diff --git a/vendor/wechatbot/Cargo.toml b/vendor/wechatbot/Cargo.toml new file mode 100644 index 0000000..2a11f09 --- /dev/null +++ b/vendor/wechatbot/Cargo.toml @@ -0,0 +1,91 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "wechatbot" +version = "0.3.2" +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "WeChat iLink Bot SDK for Rust" +homepage = "https://github.com/corespeed-io/wechatbot" +documentation = "https://docs.rs/wechatbot" +readme = "README.md" +license = "MIT" +repository = "https://github.com/corespeed-io/wechatbot" + +[lib] +name = "wechatbot" +path = "src/lib.rs" + +[[example]] +name = "echo_bot" +path = "examples/echo_bot.rs" + +[dependencies.aes] +version = "0.8" + +[dependencies.base64] +version = "0.22" + +[dependencies.dirs-next] +version = "2" + +[dependencies.hex] +version = "0.4" + +[dependencies.md-5] +version = "0.10" + +[dependencies.rand] +version = "0.10" + +[dependencies.reqwest] +version = "0.12" +default-features = false +features = ["json", "rustls-tls"] + +[dependencies.serde] +version = "1" +features = ["derive"] + +[dependencies.serde_json] +version = "1" + +[dependencies.serde_repr] +version = "0.1" + +[dependencies.thiserror] +version = "2" + +[dependencies.tokio] +version = "1" +features = ["full"] + +[dependencies.tracing] +version = "0.1" + +[dependencies.urlencoding] +version = "2" + +[dependencies.uuid] +version = "1" +features = ["v4"] + +[dev-dependencies.tokio-test] +version = "0.4" + +[dev-dependencies.tracing-subscriber] +version = "0.3" diff --git a/vendor/wechatbot/Cargo.toml.orig b/vendor/wechatbot/Cargo.toml.orig new file mode 100644 index 0000000..27cb736 --- /dev/null +++ b/vendor/wechatbot/Cargo.toml.orig @@ -0,0 +1,35 @@ +[package] +name = "wechatbot" +version = "0.3.2" +edition = "2021" +description = "WeChat iLink Bot SDK for Rust" +license = "MIT" +readme = "README.md" +repository = "https://github.com/corespeed-io/wechatbot" +homepage = "https://github.com/corespeed-io/wechatbot" +documentation = "https://docs.rs/wechatbot" + +[dependencies] +aes = "0.8" +base64 = "0.22" +hex = "0.4" +md-5 = "0.10" +rand = "0.10" +reqwest = { version = "0.12", features = ["json"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +serde_repr = "0.1" +tokio = { version = "1", features = ["full"] } +uuid = { version = "1", features = ["v4"] } +thiserror = "2" +tracing = "0.1" +urlencoding = "2" +dirs-next = "2" + +[dev-dependencies] +tokio-test = "0.4" +tracing-subscriber = "0.3" + +[[example]] +name = "echo_bot" +path = "examples/echo_bot.rs" diff --git a/vendor/wechatbot/README.md b/vendor/wechatbot/README.md new file mode 100644 index 0000000..477c99f --- /dev/null +++ b/vendor/wechatbot/README.md @@ -0,0 +1,226 @@ +# wechatbot — Rust SDK + +WeChat iLink Bot SDK for Rust — async, type-safe, zero-copy where possible. + +## Install + +```toml +[dependencies] +wechatbot = "0.1" +tokio = { version = "1", features = ["full"] } +``` + +Requires Rust 2021 edition. Built on `tokio` + `reqwest`. + +## Quick Start + +```rust +use wechatbot::{WeChatBot, BotOptions}; + +#[tokio::main] +async fn main() { + let bot = WeChatBot::new(BotOptions::default()); + let creds = bot.login(false).await.unwrap(); + println!("Logged in: {}", creds.account_id); + + bot.on_message(Box::new(|msg| { + println!("{}: {}", msg.user_id, msg.text); + })).await; + + bot.run().await.unwrap(); +} +``` + +## Architecture + +``` +src/ +├── lib.rs ← Public re-exports +├── types.rs ← All protocol & public types (serde) +├── error.rs ← Error hierarchy (thiserror) +├── protocol.rs ← Raw iLink API calls (reqwest) +├── crypto.rs ← AES-128-ECB encrypt/decrypt + key encoding +└── bot.rs ← WeChatBot client (login, run, reply, send) +``` + +## API Reference + +### Creating a Bot + +```rust +use wechatbot::{WeChatBot, BotOptions}; + +let bot = WeChatBot::new(BotOptions { + base_url: None, // default: ilinkai.weixin.qq.com + cred_path: None, // default: ~/.wechatbot/credentials.json + on_qr_url: Some(Box::new(|url| { + println!("Scan: {}", url); + })), + on_error: Some(Box::new(|err| { + eprintln!("Error: {}", err); + })), +}); +``` + +### Authentication + +```rust +// Login (skips QR if credentials exist) +let creds = bot.login(false).await?; + +// Force re-login +let creds = bot.login(true).await?; + +// Credentials struct +println!("Token: {}", creds.token); +println!("Base URL: {}", creds.base_url); +println!("Account: {}", creds.account_id); +println!("User: {}", creds.user_id); +``` + +### Message Handling + +```rust +bot.on_message(Box::new(|msg| { + match msg.content_type { + ContentType::Text => println!("Text: {}", msg.text), + ContentType::Image => { + for img in &msg.images { + println!("Image URL: {:?}", img.url); + } + } + ContentType::Voice => { + for voice in &msg.voices { + println!("Voice: {:?} ({}ms)", voice.text, voice.duration_ms.unwrap_or(0)); + } + } + ContentType::File => { + for file in &msg.files { + println!("File: {:?}", file.file_name); + } + } + ContentType::Video => println!("Video received"), + } + + if let Some(ref quoted) = msg.quoted { + println!("Quoted: {:?}", quoted.title); + } +})).await; +``` + +### Sending Messages + +```rust +// Reply to incoming message +bot.reply(&msg, "Echo: hello").await?; + +// Send to user (needs prior context_token) +bot.send(user_id, "Hello").await?; + +// Typing indicator +bot.send_typing(user_id).await?; +``` + +### Media Operations + +```rust +// Reply with media content +bot.reply_media(&msg, SendContent::Image(png_bytes)).await?; +bot.reply_media(&msg, SendContent::File { data, file_name: "report.pdf".into() }).await?; +bot.reply_media(&msg, SendContent::Video(mp4_bytes)).await?; +``` + +```rust +// Download media from incoming message (priority: image > file > video > voice) +if let Some(media) = bot.download(&msg).await? { + println!("Type: {}, Size: {} bytes", media.media_type, media.data.len()); + if let Some(name) = &media.file_name { + println!("Filename: {}", name); + } +} + +// Download a raw CDN reference directly +let raw = bot.download_raw(&msg.images[0].media.as_ref().unwrap(), None).await?; +``` + +```rust +// Upload to CDN without sending a message +let result = bot.upload(&file_bytes, user_id, 3).await?; +``` + +### Lifecycle + +```rust +// Start polling (blocks) +bot.run().await?; + +// Stop +bot.stop().await; +``` + +## Error Handling + +```rust +use wechatbot::WeChatBotError; + +match result { + Err(WeChatBotError::Api { message, errcode, .. }) => { + if errcode == -14 { + // session expired — handled automatically + } + } + Err(WeChatBotError::NoContext(user_id)) => { + // no context_token for this user yet + } + Err(WeChatBotError::Transport(e)) => { + // network error + } + _ => {} +} +``` + +## AES-128-ECB Crypto + +```rust +use wechatbot::{generate_aes_key, encrypt_aes_ecb, decrypt_aes_ecb, decode_aes_key}; + +// Generate key +let key = generate_aes_key(); + +// Encrypt/decrypt +let ciphertext = encrypt_aes_ecb(b"Hello", &key); +let plaintext = decrypt_aes_ecb(&ciphertext, &key)?; + +// Decode protocol key (handles all 3 formats) +let key = decode_aes_key("ABEiM0RVZneImaq7zN3u/w==")?; +let key = decode_aes_key("00112233445566778899aabbccddeeff")?; +``` + +## Types + +All protocol types derive `Serialize` + `Deserialize` + `Clone` + `Debug`: + +```rust +// Wire-level (protocol) +WireMessage, WireMessageItem, CDNMedia, TextItem, ImageItem, ... + +// Parsed (user-friendly) +IncomingMessage, ImageContent, VoiceContent, FileContent, VideoContent + +// Auth +Credentials + +// Enums +MessageType, MessageState, MessageItemType, ContentType, MediaType +``` + +## Testing + +```bash +cd rust +cargo test +``` + +## License + +MIT diff --git a/vendor/wechatbot/examples/echo_bot.rs b/vendor/wechatbot/examples/echo_bot.rs new file mode 100644 index 0000000..c33d2ca --- /dev/null +++ b/vendor/wechatbot/examples/echo_bot.rs @@ -0,0 +1,43 @@ +use wechatbot::{BotOptions, WeChatBot}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let bot = WeChatBot::new(BotOptions { + on_qr_url: Some(Box::new(|url| { + println!("\nScan this URL in WeChat:\n{}\n", url); + })), + on_error: Some(Box::new(|err| { + eprintln!("Error: {}", err); + })), + ..Default::default() + }); + + let creds = bot.login(false).await.expect("login failed"); + println!("Logged in: {} ({})", creds.account_id, creds.user_id); + + bot.on_message(Box::new(|msg| { + println!("[{}] {}: {}", msg.content_type_str(), msg.user_id, msg.text); + })) + .await; + + println!("Listening for messages (Ctrl+C to stop)"); + bot.run().await.expect("run failed"); +} + +trait ContentTypeStr { + fn content_type_str(&self) -> &str; +} + +impl ContentTypeStr for wechatbot::IncomingMessage { + fn content_type_str(&self) -> &str { + match self.content_type { + wechatbot::ContentType::Text => "text", + wechatbot::ContentType::Image => "image", + wechatbot::ContentType::Voice => "voice", + wechatbot::ContentType::File => "file", + wechatbot::ContentType::Video => "video", + } + } +} diff --git a/vendor/wechatbot/src/bot.rs b/vendor/wechatbot/src/bot.rs new file mode 100644 index 0000000..28d949e --- /dev/null +++ b/vendor/wechatbot/src/bot.rs @@ -0,0 +1,741 @@ +//! Main WeChatBot client. + +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::{sleep, Duration}; +use tracing::{error, info, warn}; + +use crate::cdn::CdnClient; +use crate::crypto; +use crate::error::{Result, WeChatBotError}; +use crate::protocol::{self, ILinkClient}; +use crate::types::*; +use md5::{Digest, Md5}; +use rand::Rng; +use serde_json::json; + +/// Message handler callback type. +pub type MessageHandler = Box; + +/// Bot configuration options. +pub struct BotOptions { + pub base_url: Option, + pub cred_path: Option, + pub on_qr_url: Option>, + pub on_error: Option>, +} + +impl Default for BotOptions { + fn default() -> Self { + Self { + base_url: None, + cred_path: None, + on_qr_url: None, + on_error: None, + } + } +} + +/// WeChatBot is the main entry point. +pub struct WeChatBot { + client: Arc, + cdn: CdnClient, + credentials: RwLock>, + context_tokens: RwLock>, + handlers: Mutex>, + cursor: RwLock, + base_url: RwLock, + cred_path: Option, + stopped: RwLock, + on_qr_url: Option>, + on_error: Option>, +} + +impl WeChatBot { + /// Create a new bot instance. + pub fn new(opts: BotOptions) -> Self { + Self { + client: Arc::new(ILinkClient::new()), + cdn: CdnClient::new(), + credentials: RwLock::new(None), + context_tokens: RwLock::new(HashMap::new()), + handlers: Mutex::new(Vec::new()), + cursor: RwLock::new(String::new()), + base_url: RwLock::new( + opts.base_url + .unwrap_or_else(|| protocol::DEFAULT_BASE_URL.to_string()), + ), + cred_path: opts.cred_path, + stopped: RwLock::new(false), + on_qr_url: opts.on_qr_url, + on_error: opts.on_error, + } + } + + /// Maximum number of QR code refresh attempts before giving up. + const MAX_QR_REFRESH: u32 = 3; + /// Fixed API base URL for QR code requests. + const FIXED_QR_BASE_URL: &'static str = "https://ilinkai.weixin.qq.com"; + + /// Login via QR code. Returns credentials on success. + pub async fn login(&self, force: bool) -> Result { + let base_url = self.base_url.read().await.clone(); + + if !force { + if let Some(creds) = self.load_credentials().await? { + *self.credentials.write().await = Some(creds.clone()); + *self.base_url.write().await = creds.base_url.clone(); + info!("Loaded stored credentials for {}", creds.user_id); + return Ok(creds); + } + } + + // QR code login flow + let mut qr_refresh_count = 0u32; + loop { + qr_refresh_count += 1; + if qr_refresh_count > Self::MAX_QR_REFRESH { + return Err(WeChatBotError::Auth(format!( + "QR code expired {} times — login aborted", + Self::MAX_QR_REFRESH + ))); + } + + let qr = self.client.get_qr_code(Self::FIXED_QR_BASE_URL).await?; + + if let Some(ref cb) = self.on_qr_url { + cb(&qr.qrcode_img_content); + } else { + eprintln!("[wechatbot] Scan: {}", qr.qrcode_img_content); + } + + let mut last_status = String::new(); + let mut current_poll_base_url = Self::FIXED_QR_BASE_URL.to_string(); + loop { + let status = self + .client + .poll_qr_status(¤t_poll_base_url, &qr.qrcode) + .await?; + + if status.status != last_status { + last_status = status.status.clone(); + match status.status.as_str() { + "scaned" => info!("QR scanned — confirm in WeChat"), + "expired" => warn!("QR expired — requesting new one"), + "confirmed" => info!("Login confirmed"), + _ => {} + } + } + + if status.status == "confirmed" { + let token = status + .bot_token + .ok_or_else(|| WeChatBotError::Auth("missing bot_token".into()))?; + let creds = Credentials { + token, + base_url: status.baseurl.unwrap_or_else(|| base_url.clone()), + account_id: status.ilink_bot_id.unwrap_or_default(), + user_id: status.ilink_user_id.unwrap_or_default(), + saved_at: Some(chrono_now()), + }; + self.save_credentials(&creds).await?; + *self.credentials.write().await = Some(creds.clone()); + *self.base_url.write().await = creds.base_url.clone(); + return Ok(creds); + } + + // Handle IDC redirect + if status.status == "scaned_but_redirect" { + if let Some(ref host) = status.redirect_host { + current_poll_base_url = format!("https://{}", host); + info!("IDC redirect, switching polling host to {}", host); + } else { + warn!("Received scaned_but_redirect but redirect_host is missing"); + } + sleep(Duration::from_secs(2)).await; + continue; + } + + if status.status == "expired" { + break; + } + + sleep(Duration::from_secs(2)).await; + } + } + } + + /// Register a message handler. + pub async fn on_message(&self, handler: MessageHandler) { + self.handlers.lock().await.push(handler); + } + + /// Reply to an incoming message. + pub async fn reply(&self, msg: &IncomingMessage, text: &str) -> Result<()> { + self.context_tokens + .write() + .await + .insert(msg.user_id.clone(), msg.context_token.clone()); + self.send_text(&msg.user_id, text, &msg.context_token).await + } + + /// Send text to a user (needs prior context_token). + pub async fn send(&self, user_id: &str, text: &str) -> Result<()> { + let ct = self.context_tokens.read().await.get(user_id).cloned(); + let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?; + self.send_text(user_id, text, &ct).await + } + + /// Show "typing..." indicator. + pub async fn send_typing(&self, user_id: &str) -> Result<()> { + let ct = self.context_tokens.read().await.get(user_id).cloned(); + let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?; + let (base_url, token) = self.get_auth().await?; + let config = self + .client + .get_config(&base_url, &token, user_id, &ct) + .await?; + if let Some(ticket) = config.typing_ticket { + self.client + .send_typing(&base_url, &token, user_id, &ticket, 1) + .await?; + } + Ok(()) + } + + /// Reply with media content (image, video, or file). + pub async fn reply_media(&self, msg: &IncomingMessage, content: SendContent) -> Result<()> { + self.context_tokens + .write() + .await + .insert(msg.user_id.clone(), msg.context_token.clone()); + self.send_content(&msg.user_id, &msg.context_token, content) + .await + } + + /// Send any content type to a user (needs prior context_token). + pub async fn send_media(&self, user_id: &str, content: SendContent) -> Result<()> { + let ct = self.context_tokens.read().await.get(user_id).cloned(); + let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?; + self.send_content(user_id, &ct, content).await + } + + /// Download media from an incoming message. + /// Returns None if the message has no media. Priority: image > file > video > voice. + pub async fn download(&self, msg: &IncomingMessage) -> Result> { + if let Some(img) = msg.images.first() { + if let Some(ref media) = img.media { + let data = self.cdn.download(media, img.aes_key.as_deref()).await?; + return Ok(Some(DownloadedMedia { + data, + media_type: "image".into(), + file_name: None, + format: None, + })); + } + } + if let Some(file) = msg.files.first() { + if let Some(ref media) = file.media { + let data = self.cdn.download(media, None).await?; + return Ok(Some(DownloadedMedia { + data, + media_type: "file".into(), + file_name: Some(file.file_name.clone().unwrap_or_else(|| "file.bin".into())), + format: None, + })); + } + } + if let Some(video) = msg.videos.first() { + if let Some(ref media) = video.media { + let data = self.cdn.download(media, None).await?; + return Ok(Some(DownloadedMedia { + data, + media_type: "video".into(), + file_name: None, + format: None, + })); + } + } + if let Some(voice) = msg.voices.first() { + if let Some(ref media) = voice.media { + let data = self.cdn.download(media, None).await?; + return Ok(Some(DownloadedMedia { + data, + media_type: "voice".into(), + file_name: None, + format: Some("silk".into()), + })); + } + } + Ok(None) + } + + /// Download and decrypt a raw CDN media reference. + pub async fn download_raw( + &self, + media: &CDNMedia, + aeskey_override: Option<&str>, + ) -> Result> { + self.cdn.download(media, aeskey_override).await + } + + /// Upload data to WeChat CDN without sending a message. + pub async fn upload( + &self, + data: &[u8], + user_id: &str, + media_type: i32, + ) -> Result { + let (base_url, token) = self.get_auth().await?; + self.cdn_upload(&base_url, &token, data, user_id, media_type) + .await + } + + /// Start the long-poll loop. Blocks until stopped. + pub async fn run(&self) -> Result<()> { + *self.stopped.write().await = false; + info!("Long-poll loop started"); + let mut retry_delay = Duration::from_secs(1); + + loop { + if *self.stopped.read().await { + break; + } + + let (base_url, token) = self.get_auth().await?; + let cursor = self.cursor.read().await.clone(); + + match self.client.get_updates(&base_url, &token, &cursor).await { + Ok(updates) => { + if !updates.get_updates_buf.is_empty() { + *self.cursor.write().await = updates.get_updates_buf; + } + retry_delay = Duration::from_secs(1); + + for wire in &updates.msgs { + self.remember_context(wire).await; + if let Some(incoming) = IncomingMessage::from_wire(wire) { + let handlers = self.handlers.lock().await; + for handler in handlers.iter() { + handler(&incoming); + } + } + } + } + Err(e) if e.is_session_expired() => { + warn!("Session expired — re-login required"); + *self.context_tokens.write().await = HashMap::new(); + *self.cursor.write().await = String::new(); + if let Err(e) = self.login(true).await { + self.report_error(&e); + } + continue; + } + Err(e) => { + self.report_error(&e); + sleep(retry_delay).await; + retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(10)); + continue; + } + } + } + + info!("Long-poll loop stopped"); + Ok(()) + } + + /// Stop the bot. + pub async fn stop(&self) { + *self.stopped.write().await = true; + } + + // --- internal media --- + + fn send_content<'a>( + &'a self, + user_id: &'a str, + context_token: &'a str, + content: SendContent, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let (base_url, token) = self.get_auth().await?; + match content { + SendContent::Text(text) => self.send_text(user_id, &text, context_token).await, + SendContent::Image { data, caption } => { + let result = self + .cdn_upload(&base_url, &token, &data, user_id, 1) + .await?; + let mut items = Vec::new(); + if let Some(cap) = caption { + items.push(json!({"type": 1, "text_item": {"text": cap}})); + } + items.push(json!({"type": 2, "image_item": { + "media": cdn_media_json(&result.media), + "mid_size": result.encrypted_file_size, + }})); + let msg = protocol::build_media_message(user_id, context_token, items); + self.client.send_message(&base_url, &token, &msg).await + } + SendContent::Video { data, caption } => { + let result = self + .cdn_upload(&base_url, &token, &data, user_id, 2) + .await?; + let mut items = Vec::new(); + if let Some(cap) = caption { + items.push(json!({"type": 1, "text_item": {"text": cap}})); + } + items.push(json!({"type": 5, "video_item": { + "media": cdn_media_json(&result.media), + "video_size": result.encrypted_file_size, + }})); + let msg = protocol::build_media_message(user_id, context_token, items); + self.client.send_message(&base_url, &token, &msg).await + } + SendContent::File { + data, + file_name, + caption, + } => { + let cat = categorize_by_extension(&file_name); + match cat { + "image" => { + self.send_content( + user_id, + context_token, + SendContent::Image { data, caption }, + ) + .await + } + "video" => { + self.send_content( + user_id, + context_token, + SendContent::Video { data, caption }, + ) + .await + } + _ => { + if let Some(cap) = caption { + self.send_text(user_id, &cap, context_token).await?; + } + let data_len = data.len(); + let result = self + .cdn_upload(&base_url, &token, &data, user_id, 3) + .await?; + let items = vec![json!({"type": 4, "file_item": { + "media": cdn_media_json(&result.media), + "file_name": file_name, + "len": data_len.to_string(), + }})]; + let msg = protocol::build_media_message(user_id, context_token, items); + self.client.send_message(&base_url, &token, &msg).await + } + } + } + } + }) + } + + async fn cdn_upload( + &self, + base_url: &str, + token: &str, + data: &[u8], + user_id: &str, + media_type: i32, + ) -> Result { + let aes_key = crypto::generate_aes_key(); + let ciphertext = crypto::encrypt_aes_ecb(data, &aes_key); + + let mut filekey_buf = [0u8; 16]; + rand::rng().fill_bytes(&mut filekey_buf); + let filekey = hex::encode(filekey_buf); + + let raw_md5 = hex::encode(Md5::digest(data)); + + let params = protocol::GetUploadUrlParams { + filekey: filekey.clone(), + media_type, + to_user_id: user_id.to_string(), + rawsize: data.len(), + rawfilemd5: raw_md5, + filesize: ciphertext.len(), + no_need_thumb: true, + aeskey: crypto::encode_aes_key_hex(&aes_key), + }; + + let upload_resp = self.client.get_upload_url(base_url, token, ¶ms).await?; + let upload_param = upload_resp.upload_param.ok_or_else(|| { + WeChatBotError::Media("getuploadurl did not return upload_param".into()) + })?; + + let upload_url = + protocol::build_cdn_upload_url(protocol::CDN_BASE_URL, &upload_param, &filekey); + + let encrypted_file_size = ciphertext.len(); + + let encrypt_query_param = self.client.upload_to_cdn(&upload_url, &ciphertext).await?; + + Ok(UploadResult { + media: CDNMedia { + encrypt_query_param, + aes_key: crypto::encode_aes_key_base64(&aes_key), + encrypt_type: Some(1), + full_url: None, + }, + aes_key, + encrypted_file_size, + }) + } + + // --- internal text --- + + async fn send_text(&self, user_id: &str, text: &str, context_token: &str) -> Result<()> { + let (base_url, token) = self.get_auth().await?; + for chunk in chunk_text(text, 4000) { + let msg = protocol::build_text_message(user_id, context_token, &chunk); + self.client.send_message(&base_url, &token, &msg).await?; + } + Ok(()) + } + + async fn remember_context(&self, wire: &WireMessage) { + let user_id = if wire.message_type == MessageType::User { + &wire.from_user_id + } else { + &wire.to_user_id + }; + if !user_id.is_empty() && !wire.context_token.is_empty() { + self.context_tokens + .write() + .await + .insert(user_id.clone(), wire.context_token.clone()); + } + } + + async fn get_auth(&self) -> Result<(String, String)> { + let creds = self.credentials.read().await; + let creds = creds + .as_ref() + .ok_or_else(|| WeChatBotError::Auth("not logged in".into()))?; + Ok((creds.base_url.clone(), creds.token.clone())) + } + + async fn load_credentials(&self) -> Result> { + let path = self.cred_path.clone().unwrap_or_else(default_cred_path); + match tokio::fs::read_to_string(&path).await { + Ok(data) => Ok(Some(serde_json::from_str(&data)?)), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(e.into()), + } + } + + async fn save_credentials(&self, creds: &Credentials) -> Result<()> { + let path = self.cred_path.clone().unwrap_or_else(default_cred_path); + let dir = std::path::Path::new(&path).parent().unwrap(); + tokio::fs::create_dir_all(dir).await?; + let data = serde_json::to_string_pretty(creds)?; + tokio::fs::write(&path, format!("{}\n", data)).await?; + Ok(()) + } + + fn report_error(&self, err: &WeChatBotError) { + error!("{}", err); + if let Some(ref cb) = self.on_error { + cb(err); + } + } +} + +/// Content to send via reply_media / send_media. +pub enum SendContent { + Text(String), + Image { + data: Vec, + caption: Option, + }, + Video { + data: Vec, + caption: Option, + }, + File { + data: Vec, + file_name: String, + caption: Option, + }, +} + +fn cdn_media_json(media: &CDNMedia) -> serde_json::Value { + let mut v = json!({ + "encrypt_query_param": media.encrypt_query_param, + "aes_key": media.aes_key, + }); + if let Some(et) = media.encrypt_type { + v["encrypt_type"] = json!(et); + } + v +} + +fn categorize_by_extension(filename: &str) -> &'static str { + let ext = Path::new(filename) + .extension() + .and_then(|e| e.to_str()) + .unwrap_or("") + .to_lowercase(); + match ext.as_str() { + "png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" | "svg" => "image", + "mp4" | "mov" | "webm" | "mkv" | "avi" => "video", + _ => "file", + } +} + +fn chunk_text(text: &str, limit: usize) -> Vec { + if text.len() <= limit { + return vec![text.to_string()]; + } + let mut chunks = Vec::new(); + let mut remaining = text; + while !remaining.is_empty() { + if remaining.len() <= limit { + chunks.push(remaining.to_string()); + break; + } + let window = &remaining[..limit]; + let cut = window + .rfind("\n\n") + .filter(|&i| i > limit * 3 / 10) + .map(|i| i + 2) + .or_else(|| { + window + .rfind('\n') + .filter(|&i| i > limit * 3 / 10) + .map(|i| i + 1) + }) + .or_else(|| { + window + .rfind(' ') + .filter(|&i| i > limit * 3 / 10) + .map(|i| i + 1) + }) + .unwrap_or(limit); + chunks.push(remaining[..cut].to_string()); + remaining = &remaining[cut..]; + } + if chunks.is_empty() { + vec![String::new()] + } else { + chunks + } +} + +fn default_cred_path() -> String { + let home = dirs_next::home_dir().unwrap_or_else(|| ".".into()); + home.join(".wechatbot") + .join("credentials.json") + .to_string_lossy() + .to_string() +} + +fn chrono_now() -> String { + // Simple ISO 8601 without chrono dependency + let dur = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + format!("{}Z", dur.as_secs()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chunk_text_short() { + let chunks = chunk_text("hello", 100); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn chunk_text_empty() { + let chunks = chunk_text("", 100); + assert_eq!(chunks, vec![""]); + } + + #[test] + fn chunk_text_splits_on_paragraph() { + let text = "aaaa\n\nbbbb"; + let chunks = chunk_text(text, 7); + assert_eq!(chunks, vec!["aaaa\n\n", "bbbb"]); + } + + #[test] + fn chunk_text_splits_on_newline() { + let text = "aaaa\nbbbb"; + let chunks = chunk_text(text, 7); + assert_eq!(chunks, vec!["aaaa\n", "bbbb"]); + } + + #[test] + fn chunk_text_exact_limit() { + let text = "abcdef"; + let chunks = chunk_text(text, 6); + assert_eq!(chunks, vec!["abcdef"]); + } + + #[test] + fn default_cred_path_not_empty() { + let path = default_cred_path(); + assert!(!path.is_empty()); + assert!(path.contains(".wechatbot")); + assert!(path.contains("credentials.json")); + } + + #[test] + fn categorize_image_extensions() { + assert_eq!(categorize_by_extension("photo.png"), "image"); + assert_eq!(categorize_by_extension("photo.JPG"), "image"); + assert_eq!(categorize_by_extension("anim.gif"), "image"); + assert_eq!(categorize_by_extension("pic.webp"), "image"); + } + + #[test] + fn categorize_video_extensions() { + assert_eq!(categorize_by_extension("clip.mp4"), "video"); + assert_eq!(categorize_by_extension("clip.MOV"), "video"); + assert_eq!(categorize_by_extension("movie.webm"), "video"); + } + + #[test] + fn categorize_file_extensions() { + assert_eq!(categorize_by_extension("report.pdf"), "file"); + assert_eq!(categorize_by_extension("data.csv"), "file"); + assert_eq!(categorize_by_extension("noext"), "file"); + } + + #[test] + fn cdn_media_json_with_encrypt_type() { + let media = CDNMedia { + encrypt_query_param: "param=1".to_string(), + aes_key: "key123".to_string(), + encrypt_type: Some(1), + full_url: None, + }; + let j = cdn_media_json(&media); + assert_eq!(j["encrypt_query_param"], "param=1"); + assert_eq!(j["aes_key"], "key123"); + assert_eq!(j["encrypt_type"], 1); + } + + #[test] + fn cdn_media_json_without_encrypt_type() { + let media = CDNMedia { + encrypt_query_param: "p".to_string(), + aes_key: "k".to_string(), + encrypt_type: None, + full_url: None, + }; + let j = cdn_media_json(&media); + assert!(j.get("encrypt_type").is_none()); + } +} diff --git a/vendor/wechatbot/src/cdn.rs b/vendor/wechatbot/src/cdn.rs new file mode 100644 index 0000000..c64ac90 --- /dev/null +++ b/vendor/wechatbot/src/cdn.rs @@ -0,0 +1,138 @@ +//! Low-level CDN client for direct media download. +//! +//! [`CdnClient`] is a primitive layer that can be used independently of +//! [`WeChatBot`](crate::WeChatBot), e.g. when you drive `get_updates` yourself +//! via [`ILinkClient`](crate::protocol::ILinkClient) and only need decryption +//! for a specific attachment. +//! +//! Modeled after [`teloxide_core::Bot`]: wraps a [`reqwest::Client`] so +//! connection pool / TLS session / DNS cache are reused across calls, and is +//! cheap to [`Clone`]. + +use reqwest::Client; +use std::time::Duration; + +use crate::crypto; +use crate::error::{Result, WeChatBotError}; +use crate::protocol::CDN_BASE_URL; +use crate::types::CDNMedia; + +/// HTTP client for WeChat CDN media endpoints. +/// +/// Cheap to [`Clone`] — shares the underlying [`reqwest::Client`], which uses +/// an `Arc` internally. +/// +/// # Example +/// +/// ```no_run +/// use wechatbot::{CdnClient, CDNMedia}; +/// +/// # async fn demo(media: CDNMedia) -> Result<(), Box> { +/// let cdn = CdnClient::new(); +/// let bytes = cdn.download(&media, None).await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct CdnClient { + http: Client, + base_url: String, +} + +impl Default for CdnClient { + fn default() -> Self { + Self::new() + } +} + +impl CdnClient { + /// Create a [`CdnClient`] with a fresh internal [`reqwest::Client`]. + pub fn new() -> Self { + Self::with_client(Client::new()) + } + + /// Create a [`CdnClient`] that reuses an existing [`reqwest::Client`]. + /// + /// Useful when the caller already maintains a shared HTTP client with + /// custom proxy / TLS / timeout configuration. + pub fn with_client(http: Client) -> Self { + Self { + http, + base_url: CDN_BASE_URL.to_string(), + } + } + + /// Override the CDN base URL (defaults to [`CDN_BASE_URL`]). + /// + /// Primarily intended for tests and regional endpoints. + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + /// Download and AES-decrypt a CDN media object. + /// + /// `aes_key_override` is used when the decryption key is attached to the + /// message metadata (e.g. [`ImageContent::aes_key`](crate::ImageContent::aes_key)) + /// rather than embedded in the media's own `aes_key` field. + pub async fn download( + &self, + media: &CDNMedia, + aes_key_override: Option<&str>, + ) -> Result> { + let download_url = format!( + "{}/download?encrypted_query_param={}", + self.base_url, + urlencoding::encode(&media.encrypt_query_param) + ); + + let resp = self + .http + .get(&download_url) + .timeout(Duration::from_secs(60)) + .send() + .await?; + + if !resp.status().is_success() { + return Err(WeChatBotError::Media(format!( + "CDN download failed: HTTP {}", + resp.status() + ))); + } + + let ciphertext = resp.bytes().await?.to_vec(); + + let key_source = aes_key_override.unwrap_or(&media.aes_key); + if key_source.is_empty() { + return Err(WeChatBotError::Media("no AES key available".into())); + } + + let aes_key = crypto::decode_aes_key(key_source)?; + crypto::decrypt_aes_ecb(&ciphertext, &aes_key) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_and_new_equivalent() { + let a = CdnClient::default(); + let b = CdnClient::new(); + assert_eq!(a.base_url, b.base_url); + } + + #[test] + fn with_base_url_overrides() { + let c = CdnClient::new().with_base_url("https://example.test/cdn"); + assert_eq!(c.base_url, "https://example.test/cdn"); + } + + #[test] + fn clone_is_cheap_and_preserves_config() { + let c = CdnClient::new().with_base_url("https://x.y/z"); + let cloned = c.clone(); + assert_eq!(c.base_url, cloned.base_url); + } +} diff --git a/vendor/wechatbot/src/crypto.rs b/vendor/wechatbot/src/crypto.rs new file mode 100644 index 0000000..1399b4d --- /dev/null +++ b/vendor/wechatbot/src/crypto.rs @@ -0,0 +1,148 @@ +//! AES-128-ECB encryption for WeChat CDN media files. + +use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit}; +use aes::Aes128; +use base64::Engine; +use rand::Rng; + +use crate::error::{Result, WeChatBotError}; + +/// Encrypt plaintext with AES-128-ECB and PKCS7 padding. +pub fn encrypt_aes_ecb(plaintext: &[u8], key: &[u8; 16]) -> Vec { + let cipher = Aes128::new(key.into()); + let padded = pkcs7_pad(plaintext, 16); + let mut ciphertext = padded; + for chunk in ciphertext.chunks_exact_mut(16) { + cipher.encrypt_block(chunk.into()); + } + ciphertext +} + +/// Decrypt AES-128-ECB ciphertext and remove PKCS7 padding. +pub fn decrypt_aes_ecb(ciphertext: &[u8], key: &[u8; 16]) -> Result> { + if ciphertext.len() % 16 != 0 { + return Err(WeChatBotError::Media( + "ciphertext length not a multiple of 16".into(), + )); + } + let cipher = Aes128::new(key.into()); + let mut plaintext = ciphertext.to_vec(); + for chunk in plaintext.chunks_exact_mut(16) { + cipher.decrypt_block(chunk.into()); + } + pkcs7_unpad(&plaintext) +} + +/// Generate a random 16-byte AES key. +pub fn generate_aes_key() -> [u8; 16] { + let mut key = [0u8; 16]; + rand::rng().fill_bytes(&mut key); + key +} + +/// Calculate encrypted size with PKCS7 padding. +pub fn encrypted_size(raw_size: usize) -> usize { + ((raw_size + 1 + 15) / 16) * 16 +} + +/// Decode an aes_key from the protocol (handles all three formats). +pub fn decode_aes_key(encoded: &str) -> Result<[u8; 16]> { + // Direct hex (32 chars) + if encoded.len() == 32 && encoded.chars().all(|c| c.is_ascii_hexdigit()) { + let bytes = + hex::decode(encoded).map_err(|e| WeChatBotError::Media(format!("hex decode: {e}")))?; + return bytes_to_key(&bytes); + } + + // Base64 decode + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(encoded)) + .map_err(|e| WeChatBotError::Media(format!("base64 decode: {e}")))?; + + if decoded.len() == 16 { + return bytes_to_key(&decoded); + } + + if decoded.len() == 32 { + let hex_str = std::str::from_utf8(&decoded) + .map_err(|_| WeChatBotError::Media("decoded key is not UTF-8".into()))?; + if hex_str.chars().all(|c| c.is_ascii_hexdigit()) { + let bytes = hex::decode(hex_str) + .map_err(|e| WeChatBotError::Media(format!("hex decode: {e}")))?; + return bytes_to_key(&bytes); + } + } + + Err(WeChatBotError::Media(format!( + "unexpected decoded key length: {}", + decoded.len() + ))) +} + +/// Encode an AES key as hex (for getuploadurl). +pub fn encode_aes_key_hex(key: &[u8; 16]) -> String { + hex::encode(key) +} + +/// Encode an AES key as base64(hex) (for CDNMedia.aes_key). +pub fn encode_aes_key_base64(key: &[u8; 16]) -> String { + base64::engine::general_purpose::STANDARD.encode(hex::encode(key)) +} + +fn bytes_to_key(bytes: &[u8]) -> Result<[u8; 16]> { + bytes + .try_into() + .map_err(|_| WeChatBotError::Media(format!("key length {} != 16", bytes.len()))) +} + +fn pkcs7_pad(data: &[u8], block_size: usize) -> Vec { + let padding = block_size - (data.len() % block_size); + let mut result = data.to_vec(); + result.extend(std::iter::repeat(padding as u8).take(padding)); + result +} + +fn pkcs7_unpad(data: &[u8]) -> Result> { + if data.is_empty() { + return Err(WeChatBotError::Media("empty data".into())); + } + let padding = *data.last().unwrap() as usize; + if padding == 0 || padding > data.len() || padding > 16 { + return Err(WeChatBotError::Media("invalid PKCS7 padding".into())); + } + Ok(data[..data.len() - padding].to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trip() { + let key = generate_aes_key(); + let plaintext = b"Hello, WeChat!"; + let ct = encrypt_aes_ecb(plaintext, &key); + let pt = decrypt_aes_ecb(&ct, &key).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn encrypted_size_calc() { + assert_eq!(encrypted_size(14), 16); + assert_eq!(encrypted_size(16), 32); + assert_eq!(encrypted_size(100), 112); + } + + #[test] + fn decode_direct_hex() { + let key = decode_aes_key("00112233445566778899aabbccddeeff").unwrap(); + assert_eq!(key.len(), 16); + } + + #[test] + fn decode_base64_raw() { + let key = decode_aes_key("ABEiM0RVZneImaq7zN3u/w==").unwrap(); + assert_eq!(key.len(), 16); + } +} diff --git a/vendor/wechatbot/src/error.rs b/vendor/wechatbot/src/error.rs new file mode 100644 index 0000000..7351223 --- /dev/null +++ b/vendor/wechatbot/src/error.rs @@ -0,0 +1,93 @@ +use thiserror::Error; + +/// Errors that can occur in the SDK. +#[derive(Error, Debug)] +pub enum WeChatBotError { + #[error("API error: {message} (http={http_status}, errcode={errcode})")] + Api { + message: String, + http_status: u16, + errcode: i32, + }, + + #[error("Auth error: {0}")] + Auth(String), + + #[error("No context_token for user {0}")] + NoContext(String), + + #[error("Media error: {0}")] + Media(String), + + #[error("Transport error: {0}")] + Transport(#[from] reqwest::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("{0}")] + Other(String), +} + +impl WeChatBotError { + /// Returns true if this is a session-expired error (errcode -14). + pub fn is_session_expired(&self) -> bool { + matches!(self, WeChatBotError::Api { errcode: -14, .. }) + } +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn session_expired_true() { + let err = WeChatBotError::Api { + message: "session expired".to_string(), + http_status: 200, + errcode: -14, + }; + assert!(err.is_session_expired()); + } + + #[test] + fn session_expired_false() { + let err = WeChatBotError::Api { + message: "other error".to_string(), + http_status: 400, + errcode: -1, + }; + assert!(!err.is_session_expired()); + } + + #[test] + fn non_api_not_session_expired() { + let err = WeChatBotError::Auth("test".to_string()); + assert!(!err.is_session_expired()); + } + + #[test] + fn error_display() { + let err = WeChatBotError::Api { + message: "bad request".to_string(), + http_status: 400, + errcode: -1, + }; + let msg = format!("{}", err); + assert!(msg.contains("bad request")); + assert!(msg.contains("400")); + assert!(msg.contains("-1")); + } + + #[test] + fn no_context_error() { + let err = WeChatBotError::NoContext("user123".to_string()); + let msg = format!("{}", err); + assert!(msg.contains("user123")); + } +} diff --git a/vendor/wechatbot/src/lib.rs b/vendor/wechatbot/src/lib.rs new file mode 100644 index 0000000..7191509 --- /dev/null +++ b/vendor/wechatbot/src/lib.rs @@ -0,0 +1,38 @@ +//! # wechatbot +//! +//! WeChat iLink Bot SDK for Rust. +//! +//! ## Quick Start +//! +//! ```rust,no_run +//! use wechatbot::{WeChatBot, BotOptions}; +//! +//! #[tokio::main] +//! async fn main() { +//! let bot = WeChatBot::new(BotOptions::default()); +//! bot.login(false).await.unwrap(); +//! +//! bot.on_message(Box::new(|msg| { +//! println!("{}: {}", msg.user_id, msg.text); +//! })).await; +//! +//! bot.run().await.unwrap(); +//! } +//! ``` + +pub mod bot; +pub mod cdn; +pub mod crypto; +pub mod error; +pub mod protocol; +pub mod types; + +pub use bot::{BotOptions, MessageHandler, SendContent, WeChatBot}; +pub use cdn::CdnClient; +pub use crypto::{ + decode_aes_key, decrypt_aes_ecb, decrypt_aes_ecb as download_decrypt, encode_aes_key_base64, + encode_aes_key_hex, encrypt_aes_ecb, generate_aes_key, +}; +pub use error::{Result, WeChatBotError}; +pub use protocol::{build_cdn_upload_url, GetUploadUrlParams, GetUploadUrlResponse}; +pub use types::*; diff --git a/vendor/wechatbot/src/protocol.rs b/vendor/wechatbot/src/protocol.rs new file mode 100644 index 0000000..3766686 --- /dev/null +++ b/vendor/wechatbot/src/protocol.rs @@ -0,0 +1,407 @@ +//! Raw iLink Bot API HTTP calls. + +use base64::Engine; +use rand::Rng; +use reqwest::Client; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::time::Duration; +use uuid::Uuid; + +use crate::error::{Result, WeChatBotError}; +#[allow(unused_imports)] +use crate::types::*; + +pub const DEFAULT_BASE_URL: &str = "https://ilinkai.weixin.qq.com"; +pub const CDN_BASE_URL: &str = "https://novac2c.cdn.weixin.qq.com/c2c"; +pub const CHANNEL_VERSION: &str = env!("CARGO_PKG_VERSION"); + +/// iLink-App-Id header value. +const ILINK_APP_ID: &str = "bot"; + +/// Build iLink-App-ClientVersion from the crate version (0x00MMNNPP). +fn build_client_version() -> String { + let version = env!("CARGO_PKG_VERSION"); + let parts: Vec = version.split('.').filter_map(|p| p.parse().ok()).collect(); + let major = parts.first().copied().unwrap_or(0) & 0xff; + let minor = parts.get(1).copied().unwrap_or(0) & 0xff; + let patch = parts.get(2).copied().unwrap_or(0) & 0xff; + let num = (major << 16) | (minor << 8) | patch; + num.to_string() +} + +/// Generate the X-WECHAT-UIN header value. +pub fn random_wechat_uin() -> String { + let mut buf = [0u8; 4]; + rand::rng().fill_bytes(&mut buf); + let val = u32::from_be_bytes(buf); + base64::engine::general_purpose::STANDARD.encode(val.to_string()) +} + +/// QR code response. +#[derive(Debug, Deserialize)] +pub struct QrCodeResponse { + pub qrcode: String, + pub qrcode_img_content: String, +} + +/// QR status response. +#[derive(Debug, Deserialize)] +pub struct QrStatusResponse { + pub status: String, + pub bot_token: Option, + pub ilink_bot_id: Option, + pub ilink_user_id: Option, + pub baseurl: Option, + /// New host to redirect polling to when status is "scaned_but_redirect". + pub redirect_host: Option, +} + +/// Get updates response. +#[derive(Debug, Deserialize)] +pub struct GetUpdatesResponse { + #[serde(default)] + pub ret: i32, + #[serde(default)] + pub msgs: Vec, + #[serde(default)] + pub get_updates_buf: String, + pub errcode: Option, + pub errmsg: Option, +} + +/// Get config response. +#[derive(Debug, Deserialize)] +pub struct GetConfigResponse { + pub typing_ticket: Option, +} + +/// Low-level iLink API client. +#[derive(Debug)] +pub struct ILinkClient { + http: Client, +} + +impl ILinkClient { + pub fn new() -> Self { + Self { + http: Client::builder() + .timeout(Duration::from_secs(45)) + .build() + .unwrap(), + } + } + + pub async fn get_qr_code(&self, base_url: &str) -> Result { + let url = format!("{}/ilink/bot/get_bot_qrcode?bot_type=3", base_url); + let resp = self + .http + .get(&url) + .header("iLink-App-Id", ILINK_APP_ID) + .header("iLink-App-ClientVersion", build_client_version()) + .send() + .await?; + Ok(resp.json().await?) + } + + pub async fn poll_qr_status(&self, base_url: &str, qrcode: &str) -> Result { + let url = format!( + "{}/ilink/bot/get_qrcode_status?qrcode={}", + base_url, + urlencoding::encode(qrcode) + ); + let resp = self + .http + .get(&url) + .header("iLink-App-Id", ILINK_APP_ID) + .header("iLink-App-ClientVersion", build_client_version()) + .send() + .await?; + Ok(resp.json().await?) + } + + pub async fn get_updates( + &self, + base_url: &str, + token: &str, + cursor: &str, + ) -> Result { + let body = json!({ + "get_updates_buf": cursor, + "base_info": { "channel_version": CHANNEL_VERSION } + }); + let resp = self + .api_post(base_url, "/ilink/bot/getupdates", token, &body, 45) + .await?; + let result: GetUpdatesResponse = serde_json::from_value(resp)?; + if result.ret != 0 || result.errcode.is_some_and(|c| c != 0) { + let code = result.errcode.unwrap_or(result.ret); + let msg = result + .errmsg + .unwrap_or_else(|| format!("ret={}", result.ret)); + return Err(WeChatBotError::Api { + message: msg, + http_status: 200, + errcode: code, + }); + } + Ok(result) + } + + pub async fn send_message(&self, base_url: &str, token: &str, msg: &Value) -> Result<()> { + let body = json!({ + "msg": msg, + "base_info": { "channel_version": CHANNEL_VERSION } + }); + self.api_post(base_url, "/ilink/bot/sendmessage", token, &body, 15) + .await?; + Ok(()) + } + + pub async fn get_config( + &self, + base_url: &str, + token: &str, + user_id: &str, + context_token: &str, + ) -> Result { + let body = json!({ + "ilink_user_id": user_id, + "context_token": context_token, + "base_info": { "channel_version": CHANNEL_VERSION } + }); + let resp = self + .api_post(base_url, "/ilink/bot/getconfig", token, &body, 15) + .await?; + Ok(serde_json::from_value(resp)?) + } + + pub async fn send_typing( + &self, + base_url: &str, + token: &str, + user_id: &str, + ticket: &str, + status: i32, + ) -> Result<()> { + let body = json!({ + "ilink_user_id": user_id, + "typing_ticket": ticket, + "status": status, + "base_info": { "channel_version": CHANNEL_VERSION } + }); + self.api_post(base_url, "/ilink/bot/sendtyping", token, &body, 15) + .await?; + Ok(()) + } + + async fn api_post( + &self, + base_url: &str, + endpoint: &str, + token: &str, + body: &Value, + timeout_secs: u64, + ) -> Result { + let url = format!("{}{}", base_url, endpoint); + let resp = self + .http + .post(&url) + .timeout(Duration::from_secs(timeout_secs)) + .header("Content-Type", "application/json") + .header("AuthorizationType", "ilink_bot_token") + .header("Authorization", format!("Bearer {}", token)) + .header("X-WECHAT-UIN", random_wechat_uin()) + .header("iLink-App-Id", ILINK_APP_ID) + .header("iLink-App-ClientVersion", build_client_version()) + .json(body) + .send() + .await?; + + let status = resp.status().as_u16(); + let text = resp.text().await?; + let value: Value = serde_json::from_str(&text).unwrap_or(json!({})); + + if status >= 400 { + return Err(WeChatBotError::Api { + message: value["errmsg"] + .as_str() + .or_else(|| value["message"].as_str()) + .unwrap_or(&text) + .to_string(), + http_status: status, + errcode: value["errcode"].as_i64().unwrap_or(0) as i32, + }); + } + + if let Some(errcode) = value["errcode"].as_i64() { + if errcode != 0 { + return Err(WeChatBotError::Api { + message: value["errmsg"] + .as_str() + .or_else(|| value["message"].as_str()) + .unwrap_or(&text) + .to_string(), + http_status: status, + errcode: errcode as i32, + }); + } + } + + Ok(value) + } +} + +/// Build a media message payload. +pub fn build_media_message(user_id: &str, context_token: &str, item_list: Vec) -> Value { + json!({ + "from_user_id": "", + "to_user_id": user_id, + "client_id": Uuid::new_v4().to_string(), + "message_type": 2, + "message_state": 2, + "context_token": context_token, + "item_list": item_list + }) +} + +/// GetUploadUrl request parameters. +pub struct GetUploadUrlParams { + pub filekey: String, + pub media_type: i32, + pub to_user_id: String, + pub rawsize: usize, + pub rawfilemd5: String, + pub filesize: usize, + pub no_need_thumb: bool, + pub aeskey: String, +} + +/// GetUploadUrl response. +#[derive(Debug, Deserialize)] +pub struct GetUploadUrlResponse { + pub upload_param: Option, + pub thumb_upload_param: Option, + pub upload_full_url: Option, +} + +impl ILinkClient { + /// Get a pre-signed CDN upload URL. + pub async fn get_upload_url( + &self, + base_url: &str, + token: &str, + params: &GetUploadUrlParams, + ) -> Result { + let body = json!({ + "filekey": params.filekey, + "media_type": params.media_type, + "to_user_id": params.to_user_id, + "rawsize": params.rawsize, + "rawfilemd5": params.rawfilemd5, + "filesize": params.filesize, + "no_need_thumb": params.no_need_thumb, + "aeskey": params.aeskey, + "base_info": { "channel_version": CHANNEL_VERSION } + }); + let resp = self + .api_post(base_url, "/ilink/bot/getuploadurl", token, &body, 15) + .await?; + Ok(serde_json::from_value(resp)?) + } + + /// Upload encrypted bytes to CDN with retry (up to 3 attempts). + /// Returns the download encrypted_query_param from the x-encrypted-param header. + pub async fn upload_to_cdn(&self, cdn_url: &str, ciphertext: &[u8]) -> Result { + const MAX_RETRIES: u32 = 3; + let mut last_err = None; + + for attempt in 1..=MAX_RETRIES { + match self + .http + .post(cdn_url) + .header("Content-Type", "application/octet-stream") + .body(ciphertext.to_vec()) + .send() + .await + { + Ok(resp) => { + let status = resp.status().as_u16(); + if status >= 400 && status < 500 { + let err_msg = resp + .headers() + .get("x-error-message") + .and_then(|v| v.to_str().ok()) + .unwrap_or("client error") + .to_string(); + return Err(WeChatBotError::Media(format!( + "CDN upload client error {}: {}", + status, err_msg + ))); + } + if status != 200 { + let err_msg = resp + .headers() + .get("x-error-message") + .and_then(|v| v.to_str().ok()) + .unwrap_or("server error") + .to_string(); + last_err = Some(WeChatBotError::Media(format!( + "CDN upload server error {}: {}", + status, err_msg + ))); + continue; + } + match resp + .headers() + .get("x-encrypted-param") + .and_then(|v| v.to_str().ok()) + { + Some(param) => return Ok(param.to_string()), + None => { + last_err = Some(WeChatBotError::Media( + "CDN upload response missing x-encrypted-param header".into(), + )); + continue; + } + } + } + Err(e) => { + last_err = Some(WeChatBotError::Other(format!( + "CDN upload network error: {}", + e + ))); + if attempt < MAX_RETRIES { + continue; + } + } + } + } + Err(last_err.unwrap_or_else(|| { + WeChatBotError::Media(format!("CDN upload failed after {} attempts", MAX_RETRIES)) + })) + } +} + +/// Build a CDN upload URL from params. +pub fn build_cdn_upload_url(cdn_base_url: &str, upload_param: &str, filekey: &str) -> String { + format!( + "{}/upload?encrypted_query_param={}&filekey={}", + cdn_base_url, + urlencoding::encode(upload_param), + urlencoding::encode(filekey) + ) +} + +/// Build a text message payload. +pub fn build_text_message(user_id: &str, context_token: &str, text: &str) -> Value { + json!({ + "from_user_id": "", + "to_user_id": user_id, + "client_id": Uuid::new_v4().to_string(), + "message_type": 2, + "message_state": 2, + "context_token": context_token, + "item_list": [{ "type": 1, "text_item": { "text": text } }] + }) +} diff --git a/vendor/wechatbot/src/types.rs b/vendor/wechatbot/src/types.rs new file mode 100644 index 0000000..2d063da --- /dev/null +++ b/vendor/wechatbot/src/types.rs @@ -0,0 +1,858 @@ +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::time::SystemTime; + +/// Message sender type. +/// Uses serde_repr for integer (de)serialization: JSON `1` ↔ `MessageType::User`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(i32)] +pub enum MessageType { + User = 1, + Bot = 2, +} + +/// Message delivery state. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(i32)] +pub enum MessageState { + New = 0, + Generating = 1, + Finish = 2, +} + +/// Content type of a message item. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(i32)] +pub enum MessageItemType { + Text = 1, + Image = 2, + Voice = 3, + File = 4, + Video = 5, +} + +/// Media type for upload requests. +#[derive(Debug, Clone, Copy)] +#[repr(i32)] +pub enum MediaType { + Image = 1, + Video = 2, + File = 3, + Voice = 4, +} + +/// CDN media reference. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CDNMedia { + pub encrypt_query_param: String, + pub aes_key: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub encrypt_type: Option, + /// Complete download URL returned by server; when set, use directly. + #[serde(skip_serializing_if = "Option::is_none")] + pub full_url: Option, +} + +/// Text content. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TextItem { + pub text: String, +} + +/// Image content. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageItem { + #[serde(skip_serializing_if = "Option::is_none")] + pub media: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thumb_media: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub aeskey: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mid_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thumb_width: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thumb_height: Option, +} + +/// Voice content. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceItem { + #[serde(skip_serializing_if = "Option::is_none")] + pub media: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub encode_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub playtime: Option, +} + +/// File content. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileItem { + #[serde(skip_serializing_if = "Option::is_none")] + pub media: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub md5: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub len: Option, +} + +/// Video content. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoItem { + #[serde(skip_serializing_if = "Option::is_none")] + pub media: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub video_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub play_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thumb_media: Option, +} + +/// Referenced/quoted message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RefMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message_item: Option>, +} + +/// A single content item in a message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WireMessageItem { + #[serde(rename = "type")] + pub item_type: MessageItemType, + #[serde(skip_serializing_if = "Option::is_none")] + pub text_item: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_item: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub voice_item: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_item: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub video_item: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ref_msg: Option, +} + +/// Raw wire message from the iLink API. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WireMessage { + pub from_user_id: String, + pub to_user_id: String, + pub client_id: String, + pub create_time_ms: i64, + pub message_type: MessageType, + pub message_state: MessageState, + pub context_token: String, + pub item_list: Vec, +} + +/// Parsed incoming message — user-friendly. +#[derive(Debug, Clone)] +pub struct IncomingMessage { + pub user_id: String, + pub text: String, + pub content_type: ContentType, + pub timestamp: SystemTime, + pub images: Vec, + pub voices: Vec, + pub files: Vec, + pub videos: Vec, + pub quoted: Option, + pub raw: WireMessage, + pub(crate) context_token: String, +} + +impl IncomingMessage { + /// Opaque reply token bound to this message. + /// + /// Pass it back via [`WeChatBot::reply`](crate::WeChatBot::reply) (which + /// does this automatically) or when constructing a message payload with + /// [`protocol::build_text_message`](crate::protocol::build_text_message) / + /// [`protocol::build_media_message`](crate::protocol::build_media_message) + /// for use with [`ILinkClient::send_message`](crate::protocol::ILinkClient::send_message). + pub fn context_token(&self) -> &str { + &self.context_token + } + + /// Parse a raw [`WireMessage`] into a user-friendly [`IncomingMessage`]. + /// + /// Returns `None` if the wire message is not a user-originated message + /// (e.g. it was sent by the bot itself). + /// + /// This is the stable entry point for consumers who drive + /// [`ILinkClient::get_updates`](crate::protocol::ILinkClient::get_updates) + /// themselves instead of using [`WeChatBot`](crate::WeChatBot)'s + /// dispatcher. + pub fn from_wire(wire: &WireMessage) -> Option { + if wire.message_type != MessageType::User { + return None; + } + + let mut msg = IncomingMessage { + user_id: wire.from_user_id.clone(), + text: extract_text(&wire.item_list), + content_type: detect_type(&wire.item_list), + timestamp: std::time::UNIX_EPOCH + + std::time::Duration::from_millis(wire.create_time_ms as u64), + images: Vec::new(), + voices: Vec::new(), + files: Vec::new(), + videos: Vec::new(), + quoted: None, + raw: wire.clone(), + context_token: wire.context_token.clone(), + }; + + for item in &wire.item_list { + if let Some(ref img) = item.image_item { + msg.images.push(ImageContent { + media: img.media.clone(), + thumb_media: img.thumb_media.clone(), + aes_key: img.aeskey.clone(), + url: img.url.clone(), + width: img.thumb_width, + height: img.thumb_height, + }); + } + if let Some(ref voice) = item.voice_item { + msg.voices.push(VoiceContent { + media: voice.media.clone(), + text: voice.text.clone(), + duration_ms: voice.playtime, + encode_type: voice.encode_type, + }); + } + if let Some(ref file) = item.file_item { + msg.files.push(FileContent { + media: file.media.clone(), + file_name: file.file_name.clone(), + md5: file.md5.clone(), + size: file.len.as_ref().and_then(|s| s.parse().ok()), + }); + } + if let Some(ref video) = item.video_item { + msg.videos.push(VideoContent { + media: video.media.clone(), + thumb_media: video.thumb_media.clone(), + duration_ms: video.play_length, + }); + } + if let Some(ref refm) = item.ref_msg { + msg.quoted = Some(QuotedMessage { + title: refm.title.clone(), + text: refm + .message_item + .as_ref() + .and_then(|i| i.text_item.as_ref()) + .map(|t| t.text.clone()), + }); + } + } + + Some(msg) + } +} + +fn detect_type(items: &[WireMessageItem]) -> ContentType { + items + .first() + .map_or(ContentType::Text, |item| match item.item_type { + MessageItemType::Image => ContentType::Image, + MessageItemType::Voice => ContentType::Voice, + MessageItemType::File => ContentType::File, + MessageItemType::Video => ContentType::Video, + _ => ContentType::Text, + }) +} + +fn extract_text(items: &[WireMessageItem]) -> String { + items + .iter() + .filter_map(|item| match item.item_type { + MessageItemType::Text => item.text_item.as_ref().map(|t| t.text.clone()), + MessageItemType::Image => Some( + item.image_item + .as_ref() + .and_then(|i| i.url.clone()) + .unwrap_or_else(|| "[image]".to_string()), + ), + MessageItemType::Voice => Some( + item.voice_item + .as_ref() + .and_then(|v| v.text.clone()) + .unwrap_or_else(|| "[voice]".to_string()), + ), + MessageItemType::File => Some( + item.file_item + .as_ref() + .and_then(|f| f.file_name.clone()) + .unwrap_or_else(|| "[file]".to_string()), + ), + MessageItemType::Video => Some("[video]".to_string()), + }) + .collect::>() + .join("\n") +} + +/// Content type of an incoming message. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ContentType { + Text, + Image, + Voice, + File, + Video, +} + +#[derive(Debug, Clone)] +pub struct ImageContent { + pub media: Option, + pub thumb_media: Option, + pub aes_key: Option, + pub url: Option, + pub width: Option, + pub height: Option, +} + +#[derive(Debug, Clone)] +pub struct VoiceContent { + pub media: Option, + pub text: Option, + pub duration_ms: Option, + pub encode_type: Option, +} + +#[derive(Debug, Clone)] +pub struct FileContent { + pub media: Option, + pub file_name: Option, + pub md5: Option, + pub size: Option, +} + +#[derive(Debug, Clone)] +pub struct VideoContent { + pub media: Option, + pub thumb_media: Option, + pub duration_ms: Option, +} + +#[derive(Debug, Clone)] +pub struct QuotedMessage { + pub title: Option, + pub text: Option, +} + +/// Result of downloading media from a message. +#[derive(Debug, Clone)] +pub struct DownloadedMedia { + pub data: Vec, + /// "image", "file", "video", "voice" + pub media_type: String, + pub file_name: Option, + pub format: Option, +} + +/// Result of uploading media to CDN. +#[derive(Debug, Clone)] +pub struct UploadResult { + pub media: CDNMedia, + pub aes_key: [u8; 16], + pub encrypted_file_size: usize, +} + +/// Stored login credentials. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Credentials { + pub token: String, + #[serde(rename = "baseUrl")] + pub base_url: String, + #[serde(rename = "accountId")] + pub account_id: String, + #[serde(rename = "userId")] + pub user_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub saved_at: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn message_type_values() { + assert_eq!(MessageType::User as i32, 1); + assert_eq!(MessageType::Bot as i32, 2); + } + + #[test] + fn message_state_values() { + assert_eq!(MessageState::New as i32, 0); + assert_eq!(MessageState::Generating as i32, 1); + assert_eq!(MessageState::Finish as i32, 2); + } + + #[test] + fn message_item_type_values() { + assert_eq!(MessageItemType::Text as i32, 1); + assert_eq!(MessageItemType::Image as i32, 2); + assert_eq!(MessageItemType::Voice as i32, 3); + assert_eq!(MessageItemType::File as i32, 4); + assert_eq!(MessageItemType::Video as i32, 5); + } + + #[test] + fn wire_message_json_round_trip() { + let wire = WireMessage { + from_user_id: "user1".to_string(), + to_user_id: "bot1".to_string(), + client_id: "c1".to_string(), + create_time_ms: 1700000000000, + message_type: MessageType::User, + message_state: MessageState::Finish, + context_token: "ctx".to_string(), + item_list: vec![WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "hello".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }], + }; + let json = serde_json::to_string(&wire).unwrap(); + let decoded: WireMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(decoded.from_user_id, "user1"); + assert_eq!(decoded.message_type, MessageType::User); + assert_eq!(decoded.item_list.len(), 1); + assert_eq!( + decoded.item_list[0].text_item.as_ref().unwrap().text, + "hello" + ); + } + + #[test] + fn credentials_json_camel_case() { + let creds = Credentials { + token: "tok".to_string(), + base_url: "https://api.example.com".to_string(), + account_id: "acc1".to_string(), + user_id: "uid1".to_string(), + saved_at: Some("2024-01-01T00:00:00Z".to_string()), + }; + let json = serde_json::to_string(&creds).unwrap(); + assert!(json.contains("\"baseUrl\""), "expected camelCase baseUrl"); + assert!( + json.contains("\"accountId\""), + "expected camelCase accountId" + ); + assert!(json.contains("\"userId\""), "expected camelCase userId"); + + let decoded: Credentials = serde_json::from_str(&json).unwrap(); + assert_eq!(decoded.token, "tok"); + assert_eq!(decoded.base_url, "https://api.example.com"); + } + + #[test] + fn credentials_omits_none_saved_at() { + let creds = Credentials { + token: "tok".to_string(), + base_url: "https://api.example.com".to_string(), + account_id: "acc1".to_string(), + user_id: "uid1".to_string(), + saved_at: None, + }; + let json = serde_json::to_string(&creds).unwrap(); + assert!(!json.contains("saved_at"), "should omit None saved_at"); + } + + #[test] + fn cdn_media_json() { + let media = CDNMedia { + encrypt_query_param: "param=abc".to_string(), + aes_key: "key123".to_string(), + encrypt_type: Some(1), + full_url: None, + }; + let json = serde_json::to_string(&media).unwrap(); + let decoded: CDNMedia = serde_json::from_str(&json).unwrap(); + assert_eq!(decoded.encrypt_query_param, "param=abc"); + assert_eq!(decoded.aes_key, "key123"); + assert_eq!(decoded.encrypt_type, Some(1)); + } + + #[test] + fn wire_message_with_image() { + let wire = WireMessage { + from_user_id: "user1".to_string(), + to_user_id: "bot1".to_string(), + client_id: "c1".to_string(), + create_time_ms: 1700000000000, + message_type: MessageType::User, + message_state: MessageState::Finish, + context_token: "ctx".to_string(), + item_list: vec![WireMessageItem { + item_type: MessageItemType::Image, + text_item: None, + image_item: Some(ImageItem { + media: None, + thumb_media: None, + aeskey: Some("key".to_string()), + url: Some("http://img.jpg".to_string()), + mid_size: Some(1024), + thumb_width: Some(100), + thumb_height: Some(200), + }), + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }], + }; + let json = serde_json::to_string(&wire).unwrap(); + let decoded: WireMessage = serde_json::from_str(&json).unwrap(); + let img = decoded.item_list[0].image_item.as_ref().unwrap(); + assert_eq!(img.url, Some("http://img.jpg".to_string())); + assert_eq!(img.thumb_width, Some(100)); + } + + #[test] + fn content_type_equality() { + assert_eq!(ContentType::Text, ContentType::Text); + assert_ne!(ContentType::Text, ContentType::Image); + } + + #[test] + fn detect_type_text() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "hi".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }]; + assert_eq!(detect_type(&items), ContentType::Text); + } + + #[test] + fn detect_type_image() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Image, + text_item: None, + image_item: Some(ImageItem { + media: None, + thumb_media: None, + aeskey: None, + url: Some("http://img".to_string()), + mid_size: None, + thumb_width: None, + thumb_height: None, + }), + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }]; + assert_eq!(detect_type(&items), ContentType::Image); + } + + #[test] + fn detect_type_empty() { + assert_eq!(detect_type(&[]), ContentType::Text); + } + + #[test] + fn extract_text_single() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "hello world".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }]; + assert_eq!(extract_text(&items), "hello world"); + } + + #[test] + fn extract_text_multi() { + let items = vec![ + WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "line1".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }, + WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "line2".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }, + ]; + assert_eq!(extract_text(&items), "line1\nline2"); + } + + #[test] + fn extract_text_image_url() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Image, + text_item: None, + image_item: Some(ImageItem { + media: None, + thumb_media: None, + aeskey: None, + url: Some("http://img.jpg".to_string()), + mid_size: None, + thumb_width: None, + thumb_height: None, + }), + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }]; + assert_eq!(extract_text(&items), "http://img.jpg"); + } + + #[test] + fn extract_text_image_placeholder() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Image, + text_item: None, + image_item: Some(ImageItem { + media: None, + thumb_media: None, + aeskey: None, + url: None, + mid_size: None, + thumb_width: None, + thumb_height: None, + }), + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }]; + assert_eq!(extract_text(&items), "[image]"); + } + + #[test] + fn extract_text_voice_with_text() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Voice, + text_item: None, + image_item: None, + voice_item: Some(VoiceItem { + media: None, + encode_type: None, + text: Some("hello".to_string()), + playtime: None, + }), + file_item: None, + video_item: None, + ref_msg: None, + }]; + assert_eq!(extract_text(&items), "hello"); + } + + #[test] + fn extract_text_file_name() { + let items = vec![WireMessageItem { + item_type: MessageItemType::File, + text_item: None, + image_item: None, + voice_item: None, + file_item: Some(FileItem { + media: None, + file_name: Some("doc.pdf".to_string()), + md5: None, + len: None, + }), + video_item: None, + ref_msg: None, + }]; + assert_eq!(extract_text(&items), "doc.pdf"); + } + + #[test] + fn extract_text_video() { + let items = vec![WireMessageItem { + item_type: MessageItemType::Video, + text_item: None, + image_item: None, + voice_item: None, + file_item: None, + video_item: Some(VideoItem { + media: None, + video_size: None, + play_length: None, + thumb_media: None, + }), + ref_msg: None, + }]; + assert_eq!(extract_text(&items), "[video]"); + } + + #[test] + fn from_wire_user_text() { + let wire = WireMessage { + from_user_id: "user123".to_string(), + to_user_id: "bot456".to_string(), + client_id: "c1".to_string(), + create_time_ms: 1700000000000, + message_type: MessageType::User, + message_state: MessageState::Finish, + context_token: "ctx-abc".to_string(), + item_list: vec![WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "hello".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }], + }; + let msg = IncomingMessage::from_wire(&wire).unwrap(); + assert_eq!(msg.user_id, "user123"); + assert_eq!(msg.text, "hello"); + assert_eq!(msg.content_type, ContentType::Text); + assert_eq!(msg.context_token(), "ctx-abc"); + } + + #[test] + fn from_wire_skips_bot() { + let wire = WireMessage { + from_user_id: "bot456".to_string(), + to_user_id: "user123".to_string(), + client_id: "c1".to_string(), + create_time_ms: 1700000000000, + message_type: MessageType::Bot, + message_state: MessageState::Finish, + context_token: "ctx".to_string(), + item_list: vec![WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "reply".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }], + }; + assert!(IncomingMessage::from_wire(&wire).is_none()); + } + + #[test] + fn from_wire_with_image() { + let wire = WireMessage { + from_user_id: "user123".to_string(), + to_user_id: "bot456".to_string(), + client_id: "c1".to_string(), + create_time_ms: 1700000000000, + message_type: MessageType::User, + message_state: MessageState::Finish, + context_token: "ctx".to_string(), + item_list: vec![WireMessageItem { + item_type: MessageItemType::Image, + text_item: None, + image_item: Some(ImageItem { + media: None, + thumb_media: None, + aeskey: Some("key".to_string()), + url: Some("http://img.jpg".to_string()), + mid_size: None, + thumb_width: Some(100), + thumb_height: Some(200), + }), + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + }], + }; + let msg = IncomingMessage::from_wire(&wire).unwrap(); + assert_eq!(msg.images.len(), 1); + assert_eq!(msg.images[0].url, Some("http://img.jpg".to_string())); + assert_eq!(msg.images[0].width, Some(100)); + assert_eq!(msg.images[0].height, Some(200)); + } + + #[test] + fn from_wire_with_quoted() { + let wire = WireMessage { + from_user_id: "user123".to_string(), + to_user_id: "bot456".to_string(), + client_id: "c1".to_string(), + create_time_ms: 1700000000000, + message_type: MessageType::User, + message_state: MessageState::Finish, + context_token: "ctx".to_string(), + item_list: vec![WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "replying".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: Some(RefMessage { + title: Some("Original".to_string()), + message_item: Some(Box::new(WireMessageItem { + item_type: MessageItemType::Text, + text_item: Some(TextItem { + text: "original text".to_string(), + }), + image_item: None, + voice_item: None, + file_item: None, + video_item: None, + ref_msg: None, + })), + }), + }], + }; + let msg = IncomingMessage::from_wire(&wire).unwrap(); + let quoted = msg.quoted.as_ref().unwrap(); + assert_eq!(quoted.title, Some("Original".to_string())); + assert_eq!(quoted.text, Some("original text".to_string())); + } +}