Compare commits

..

No commits in common. "69364e484b7782ff4da1571ce021ab07f6c14deb" and "3f5ed6e4e46bd1d776a46542f07211ce2560fb12" have entirely different histories.

37 changed files with 139 additions and 4412 deletions

1
.gitignore vendored
View File

@ -8,4 +8,3 @@ Cargo.lock
.playwright-cli/
.venv
PicoBot.code-workspace
.picobot

View File

@ -34,5 +34,3 @@ image = { version = "0.25", default-features = false, features = ["jpeg", "png",
tempfile = "3"
meval = "0.2"
rusqlite = { version = "0.32", features = ["bundled"] }
rustls = { version = "0.23", features = ["ring"] }
wechatbot = { path = "vendor/wechatbot" }

View File

@ -378,9 +378,9 @@ PicoBot 支持基于文件系统的技能系统,用来给 Agent 注入某一
当前技能运行时按从低到高优先级合并多个来源,后加载来源可覆盖同名技能:
- 用户级技能:~/.picobot/skills/*/SKILL.md
- 用户 Agent 级技能:~/.agents/skills/*/SKILL.md
- 用户 Agent 级技能:~/.picobot/agent/skills/*/SKILL.md
- 项目级技能:.picobot/skills/*/SKILL.md
- 项目 Agent 级技能:.agents/skills/*/SKILL.md
- 项目 Agent 级技能:.picobot/agent/skills/*/SKILL.md
### 7.2 最小 SKILL.md 格式
@ -410,7 +410,7 @@ description: 用于总结 Rust 项目架构
内置工具:
- skill_list只读列出技能
- skill_manage运行时创建、更新、删除、批量禁用、读取和重载技能
- skill_manage运行时创建、更新、删除、读取和重载技能
skill_manage 支持的 action
@ -419,40 +419,8 @@ skill_manage 支持的 action
- create
- update
- delete
- disable
- reload
skill 的启用/禁用状态不会写入 config.json而是写入独立状态文件
- 用户级状态:~/.picobot/skill-state.json
- 项目级状态:.picobot/skill-state.json
状态文件当前使用最小 JSON 结构:
```json
{
"disabled_skills": ["example-skill"]
}
```
说明:
- disable 默认写入项目级状态文件,可通过 tool 参数中的 scope 指定 user 或 project
- disable 只接受 names 数组;即使只禁用 1 个 skill也需要传单元素数组
- 一次 disable 调用会批量处理 names 里的所有 skill并只做一次 reload
- 用户级与项目级状态同时生效,项目运行时会同时读取两者
- 某个 skill 只要在任一层状态文件中被禁用,就不会出现在 skill_list、skill_activate 和技能索引提示里
批量示例:
```json
{
"action": "disable",
"scope": "project",
"names": ["lark-calendar", "lark-vc", "lark-minutes"]
}
```
skills 配置示例:
```json

View File

@ -1,3 +0,0 @@
pub fn initialize_process_runtime() {
let _ = rustls::crypto::ring::default_provider().install_default();
}

View File

@ -538,8 +538,6 @@ mod tests {
use serde_json::json;
use std::collections::HashMap;
const TEST_CHANNEL: &str = "test-channel";
#[test]
fn test_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
@ -558,13 +556,8 @@ mod tests {
],
);
let outbound = OutboundMessage::from_chat_message(
TEST_CHANNEL,
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 2);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
@ -595,13 +588,8 @@ mod tests {
}],
);
let outbound = OutboundMessage::from_chat_message(
TEST_CHANNEL,
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 2);
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
@ -614,13 +602,8 @@ mod tests {
fn test_from_chat_message_includes_tool_result() {
let message = ChatMessage::tool("call-9", "calculator", "2");
let outbound = OutboundMessage::from_chat_message(
TEST_CHANNEL,
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
@ -635,13 +618,8 @@ mod tests {
ToolMessageState::PendingUserAction,
);
let outbound = OutboundMessage::from_chat_message(
TEST_CHANNEL,
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);

View File

@ -149,7 +149,6 @@ struct CachedTenantToken {
#[derive(Clone)]
pub struct FeishuChannel {
name: String,
config: FeishuChannelConfig,
http_client: reqwest::Client,
running: Arc<RwLock<bool>>,
@ -175,12 +174,10 @@ struct ParsedMessage {
impl FeishuChannel {
pub fn new(
name: String,
config: FeishuChannelConfig,
_provider_config: LLMProviderConfig,
) -> Result<Self, ChannelError> {
Ok(Self {
name,
config,
http_client: reqwest::Client::new(),
running: Arc::new(RwLock::new(false)),
@ -1254,7 +1251,7 @@ impl FeishuChannel {
#[cfg(debug_assertions)]
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
let msg = crate::bus::InboundMessage {
channel: channel.name().to_string(),
channel: "feishu".to_string(),
sender_id: parsed.open_id.clone(),
chat_id: parsed.chat_id.clone(),
content: parsed.content.clone(),
@ -2284,7 +2281,7 @@ mod tests {
#[async_trait]
impl Channel for FeishuChannel {
fn name(&self) -> &str {
&self.name
"feishu"
}
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {

View File

@ -6,8 +6,7 @@ use crate::bus::MessageBus;
use crate::channels::base::{Channel, ChannelError};
use crate::channels::cli::CliChannel;
use crate::channels::feishu::FeishuChannel;
use crate::channels::wechat::WechatChannel;
use crate::config::{Config, TaggedChannelConfig};
use crate::config::Config;
/// ChannelManager manages all Channel instances and the MessageBus
#[derive(Clone)]
@ -43,57 +42,23 @@ impl ChannelManager {
pub async fn init(
&self,
config: &Config,
provider_config: crate::config::LLMProviderConfig,
_provider_config: crate::config::LLMProviderConfig,
) -> Result<(), ChannelError> {
for (name, channel_config) in &config.channels {
match channel_config {
crate::config::ChannelConfig::Tagged(TaggedChannelConfig::Feishu(feishu_config))
| crate::config::ChannelConfig::LegacyFeishu(feishu_config) => {
if feishu_config.enabled {
let channel = FeishuChannel::new(
name.clone(),
feishu_config.clone(),
provider_config.clone(),
)
.map_err(|e| {
ChannelError::Other(format!(
"Failed to create Feishu channel '{}': {}",
name, e
))
})?;
// Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled {
let channel =
FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| {
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
})?;
self.channels
.write()
.await
.insert(name.clone(), Arc::new(channel));
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel registered");
} else {
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel disabled in config");
}
}
crate::config::ChannelConfig::Tagged(TaggedChannelConfig::Wechat(wechat_config)) => {
if wechat_config.enabled {
let channel = WechatChannel::new(
name.clone(),
wechat_config.clone(),
provider_config.clone(),
)
.map_err(|e| {
ChannelError::Other(format!(
"Failed to create WeChat channel '{}': {}",
name, e
))
})?;
self.channels
.write()
.await
.insert(name.clone(), Arc::new(channel));
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel registered");
} else {
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel disabled in config");
}
}
self.channels
.write()
.await
.insert("feishu".to_string(), Arc::new(channel));
tracing::info!("Feishu channel registered");
} else {
tracing::info!("Feishu channel disabled in config");
}
}
Ok(())
@ -136,128 +101,3 @@ impl ChannelManager {
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn write_test_config() -> tempfile::NamedTempFile {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"primary": {
"type": "feishu",
"enabled": true,
"app_id": "app-id-1",
"app_secret": "secret-1"
},
"backup": {
"type": "feishu",
"enabled": true,
"app_id": "app-id-2",
"app_secret": "secret-2"
}
}
}"#,
)
.unwrap();
file
}
#[tokio::test]
async fn init_registers_all_configured_channels_by_instance_name() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
let manager = ChannelManager::new();
manager.init(&config, provider_config).await.unwrap();
let mut names = manager
.channels()
.await
.into_iter()
.map(|(name, _)| name)
.collect::<Vec<_>>();
names.sort();
assert_eq!(names, vec!["backup", "cli", "primary"]);
assert_eq!(manager.get_channel("primary").await.unwrap().name(), "primary");
assert_eq!(manager.get_channel("backup").await.unwrap().name(), "backup");
}
#[tokio::test]
async fn init_registers_wechat_channel_by_instance_name() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"wechat_main": {
"type": "wechat",
"enabled": true,
"cred_path": "/tmp/wechat-creds.json"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
let manager = ChannelManager::new();
manager.init(&config, provider_config).await.unwrap();
let mut names = manager
.channels()
.await
.into_iter()
.map(|(name, _)| name)
.collect::<Vec<_>>();
names.sort();
assert_eq!(names, vec!["cli", "wechat_main"]);
assert_eq!(manager.get_channel("wechat_main").await.unwrap().name(), "wechat_main");
}
}

View File

@ -2,10 +2,8 @@ pub mod base;
pub mod cli;
pub mod feishu;
pub mod manager;
pub mod wechat;
pub use base::{Channel, ChannelError};
pub use cli::CliChannel;
pub use feishu::FeishuChannel;
pub use manager::ChannelManager;
pub use wechat::WechatChannel;

View File

@ -1,383 +0,0 @@
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use std::time::UNIX_EPOCH;
use async_trait::async_trait;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use wechatbot::{BotOptions, SendContent, WeChatBot};
use crate::bus::{InboundMessage, MediaItem, MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::config::{LLMProviderConfig, WechatChannelConfig};
#[derive(Clone)]
pub struct WechatChannel {
name: String,
config: WechatChannelConfig,
bot: Arc<WeChatBot>,
running: Arc<AtomicBool>,
task: Arc<RwLock<Option<JoinHandle<()>>>>,
}
impl WechatChannel {
pub fn new(
name: String,
config: WechatChannelConfig,
_provider_config: LLMProviderConfig,
) -> Result<Self, ChannelError> {
let channel_name = name.clone();
let bot = WeChatBot::new(BotOptions {
base_url: Some(config.base_url.clone()),
cred_path: Some(config.cred_path.clone()),
on_qr_url: Some(Box::new(move |url| {
tracing::info!(channel = %channel_name, qr_url = %url, "WeChat QR code ready");
})),
on_error: Some(Box::new(move |error| {
tracing::error!(error = %error, "WeChat SDK error");
})),
});
Ok(Self {
name,
config,
bot: Arc::new(bot),
running: Arc::new(AtomicBool::new(false)),
task: Arc::new(RwLock::new(None)),
})
}
fn sender_allowed(&self, sender_id: &str) -> bool {
self.config.allow_from.iter().any(|pattern| pattern == "*" || pattern == sender_id)
}
fn media_to_send_content(
media: &MediaItem,
caption: Option<String>,
) -> Result<SendContent, ChannelError> {
let data = std::fs::read(&media.path).map_err(|error| {
ChannelError::SendError(format!(
"WeChat media read failed for '{}': {}",
media.path, error
))
})?;
if data.is_empty() {
return Err(ChannelError::SendError(format!(
"WeChat media file is empty: {}",
media.path
)));
}
let file_name = Path::new(&media.path)
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("attachment.bin")
.to_string();
match media.media_type.as_str() {
"image" => Ok(SendContent::Image { data, caption }),
"video" => Ok(SendContent::Video { data, caption }),
_ => Ok(SendContent::File {
data,
file_name,
caption,
}),
}
}
fn default_media_dir() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("media").join("wechat")
}
fn build_download_filename(
media_type: &str,
file_name: Option<&str>,
format: Option<&str>,
) -> String {
if let Some(file_name) = file_name {
let sanitized: String = file_name
.chars()
.map(|ch| match ch {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
_ => ch,
})
.collect();
if !sanitized.trim().is_empty() {
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
}
}
let ext = match (media_type, format) {
("image", _) => "jpg",
("video", _) => "mp4",
("voice", Some(fmt)) if !fmt.trim().is_empty() => fmt,
("voice", _) => "bin",
_ => "bin",
};
format!("{}_{}.{}", media_type, uuid::Uuid::new_v4(), ext)
}
async fn download_inbound_media(
bot: Arc<WeChatBot>,
msg: wechatbot::IncomingMessage,
) -> Result<Vec<MediaItem>, ChannelError> {
let Some(downloaded) = bot.download(&msg).await.map_err(|error| {
ChannelError::Other(format!("WeChat media download failed: {}", error))
})? else {
return Ok(Vec::new());
};
let media_dir = Self::default_media_dir();
tokio::fs::create_dir_all(&media_dir)
.await
.map_err(|error| ChannelError::Other(format!("Failed to create WeChat media dir: {}", error)))?;
let filename = Self::build_download_filename(
&downloaded.media_type,
downloaded.file_name.as_deref(),
downloaded.format.as_deref(),
);
let file_path = media_dir.join(&filename);
tokio::fs::write(&file_path, downloaded.data)
.await
.map_err(|error| ChannelError::Other(format!("Failed to write WeChat media file: {}", error)))?;
tracing::info!(filename = %filename, media_type = %downloaded.media_type, "Downloaded WeChat media");
let mut media_item = MediaItem::new(
file_path.to_string_lossy().to_string(),
downloaded.media_type,
);
media_item.mime_type = mime_guess::from_path(&file_path)
.first_raw()
.map(ToOwned::to_owned);
Ok(vec![media_item])
}
}
#[async_trait]
impl Channel for WechatChannel {
fn name(&self) -> &str {
&self.name
}
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
if self.running.swap(true, Ordering::SeqCst) {
return Ok(());
}
let channel_name = self.name.clone();
let allow_from = self.config.allow_from.clone();
let bus_for_handler = bus.clone();
let bot_for_handler = self.bot.clone();
self.bot
.on_message(Box::new(move |msg| {
let sender_id = msg.user_id.clone();
let allowed = allow_from
.iter()
.any(|pattern| pattern == "*" || pattern == &sender_id);
if !allowed {
tracing::warn!(channel = %channel_name, sender = %sender_id, "Access denied");
return;
}
let msg = msg.clone();
let timestamp = msg
.timestamp
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
let bus = bus_for_handler.clone();
let bot = bot_for_handler.clone();
let channel_name_for_publish = channel_name.clone();
tokio::spawn(async move {
let media = match Self::download_inbound_media(bot, msg.clone()).await {
Ok(media) => media,
Err(error) => {
tracing::error!(error = %error, "Failed to download WeChat inbound media");
Vec::new()
}
};
let mut metadata = HashMap::new();
metadata.insert("context_token".to_string(), msg.context_token().to_string());
let inbound = InboundMessage {
channel: channel_name_for_publish,
sender_id: sender_id.clone(),
chat_id: sender_id,
content: msg.text.clone(),
timestamp,
media,
metadata,
forwarded_metadata: HashMap::new(),
};
if let Err(error) = bus.publish_inbound(inbound).await {
tracing::error!(error = %error, "Failed to publish WeChat inbound message");
}
});
}))
.await;
let bot = self.bot.clone();
let channel_name = self.name.clone();
let force_login = self.config.force_login;
let running = self.running.clone();
let handle = tokio::spawn(async move {
match bot.login(force_login).await {
Ok(creds) => {
tracing::info!(
channel = %channel_name,
account_id = %creds.account_id,
user_id = %creds.user_id,
"WeChat login succeeded"
);
}
Err(error) => {
running.store(false, Ordering::SeqCst);
tracing::error!(channel = %channel_name, error = %error, "WeChat login failed");
return;
}
}
if let Err(error) = bot.run().await {
tracing::error!(channel = %channel_name, error = %error, "WeChat channel stopped with error");
}
running.store(false, Ordering::SeqCst);
});
*self.task.write().await = Some(handle);
tracing::info!(channel = %self.name, "WeChat channel started");
Ok(())
}
async fn stop(&self) -> Result<(), ChannelError> {
self.running.store(false, Ordering::SeqCst);
self.bot.stop().await;
if let Some(handle) = self.task.write().await.take() {
handle.abort();
}
tracing::info!(channel = %self.name, "WeChat channel stopped");
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let text = msg.content.trim().to_string();
let mut text_sent = false;
if !text.is_empty() {
self.bot.send(&msg.chat_id, &text).await.map_err(|error| {
ChannelError::SendError(format!("WeChat text send failed: {}", error))
})?;
tracing::info!(
channel = %self.name,
chat_id = %msg.chat_id,
content_len = text.len(),
"WeChat text message sent"
);
text_sent = true;
}
for (index, media) in msg.media.iter().enumerate() {
let caption = if !text.is_empty() && !text_sent && index == 0 {
Some(text.clone())
} else {
None
};
let content = Self::media_to_send_content(media, caption)?;
self.bot.send_media(&msg.chat_id, content).await.map_err(|error| {
ChannelError::SendError(format!("WeChat media send failed: {}", error))
})?;
tracing::info!(
channel = %self.name,
chat_id = %msg.chat_id,
media_type = %media.media_type,
media_path = %media.path,
"WeChat media message sent"
);
}
if text.is_empty() && msg.media.is_empty() {
return Ok(());
}
Ok(())
}
fn is_allowed(&self, sender_id: &str) -> bool {
self.sender_allowed(sender_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn build_download_filename_preserves_file_name() {
let filename = WechatChannel::build_download_filename("file", Some("README.md"), None);
assert!(filename.ends_with("_README.md"));
}
#[test]
fn build_download_filename_adds_voice_extension_when_missing_name() {
let filename = WechatChannel::build_download_filename("voice", None, Some("silk"));
assert!(filename.starts_with("voice_"));
assert!(filename.ends_with(".silk"));
}
#[test]
fn media_to_send_content_maps_image() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), b"demo-image").unwrap();
let image_path = file.path().with_extension("png");
std::fs::rename(file.path(), &image_path).unwrap();
let media = MediaItem::new(image_path.to_string_lossy().to_string(), "image");
let content = WechatChannel::media_to_send_content(&media, None).unwrap();
assert!(matches!(content, SendContent::Image { .. }));
}
#[test]
fn media_to_send_content_maps_generic_file() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), b"hello").unwrap();
let doc_path = file.path().with_extension("md");
std::fs::rename(file.path(), &doc_path).unwrap();
let media = MediaItem::new(doc_path.to_string_lossy().to_string(), "file");
let content = WechatChannel::media_to_send_content(&media, Some("note".to_string())).unwrap();
match content {
SendContent::File {
file_name,
caption,
..
} => {
assert_eq!(file_name, doc_path.file_name().unwrap().to_string_lossy());
assert_eq!(caption.as_deref(), Some("note"));
}
_ => panic!("expected file send content"),
}
}
}

View File

@ -22,7 +22,7 @@ pub struct Config {
#[serde(default)]
pub client: ClientConfig,
#[serde(default)]
pub channels: HashMap<String, ChannelConfig>,
pub channels: HashMap<String, FeishuChannelConfig>,
#[serde(default)]
pub skills: SkillsConfig,
}
@ -96,54 +96,6 @@ impl Default for SkillsConfig {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChannelConfig {
Tagged(TaggedChannelConfig),
LegacyFeishu(FeishuChannelConfig),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TaggedChannelConfig {
Feishu(FeishuChannelConfig),
Wechat(WechatChannelConfig),
}
impl ChannelConfig {
pub fn kind(&self) -> &'static str {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => "feishu",
Self::Tagged(TaggedChannelConfig::Wechat(_)) => "wechat",
}
}
pub fn enabled(&self) -> bool {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
config.enabled
}
Self::Tagged(TaggedChannelConfig::Wechat(config)) => config.enabled,
}
}
pub fn as_feishu(&self) -> Option<&FeishuChannelConfig> {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
Some(config)
}
Self::Tagged(TaggedChannelConfig::Wechat(_)) => None,
}
}
pub fn as_wechat(&self) -> Option<&WechatChannelConfig> {
match self {
Self::Tagged(TaggedChannelConfig::Wechat(config)) => Some(config),
Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FeishuChannelConfig {
#[serde(default)]
@ -165,22 +117,6 @@ pub struct FeishuChannelConfig {
pub reply_context_max_chars: usize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WechatChannelConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_allow_from")]
pub allow_from: Vec<String>,
#[serde(default)]
pub agent: String,
#[serde(default = "default_wechat_base_url")]
pub base_url: String,
#[serde(default = "default_wechat_cred_path")]
pub cred_path: String,
#[serde(default)]
pub force_login: bool,
}
fn default_allow_from() -> Vec<String> {
vec!["*".to_string()]
}
@ -192,17 +128,6 @@ fn default_media_dir() -> String {
.to_string()
}
fn default_wechat_base_url() -> String {
"https://ilinkai.weixin.qq.com".to_string()
}
fn default_wechat_cred_path() -> String {
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
home.join(".picobot/wechat/credentials.json")
.to_string_lossy()
.to_string()
}
fn default_reaction_emoji() -> String {
"Typing".to_string()
}
@ -1246,105 +1171,11 @@ mod tests {
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let feishu = config.channels["feishu"].as_feishu().unwrap();
let feishu = &config.channels["feishu"];
assert_eq!(feishu.max_message_chars, 20_000);
assert_eq!(feishu.reply_context_max_chars, 20_000);
}
#[test]
fn test_tagged_feishu_channel_config_loads() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"primary": {
"type": "feishu",
"enabled": true,
"app_id": "app-id",
"app_secret": "secret"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let feishu = config.channels["primary"].as_feishu().unwrap();
assert_eq!(config.channels["primary"].kind(), "feishu");
assert!(config.channels["primary"].enabled());
assert_eq!(feishu.app_id, "app-id");
}
#[test]
fn test_tagged_wechat_channel_config_loads() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"wechat_main": {
"type": "wechat",
"enabled": true,
"base_url": "https://ilinkai.weixin.qq.com",
"cred_path": "/tmp/wechat-creds.json",
"force_login": true,
"allow_from": ["wxid_1"]
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let wechat = config.channels["wechat_main"].as_wechat().unwrap();
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
assert!(config.channels["wechat_main"].enabled());
assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json");
assert!(wechat.force_login);
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
}
#[test]
fn test_feishu_channel_config_loads_custom_truncation_limits() {
let file = tempfile::NamedTempFile::new().unwrap();
@ -1384,7 +1215,7 @@ mod tests {
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let feishu = config.channels["feishu"].as_feishu().unwrap();
let feishu = &config.channels["feishu"];
assert_eq!(feishu.max_message_chars, 3456);
assert_eq!(feishu.reply_context_max_chars, 4567);
}

View File

@ -40,8 +40,6 @@ mod tests {
use std::sync::Arc;
use tokio::sync::mpsc;
const TEST_CHANNEL: &str = "test-channel";
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
@ -82,7 +80,7 @@ mod tests {
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
TEST_CHANNEL.to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
@ -132,7 +130,7 @@ mod tests {
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
TEST_CHANNEL.to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,

View File

@ -542,7 +542,7 @@ mod tests {
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
@ -587,7 +587,7 @@ mod tests {
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
@ -791,7 +791,7 @@ mod tests {
.unwrap();
let outbound = session_manager
.handle_message("test-channel", "user-1", "chat-1", "hello", Vec::new(), None)
.handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None)
.await
.unwrap();
@ -840,7 +840,7 @@ mod tests {
let planner_outbound = session_manager
.run_scheduled_agent_task(
"test-channel",
"feishu",
"chat-planner",
"请规划今天工作",
ScheduledAgentTaskOptions {
@ -856,7 +856,7 @@ mod tests {
let default_outbound = session_manager
.run_scheduled_agent_task(
"test-channel",
"feishu",
"chat-default",
"请规划今天工作",
ScheduledAgentTaskOptions {
@ -904,7 +904,7 @@ mod tests {
session_manager
.run_scheduled_agent_task(
"test-channel",
"feishu",
"chat-guard",
"每小时执行以下流程:检查邮箱并同步待办",
ScheduledAgentTaskOptions {
@ -916,7 +916,7 @@ mod tests {
.await
.unwrap();
let session = session_manager.get("test-channel").await.unwrap();
let session = session_manager.get("feishu").await.unwrap();
let session_guard = session.lock().await;
let persisted_messages = session_guard
.store()
@ -1477,13 +1477,7 @@ mod tests {
async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() {
let bus = MessageBus::new(4);
let emitter =
BusToolCallEmitter::new(
bus.clone(),
"test-channel",
"chat-1",
HashMap::new(),
false,
);
BusToolCallEmitter::new(bus.clone(), "feishu", "chat-1", HashMap::new(), false);
emitter
.handle(ChatMessage::tool("call-1", "calculator", "2"))
@ -1514,7 +1508,7 @@ mod tests {
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
@ -1552,7 +1546,7 @@ mod tests {
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
@ -1618,7 +1612,7 @@ mod tests {
.build(),
);
let mut session = Session::new(
"test-channel".to_string(),
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,

View File

@ -42,7 +42,6 @@ impl SessionMessageSender for BusSessionMessageSender {
.is_some();
if let Some(text) = request.text.filter(|value| !value.trim().is_empty()) {
let content_len = text.len();
self.bus
.publish_outbound(OutboundMessage::assistant(
channel_name.to_string(),
@ -53,18 +52,10 @@ impl SessionMessageSender for BusSessionMessageSender {
))
.await?;
published_messages += 1;
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content_len,
"Published session text message to outbound bus"
);
}
let attachment_count = request.attachments.len();
for attachment in request.attachments {
let media_path = attachment.path.clone();
let media_type = attachment.media_type.clone();
let mut outbound = OutboundMessage::assistant(
channel_name.to_string(),
chat_id.to_string(),
@ -75,13 +66,6 @@ impl SessionMessageSender for BusSessionMessageSender {
outbound.media = vec![attachment];
self.bus.publish_outbound(outbound).await?;
published_messages += 1;
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
media_type = %media_type,
media_path = %media_path,
"Published session attachment to outbound bus"
);
}
Ok(SessionSendOutcome {
@ -97,14 +81,12 @@ mod tests {
use super::*;
use crate::bus::MediaItem;
const TEST_CHANNEL: &str = "test-channel";
#[tokio::test]
async fn bus_sender_publishes_text_then_attachment() {
let bus = MessageBus::new(8);
let sender = BusSessionMessageSender::new(bus.clone());
let context = ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
chat_id: Some("chat-1".to_string()),
..ToolContext::default()
};

View File

@ -1,5 +1,4 @@
pub mod agent;
pub mod bootstrap;
pub mod bus;
pub mod channels;
pub mod cli;

View File

@ -23,8 +23,6 @@ enum Command {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
picobot::bootstrap::initialize_process_runtime();
let mut cmd = Command::command();
// If no arguments, print help

View File

@ -812,7 +812,7 @@ mod agent_task_tests {
interval_secs: 0,
startup_delay_secs: 0,
target: serde_json::json!({
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
}),
payload: serde_json::json!({
@ -859,7 +859,7 @@ mod agent_task_tests {
interval_secs: 0,
startup_delay_secs: 0,
target: serde_json::json!({
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo",
"session_chat_id": "scheduler/agent.daily_summary.background"
}),
@ -905,7 +905,7 @@ mod agent_task_tests {
startup_delay_secs: 0,
},
target: SchedulerJobTarget {
channel: Some("test-channel".to_string()),
channel: Some("feishu".to_string()),
chat_id: Some("oc_demo".to_string()),
session_chat_id: None,
reply_to: None,
@ -965,7 +965,7 @@ mod agent_task_tests {
startup_delay_secs: 0,
},
target: SchedulerJobTarget {
channel: Some("test-channel".to_string()),
channel: Some("feishu".to_string()),
chat_id: Some("oc_demo".to_string()),
session_chat_id: None,
reply_to: None,
@ -1101,7 +1101,7 @@ mod tests {
interval_secs: 0,
startup_delay_secs: 0,
target: serde_json::json!({
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
}),
payload: serde_json::json!({"content": "hello"}),
@ -1151,7 +1151,7 @@ mod tests {
interval_secs: 60,
startup_delay_secs: 0,
target: serde_json::json!({
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
}),
payload: serde_json::json!({
@ -1271,7 +1271,7 @@ mod tests {
startup_delay_secs: 0,
},
target: SchedulerJobTarget {
channel: Some("test-channel".to_string()),
channel: Some("feishu".to_string()),
chat_id: Some("oc_demo".to_string()),
session_chat_id: Some("scheduler/agent.daily_summary.background".to_string()),
reply_to: None,
@ -1300,7 +1300,7 @@ mod tests {
)
.await
.unwrap();
assert_eq!(outbound.channel, "test-channel");
assert_eq!(outbound.channel, "feishu");
assert_eq!(outbound.chat_id, "oc_demo");
assert!(outbound.content.contains("定时任务执行失败"));
assert!(outbound.content.contains("agent.daily_summary.background"));

View File

@ -1,18 +1,10 @@
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::RwLock;
#[cfg(test)]
static SKILL_TEST_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[cfg(test)]
pub(crate) fn acquire_skill_test_env_lock() -> std::sync::MutexGuard<'static, ()> {
SKILL_TEST_ENV_LOCK.lock().unwrap_or_else(|err| err.into_inner())
}
use crate::config::SkillsConfig;
#[derive(Debug, Clone)]
@ -32,7 +24,7 @@ pub enum SkillSource {
ProjectAgent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SkillScope {
User,
Project,
@ -70,16 +62,6 @@ pub struct SkillRuntime {
catalog: RwLock<SkillCatalog>,
}
#[derive(Debug, Clone)]
pub struct SkillAvailabilityChange {
pub name: String,
pub scope: SkillScope,
pub state_path: PathBuf,
pub changed: bool,
pub disabled_in_scopes: Vec<SkillScope>,
pub available: bool,
}
impl Default for SkillRuntime {
fn default() -> Self {
Self {
@ -238,78 +220,6 @@ impl SkillRuntime {
}
Ok(dir)
}
pub fn disable_skill(
&self,
scope: SkillScope,
name: &str,
reload: bool,
) -> Result<SkillAvailabilityChange, String> {
self.set_skill_enabled(scope, name, false, reload)
}
pub fn enable_skill(
&self,
scope: SkillScope,
name: &str,
reload: bool,
) -> Result<SkillAvailabilityChange, String> {
self.set_skill_enabled(scope, name, true, reload)
}
pub fn has_skill_definition(&self, name: &str) -> Result<bool, String> {
validate_skill_name(name)?;
let cwd = std::env::current_dir()
.map_err(|err| format!("failed to get current dir: {}", err))?;
Ok(SkillCatalog::discover_without_state(&self.config, &cwd)
.find_skill(name)
.is_some())
}
fn set_skill_enabled(
&self,
scope: SkillScope,
name: &str,
enabled: bool,
reload: bool,
) -> Result<SkillAvailabilityChange, String> {
validate_skill_name(name)?;
if !self.has_skill_definition(name)? {
return Err(format!("skill '{}' not found", name));
}
let state_path = skill_state_path(scope)?;
let mut state = load_skill_state_file(&state_path)?;
let mut disabled: HashSet<String> = state.disabled_skills.into_iter().collect();
let changed = if enabled {
disabled.remove(name)
} else {
disabled.insert(name.to_string())
};
let mut disabled_skills: Vec<String> = disabled.into_iter().collect();
disabled_skills.sort();
state.disabled_skills = disabled_skills;
save_skill_state_file(&state_path, &state)?;
if reload {
let _ = self.reload()?;
}
let cwd = std::env::current_dir()
.map_err(|err| format!("failed to get current dir: {}", err))?;
let effective_state = load_skill_disable_state(&cwd);
let disabled_in_scopes = effective_state.disabled_scopes_for(name);
Ok(SkillAvailabilityChange {
name: name.to_string(),
scope,
state_path,
changed,
available: disabled_in_scopes.is_empty(),
disabled_in_scopes,
})
}
}
impl crate::agent::SkillProvider for SkillRuntime {
@ -352,20 +262,6 @@ impl Default for SkillCatalog {
impl SkillCatalog {
pub fn discover(config: &SkillsConfig) -> Self {
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
let disable_state = load_skill_disable_state(&cwd);
Self::discover_with_state(config, &cwd, Some(&disable_state))
}
fn discover_without_state(config: &SkillsConfig, cwd: &Path) -> Self {
Self::discover_with_state(config, cwd, None)
}
fn discover_with_state(
config: &SkillsConfig,
cwd: &Path,
disable_state: Option<&SkillDisableState>,
) -> Self {
if !config.enabled {
return Self {
max_index_chars: config.max_index_chars,
@ -374,6 +270,7 @@ impl SkillCatalog {
};
}
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
let mut merged: HashMap<String, Skill> = HashMap::new();
let mut sources_seen = 0usize;
@ -397,9 +294,6 @@ impl SkillCatalog {
}
let mut skills: Vec<Skill> = merged.into_values().collect();
if let Some(disable_state) = disable_state {
skills.retain(|skill| !disable_state.is_disabled(&skill.name));
}
skills.sort_by(|a, b| a.name.cmp(&b.name));
tracing::info!(
@ -505,35 +399,6 @@ impl SkillCatalog {
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
struct SkillStateFile {
#[serde(default)]
disabled_skills: Vec<String>,
}
#[derive(Debug, Clone, Default)]
struct SkillDisableState {
user_disabled: HashSet<String>,
project_disabled: HashSet<String>,
}
impl SkillDisableState {
fn is_disabled(&self, name: &str) -> bool {
self.user_disabled.contains(name) || self.project_disabled.contains(name)
}
fn disabled_scopes_for(&self, name: &str) -> Vec<SkillScope> {
let mut scopes = Vec::new();
if self.user_disabled.contains(name) {
scopes.push(SkillScope::User);
}
if self.project_disabled.contains(name) {
scopes.push(SkillScope::Project);
}
scopes
}
}
fn source_order(sources: &[String]) -> Vec<SkillSource> {
let mut result = Vec::new();
for source in sources {
@ -596,18 +461,10 @@ fn project_agent_skills_root(cwd: &Path) -> PathBuf {
cwd.join(".agents").join("skills")
}
fn project_skill_state_path(cwd: &Path) -> PathBuf {
cwd.join(".picobot").join("skill-state.json")
}
fn user_skills_root() -> Option<PathBuf> {
dirs::home_dir().map(|p| p.join(".picobot").join("skills"))
}
fn user_skill_state_path() -> Option<PathBuf> {
dirs::home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
}
fn user_agent_skills_root() -> Option<PathBuf> {
dirs::home_dir().map(|p| p.join(".agents").join("skills"))
}
@ -638,78 +495,6 @@ fn skill_file_path(scope: SkillScope, name: &str) -> Result<PathBuf, String> {
Ok(skill_dir_path(scope, name)?.join("SKILL.md"))
}
fn skill_state_path(scope: SkillScope) -> Result<PathBuf, String> {
match scope {
SkillScope::User => user_skill_state_path()
.ok_or_else(|| "failed to resolve home directory".to_string()),
SkillScope::Project => {
let cwd = std::env::current_dir()
.map_err(|err| format!("failed to get current dir: {}", err))?;
Ok(project_skill_state_path(&cwd))
}
}
}
fn load_skill_disable_state(cwd: &Path) -> SkillDisableState {
SkillDisableState {
user_disabled: user_skill_state_path()
.map(|path| load_disabled_skill_names(&path))
.unwrap_or_default(),
project_disabled: load_disabled_skill_names(&project_skill_state_path(cwd)),
}
}
fn load_disabled_skill_names(path: &Path) -> HashSet<String> {
match load_skill_state_file(path) {
Ok(state) => state
.disabled_skills
.into_iter()
.filter_map(|name| normalize_skill_name(name, path))
.collect(),
Err(err) => {
tracing::warn!(path = %path.display(), error = %err, "Failed to load skill state file");
HashSet::new()
}
}
}
fn normalize_skill_name(name: String, path: &Path) -> Option<String> {
let trimmed = name.trim();
match validate_skill_name(trimmed) {
Ok(()) => Some(trimmed.to_string()),
Err(err) => {
tracing::warn!(path = %path.display(), skill = %name, error = %err, "Ignoring invalid disabled skill entry");
None
}
}
}
fn load_skill_state_file(path: &Path) -> Result<SkillStateFile, String> {
if !path.exists() {
return Ok(SkillStateFile::default());
}
let content = fs::read_to_string(path)
.map_err(|err| format!("failed to read skill state file: {}", err))?;
serde_json::from_str(&content)
.map_err(|err| format!("failed to parse skill state file: {}", err))
}
fn save_skill_state_file(path: &Path, state: &SkillStateFile) -> Result<(), String> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| format!("failed to create skill state directory: {}", err))?;
}
let content = serde_json::to_string_pretty(state)
.map_err(|err| format!("failed to render skill state file: {}", err))?;
let tmp_path = path.with_extension("json.tmp");
fs::write(&tmp_path, format!("{}\n", content))
.map_err(|err| format!("failed to write temporary skill state file: {}", err))?;
fs::rename(&tmp_path, path)
.map_err(|err| format!("failed to persist skill state file: {}", err))
}
fn render_skill_file(name: &str, description: &str, body: &str) -> Result<String, String> {
if description.trim().is_empty() {
return Err("description is required and cannot be empty".to_string());
@ -830,16 +615,14 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::OsString;
use std::sync::Mutex;
static CWD_TEST_LOCK: Mutex<()> = Mutex::new(());
struct CurrentDirGuard {
previous: PathBuf,
}
struct HomeDirGuard {
previous: Option<OsString>,
}
impl CurrentDirGuard {
fn enter(path: &Path) -> Self {
let previous = std::env::current_dir().unwrap();
@ -854,33 +637,6 @@ mod tests {
}
}
impl HomeDirGuard {
fn enter(path: &Path) -> Self {
let previous = std::env::var_os("HOME");
unsafe {
std::env::set_var("HOME", path);
}
Self { previous }
}
}
impl Drop for HomeDirGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => unsafe {
std::env::set_var("HOME", value);
},
None => unsafe {
std::env::remove_var("HOME");
},
}
}
}
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
acquire_skill_test_env_lock()
}
#[test]
fn test_split_frontmatter() {
let input = "---\ndescription: demo\n---\nhello";
@ -927,14 +683,9 @@ mod tests {
#[test]
fn test_runtime_create_update_delete_reload() {
let _lock = acquire_test_lock();
let _lock = CWD_TEST_LOCK.lock().unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
fs::create_dir_all(&home_dir).unwrap();
fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let _guard = CurrentDirGuard::enter(temp_dir.path());
let runtime = SkillRuntime::from_config(SkillsConfig {
enabled: true,
@ -1005,16 +756,12 @@ mod tests {
#[test]
fn test_discover_loads_project_agent_skills() {
let _lock = acquire_test_lock();
let _lock = CWD_TEST_LOCK.lock().unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
fs::create_dir_all(&home_dir).unwrap();
fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let _guard = CurrentDirGuard::enter(temp_dir.path());
let agent_skill_dir = project_dir
let agent_skill_dir = temp_dir
.path()
.join(".agents")
.join("skills")
.join("demo-agent");
@ -1039,16 +786,11 @@ mod tests {
#[test]
fn test_discover_prefers_project_agent_on_conflict() {
let _lock = acquire_test_lock();
let _lock = CWD_TEST_LOCK.lock().unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
fs::create_dir_all(&home_dir).unwrap();
fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let _guard = CurrentDirGuard::enter(temp_dir.path());
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
let project_skill_dir = temp_dir.path().join(".picobot").join("skills").join("demo");
fs::create_dir_all(&project_skill_dir).unwrap();
fs::write(
project_skill_dir.join("SKILL.md"),
@ -1056,7 +798,7 @@ mod tests {
)
.unwrap();
let agent_skill_dir = project_dir.join(".agents").join("skills").join("demo");
let agent_skill_dir = temp_dir.path().join(".agents").join("skills").join("demo");
fs::create_dir_all(&agent_skill_dir).unwrap();
fs::write(
agent_skill_dir.join("SKILL.md"),
@ -1075,136 +817,4 @@ mod tests {
assert!(payload.contains("Source: project_agent"));
assert!(payload.contains("Agent body"));
}
#[test]
fn test_skill_state_file_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("skill-state.json");
let state = SkillStateFile {
disabled_skills: vec!["demo".to_string(), "other".to_string()],
};
save_skill_state_file(&path, &state).unwrap();
let loaded = load_skill_state_file(&path).unwrap();
assert_eq!(loaded, state);
}
#[test]
fn test_discover_filters_disabled_skills_from_sidecar() {
let _lock = acquire_test_lock();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
fs::create_dir_all(&home_dir).unwrap();
fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
fs::create_dir_all(&project_skill_dir).unwrap();
fs::write(
project_skill_dir.join("SKILL.md"),
"---\ndescription: project version\n---\nProject body",
)
.unwrap();
save_skill_state_file(
&project_dir.join(".picobot").join("skill-state.json"),
&SkillStateFile {
disabled_skills: vec!["demo".to_string()],
},
)
.unwrap();
let catalog = SkillCatalog::discover(&SkillsConfig {
enabled: true,
sources: vec!["project".to_string()],
max_index_chars: 4000,
max_listed_skills: 32,
});
assert_eq!(catalog.len(), 0);
assert!(catalog.activation_payload("demo").is_err());
}
#[test]
fn test_runtime_disable_and_enable_skill_updates_visibility() {
let _lock = acquire_test_lock();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
fs::create_dir_all(&home_dir).unwrap();
fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
fs::create_dir_all(&project_skill_dir).unwrap();
fs::write(
project_skill_dir.join("SKILL.md"),
"---\ndescription: project version\n---\nProject body",
)
.unwrap();
let runtime = SkillRuntime::from_config(SkillsConfig {
enabled: true,
sources: vec!["project".to_string()],
max_index_chars: 4000,
max_listed_skills: 32,
});
let disabled = runtime.disable_skill(SkillScope::Project, "demo", true).unwrap();
assert!(disabled.changed);
assert_eq!(disabled.disabled_in_scopes, vec![SkillScope::Project]);
assert!(!disabled.available);
assert!(runtime.get_skill("demo").is_none());
let enabled = runtime.enable_skill(SkillScope::Project, "demo", true).unwrap();
assert!(enabled.changed);
assert!(enabled.disabled_in_scopes.is_empty());
assert!(enabled.available);
assert!(runtime.get_skill("demo").is_some());
}
#[test]
fn test_user_scope_disable_overrides_project_scope_enable() {
let _lock = acquire_test_lock();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
fs::create_dir_all(&home_dir).unwrap();
fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
fs::create_dir_all(&project_skill_dir).unwrap();
fs::write(
project_skill_dir.join("SKILL.md"),
"---\ndescription: project version\n---\nProject body",
)
.unwrap();
let runtime = SkillRuntime::from_config(SkillsConfig {
enabled: true,
sources: vec!["project".to_string()],
max_index_chars: 4000,
max_listed_skills: 32,
});
let user_disabled = runtime.disable_skill(SkillScope::User, "demo", true).unwrap();
assert_eq!(user_disabled.disabled_in_scopes, vec![SkillScope::User]);
assert!(runtime.get_skill("demo").is_none());
let project_enabled = runtime.enable_skill(SkillScope::Project, "demo", true).unwrap();
assert!(!project_enabled.available);
assert_eq!(project_enabled.disabled_in_scopes, vec![SkillScope::User]);
assert!(runtime.get_skill("demo").is_none());
let user_enabled = runtime.enable_skill(SkillScope::User, "demo", true).unwrap();
assert!(user_enabled.available);
assert!(user_enabled.disabled_in_scopes.is_empty());
assert!(runtime.get_skill("demo").is_some());
}
}

View File

@ -1614,12 +1614,10 @@ mod tests {
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
use crate::domain::messages::ToolCall;
const TEST_CHANNEL: &str = "test-channel";
#[test]
fn test_persistent_session_id_for_cli_and_channel() {
assert_eq!(persistent_session_id("cli", "abc"), "abc");
assert_eq!(persistent_session_id(TEST_CHANNEL, "abc"), "test-channel:abc");
assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc");
}
#[test]
@ -1684,12 +1682,12 @@ mod tests {
fn test_ensure_channel_session_is_stable() {
let store = SessionStore::in_memory().unwrap();
let first = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
let second = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
let first = store.ensure_channel_session("feishu", "chat-1").unwrap();
let second = store.ensure_channel_session("feishu", "chat-1").unwrap();
assert_eq!(first.id, second.id);
assert_eq!(first.chat_id, "chat-1");
assert_eq!(second.channel_name, TEST_CHANNEL);
assert_eq!(second.channel_name, "feishu");
}
#[test]
@ -2042,27 +2040,27 @@ mod tests {
let saved = store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "profile".to_string(),
memory_key: "language".to_string(),
content: "Rust".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-1".to_string()),
source_message_id: Some("msg-1".to_string()),
source_message_seq: Some(7),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-1".to_string()),
})
.unwrap();
assert_eq!(saved.content, "Rust");
assert_eq!(saved.source_type, "message");
assert_eq!(saved.source_session_id.as_deref(), Some("test-channel:chat-1"));
assert_eq!(saved.source_session_id.as_deref(), Some("feishu:chat-1"));
assert_eq!(saved.source_message_id.as_deref(), Some("msg-1"));
assert_eq!(saved.source_message_seq, Some(7));
let fetched = store
.get_memory("user", "test-channel:user-1", "profile", "language")
.get_memory("user", "feishu:user-1", "profile", "language")
.unwrap()
.unwrap();
assert_eq!(fetched.id, saved.id);
@ -2076,21 +2074,21 @@ mod tests {
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "editor".to_string(),
content: "Prefers rust-analyzer and cargo test output".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-2".to_string()),
source_message_id: Some("msg-2".to_string()),
source_message_seq: Some(3),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-2".to_string()),
})
.unwrap();
let hits = store
.search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10)
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
.unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].memory_key, "editor");
@ -2098,36 +2096,36 @@ mod tests {
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "editor".to_string(),
content: "Prefers clippy diagnostics".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-3".to_string()),
source_message_id: Some("msg-3".to_string()),
source_message_seq: Some(4),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-3".to_string()),
})
.unwrap();
let old_hits = store
.search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10)
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
.unwrap();
assert!(old_hits.is_empty());
let new_hits = store
.search_memories("user", "test-channel:user-1", "clippy", None, 10)
.search_memories("user", "feishu:user-1", "clippy", None, 10)
.unwrap();
assert_eq!(new_hits.len(), 1);
let deleted = store
.delete_memory("user", "test-channel:user-1", "preferences", "editor")
.delete_memory("user", "feishu:user-1", "preferences", "editor")
.unwrap();
assert!(deleted);
let hits_after_delete = store
.search_memories("user", "test-channel:user-1", "clippy", None, 10)
.search_memories("user", "feishu:user-1", "clippy", None, 10)
.unwrap();
assert!(hits_after_delete.is_empty());
}
@ -2139,21 +2137,21 @@ mod tests {
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "email_folder_preference".to_string(),
content: "用户提到邮件时默认查看代收邮箱。".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-8", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-8".to_string()),
source_message_id: Some("msg-8".to_string()),
source_message_seq: Some(8),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-8".to_string()),
})
.unwrap();
let hits = store
.search_memories("user", "test-channel:user-1", "email_folder_preference", None, 10)
.search_memories("user", "feishu:user-1", "email_folder_preference", None, 10)
.unwrap();
assert_eq!(hits.len(), 1);
@ -2167,15 +2165,15 @@ mod tests {
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "editor".to_string(),
content: "Prefers rust-analyzer and cargo test output".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-2".to_string()),
source_message_id: Some("msg-2".to_string()),
source_message_seq: Some(3),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-2".to_string()),
})
.unwrap();
@ -2183,15 +2181,15 @@ mod tests {
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "tasks".to_string(),
memory_key: "quality".to_string(),
content: "Tracks clippy warnings before release".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-3".to_string()),
source_message_id: Some("msg-3".to_string()),
source_message_seq: Some(4),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-3".to_string()),
})
.unwrap();
@ -2199,7 +2197,7 @@ mod tests {
let hits = store
.search_memories_any(
"user",
"test-channel:user-1",
"feishu:user-1",
&["rust-analyzer".to_string(), "clippy".to_string()],
None,
10,
@ -2218,45 +2216,45 @@ mod tests {
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-2", TEST_CHANNEL),
scope_key: "feishu:user-2".to_string(),
namespace: "preferences".to_string(),
memory_key: "style".to_string(),
content: "偏好简洁表达".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-2".to_string()),
source_message_id: Some("msg-2".to_string()),
source_message_seq: Some(2),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-2".to_string()),
})
.unwrap();
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "profile".to_string(),
memory_key: "work".to_string(),
content: "用户在做AI产品".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-1".to_string()),
source_message_id: Some("msg-1".to_string()),
source_message_seq: Some(1),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-1".to_string()),
})
.unwrap();
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "patterns".to_string(),
memory_key: "workflow".to_string(),
content: "习惯先问方案再要代码".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-1".to_string()),
source_message_id: Some("msg-3".to_string()),
source_message_seq: Some(3),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-1".to_string()),
})
.unwrap();
@ -2264,17 +2262,17 @@ mod tests {
let scope_keys = store.list_memory_scope_keys("user").unwrap();
assert_eq!(
scope_keys,
vec!["test-channel:user-1".to_string(), "test-channel:user-2".to_string()]
vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]
);
let full_scope = store
.list_memories_for_scope("user", "test-channel:user-1")
.list_memories_for_scope("user", "feishu:user-1")
.unwrap();
assert_eq!(full_scope.len(), 2);
assert!(
full_scope
.iter()
.all(|memory| memory.scope_key == "test-channel:user-1")
.all(|memory| memory.scope_key == "feishu:user-1")
);
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
assert!(
@ -2300,7 +2298,7 @@ mod tests {
interval_secs: 300,
startup_delay_secs: 10,
target: serde_json::json!({
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo",
}),
payload: serde_json::json!({

View File

@ -221,17 +221,15 @@ mod tests {
use super::*;
use crate::storage::SessionStore;
const TEST_CHANNEL: &str = "test-channel";
#[tokio::test]
async fn test_memory_manage_put_returns_saved_memory() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemoryManageTool::new(store);
let context = ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("chat-1".to_string()),
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
session_id: Some("feishu:chat-1".to_string()),
message_id: Some("msg-1".to_string()),
message_seq: Some(1),
};
@ -277,7 +275,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemoryManageTool::new(store);
let context = ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
..ToolContext::default()
};

View File

@ -207,33 +207,31 @@ mod tests {
use super::*;
use crate::storage::SessionStore;
const TEST_CHANNEL: &str = "test-channel";
#[tokio::test]
async fn test_memory_search_search_and_get() {
let store = Arc::new(SessionStore::in_memory().unwrap());
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: format!("{}:user-1", TEST_CHANNEL),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "language".to_string(),
content: "User prefers Chinese responses".to_string(),
source_type: "message".to_string(),
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
source_session_id: Some("feishu:chat-1".to_string()),
source_message_id: Some("msg-1".to_string()),
source_message_seq: Some(1),
source_channel_name: Some(TEST_CHANNEL.to_string()),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-1".to_string()),
})
.unwrap();
let tool = MemorySearchTool::new(store);
let context = ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("chat-1".to_string()),
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
session_id: Some("feishu:chat-1".to_string()),
message_id: Some("msg-2".to_string()),
message_seq: Some(2),
};
@ -287,7 +285,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemorySearchTool::new(store);
let context = ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
..ToolContext::default()
};

View File

@ -435,8 +435,6 @@ mod tests {
use super::*;
use crate::storage::SessionStore;
const TEST_CHANNEL: &str = "test-channel";
#[tokio::test]
async fn test_scheduler_manage_put_and_get() {
let store = Arc::new(SessionStore::in_memory().unwrap());
@ -452,7 +450,7 @@ mod tests {
"seconds": 60
},
"target": {
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
@ -490,7 +488,7 @@ mod tests {
"expression": "0 9 * * *"
},
"target": {
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
@ -520,7 +518,7 @@ mod tests {
"expression": "0 9 * * *"
},
"target": {
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo",
"session_chat_id": "scheduler/agent.daily_summary.background"
},
@ -578,10 +576,10 @@ mod tests {
let put_result = tool
.execute_with_context(
&crate::tools::ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("oc_demo".to_string()),
session_id: Some(format!("{}:oc_demo", TEST_CHANNEL)),
session_id: Some("feishu:oc_demo".to_string()),
message_id: Some("msg-1".to_string()),
message_seq: Some(1),
},
@ -604,7 +602,7 @@ mod tests {
assert!(put_result.success);
let saved = store.get_scheduler_job("work_reminder").unwrap().unwrap();
assert_eq!(saved.target["channel"], "test-channel");
assert_eq!(saved.target["channel"], "feishu");
assert_eq!(saved.target["chat_id"], "oc_demo");
}
@ -623,7 +621,7 @@ mod tests {
"expression": "0 9 * * *"
},
"target": {
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
@ -655,7 +653,7 @@ mod tests {
"expression": "0 9 * * *"
},
"target": {
"channel": "test-channel",
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {

View File

@ -240,8 +240,6 @@ mod tests {
use super::*;
use tempfile::NamedTempFile;
const TEST_CHANNEL: &str = "test-channel";
struct MockSender {
outcome: SessionSendOutcome,
}
@ -259,7 +257,7 @@ mod tests {
fn context() -> ToolContext {
ToolContext {
channel_name: Some(TEST_CHANNEL.to_string()),
channel_name: Some("feishu".to_string()),
chat_id: Some("chat-1".to_string()),
..ToolContext::default()
}

View File

@ -124,16 +124,14 @@ mod tests {
use super::*;
use crate::storage::SessionStore;
const TEST_CHANNEL: &str = "test-channel";
#[tokio::test]
async fn test_skill_activate_records_failed_activation_event() {
let skills = Arc::new(SkillRuntime::default());
let store = Arc::new(SessionStore::in_memory().unwrap());
store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
store.ensure_channel_session("feishu", "chat-1").unwrap();
let tool = SkillActivateTool::new(skills, store.clone());
let context = ToolContext {
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
session_id: Some("feishu:chat-1".to_string()),
..ToolContext::default()
};
@ -145,9 +143,7 @@ mod tests {
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
let events = store
.list_skill_events(Some(&format!("{}:chat-1", TEST_CHANNEL)))
.unwrap();
let events = store.list_skill_events(Some("feishu:chat-1")).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "activation_failed");
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));

View File

@ -32,7 +32,7 @@ impl Tool for SkillManageTool {
}
fn description(&self) -> &str {
"Manage PicoBot skills stored under .picobot/skills or ~/.picobot/skills, while discovery also reads .agents/skills and ~/.agents/skills. Supports actions: list, get, create, update, delete, disable, reload."
"Manage PicoBot skills stored under .picobot/skills or ~/.picobot/skills, while discovery also reads .agents/skills and ~/.agents/skills. Supports actions: list, get, create, update, delete, reload."
}
fn parameters_schema(&self) -> serde_json::Value {
@ -41,25 +41,18 @@ impl Tool for SkillManageTool {
"properties": {
"action": {
"type": "string",
"enum": ["list", "get", "create", "update", "delete", "disable", "reload"],
"enum": ["list", "get", "create", "update", "delete", "reload"],
"description": "Management action to perform"
},
"scope": {
"type": "string",
"enum": ["project", "user"],
"description": "Writable skill scope for create/update/delete/disable. Defaults to project. .agents discovery sources are read-only here, but can still be disabled via sidecar state."
"description": "Writable skill scope for create/update/delete. Defaults to project. .agents discovery sources are read-only here."
},
"name": {
"type": "string",
"description": "Skill name"
},
"names": {
"type": "array",
"items": {
"type": "string"
},
"description": "Skill names for batch disable; pass a single-item array to disable one skill"
},
"description": {
"type": "string",
"description": "Skill description used for discovery"
@ -100,10 +93,6 @@ impl Tool for SkillManageTool {
};
let name = args.get("name").and_then(|v| v.as_str());
let names = match parse_disable_names(&args) {
Ok(names) => names,
Err(err) => return Ok(error_result(&err)),
};
let result = match action {
"list" => list_skills_payload(&self.skills),
@ -203,30 +192,6 @@ impl Tool for SkillManageTool {
}),
Err(err) => return Ok(error_result(&err)),
},
"disable" => {
let targets = &names;
let mut changes = Vec::new();
for target in targets {
match self.skills.disable_skill(scope, target, false) {
Ok(change) => changes.push(change),
Err(err) => return Ok(error_result(&err)),
}
}
if reload {
if let Err(err) = self.skills.reload() {
return Ok(error_result(&err));
}
}
json!({
"status": "disabled",
"scope": scope.as_str(),
"count": changes.len(),
"reloaded": reload,
"changes": changes.into_iter().map(skill_change_payload).collect::<Vec<_>>(),
})
}
_ => return Ok(error_result("Unsupported action")),
};
@ -277,42 +242,6 @@ fn error_result(message: &str) -> ToolResult {
}
}
fn parse_disable_names(args: &serde_json::Value) -> Result<Vec<String>, String> {
let names = args
.get("names")
.ok_or_else(|| "disable requires names".to_string())?
.as_array()
.ok_or_else(|| "names must be an array of strings".to_string())?;
let mut parsed = Vec::new();
for item in names {
let name = item
.as_str()
.ok_or_else(|| "names must be an array of strings".to_string())?
.trim()
.to_string();
if name.is_empty() {
return Err("names must not contain empty values".to_string());
}
parsed.push(name);
}
if parsed.is_empty() {
return Err("names must not be empty".to_string());
}
Ok(parsed)
}
fn skill_change_payload(change: crate::skills::SkillAvailabilityChange) -> serde_json::Value {
json!({
"name": change.name,
"scope": change.scope.as_str(),
"path": change.state_path.display().to_string(),
"changed": change.changed,
"available": change.available,
"disabled_in_scopes": change.disabled_in_scopes.into_iter().map(|scope| scope.as_str()).collect::<Vec<_>>(),
})
}
fn list_skills_payload(skills: &Arc<SkillRuntime>) -> serde_json::Value {
let skills = skills.list_skills();
json!({
@ -330,170 +259,3 @@ fn list_skills_payload(skills: &Arc<SkillRuntime>) -> serde_json::Value {
})).collect::<Vec<_>>()
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SkillsConfig;
use crate::skills::acquire_skill_test_env_lock;
use std::ffi::OsString;
use std::path::{Path, PathBuf};
struct CurrentDirGuard {
previous: PathBuf,
}
struct HomeDirGuard {
previous: Option<OsString>,
}
impl CurrentDirGuard {
fn enter(path: &Path) -> Self {
let previous = std::env::current_dir().unwrap();
std::env::set_current_dir(path).unwrap();
Self { previous }
}
}
impl Drop for CurrentDirGuard {
fn drop(&mut self) {
let _ = std::env::set_current_dir(&self.previous);
}
}
impl HomeDirGuard {
fn enter(path: &Path) -> Self {
let previous = std::env::var_os("HOME");
unsafe {
std::env::set_var("HOME", path);
}
Self { previous }
}
}
impl Drop for HomeDirGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => unsafe {
std::env::set_var("HOME", value);
},
None => unsafe {
std::env::remove_var("HOME");
},
}
}
}
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
acquire_skill_test_env_lock()
}
#[tokio::test]
async fn test_skill_manage_disable_updates_runtime() {
let _lock = acquire_test_lock();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
std::fs::create_dir_all(&home_dir).unwrap();
std::fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let runtime = Arc::new(SkillRuntime::from_config(SkillsConfig {
enabled: true,
sources: vec!["project".to_string()],
max_index_chars: 4000,
max_listed_skills: 32,
}));
runtime
.create_skill(SkillScope::Project, "demo", "demo skill", "body", true)
.unwrap();
let tool = SkillManageTool::new(runtime.clone());
let disabled = tool
.execute(json!({
"action": "disable",
"names": ["demo"],
"scope": "project"
}))
.await
.unwrap();
assert!(disabled.success);
assert!(disabled.output.contains("disabled"));
assert!(runtime.get_skill("demo").is_none());
}
#[tokio::test]
async fn test_skill_manage_batch_disable_uses_names_array() {
let _lock = acquire_test_lock();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
std::fs::create_dir_all(&home_dir).unwrap();
std::fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let runtime = Arc::new(SkillRuntime::from_config(SkillsConfig {
enabled: true,
sources: vec!["project".to_string()],
max_index_chars: 4000,
max_listed_skills: 32,
}));
runtime
.create_skill(SkillScope::Project, "demo-a", "demo skill a", "body", true)
.unwrap();
runtime
.create_skill(SkillScope::Project, "demo-b", "demo skill b", "body", true)
.unwrap();
let tool = SkillManageTool::new(runtime.clone());
let disabled = tool
.execute(json!({
"action": "disable",
"names": ["demo-a", "demo-b"],
"scope": "project"
}))
.await
.unwrap();
assert!(disabled.success);
assert!(disabled.output.contains("\"count\": 2"));
assert!(runtime.get_skill("demo-a").is_none());
assert!(runtime.get_skill("demo-b").is_none());
}
#[tokio::test]
async fn test_skill_manage_disable_requires_names_array() {
let _lock = acquire_test_lock();
let temp_dir = tempfile::tempdir().unwrap();
let home_dir = temp_dir.path().join("home");
let project_dir = temp_dir.path().join("project");
std::fs::create_dir_all(&home_dir).unwrap();
std::fs::create_dir_all(&project_dir).unwrap();
let _home = HomeDirGuard::enter(&home_dir);
let _guard = CurrentDirGuard::enter(&project_dir);
let runtime = Arc::new(SkillRuntime::from_config(SkillsConfig {
enabled: true,
sources: vec!["project".to_string()],
max_index_chars: 4000,
max_listed_skills: 32,
}));
let tool = SkillManageTool::new(runtime);
let result = tool
.execute(json!({
"action": "disable",
"name": "demo",
"scope": "project"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("disable requires names"));
}
}

View File

@ -1 +0,0 @@
{"v":1}

View File

@ -1,6 +0,0 @@
{
"git": {
"sha1": "70bc64cc8035de4677bbe01265570e7f157bb31d"
},
"path_in_vcs": "rust"
}

View File

@ -1,91 +0,0 @@
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.
[package]
edition = "2021"
name = "wechatbot"
version = "0.3.2"
build = false
autolib = false
autobins = false
autoexamples = false
autotests = false
autobenches = false
description = "WeChat iLink Bot SDK for Rust"
homepage = "https://github.com/corespeed-io/wechatbot"
documentation = "https://docs.rs/wechatbot"
readme = "README.md"
license = "MIT"
repository = "https://github.com/corespeed-io/wechatbot"
[lib]
name = "wechatbot"
path = "src/lib.rs"
[[example]]
name = "echo_bot"
path = "examples/echo_bot.rs"
[dependencies.aes]
version = "0.8"
[dependencies.base64]
version = "0.22"
[dependencies.dirs-next]
version = "2"
[dependencies.hex]
version = "0.4"
[dependencies.md-5]
version = "0.10"
[dependencies.rand]
version = "0.10"
[dependencies.reqwest]
version = "0.12"
default-features = false
features = ["json", "rustls-tls"]
[dependencies.serde]
version = "1"
features = ["derive"]
[dependencies.serde_json]
version = "1"
[dependencies.serde_repr]
version = "0.1"
[dependencies.thiserror]
version = "2"
[dependencies.tokio]
version = "1"
features = ["full"]
[dependencies.tracing]
version = "0.1"
[dependencies.urlencoding]
version = "2"
[dependencies.uuid]
version = "1"
features = ["v4"]
[dev-dependencies.tokio-test]
version = "0.4"
[dev-dependencies.tracing-subscriber]
version = "0.3"

35
vendor/wechatbot/Cargo.toml.orig generated vendored
View File

@ -1,35 +0,0 @@
[package]
name = "wechatbot"
version = "0.3.2"
edition = "2021"
description = "WeChat iLink Bot SDK for Rust"
license = "MIT"
readme = "README.md"
repository = "https://github.com/corespeed-io/wechatbot"
homepage = "https://github.com/corespeed-io/wechatbot"
documentation = "https://docs.rs/wechatbot"
[dependencies]
aes = "0.8"
base64 = "0.22"
hex = "0.4"
md-5 = "0.10"
rand = "0.10"
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_repr = "0.1"
tokio = { version = "1", features = ["full"] }
uuid = { version = "1", features = ["v4"] }
thiserror = "2"
tracing = "0.1"
urlencoding = "2"
dirs-next = "2"
[dev-dependencies]
tokio-test = "0.4"
tracing-subscriber = "0.3"
[[example]]
name = "echo_bot"
path = "examples/echo_bot.rs"

View File

@ -1,226 +0,0 @@
# wechatbot — Rust SDK
WeChat iLink Bot SDK for Rust — async, type-safe, zero-copy where possible.
## Install
```toml
[dependencies]
wechatbot = "0.1"
tokio = { version = "1", features = ["full"] }
```
Requires Rust 2021 edition. Built on `tokio` + `reqwest`.
## Quick Start
```rust
use wechatbot::{WeChatBot, BotOptions};
#[tokio::main]
async fn main() {
let bot = WeChatBot::new(BotOptions::default());
let creds = bot.login(false).await.unwrap();
println!("Logged in: {}", creds.account_id);
bot.on_message(Box::new(|msg| {
println!("{}: {}", msg.user_id, msg.text);
})).await;
bot.run().await.unwrap();
}
```
## Architecture
```
src/
├── lib.rs ← Public re-exports
├── types.rs ← All protocol & public types (serde)
├── error.rs ← Error hierarchy (thiserror)
├── protocol.rs ← Raw iLink API calls (reqwest)
├── crypto.rs ← AES-128-ECB encrypt/decrypt + key encoding
└── bot.rs ← WeChatBot client (login, run, reply, send)
```
## API Reference
### Creating a Bot
```rust
use wechatbot::{WeChatBot, BotOptions};
let bot = WeChatBot::new(BotOptions {
base_url: None, // default: ilinkai.weixin.qq.com
cred_path: None, // default: ~/.wechatbot/credentials.json
on_qr_url: Some(Box::new(|url| {
println!("Scan: {}", url);
})),
on_error: Some(Box::new(|err| {
eprintln!("Error: {}", err);
})),
});
```
### Authentication
```rust
// Login (skips QR if credentials exist)
let creds = bot.login(false).await?;
// Force re-login
let creds = bot.login(true).await?;
// Credentials struct
println!("Token: {}", creds.token);
println!("Base URL: {}", creds.base_url);
println!("Account: {}", creds.account_id);
println!("User: {}", creds.user_id);
```
### Message Handling
```rust
bot.on_message(Box::new(|msg| {
match msg.content_type {
ContentType::Text => println!("Text: {}", msg.text),
ContentType::Image => {
for img in &msg.images {
println!("Image URL: {:?}", img.url);
}
}
ContentType::Voice => {
for voice in &msg.voices {
println!("Voice: {:?} ({}ms)", voice.text, voice.duration_ms.unwrap_or(0));
}
}
ContentType::File => {
for file in &msg.files {
println!("File: {:?}", file.file_name);
}
}
ContentType::Video => println!("Video received"),
}
if let Some(ref quoted) = msg.quoted {
println!("Quoted: {:?}", quoted.title);
}
})).await;
```
### Sending Messages
```rust
// Reply to incoming message
bot.reply(&msg, "Echo: hello").await?;
// Send to user (needs prior context_token)
bot.send(user_id, "Hello").await?;
// Typing indicator
bot.send_typing(user_id).await?;
```
### Media Operations
```rust
// Reply with media content
bot.reply_media(&msg, SendContent::Image(png_bytes)).await?;
bot.reply_media(&msg, SendContent::File { data, file_name: "report.pdf".into() }).await?;
bot.reply_media(&msg, SendContent::Video(mp4_bytes)).await?;
```
```rust
// Download media from incoming message (priority: image > file > video > voice)
if let Some(media) = bot.download(&msg).await? {
println!("Type: {}, Size: {} bytes", media.media_type, media.data.len());
if let Some(name) = &media.file_name {
println!("Filename: {}", name);
}
}
// Download a raw CDN reference directly
let raw = bot.download_raw(&msg.images[0].media.as_ref().unwrap(), None).await?;
```
```rust
// Upload to CDN without sending a message
let result = bot.upload(&file_bytes, user_id, 3).await?;
```
### Lifecycle
```rust
// Start polling (blocks)
bot.run().await?;
// Stop
bot.stop().await;
```
## Error Handling
```rust
use wechatbot::WeChatBotError;
match result {
Err(WeChatBotError::Api { message, errcode, .. }) => {
if errcode == -14 {
// session expired — handled automatically
}
}
Err(WeChatBotError::NoContext(user_id)) => {
// no context_token for this user yet
}
Err(WeChatBotError::Transport(e)) => {
// network error
}
_ => {}
}
```
## AES-128-ECB Crypto
```rust
use wechatbot::{generate_aes_key, encrypt_aes_ecb, decrypt_aes_ecb, decode_aes_key};
// Generate key
let key = generate_aes_key();
// Encrypt/decrypt
let ciphertext = encrypt_aes_ecb(b"Hello", &key);
let plaintext = decrypt_aes_ecb(&ciphertext, &key)?;
// Decode protocol key (handles all 3 formats)
let key = decode_aes_key("ABEiM0RVZneImaq7zN3u/w==")?;
let key = decode_aes_key("00112233445566778899aabbccddeeff")?;
```
## Types
All protocol types derive `Serialize` + `Deserialize` + `Clone` + `Debug`:
```rust
// Wire-level (protocol)
WireMessage, WireMessageItem, CDNMedia, TextItem, ImageItem, ...
// Parsed (user-friendly)
IncomingMessage, ImageContent, VoiceContent, FileContent, VideoContent
// Auth
Credentials
// Enums
MessageType, MessageState, MessageItemType, ContentType, MediaType
```
## Testing
```bash
cd rust
cargo test
```
## License
MIT

View File

@ -1,43 +0,0 @@
use wechatbot::{BotOptions, WeChatBot};
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let bot = WeChatBot::new(BotOptions {
on_qr_url: Some(Box::new(|url| {
println!("\nScan this URL in WeChat:\n{}\n", url);
})),
on_error: Some(Box::new(|err| {
eprintln!("Error: {}", err);
})),
..Default::default()
});
let creds = bot.login(false).await.expect("login failed");
println!("Logged in: {} ({})", creds.account_id, creds.user_id);
bot.on_message(Box::new(|msg| {
println!("[{}] {}: {}", msg.content_type_str(), msg.user_id, msg.text);
}))
.await;
println!("Listening for messages (Ctrl+C to stop)");
bot.run().await.expect("run failed");
}
trait ContentTypeStr {
fn content_type_str(&self) -> &str;
}
impl ContentTypeStr for wechatbot::IncomingMessage {
fn content_type_str(&self) -> &str {
match self.content_type {
wechatbot::ContentType::Text => "text",
wechatbot::ContentType::Image => "image",
wechatbot::ContentType::Voice => "voice",
wechatbot::ContentType::File => "file",
wechatbot::ContentType::Video => "video",
}
}
}

View File

@ -1,741 +0,0 @@
//! Main WeChatBot client.
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{sleep, Duration};
use tracing::{error, info, warn};
use crate::cdn::CdnClient;
use crate::crypto;
use crate::error::{Result, WeChatBotError};
use crate::protocol::{self, ILinkClient};
use crate::types::*;
use md5::{Digest, Md5};
use rand::Rng;
use serde_json::json;
/// Message handler callback type.
pub type MessageHandler = Box<dyn Fn(&IncomingMessage) + Send + Sync>;
/// Bot configuration options.
pub struct BotOptions {
pub base_url: Option<String>,
pub cred_path: Option<String>,
pub on_qr_url: Option<Box<dyn Fn(&str) + Send + Sync>>,
pub on_error: Option<Box<dyn Fn(&WeChatBotError) + Send + Sync>>,
}
impl Default for BotOptions {
fn default() -> Self {
Self {
base_url: None,
cred_path: None,
on_qr_url: None,
on_error: None,
}
}
}
/// WeChatBot is the main entry point.
pub struct WeChatBot {
client: Arc<ILinkClient>,
cdn: CdnClient,
credentials: RwLock<Option<Credentials>>,
context_tokens: RwLock<HashMap<String, String>>,
handlers: Mutex<Vec<MessageHandler>>,
cursor: RwLock<String>,
base_url: RwLock<String>,
cred_path: Option<String>,
stopped: RwLock<bool>,
on_qr_url: Option<Box<dyn Fn(&str) + Send + Sync>>,
on_error: Option<Box<dyn Fn(&WeChatBotError) + Send + Sync>>,
}
impl WeChatBot {
/// Create a new bot instance.
pub fn new(opts: BotOptions) -> Self {
Self {
client: Arc::new(ILinkClient::new()),
cdn: CdnClient::new(),
credentials: RwLock::new(None),
context_tokens: RwLock::new(HashMap::new()),
handlers: Mutex::new(Vec::new()),
cursor: RwLock::new(String::new()),
base_url: RwLock::new(
opts.base_url
.unwrap_or_else(|| protocol::DEFAULT_BASE_URL.to_string()),
),
cred_path: opts.cred_path,
stopped: RwLock::new(false),
on_qr_url: opts.on_qr_url,
on_error: opts.on_error,
}
}
/// Maximum number of QR code refresh attempts before giving up.
const MAX_QR_REFRESH: u32 = 3;
/// Fixed API base URL for QR code requests.
const FIXED_QR_BASE_URL: &'static str = "https://ilinkai.weixin.qq.com";
/// Login via QR code. Returns credentials on success.
pub async fn login(&self, force: bool) -> Result<Credentials> {
let base_url = self.base_url.read().await.clone();
if !force {
if let Some(creds) = self.load_credentials().await? {
*self.credentials.write().await = Some(creds.clone());
*self.base_url.write().await = creds.base_url.clone();
info!("Loaded stored credentials for {}", creds.user_id);
return Ok(creds);
}
}
// QR code login flow
let mut qr_refresh_count = 0u32;
loop {
qr_refresh_count += 1;
if qr_refresh_count > Self::MAX_QR_REFRESH {
return Err(WeChatBotError::Auth(format!(
"QR code expired {} times — login aborted",
Self::MAX_QR_REFRESH
)));
}
let qr = self.client.get_qr_code(Self::FIXED_QR_BASE_URL).await?;
if let Some(ref cb) = self.on_qr_url {
cb(&qr.qrcode_img_content);
} else {
eprintln!("[wechatbot] Scan: {}", qr.qrcode_img_content);
}
let mut last_status = String::new();
let mut current_poll_base_url = Self::FIXED_QR_BASE_URL.to_string();
loop {
let status = self
.client
.poll_qr_status(&current_poll_base_url, &qr.qrcode)
.await?;
if status.status != last_status {
last_status = status.status.clone();
match status.status.as_str() {
"scaned" => info!("QR scanned — confirm in WeChat"),
"expired" => warn!("QR expired — requesting new one"),
"confirmed" => info!("Login confirmed"),
_ => {}
}
}
if status.status == "confirmed" {
let token = status
.bot_token
.ok_or_else(|| WeChatBotError::Auth("missing bot_token".into()))?;
let creds = Credentials {
token,
base_url: status.baseurl.unwrap_or_else(|| base_url.clone()),
account_id: status.ilink_bot_id.unwrap_or_default(),
user_id: status.ilink_user_id.unwrap_or_default(),
saved_at: Some(chrono_now()),
};
self.save_credentials(&creds).await?;
*self.credentials.write().await = Some(creds.clone());
*self.base_url.write().await = creds.base_url.clone();
return Ok(creds);
}
// Handle IDC redirect
if status.status == "scaned_but_redirect" {
if let Some(ref host) = status.redirect_host {
current_poll_base_url = format!("https://{}", host);
info!("IDC redirect, switching polling host to {}", host);
} else {
warn!("Received scaned_but_redirect but redirect_host is missing");
}
sleep(Duration::from_secs(2)).await;
continue;
}
if status.status == "expired" {
break;
}
sleep(Duration::from_secs(2)).await;
}
}
}
/// Register a message handler.
pub async fn on_message(&self, handler: MessageHandler) {
self.handlers.lock().await.push(handler);
}
/// Reply to an incoming message.
pub async fn reply(&self, msg: &IncomingMessage, text: &str) -> Result<()> {
self.context_tokens
.write()
.await
.insert(msg.user_id.clone(), msg.context_token.clone());
self.send_text(&msg.user_id, text, &msg.context_token).await
}
/// Send text to a user (needs prior context_token).
pub async fn send(&self, user_id: &str, text: &str) -> Result<()> {
let ct = self.context_tokens.read().await.get(user_id).cloned();
let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?;
self.send_text(user_id, text, &ct).await
}
/// Show "typing..." indicator.
pub async fn send_typing(&self, user_id: &str) -> Result<()> {
let ct = self.context_tokens.read().await.get(user_id).cloned();
let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?;
let (base_url, token) = self.get_auth().await?;
let config = self
.client
.get_config(&base_url, &token, user_id, &ct)
.await?;
if let Some(ticket) = config.typing_ticket {
self.client
.send_typing(&base_url, &token, user_id, &ticket, 1)
.await?;
}
Ok(())
}
/// Reply with media content (image, video, or file).
pub async fn reply_media(&self, msg: &IncomingMessage, content: SendContent) -> Result<()> {
self.context_tokens
.write()
.await
.insert(msg.user_id.clone(), msg.context_token.clone());
self.send_content(&msg.user_id, &msg.context_token, content)
.await
}
/// Send any content type to a user (needs prior context_token).
pub async fn send_media(&self, user_id: &str, content: SendContent) -> Result<()> {
let ct = self.context_tokens.read().await.get(user_id).cloned();
let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?;
self.send_content(user_id, &ct, content).await
}
/// Download media from an incoming message.
/// Returns None if the message has no media. Priority: image > file > video > voice.
pub async fn download(&self, msg: &IncomingMessage) -> Result<Option<DownloadedMedia>> {
if let Some(img) = msg.images.first() {
if let Some(ref media) = img.media {
let data = self.cdn.download(media, img.aes_key.as_deref()).await?;
return Ok(Some(DownloadedMedia {
data,
media_type: "image".into(),
file_name: None,
format: None,
}));
}
}
if let Some(file) = msg.files.first() {
if let Some(ref media) = file.media {
let data = self.cdn.download(media, None).await?;
return Ok(Some(DownloadedMedia {
data,
media_type: "file".into(),
file_name: Some(file.file_name.clone().unwrap_or_else(|| "file.bin".into())),
format: None,
}));
}
}
if let Some(video) = msg.videos.first() {
if let Some(ref media) = video.media {
let data = self.cdn.download(media, None).await?;
return Ok(Some(DownloadedMedia {
data,
media_type: "video".into(),
file_name: None,
format: None,
}));
}
}
if let Some(voice) = msg.voices.first() {
if let Some(ref media) = voice.media {
let data = self.cdn.download(media, None).await?;
return Ok(Some(DownloadedMedia {
data,
media_type: "voice".into(),
file_name: None,
format: Some("silk".into()),
}));
}
}
Ok(None)
}
/// Download and decrypt a raw CDN media reference.
pub async fn download_raw(
&self,
media: &CDNMedia,
aeskey_override: Option<&str>,
) -> Result<Vec<u8>> {
self.cdn.download(media, aeskey_override).await
}
/// Upload data to WeChat CDN without sending a message.
pub async fn upload(
&self,
data: &[u8],
user_id: &str,
media_type: i32,
) -> Result<UploadResult> {
let (base_url, token) = self.get_auth().await?;
self.cdn_upload(&base_url, &token, data, user_id, media_type)
.await
}
/// Start the long-poll loop. Blocks until stopped.
pub async fn run(&self) -> Result<()> {
*self.stopped.write().await = false;
info!("Long-poll loop started");
let mut retry_delay = Duration::from_secs(1);
loop {
if *self.stopped.read().await {
break;
}
let (base_url, token) = self.get_auth().await?;
let cursor = self.cursor.read().await.clone();
match self.client.get_updates(&base_url, &token, &cursor).await {
Ok(updates) => {
if !updates.get_updates_buf.is_empty() {
*self.cursor.write().await = updates.get_updates_buf;
}
retry_delay = Duration::from_secs(1);
for wire in &updates.msgs {
self.remember_context(wire).await;
if let Some(incoming) = IncomingMessage::from_wire(wire) {
let handlers = self.handlers.lock().await;
for handler in handlers.iter() {
handler(&incoming);
}
}
}
}
Err(e) if e.is_session_expired() => {
warn!("Session expired — re-login required");
*self.context_tokens.write().await = HashMap::new();
*self.cursor.write().await = String::new();
if let Err(e) = self.login(true).await {
self.report_error(&e);
}
continue;
}
Err(e) => {
self.report_error(&e);
sleep(retry_delay).await;
retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(10));
continue;
}
}
}
info!("Long-poll loop stopped");
Ok(())
}
/// Stop the bot.
pub async fn stop(&self) {
*self.stopped.write().await = true;
}
// --- internal media ---
fn send_content<'a>(
&'a self,
user_id: &'a str,
context_token: &'a str,
content: SendContent,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
let (base_url, token) = self.get_auth().await?;
match content {
SendContent::Text(text) => self.send_text(user_id, &text, context_token).await,
SendContent::Image { data, caption } => {
let result = self
.cdn_upload(&base_url, &token, &data, user_id, 1)
.await?;
let mut items = Vec::new();
if let Some(cap) = caption {
items.push(json!({"type": 1, "text_item": {"text": cap}}));
}
items.push(json!({"type": 2, "image_item": {
"media": cdn_media_json(&result.media),
"mid_size": result.encrypted_file_size,
}}));
let msg = protocol::build_media_message(user_id, context_token, items);
self.client.send_message(&base_url, &token, &msg).await
}
SendContent::Video { data, caption } => {
let result = self
.cdn_upload(&base_url, &token, &data, user_id, 2)
.await?;
let mut items = Vec::new();
if let Some(cap) = caption {
items.push(json!({"type": 1, "text_item": {"text": cap}}));
}
items.push(json!({"type": 5, "video_item": {
"media": cdn_media_json(&result.media),
"video_size": result.encrypted_file_size,
}}));
let msg = protocol::build_media_message(user_id, context_token, items);
self.client.send_message(&base_url, &token, &msg).await
}
SendContent::File {
data,
file_name,
caption,
} => {
let cat = categorize_by_extension(&file_name);
match cat {
"image" => {
self.send_content(
user_id,
context_token,
SendContent::Image { data, caption },
)
.await
}
"video" => {
self.send_content(
user_id,
context_token,
SendContent::Video { data, caption },
)
.await
}
_ => {
if let Some(cap) = caption {
self.send_text(user_id, &cap, context_token).await?;
}
let data_len = data.len();
let result = self
.cdn_upload(&base_url, &token, &data, user_id, 3)
.await?;
let items = vec![json!({"type": 4, "file_item": {
"media": cdn_media_json(&result.media),
"file_name": file_name,
"len": data_len.to_string(),
}})];
let msg = protocol::build_media_message(user_id, context_token, items);
self.client.send_message(&base_url, &token, &msg).await
}
}
}
}
})
}
async fn cdn_upload(
&self,
base_url: &str,
token: &str,
data: &[u8],
user_id: &str,
media_type: i32,
) -> Result<UploadResult> {
let aes_key = crypto::generate_aes_key();
let ciphertext = crypto::encrypt_aes_ecb(data, &aes_key);
let mut filekey_buf = [0u8; 16];
rand::rng().fill_bytes(&mut filekey_buf);
let filekey = hex::encode(filekey_buf);
let raw_md5 = hex::encode(Md5::digest(data));
let params = protocol::GetUploadUrlParams {
filekey: filekey.clone(),
media_type,
to_user_id: user_id.to_string(),
rawsize: data.len(),
rawfilemd5: raw_md5,
filesize: ciphertext.len(),
no_need_thumb: true,
aeskey: crypto::encode_aes_key_hex(&aes_key),
};
let upload_resp = self.client.get_upload_url(base_url, token, &params).await?;
let upload_param = upload_resp.upload_param.ok_or_else(|| {
WeChatBotError::Media("getuploadurl did not return upload_param".into())
})?;
let upload_url =
protocol::build_cdn_upload_url(protocol::CDN_BASE_URL, &upload_param, &filekey);
let encrypted_file_size = ciphertext.len();
let encrypt_query_param = self.client.upload_to_cdn(&upload_url, &ciphertext).await?;
Ok(UploadResult {
media: CDNMedia {
encrypt_query_param,
aes_key: crypto::encode_aes_key_base64(&aes_key),
encrypt_type: Some(1),
full_url: None,
},
aes_key,
encrypted_file_size,
})
}
// --- internal text ---
async fn send_text(&self, user_id: &str, text: &str, context_token: &str) -> Result<()> {
let (base_url, token) = self.get_auth().await?;
for chunk in chunk_text(text, 4000) {
let msg = protocol::build_text_message(user_id, context_token, &chunk);
self.client.send_message(&base_url, &token, &msg).await?;
}
Ok(())
}
async fn remember_context(&self, wire: &WireMessage) {
let user_id = if wire.message_type == MessageType::User {
&wire.from_user_id
} else {
&wire.to_user_id
};
if !user_id.is_empty() && !wire.context_token.is_empty() {
self.context_tokens
.write()
.await
.insert(user_id.clone(), wire.context_token.clone());
}
}
async fn get_auth(&self) -> Result<(String, String)> {
let creds = self.credentials.read().await;
let creds = creds
.as_ref()
.ok_or_else(|| WeChatBotError::Auth("not logged in".into()))?;
Ok((creds.base_url.clone(), creds.token.clone()))
}
async fn load_credentials(&self) -> Result<Option<Credentials>> {
let path = self.cred_path.clone().unwrap_or_else(default_cred_path);
match tokio::fs::read_to_string(&path).await {
Ok(data) => Ok(Some(serde_json::from_str(&data)?)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
}
async fn save_credentials(&self, creds: &Credentials) -> Result<()> {
let path = self.cred_path.clone().unwrap_or_else(default_cred_path);
let dir = std::path::Path::new(&path).parent().unwrap();
tokio::fs::create_dir_all(dir).await?;
let data = serde_json::to_string_pretty(creds)?;
tokio::fs::write(&path, format!("{}\n", data)).await?;
Ok(())
}
fn report_error(&self, err: &WeChatBotError) {
error!("{}", err);
if let Some(ref cb) = self.on_error {
cb(err);
}
}
}
/// Content to send via reply_media / send_media.
pub enum SendContent {
Text(String),
Image {
data: Vec<u8>,
caption: Option<String>,
},
Video {
data: Vec<u8>,
caption: Option<String>,
},
File {
data: Vec<u8>,
file_name: String,
caption: Option<String>,
},
}
fn cdn_media_json(media: &CDNMedia) -> serde_json::Value {
let mut v = json!({
"encrypt_query_param": media.encrypt_query_param,
"aes_key": media.aes_key,
});
if let Some(et) = media.encrypt_type {
v["encrypt_type"] = json!(et);
}
v
}
fn categorize_by_extension(filename: &str) -> &'static str {
let ext = Path::new(filename)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
match ext.as_str() {
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" | "svg" => "image",
"mp4" | "mov" | "webm" | "mkv" | "avi" => "video",
_ => "file",
}
}
fn chunk_text(text: &str, limit: usize) -> Vec<String> {
if text.len() <= limit {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut remaining = text;
while !remaining.is_empty() {
if remaining.len() <= limit {
chunks.push(remaining.to_string());
break;
}
let window = &remaining[..limit];
let cut = window
.rfind("\n\n")
.filter(|&i| i > limit * 3 / 10)
.map(|i| i + 2)
.or_else(|| {
window
.rfind('\n')
.filter(|&i| i > limit * 3 / 10)
.map(|i| i + 1)
})
.or_else(|| {
window
.rfind(' ')
.filter(|&i| i > limit * 3 / 10)
.map(|i| i + 1)
})
.unwrap_or(limit);
chunks.push(remaining[..cut].to_string());
remaining = &remaining[cut..];
}
if chunks.is_empty() {
vec![String::new()]
} else {
chunks
}
}
fn default_cred_path() -> String {
let home = dirs_next::home_dir().unwrap_or_else(|| ".".into());
home.join(".wechatbot")
.join("credentials.json")
.to_string_lossy()
.to_string()
}
fn chrono_now() -> String {
// Simple ISO 8601 without chrono dependency
let dur = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap();
format!("{}Z", dur.as_secs())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_text_short() {
let chunks = chunk_text("hello", 100);
assert_eq!(chunks, vec!["hello"]);
}
#[test]
fn chunk_text_empty() {
let chunks = chunk_text("", 100);
assert_eq!(chunks, vec![""]);
}
#[test]
fn chunk_text_splits_on_paragraph() {
let text = "aaaa\n\nbbbb";
let chunks = chunk_text(text, 7);
assert_eq!(chunks, vec!["aaaa\n\n", "bbbb"]);
}
#[test]
fn chunk_text_splits_on_newline() {
let text = "aaaa\nbbbb";
let chunks = chunk_text(text, 7);
assert_eq!(chunks, vec!["aaaa\n", "bbbb"]);
}
#[test]
fn chunk_text_exact_limit() {
let text = "abcdef";
let chunks = chunk_text(text, 6);
assert_eq!(chunks, vec!["abcdef"]);
}
#[test]
fn default_cred_path_not_empty() {
let path = default_cred_path();
assert!(!path.is_empty());
assert!(path.contains(".wechatbot"));
assert!(path.contains("credentials.json"));
}
#[test]
fn categorize_image_extensions() {
assert_eq!(categorize_by_extension("photo.png"), "image");
assert_eq!(categorize_by_extension("photo.JPG"), "image");
assert_eq!(categorize_by_extension("anim.gif"), "image");
assert_eq!(categorize_by_extension("pic.webp"), "image");
}
#[test]
fn categorize_video_extensions() {
assert_eq!(categorize_by_extension("clip.mp4"), "video");
assert_eq!(categorize_by_extension("clip.MOV"), "video");
assert_eq!(categorize_by_extension("movie.webm"), "video");
}
#[test]
fn categorize_file_extensions() {
assert_eq!(categorize_by_extension("report.pdf"), "file");
assert_eq!(categorize_by_extension("data.csv"), "file");
assert_eq!(categorize_by_extension("noext"), "file");
}
#[test]
fn cdn_media_json_with_encrypt_type() {
let media = CDNMedia {
encrypt_query_param: "param=1".to_string(),
aes_key: "key123".to_string(),
encrypt_type: Some(1),
full_url: None,
};
let j = cdn_media_json(&media);
assert_eq!(j["encrypt_query_param"], "param=1");
assert_eq!(j["aes_key"], "key123");
assert_eq!(j["encrypt_type"], 1);
}
#[test]
fn cdn_media_json_without_encrypt_type() {
let media = CDNMedia {
encrypt_query_param: "p".to_string(),
aes_key: "k".to_string(),
encrypt_type: None,
full_url: None,
};
let j = cdn_media_json(&media);
assert!(j.get("encrypt_type").is_none());
}
}

View File

@ -1,138 +0,0 @@
//! Low-level CDN client for direct media download.
//!
//! [`CdnClient`] is a primitive layer that can be used independently of
//! [`WeChatBot`](crate::WeChatBot), e.g. when you drive `get_updates` yourself
//! via [`ILinkClient`](crate::protocol::ILinkClient) and only need decryption
//! for a specific attachment.
//!
//! Modeled after [`teloxide_core::Bot`]: wraps a [`reqwest::Client`] so
//! connection pool / TLS session / DNS cache are reused across calls, and is
//! cheap to [`Clone`].
use reqwest::Client;
use std::time::Duration;
use crate::crypto;
use crate::error::{Result, WeChatBotError};
use crate::protocol::CDN_BASE_URL;
use crate::types::CDNMedia;
/// HTTP client for WeChat CDN media endpoints.
///
/// Cheap to [`Clone`] — shares the underlying [`reqwest::Client`], which uses
/// an `Arc` internally.
///
/// # Example
///
/// ```no_run
/// use wechatbot::{CdnClient, CDNMedia};
///
/// # async fn demo(media: CDNMedia) -> Result<(), Box<dyn std::error::Error>> {
/// let cdn = CdnClient::new();
/// let bytes = cdn.download(&media, None).await?;
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct CdnClient {
http: Client,
base_url: String,
}
impl Default for CdnClient {
fn default() -> Self {
Self::new()
}
}
impl CdnClient {
/// Create a [`CdnClient`] with a fresh internal [`reqwest::Client`].
pub fn new() -> Self {
Self::with_client(Client::new())
}
/// Create a [`CdnClient`] that reuses an existing [`reqwest::Client`].
///
/// Useful when the caller already maintains a shared HTTP client with
/// custom proxy / TLS / timeout configuration.
pub fn with_client(http: Client) -> Self {
Self {
http,
base_url: CDN_BASE_URL.to_string(),
}
}
/// Override the CDN base URL (defaults to [`CDN_BASE_URL`]).
///
/// Primarily intended for tests and regional endpoints.
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
/// Download and AES-decrypt a CDN media object.
///
/// `aes_key_override` is used when the decryption key is attached to the
/// message metadata (e.g. [`ImageContent::aes_key`](crate::ImageContent::aes_key))
/// rather than embedded in the media's own `aes_key` field.
pub async fn download(
&self,
media: &CDNMedia,
aes_key_override: Option<&str>,
) -> Result<Vec<u8>> {
let download_url = format!(
"{}/download?encrypted_query_param={}",
self.base_url,
urlencoding::encode(&media.encrypt_query_param)
);
let resp = self
.http
.get(&download_url)
.timeout(Duration::from_secs(60))
.send()
.await?;
if !resp.status().is_success() {
return Err(WeChatBotError::Media(format!(
"CDN download failed: HTTP {}",
resp.status()
)));
}
let ciphertext = resp.bytes().await?.to_vec();
let key_source = aes_key_override.unwrap_or(&media.aes_key);
if key_source.is_empty() {
return Err(WeChatBotError::Media("no AES key available".into()));
}
let aes_key = crypto::decode_aes_key(key_source)?;
crypto::decrypt_aes_ecb(&ciphertext, &aes_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_and_new_equivalent() {
let a = CdnClient::default();
let b = CdnClient::new();
assert_eq!(a.base_url, b.base_url);
}
#[test]
fn with_base_url_overrides() {
let c = CdnClient::new().with_base_url("https://example.test/cdn");
assert_eq!(c.base_url, "https://example.test/cdn");
}
#[test]
fn clone_is_cheap_and_preserves_config() {
let c = CdnClient::new().with_base_url("https://x.y/z");
let cloned = c.clone();
assert_eq!(c.base_url, cloned.base_url);
}
}

View File

@ -1,148 +0,0 @@
//! AES-128-ECB encryption for WeChat CDN media files.
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit};
use aes::Aes128;
use base64::Engine;
use rand::Rng;
use crate::error::{Result, WeChatBotError};
/// Encrypt plaintext with AES-128-ECB and PKCS7 padding.
pub fn encrypt_aes_ecb(plaintext: &[u8], key: &[u8; 16]) -> Vec<u8> {
let cipher = Aes128::new(key.into());
let padded = pkcs7_pad(plaintext, 16);
let mut ciphertext = padded;
for chunk in ciphertext.chunks_exact_mut(16) {
cipher.encrypt_block(chunk.into());
}
ciphertext
}
/// Decrypt AES-128-ECB ciphertext and remove PKCS7 padding.
pub fn decrypt_aes_ecb(ciphertext: &[u8], key: &[u8; 16]) -> Result<Vec<u8>> {
if ciphertext.len() % 16 != 0 {
return Err(WeChatBotError::Media(
"ciphertext length not a multiple of 16".into(),
));
}
let cipher = Aes128::new(key.into());
let mut plaintext = ciphertext.to_vec();
for chunk in plaintext.chunks_exact_mut(16) {
cipher.decrypt_block(chunk.into());
}
pkcs7_unpad(&plaintext)
}
/// Generate a random 16-byte AES key.
pub fn generate_aes_key() -> [u8; 16] {
let mut key = [0u8; 16];
rand::rng().fill_bytes(&mut key);
key
}
/// Calculate encrypted size with PKCS7 padding.
pub fn encrypted_size(raw_size: usize) -> usize {
((raw_size + 1 + 15) / 16) * 16
}
/// Decode an aes_key from the protocol (handles all three formats).
pub fn decode_aes_key(encoded: &str) -> Result<[u8; 16]> {
// Direct hex (32 chars)
if encoded.len() == 32 && encoded.chars().all(|c| c.is_ascii_hexdigit()) {
let bytes =
hex::decode(encoded).map_err(|e| WeChatBotError::Media(format!("hex decode: {e}")))?;
return bytes_to_key(&bytes);
}
// Base64 decode
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded)
.or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(encoded))
.map_err(|e| WeChatBotError::Media(format!("base64 decode: {e}")))?;
if decoded.len() == 16 {
return bytes_to_key(&decoded);
}
if decoded.len() == 32 {
let hex_str = std::str::from_utf8(&decoded)
.map_err(|_| WeChatBotError::Media("decoded key is not UTF-8".into()))?;
if hex_str.chars().all(|c| c.is_ascii_hexdigit()) {
let bytes = hex::decode(hex_str)
.map_err(|e| WeChatBotError::Media(format!("hex decode: {e}")))?;
return bytes_to_key(&bytes);
}
}
Err(WeChatBotError::Media(format!(
"unexpected decoded key length: {}",
decoded.len()
)))
}
/// Encode an AES key as hex (for getuploadurl).
pub fn encode_aes_key_hex(key: &[u8; 16]) -> String {
hex::encode(key)
}
/// Encode an AES key as base64(hex) (for CDNMedia.aes_key).
pub fn encode_aes_key_base64(key: &[u8; 16]) -> String {
base64::engine::general_purpose::STANDARD.encode(hex::encode(key))
}
fn bytes_to_key(bytes: &[u8]) -> Result<[u8; 16]> {
bytes
.try_into()
.map_err(|_| WeChatBotError::Media(format!("key length {} != 16", bytes.len())))
}
fn pkcs7_pad(data: &[u8], block_size: usize) -> Vec<u8> {
let padding = block_size - (data.len() % block_size);
let mut result = data.to_vec();
result.extend(std::iter::repeat(padding as u8).take(padding));
result
}
fn pkcs7_unpad(data: &[u8]) -> Result<Vec<u8>> {
if data.is_empty() {
return Err(WeChatBotError::Media("empty data".into()));
}
let padding = *data.last().unwrap() as usize;
if padding == 0 || padding > data.len() || padding > 16 {
return Err(WeChatBotError::Media("invalid PKCS7 padding".into()));
}
Ok(data[..data.len() - padding].to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip() {
let key = generate_aes_key();
let plaintext = b"Hello, WeChat!";
let ct = encrypt_aes_ecb(plaintext, &key);
let pt = decrypt_aes_ecb(&ct, &key).unwrap();
assert_eq!(pt, plaintext);
}
#[test]
fn encrypted_size_calc() {
assert_eq!(encrypted_size(14), 16);
assert_eq!(encrypted_size(16), 32);
assert_eq!(encrypted_size(100), 112);
}
#[test]
fn decode_direct_hex() {
let key = decode_aes_key("00112233445566778899aabbccddeeff").unwrap();
assert_eq!(key.len(), 16);
}
#[test]
fn decode_base64_raw() {
let key = decode_aes_key("ABEiM0RVZneImaq7zN3u/w==").unwrap();
assert_eq!(key.len(), 16);
}
}

View File

@ -1,93 +0,0 @@
use thiserror::Error;
/// Errors that can occur in the SDK.
#[derive(Error, Debug)]
pub enum WeChatBotError {
#[error("API error: {message} (http={http_status}, errcode={errcode})")]
Api {
message: String,
http_status: u16,
errcode: i32,
},
#[error("Auth error: {0}")]
Auth(String),
#[error("No context_token for user {0}")]
NoContext(String),
#[error("Media error: {0}")]
Media(String),
#[error("Transport error: {0}")]
Transport(#[from] reqwest::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
Other(String),
}
impl WeChatBotError {
/// Returns true if this is a session-expired error (errcode -14).
pub fn is_session_expired(&self) -> bool {
matches!(self, WeChatBotError::Api { errcode: -14, .. })
}
}
pub type Result<T> = std::result::Result<T, WeChatBotError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_expired_true() {
let err = WeChatBotError::Api {
message: "session expired".to_string(),
http_status: 200,
errcode: -14,
};
assert!(err.is_session_expired());
}
#[test]
fn session_expired_false() {
let err = WeChatBotError::Api {
message: "other error".to_string(),
http_status: 400,
errcode: -1,
};
assert!(!err.is_session_expired());
}
#[test]
fn non_api_not_session_expired() {
let err = WeChatBotError::Auth("test".to_string());
assert!(!err.is_session_expired());
}
#[test]
fn error_display() {
let err = WeChatBotError::Api {
message: "bad request".to_string(),
http_status: 400,
errcode: -1,
};
let msg = format!("{}", err);
assert!(msg.contains("bad request"));
assert!(msg.contains("400"));
assert!(msg.contains("-1"));
}
#[test]
fn no_context_error() {
let err = WeChatBotError::NoContext("user123".to_string());
let msg = format!("{}", err);
assert!(msg.contains("user123"));
}
}

View File

@ -1,38 +0,0 @@
//! # wechatbot
//!
//! WeChat iLink Bot SDK for Rust.
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use wechatbot::{WeChatBot, BotOptions};
//!
//! #[tokio::main]
//! async fn main() {
//! let bot = WeChatBot::new(BotOptions::default());
//! bot.login(false).await.unwrap();
//!
//! bot.on_message(Box::new(|msg| {
//! println!("{}: {}", msg.user_id, msg.text);
//! })).await;
//!
//! bot.run().await.unwrap();
//! }
//! ```
pub mod bot;
pub mod cdn;
pub mod crypto;
pub mod error;
pub mod protocol;
pub mod types;
pub use bot::{BotOptions, MessageHandler, SendContent, WeChatBot};
pub use cdn::CdnClient;
pub use crypto::{
decode_aes_key, decrypt_aes_ecb, decrypt_aes_ecb as download_decrypt, encode_aes_key_base64,
encode_aes_key_hex, encrypt_aes_ecb, generate_aes_key,
};
pub use error::{Result, WeChatBotError};
pub use protocol::{build_cdn_upload_url, GetUploadUrlParams, GetUploadUrlResponse};
pub use types::*;

View File

@ -1,407 +0,0 @@
//! Raw iLink Bot API HTTP calls.
use base64::Engine;
use rand::Rng;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::time::Duration;
use uuid::Uuid;
use crate::error::{Result, WeChatBotError};
#[allow(unused_imports)]
use crate::types::*;
pub const DEFAULT_BASE_URL: &str = "https://ilinkai.weixin.qq.com";
pub const CDN_BASE_URL: &str = "https://novac2c.cdn.weixin.qq.com/c2c";
pub const CHANNEL_VERSION: &str = env!("CARGO_PKG_VERSION");
/// iLink-App-Id header value.
const ILINK_APP_ID: &str = "bot";
/// Build iLink-App-ClientVersion from the crate version (0x00MMNNPP).
fn build_client_version() -> String {
let version = env!("CARGO_PKG_VERSION");
let parts: Vec<u32> = version.split('.').filter_map(|p| p.parse().ok()).collect();
let major = parts.first().copied().unwrap_or(0) & 0xff;
let minor = parts.get(1).copied().unwrap_or(0) & 0xff;
let patch = parts.get(2).copied().unwrap_or(0) & 0xff;
let num = (major << 16) | (minor << 8) | patch;
num.to_string()
}
/// Generate the X-WECHAT-UIN header value.
pub fn random_wechat_uin() -> String {
let mut buf = [0u8; 4];
rand::rng().fill_bytes(&mut buf);
let val = u32::from_be_bytes(buf);
base64::engine::general_purpose::STANDARD.encode(val.to_string())
}
/// QR code response.
#[derive(Debug, Deserialize)]
pub struct QrCodeResponse {
pub qrcode: String,
pub qrcode_img_content: String,
}
/// QR status response.
#[derive(Debug, Deserialize)]
pub struct QrStatusResponse {
pub status: String,
pub bot_token: Option<String>,
pub ilink_bot_id: Option<String>,
pub ilink_user_id: Option<String>,
pub baseurl: Option<String>,
/// New host to redirect polling to when status is "scaned_but_redirect".
pub redirect_host: Option<String>,
}
/// Get updates response.
#[derive(Debug, Deserialize)]
pub struct GetUpdatesResponse {
#[serde(default)]
pub ret: i32,
#[serde(default)]
pub msgs: Vec<WireMessage>,
#[serde(default)]
pub get_updates_buf: String,
pub errcode: Option<i32>,
pub errmsg: Option<String>,
}
/// Get config response.
#[derive(Debug, Deserialize)]
pub struct GetConfigResponse {
pub typing_ticket: Option<String>,
}
/// Low-level iLink API client.
#[derive(Debug)]
pub struct ILinkClient {
http: Client,
}
impl ILinkClient {
pub fn new() -> Self {
Self {
http: Client::builder()
.timeout(Duration::from_secs(45))
.build()
.unwrap(),
}
}
pub async fn get_qr_code(&self, base_url: &str) -> Result<QrCodeResponse> {
let url = format!("{}/ilink/bot/get_bot_qrcode?bot_type=3", base_url);
let resp = self
.http
.get(&url)
.header("iLink-App-Id", ILINK_APP_ID)
.header("iLink-App-ClientVersion", build_client_version())
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn poll_qr_status(&self, base_url: &str, qrcode: &str) -> Result<QrStatusResponse> {
let url = format!(
"{}/ilink/bot/get_qrcode_status?qrcode={}",
base_url,
urlencoding::encode(qrcode)
);
let resp = self
.http
.get(&url)
.header("iLink-App-Id", ILINK_APP_ID)
.header("iLink-App-ClientVersion", build_client_version())
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn get_updates(
&self,
base_url: &str,
token: &str,
cursor: &str,
) -> Result<GetUpdatesResponse> {
let body = json!({
"get_updates_buf": cursor,
"base_info": { "channel_version": CHANNEL_VERSION }
});
let resp = self
.api_post(base_url, "/ilink/bot/getupdates", token, &body, 45)
.await?;
let result: GetUpdatesResponse = serde_json::from_value(resp)?;
if result.ret != 0 || result.errcode.is_some_and(|c| c != 0) {
let code = result.errcode.unwrap_or(result.ret);
let msg = result
.errmsg
.unwrap_or_else(|| format!("ret={}", result.ret));
return Err(WeChatBotError::Api {
message: msg,
http_status: 200,
errcode: code,
});
}
Ok(result)
}
pub async fn send_message(&self, base_url: &str, token: &str, msg: &Value) -> Result<()> {
let body = json!({
"msg": msg,
"base_info": { "channel_version": CHANNEL_VERSION }
});
self.api_post(base_url, "/ilink/bot/sendmessage", token, &body, 15)
.await?;
Ok(())
}
pub async fn get_config(
&self,
base_url: &str,
token: &str,
user_id: &str,
context_token: &str,
) -> Result<GetConfigResponse> {
let body = json!({
"ilink_user_id": user_id,
"context_token": context_token,
"base_info": { "channel_version": CHANNEL_VERSION }
});
let resp = self
.api_post(base_url, "/ilink/bot/getconfig", token, &body, 15)
.await?;
Ok(serde_json::from_value(resp)?)
}
pub async fn send_typing(
&self,
base_url: &str,
token: &str,
user_id: &str,
ticket: &str,
status: i32,
) -> Result<()> {
let body = json!({
"ilink_user_id": user_id,
"typing_ticket": ticket,
"status": status,
"base_info": { "channel_version": CHANNEL_VERSION }
});
self.api_post(base_url, "/ilink/bot/sendtyping", token, &body, 15)
.await?;
Ok(())
}
async fn api_post(
&self,
base_url: &str,
endpoint: &str,
token: &str,
body: &Value,
timeout_secs: u64,
) -> Result<Value> {
let url = format!("{}{}", base_url, endpoint);
let resp = self
.http
.post(&url)
.timeout(Duration::from_secs(timeout_secs))
.header("Content-Type", "application/json")
.header("AuthorizationType", "ilink_bot_token")
.header("Authorization", format!("Bearer {}", token))
.header("X-WECHAT-UIN", random_wechat_uin())
.header("iLink-App-Id", ILINK_APP_ID)
.header("iLink-App-ClientVersion", build_client_version())
.json(body)
.send()
.await?;
let status = resp.status().as_u16();
let text = resp.text().await?;
let value: Value = serde_json::from_str(&text).unwrap_or(json!({}));
if status >= 400 {
return Err(WeChatBotError::Api {
message: value["errmsg"]
.as_str()
.or_else(|| value["message"].as_str())
.unwrap_or(&text)
.to_string(),
http_status: status,
errcode: value["errcode"].as_i64().unwrap_or(0) as i32,
});
}
if let Some(errcode) = value["errcode"].as_i64() {
if errcode != 0 {
return Err(WeChatBotError::Api {
message: value["errmsg"]
.as_str()
.or_else(|| value["message"].as_str())
.unwrap_or(&text)
.to_string(),
http_status: status,
errcode: errcode as i32,
});
}
}
Ok(value)
}
}
/// Build a media message payload.
pub fn build_media_message(user_id: &str, context_token: &str, item_list: Vec<Value>) -> Value {
json!({
"from_user_id": "",
"to_user_id": user_id,
"client_id": Uuid::new_v4().to_string(),
"message_type": 2,
"message_state": 2,
"context_token": context_token,
"item_list": item_list
})
}
/// GetUploadUrl request parameters.
pub struct GetUploadUrlParams {
pub filekey: String,
pub media_type: i32,
pub to_user_id: String,
pub rawsize: usize,
pub rawfilemd5: String,
pub filesize: usize,
pub no_need_thumb: bool,
pub aeskey: String,
}
/// GetUploadUrl response.
#[derive(Debug, Deserialize)]
pub struct GetUploadUrlResponse {
pub upload_param: Option<String>,
pub thumb_upload_param: Option<String>,
pub upload_full_url: Option<String>,
}
impl ILinkClient {
/// Get a pre-signed CDN upload URL.
pub async fn get_upload_url(
&self,
base_url: &str,
token: &str,
params: &GetUploadUrlParams,
) -> Result<GetUploadUrlResponse> {
let body = json!({
"filekey": params.filekey,
"media_type": params.media_type,
"to_user_id": params.to_user_id,
"rawsize": params.rawsize,
"rawfilemd5": params.rawfilemd5,
"filesize": params.filesize,
"no_need_thumb": params.no_need_thumb,
"aeskey": params.aeskey,
"base_info": { "channel_version": CHANNEL_VERSION }
});
let resp = self
.api_post(base_url, "/ilink/bot/getuploadurl", token, &body, 15)
.await?;
Ok(serde_json::from_value(resp)?)
}
/// Upload encrypted bytes to CDN with retry (up to 3 attempts).
/// Returns the download encrypted_query_param from the x-encrypted-param header.
pub async fn upload_to_cdn(&self, cdn_url: &str, ciphertext: &[u8]) -> Result<String> {
const MAX_RETRIES: u32 = 3;
let mut last_err = None;
for attempt in 1..=MAX_RETRIES {
match self
.http
.post(cdn_url)
.header("Content-Type", "application/octet-stream")
.body(ciphertext.to_vec())
.send()
.await
{
Ok(resp) => {
let status = resp.status().as_u16();
if status >= 400 && status < 500 {
let err_msg = resp
.headers()
.get("x-error-message")
.and_then(|v| v.to_str().ok())
.unwrap_or("client error")
.to_string();
return Err(WeChatBotError::Media(format!(
"CDN upload client error {}: {}",
status, err_msg
)));
}
if status != 200 {
let err_msg = resp
.headers()
.get("x-error-message")
.and_then(|v| v.to_str().ok())
.unwrap_or("server error")
.to_string();
last_err = Some(WeChatBotError::Media(format!(
"CDN upload server error {}: {}",
status, err_msg
)));
continue;
}
match resp
.headers()
.get("x-encrypted-param")
.and_then(|v| v.to_str().ok())
{
Some(param) => return Ok(param.to_string()),
None => {
last_err = Some(WeChatBotError::Media(
"CDN upload response missing x-encrypted-param header".into(),
));
continue;
}
}
}
Err(e) => {
last_err = Some(WeChatBotError::Other(format!(
"CDN upload network error: {}",
e
)));
if attempt < MAX_RETRIES {
continue;
}
}
}
}
Err(last_err.unwrap_or_else(|| {
WeChatBotError::Media(format!("CDN upload failed after {} attempts", MAX_RETRIES))
}))
}
}
/// Build a CDN upload URL from params.
pub fn build_cdn_upload_url(cdn_base_url: &str, upload_param: &str, filekey: &str) -> String {
format!(
"{}/upload?encrypted_query_param={}&filekey={}",
cdn_base_url,
urlencoding::encode(upload_param),
urlencoding::encode(filekey)
)
}
/// Build a text message payload.
pub fn build_text_message(user_id: &str, context_token: &str, text: &str) -> Value {
json!({
"from_user_id": "",
"to_user_id": user_id,
"client_id": Uuid::new_v4().to_string(),
"message_type": 2,
"message_state": 2,
"context_token": context_token,
"item_list": [{ "type": 1, "text_item": { "text": text } }]
})
}

View File

@ -1,858 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::time::SystemTime;
/// Message sender type.
/// Uses serde_repr for integer (de)serialization: JSON `1` ↔ `MessageType::User`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i32)]
pub enum MessageType {
User = 1,
Bot = 2,
}
/// Message delivery state.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i32)]
pub enum MessageState {
New = 0,
Generating = 1,
Finish = 2,
}
/// Content type of a message item.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i32)]
pub enum MessageItemType {
Text = 1,
Image = 2,
Voice = 3,
File = 4,
Video = 5,
}
/// Media type for upload requests.
#[derive(Debug, Clone, Copy)]
#[repr(i32)]
pub enum MediaType {
Image = 1,
Video = 2,
File = 3,
Voice = 4,
}
/// CDN media reference.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CDNMedia {
pub encrypt_query_param: String,
pub aes_key: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub encrypt_type: Option<i32>,
/// Complete download URL returned by server; when set, use directly.
#[serde(skip_serializing_if = "Option::is_none")]
pub full_url: Option<String>,
}
/// Text content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextItem {
pub text: String,
}
/// Image content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageItem {
#[serde(skip_serializing_if = "Option::is_none")]
pub media: Option<CDNMedia>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thumb_media: Option<CDNMedia>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aeskey: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mid_size: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thumb_width: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thumb_height: Option<i32>,
}
/// Voice content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoiceItem {
#[serde(skip_serializing_if = "Option::is_none")]
pub media: Option<CDNMedia>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encode_type: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub playtime: Option<i32>,
}
/// File content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileItem {
#[serde(skip_serializing_if = "Option::is_none")]
pub media: Option<CDNMedia>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub md5: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub len: Option<String>,
}
/// Video content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoItem {
#[serde(skip_serializing_if = "Option::is_none")]
pub media: Option<CDNMedia>,
#[serde(skip_serializing_if = "Option::is_none")]
pub video_size: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub play_length: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thumb_media: Option<CDNMedia>,
}
/// Referenced/quoted message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefMessage {
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message_item: Option<Box<WireMessageItem>>,
}
/// A single content item in a message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WireMessageItem {
#[serde(rename = "type")]
pub item_type: MessageItemType,
#[serde(skip_serializing_if = "Option::is_none")]
pub text_item: Option<TextItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_item: Option<ImageItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_item: Option<VoiceItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_item: Option<FileItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub video_item: Option<VideoItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ref_msg: Option<RefMessage>,
}
/// Raw wire message from the iLink API.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WireMessage {
pub from_user_id: String,
pub to_user_id: String,
pub client_id: String,
pub create_time_ms: i64,
pub message_type: MessageType,
pub message_state: MessageState,
pub context_token: String,
pub item_list: Vec<WireMessageItem>,
}
/// Parsed incoming message — user-friendly.
#[derive(Debug, Clone)]
pub struct IncomingMessage {
pub user_id: String,
pub text: String,
pub content_type: ContentType,
pub timestamp: SystemTime,
pub images: Vec<ImageContent>,
pub voices: Vec<VoiceContent>,
pub files: Vec<FileContent>,
pub videos: Vec<VideoContent>,
pub quoted: Option<QuotedMessage>,
pub raw: WireMessage,
pub(crate) context_token: String,
}
impl IncomingMessage {
/// Opaque reply token bound to this message.
///
/// Pass it back via [`WeChatBot::reply`](crate::WeChatBot::reply) (which
/// does this automatically) or when constructing a message payload with
/// [`protocol::build_text_message`](crate::protocol::build_text_message) /
/// [`protocol::build_media_message`](crate::protocol::build_media_message)
/// for use with [`ILinkClient::send_message`](crate::protocol::ILinkClient::send_message).
pub fn context_token(&self) -> &str {
&self.context_token
}
/// Parse a raw [`WireMessage`] into a user-friendly [`IncomingMessage`].
///
/// Returns `None` if the wire message is not a user-originated message
/// (e.g. it was sent by the bot itself).
///
/// This is the stable entry point for consumers who drive
/// [`ILinkClient::get_updates`](crate::protocol::ILinkClient::get_updates)
/// themselves instead of using [`WeChatBot`](crate::WeChatBot)'s
/// dispatcher.
pub fn from_wire(wire: &WireMessage) -> Option<Self> {
if wire.message_type != MessageType::User {
return None;
}
let mut msg = IncomingMessage {
user_id: wire.from_user_id.clone(),
text: extract_text(&wire.item_list),
content_type: detect_type(&wire.item_list),
timestamp: std::time::UNIX_EPOCH
+ std::time::Duration::from_millis(wire.create_time_ms as u64),
images: Vec::new(),
voices: Vec::new(),
files: Vec::new(),
videos: Vec::new(),
quoted: None,
raw: wire.clone(),
context_token: wire.context_token.clone(),
};
for item in &wire.item_list {
if let Some(ref img) = item.image_item {
msg.images.push(ImageContent {
media: img.media.clone(),
thumb_media: img.thumb_media.clone(),
aes_key: img.aeskey.clone(),
url: img.url.clone(),
width: img.thumb_width,
height: img.thumb_height,
});
}
if let Some(ref voice) = item.voice_item {
msg.voices.push(VoiceContent {
media: voice.media.clone(),
text: voice.text.clone(),
duration_ms: voice.playtime,
encode_type: voice.encode_type,
});
}
if let Some(ref file) = item.file_item {
msg.files.push(FileContent {
media: file.media.clone(),
file_name: file.file_name.clone(),
md5: file.md5.clone(),
size: file.len.as_ref().and_then(|s| s.parse().ok()),
});
}
if let Some(ref video) = item.video_item {
msg.videos.push(VideoContent {
media: video.media.clone(),
thumb_media: video.thumb_media.clone(),
duration_ms: video.play_length,
});
}
if let Some(ref refm) = item.ref_msg {
msg.quoted = Some(QuotedMessage {
title: refm.title.clone(),
text: refm
.message_item
.as_ref()
.and_then(|i| i.text_item.as_ref())
.map(|t| t.text.clone()),
});
}
}
Some(msg)
}
}
fn detect_type(items: &[WireMessageItem]) -> ContentType {
items
.first()
.map_or(ContentType::Text, |item| match item.item_type {
MessageItemType::Image => ContentType::Image,
MessageItemType::Voice => ContentType::Voice,
MessageItemType::File => ContentType::File,
MessageItemType::Video => ContentType::Video,
_ => ContentType::Text,
})
}
fn extract_text(items: &[WireMessageItem]) -> String {
items
.iter()
.filter_map(|item| match item.item_type {
MessageItemType::Text => item.text_item.as_ref().map(|t| t.text.clone()),
MessageItemType::Image => Some(
item.image_item
.as_ref()
.and_then(|i| i.url.clone())
.unwrap_or_else(|| "[image]".to_string()),
),
MessageItemType::Voice => Some(
item.voice_item
.as_ref()
.and_then(|v| v.text.clone())
.unwrap_or_else(|| "[voice]".to_string()),
),
MessageItemType::File => Some(
item.file_item
.as_ref()
.and_then(|f| f.file_name.clone())
.unwrap_or_else(|| "[file]".to_string()),
),
MessageItemType::Video => Some("[video]".to_string()),
})
.collect::<Vec<_>>()
.join("\n")
}
/// Content type of an incoming message.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContentType {
Text,
Image,
Voice,
File,
Video,
}
#[derive(Debug, Clone)]
pub struct ImageContent {
pub media: Option<CDNMedia>,
pub thumb_media: Option<CDNMedia>,
pub aes_key: Option<String>,
pub url: Option<String>,
pub width: Option<i32>,
pub height: Option<i32>,
}
#[derive(Debug, Clone)]
pub struct VoiceContent {
pub media: Option<CDNMedia>,
pub text: Option<String>,
pub duration_ms: Option<i32>,
pub encode_type: Option<i32>,
}
#[derive(Debug, Clone)]
pub struct FileContent {
pub media: Option<CDNMedia>,
pub file_name: Option<String>,
pub md5: Option<String>,
pub size: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct VideoContent {
pub media: Option<CDNMedia>,
pub thumb_media: Option<CDNMedia>,
pub duration_ms: Option<i32>,
}
#[derive(Debug, Clone)]
pub struct QuotedMessage {
pub title: Option<String>,
pub text: Option<String>,
}
/// Result of downloading media from a message.
#[derive(Debug, Clone)]
pub struct DownloadedMedia {
pub data: Vec<u8>,
/// "image", "file", "video", "voice"
pub media_type: String,
pub file_name: Option<String>,
pub format: Option<String>,
}
/// Result of uploading media to CDN.
#[derive(Debug, Clone)]
pub struct UploadResult {
pub media: CDNMedia,
pub aes_key: [u8; 16],
pub encrypted_file_size: usize,
}
/// Stored login credentials.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Credentials {
pub token: String,
#[serde(rename = "baseUrl")]
pub base_url: String,
#[serde(rename = "accountId")]
pub account_id: String,
#[serde(rename = "userId")]
pub user_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub saved_at: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_type_values() {
assert_eq!(MessageType::User as i32, 1);
assert_eq!(MessageType::Bot as i32, 2);
}
#[test]
fn message_state_values() {
assert_eq!(MessageState::New as i32, 0);
assert_eq!(MessageState::Generating as i32, 1);
assert_eq!(MessageState::Finish as i32, 2);
}
#[test]
fn message_item_type_values() {
assert_eq!(MessageItemType::Text as i32, 1);
assert_eq!(MessageItemType::Image as i32, 2);
assert_eq!(MessageItemType::Voice as i32, 3);
assert_eq!(MessageItemType::File as i32, 4);
assert_eq!(MessageItemType::Video as i32, 5);
}
#[test]
fn wire_message_json_round_trip() {
let wire = WireMessage {
from_user_id: "user1".to_string(),
to_user_id: "bot1".to_string(),
client_id: "c1".to_string(),
create_time_ms: 1700000000000,
message_type: MessageType::User,
message_state: MessageState::Finish,
context_token: "ctx".to_string(),
item_list: vec![WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "hello".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}],
};
let json = serde_json::to_string(&wire).unwrap();
let decoded: WireMessage = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.from_user_id, "user1");
assert_eq!(decoded.message_type, MessageType::User);
assert_eq!(decoded.item_list.len(), 1);
assert_eq!(
decoded.item_list[0].text_item.as_ref().unwrap().text,
"hello"
);
}
#[test]
fn credentials_json_camel_case() {
let creds = Credentials {
token: "tok".to_string(),
base_url: "https://api.example.com".to_string(),
account_id: "acc1".to_string(),
user_id: "uid1".to_string(),
saved_at: Some("2024-01-01T00:00:00Z".to_string()),
};
let json = serde_json::to_string(&creds).unwrap();
assert!(json.contains("\"baseUrl\""), "expected camelCase baseUrl");
assert!(
json.contains("\"accountId\""),
"expected camelCase accountId"
);
assert!(json.contains("\"userId\""), "expected camelCase userId");
let decoded: Credentials = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.token, "tok");
assert_eq!(decoded.base_url, "https://api.example.com");
}
#[test]
fn credentials_omits_none_saved_at() {
let creds = Credentials {
token: "tok".to_string(),
base_url: "https://api.example.com".to_string(),
account_id: "acc1".to_string(),
user_id: "uid1".to_string(),
saved_at: None,
};
let json = serde_json::to_string(&creds).unwrap();
assert!(!json.contains("saved_at"), "should omit None saved_at");
}
#[test]
fn cdn_media_json() {
let media = CDNMedia {
encrypt_query_param: "param=abc".to_string(),
aes_key: "key123".to_string(),
encrypt_type: Some(1),
full_url: None,
};
let json = serde_json::to_string(&media).unwrap();
let decoded: CDNMedia = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.encrypt_query_param, "param=abc");
assert_eq!(decoded.aes_key, "key123");
assert_eq!(decoded.encrypt_type, Some(1));
}
#[test]
fn wire_message_with_image() {
let wire = WireMessage {
from_user_id: "user1".to_string(),
to_user_id: "bot1".to_string(),
client_id: "c1".to_string(),
create_time_ms: 1700000000000,
message_type: MessageType::User,
message_state: MessageState::Finish,
context_token: "ctx".to_string(),
item_list: vec![WireMessageItem {
item_type: MessageItemType::Image,
text_item: None,
image_item: Some(ImageItem {
media: None,
thumb_media: None,
aeskey: Some("key".to_string()),
url: Some("http://img.jpg".to_string()),
mid_size: Some(1024),
thumb_width: Some(100),
thumb_height: Some(200),
}),
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}],
};
let json = serde_json::to_string(&wire).unwrap();
let decoded: WireMessage = serde_json::from_str(&json).unwrap();
let img = decoded.item_list[0].image_item.as_ref().unwrap();
assert_eq!(img.url, Some("http://img.jpg".to_string()));
assert_eq!(img.thumb_width, Some(100));
}
#[test]
fn content_type_equality() {
assert_eq!(ContentType::Text, ContentType::Text);
assert_ne!(ContentType::Text, ContentType::Image);
}
#[test]
fn detect_type_text() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "hi".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}];
assert_eq!(detect_type(&items), ContentType::Text);
}
#[test]
fn detect_type_image() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Image,
text_item: None,
image_item: Some(ImageItem {
media: None,
thumb_media: None,
aeskey: None,
url: Some("http://img".to_string()),
mid_size: None,
thumb_width: None,
thumb_height: None,
}),
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}];
assert_eq!(detect_type(&items), ContentType::Image);
}
#[test]
fn detect_type_empty() {
assert_eq!(detect_type(&[]), ContentType::Text);
}
#[test]
fn extract_text_single() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "hello world".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}];
assert_eq!(extract_text(&items), "hello world");
}
#[test]
fn extract_text_multi() {
let items = vec![
WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "line1".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
},
WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "line2".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
},
];
assert_eq!(extract_text(&items), "line1\nline2");
}
#[test]
fn extract_text_image_url() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Image,
text_item: None,
image_item: Some(ImageItem {
media: None,
thumb_media: None,
aeskey: None,
url: Some("http://img.jpg".to_string()),
mid_size: None,
thumb_width: None,
thumb_height: None,
}),
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}];
assert_eq!(extract_text(&items), "http://img.jpg");
}
#[test]
fn extract_text_image_placeholder() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Image,
text_item: None,
image_item: Some(ImageItem {
media: None,
thumb_media: None,
aeskey: None,
url: None,
mid_size: None,
thumb_width: None,
thumb_height: None,
}),
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}];
assert_eq!(extract_text(&items), "[image]");
}
#[test]
fn extract_text_voice_with_text() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Voice,
text_item: None,
image_item: None,
voice_item: Some(VoiceItem {
media: None,
encode_type: None,
text: Some("hello".to_string()),
playtime: None,
}),
file_item: None,
video_item: None,
ref_msg: None,
}];
assert_eq!(extract_text(&items), "hello");
}
#[test]
fn extract_text_file_name() {
let items = vec![WireMessageItem {
item_type: MessageItemType::File,
text_item: None,
image_item: None,
voice_item: None,
file_item: Some(FileItem {
media: None,
file_name: Some("doc.pdf".to_string()),
md5: None,
len: None,
}),
video_item: None,
ref_msg: None,
}];
assert_eq!(extract_text(&items), "doc.pdf");
}
#[test]
fn extract_text_video() {
let items = vec![WireMessageItem {
item_type: MessageItemType::Video,
text_item: None,
image_item: None,
voice_item: None,
file_item: None,
video_item: Some(VideoItem {
media: None,
video_size: None,
play_length: None,
thumb_media: None,
}),
ref_msg: None,
}];
assert_eq!(extract_text(&items), "[video]");
}
#[test]
fn from_wire_user_text() {
let wire = WireMessage {
from_user_id: "user123".to_string(),
to_user_id: "bot456".to_string(),
client_id: "c1".to_string(),
create_time_ms: 1700000000000,
message_type: MessageType::User,
message_state: MessageState::Finish,
context_token: "ctx-abc".to_string(),
item_list: vec![WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "hello".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}],
};
let msg = IncomingMessage::from_wire(&wire).unwrap();
assert_eq!(msg.user_id, "user123");
assert_eq!(msg.text, "hello");
assert_eq!(msg.content_type, ContentType::Text);
assert_eq!(msg.context_token(), "ctx-abc");
}
#[test]
fn from_wire_skips_bot() {
let wire = WireMessage {
from_user_id: "bot456".to_string(),
to_user_id: "user123".to_string(),
client_id: "c1".to_string(),
create_time_ms: 1700000000000,
message_type: MessageType::Bot,
message_state: MessageState::Finish,
context_token: "ctx".to_string(),
item_list: vec![WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "reply".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}],
};
assert!(IncomingMessage::from_wire(&wire).is_none());
}
#[test]
fn from_wire_with_image() {
let wire = WireMessage {
from_user_id: "user123".to_string(),
to_user_id: "bot456".to_string(),
client_id: "c1".to_string(),
create_time_ms: 1700000000000,
message_type: MessageType::User,
message_state: MessageState::Finish,
context_token: "ctx".to_string(),
item_list: vec![WireMessageItem {
item_type: MessageItemType::Image,
text_item: None,
image_item: Some(ImageItem {
media: None,
thumb_media: None,
aeskey: Some("key".to_string()),
url: Some("http://img.jpg".to_string()),
mid_size: None,
thumb_width: Some(100),
thumb_height: Some(200),
}),
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
}],
};
let msg = IncomingMessage::from_wire(&wire).unwrap();
assert_eq!(msg.images.len(), 1);
assert_eq!(msg.images[0].url, Some("http://img.jpg".to_string()));
assert_eq!(msg.images[0].width, Some(100));
assert_eq!(msg.images[0].height, Some(200));
}
#[test]
fn from_wire_with_quoted() {
let wire = WireMessage {
from_user_id: "user123".to_string(),
to_user_id: "bot456".to_string(),
client_id: "c1".to_string(),
create_time_ms: 1700000000000,
message_type: MessageType::User,
message_state: MessageState::Finish,
context_token: "ctx".to_string(),
item_list: vec![WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "replying".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: Some(RefMessage {
title: Some("Original".to_string()),
message_item: Some(Box::new(WireMessageItem {
item_type: MessageItemType::Text,
text_item: Some(TextItem {
text: "original text".to_string(),
}),
image_item: None,
voice_item: None,
file_item: None,
video_item: None,
ref_msg: None,
})),
}),
}],
};
let msg = IncomingMessage::from_wire(&wire).unwrap();
let quoted = msg.quoted.as_ref().unwrap();
assert_eq!(quoted.title, Some("Original".to_string()));
assert_eq!(quoted.text, Some("original text".to_string()));
}
}