Compare commits
4 Commits
3f5ed6e4e4
...
69364e484b
| Author | SHA1 | Date | |
|---|---|---|---|
| 69364e484b | |||
| b239083eb5 | |||
| 32690cb792 | |||
| 597881f72e |
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ Cargo.lock
|
|||||||
.playwright-cli/
|
.playwright-cli/
|
||||||
.venv
|
.venv
|
||||||
PicoBot.code-workspace
|
PicoBot.code-workspace
|
||||||
|
.picobot
|
||||||
|
|||||||
@ -34,3 +34,5 @@ image = { version = "0.25", default-features = false, features = ["jpeg", "png",
|
|||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
meval = "0.2"
|
meval = "0.2"
|
||||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||||
|
rustls = { version = "0.23", features = ["ring"] }
|
||||||
|
wechatbot = { path = "vendor/wechatbot" }
|
||||||
|
|||||||
38
README.md
38
README.md
@ -378,9 +378,9 @@ PicoBot 支持基于文件系统的技能系统,用来给 Agent 注入某一
|
|||||||
当前技能运行时按从低到高优先级合并多个来源,后加载来源可覆盖同名技能:
|
当前技能运行时按从低到高优先级合并多个来源,后加载来源可覆盖同名技能:
|
||||||
|
|
||||||
- 用户级技能:~/.picobot/skills/*/SKILL.md
|
- 用户级技能:~/.picobot/skills/*/SKILL.md
|
||||||
- 用户 Agent 级技能:~/.picobot/agent/skills/*/SKILL.md
|
- 用户 Agent 级技能:~/.agents/skills/*/SKILL.md
|
||||||
- 项目级技能:.picobot/skills/*/SKILL.md
|
- 项目级技能:.picobot/skills/*/SKILL.md
|
||||||
- 项目 Agent 级技能:.picobot/agent/skills/*/SKILL.md
|
- 项目 Agent 级技能:.agents/skills/*/SKILL.md
|
||||||
|
|
||||||
### 7.2 最小 SKILL.md 格式
|
### 7.2 最小 SKILL.md 格式
|
||||||
|
|
||||||
@ -410,7 +410,7 @@ description: 用于总结 Rust 项目架构
|
|||||||
内置工具:
|
内置工具:
|
||||||
|
|
||||||
- skill_list:只读列出技能
|
- skill_list:只读列出技能
|
||||||
- skill_manage:运行时创建、更新、删除、读取和重载技能
|
- skill_manage:运行时创建、更新、删除、批量禁用、读取和重载技能
|
||||||
|
|
||||||
skill_manage 支持的 action:
|
skill_manage 支持的 action:
|
||||||
|
|
||||||
@ -419,8 +419,40 @@ skill_manage 支持的 action:
|
|||||||
- create
|
- create
|
||||||
- update
|
- update
|
||||||
- delete
|
- delete
|
||||||
|
- disable
|
||||||
- reload
|
- reload
|
||||||
|
|
||||||
|
skill 的启用/禁用状态不会写入 config.json,而是写入独立状态文件:
|
||||||
|
|
||||||
|
- 用户级状态:~/.picobot/skill-state.json
|
||||||
|
- 项目级状态:.picobot/skill-state.json
|
||||||
|
|
||||||
|
状态文件当前使用最小 JSON 结构:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"disabled_skills": ["example-skill"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
- disable 默认写入项目级状态文件,可通过 tool 参数中的 scope 指定 user 或 project
|
||||||
|
- disable 只接受 names 数组;即使只禁用 1 个 skill,也需要传单元素数组
|
||||||
|
- 一次 disable 调用会批量处理 names 里的所有 skill,并只做一次 reload
|
||||||
|
- 用户级与项目级状态同时生效,项目运行时会同时读取两者
|
||||||
|
- 某个 skill 只要在任一层状态文件中被禁用,就不会出现在 skill_list、skill_activate 和技能索引提示里
|
||||||
|
|
||||||
|
批量示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action": "disable",
|
||||||
|
"scope": "project",
|
||||||
|
"names": ["lark-calendar", "lark-vc", "lark-minutes"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
skills 配置示例:
|
skills 配置示例:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|||||||
3
src/bootstrap.rs
Normal file
3
src/bootstrap.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub fn initialize_process_runtime() {
|
||||||
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
}
|
||||||
@ -538,6 +538,8 @@ mod tests {
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_from_chat_message_expands_tool_calls() {
|
fn test_from_chat_message_expands_tool_calls() {
|
||||||
let message = ChatMessage::assistant_with_tool_calls(
|
let message = ChatMessage::assistant_with_tool_calls(
|
||||||
@ -556,8 +558,13 @@ mod tests {
|
|||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
let outbound =
|
let outbound = OutboundMessage::from_chat_message(
|
||||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
TEST_CHANNEL,
|
||||||
|
"chat-1",
|
||||||
|
None,
|
||||||
|
&HashMap::new(),
|
||||||
|
&message,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 2);
|
assert_eq!(outbound.len(), 2);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
||||||
@ -588,8 +595,13 @@ mod tests {
|
|||||||
}],
|
}],
|
||||||
);
|
);
|
||||||
|
|
||||||
let outbound =
|
let outbound = OutboundMessage::from_chat_message(
|
||||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
TEST_CHANNEL,
|
||||||
|
"chat-1",
|
||||||
|
None,
|
||||||
|
&HashMap::new(),
|
||||||
|
&message,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 2);
|
assert_eq!(outbound.len(), 2);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
|
||||||
@ -602,8 +614,13 @@ mod tests {
|
|||||||
fn test_from_chat_message_includes_tool_result() {
|
fn test_from_chat_message_includes_tool_result() {
|
||||||
let message = ChatMessage::tool("call-9", "calculator", "2");
|
let message = ChatMessage::tool("call-9", "calculator", "2");
|
||||||
|
|
||||||
let outbound =
|
let outbound = OutboundMessage::from_chat_message(
|
||||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
TEST_CHANNEL,
|
||||||
|
"chat-1",
|
||||||
|
None,
|
||||||
|
&HashMap::new(),
|
||||||
|
&message,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
||||||
@ -618,8 +635,13 @@ mod tests {
|
|||||||
ToolMessageState::PendingUserAction,
|
ToolMessageState::PendingUserAction,
|
||||||
);
|
);
|
||||||
|
|
||||||
let outbound =
|
let outbound = OutboundMessage::from_chat_message(
|
||||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
TEST_CHANNEL,
|
||||||
|
"chat-1",
|
||||||
|
None,
|
||||||
|
&HashMap::new(),
|
||||||
|
&message,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
|
||||||
|
|||||||
@ -149,6 +149,7 @@ struct CachedTenantToken {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FeishuChannel {
|
pub struct FeishuChannel {
|
||||||
|
name: String,
|
||||||
config: FeishuChannelConfig,
|
config: FeishuChannelConfig,
|
||||||
http_client: reqwest::Client,
|
http_client: reqwest::Client,
|
||||||
running: Arc<RwLock<bool>>,
|
running: Arc<RwLock<bool>>,
|
||||||
@ -174,10 +175,12 @@ struct ParsedMessage {
|
|||||||
|
|
||||||
impl FeishuChannel {
|
impl FeishuChannel {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
name: String,
|
||||||
config: FeishuChannelConfig,
|
config: FeishuChannelConfig,
|
||||||
_provider_config: LLMProviderConfig,
|
_provider_config: LLMProviderConfig,
|
||||||
) -> Result<Self, ChannelError> {
|
) -> Result<Self, ChannelError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
name,
|
||||||
config,
|
config,
|
||||||
http_client: reqwest::Client::new(),
|
http_client: reqwest::Client::new(),
|
||||||
running: Arc::new(RwLock::new(false)),
|
running: Arc::new(RwLock::new(false)),
|
||||||
@ -1251,7 +1254,7 @@ impl FeishuChannel {
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
|
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
|
||||||
let msg = crate::bus::InboundMessage {
|
let msg = crate::bus::InboundMessage {
|
||||||
channel: "feishu".to_string(),
|
channel: channel.name().to_string(),
|
||||||
sender_id: parsed.open_id.clone(),
|
sender_id: parsed.open_id.clone(),
|
||||||
chat_id: parsed.chat_id.clone(),
|
chat_id: parsed.chat_id.clone(),
|
||||||
content: parsed.content.clone(),
|
content: parsed.content.clone(),
|
||||||
@ -2281,7 +2284,7 @@ mod tests {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Channel for FeishuChannel {
|
impl Channel for FeishuChannel {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
"feishu"
|
&self.name
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
||||||
|
|||||||
@ -6,7 +6,8 @@ use crate::bus::MessageBus;
|
|||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::channels::cli::CliChannel;
|
use crate::channels::cli::CliChannel;
|
||||||
use crate::channels::feishu::FeishuChannel;
|
use crate::channels::feishu::FeishuChannel;
|
||||||
use crate::config::Config;
|
use crate::channels::wechat::WechatChannel;
|
||||||
|
use crate::config::{Config, TaggedChannelConfig};
|
||||||
|
|
||||||
/// ChannelManager manages all Channel instances and the MessageBus
|
/// ChannelManager manages all Channel instances and the MessageBus
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -42,23 +43,57 @@ impl ChannelManager {
|
|||||||
pub async fn init(
|
pub async fn init(
|
||||||
&self,
|
&self,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
_provider_config: crate::config::LLMProviderConfig,
|
provider_config: crate::config::LLMProviderConfig,
|
||||||
) -> Result<(), ChannelError> {
|
) -> Result<(), ChannelError> {
|
||||||
// Initialize Feishu channel if enabled
|
for (name, channel_config) in &config.channels {
|
||||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
match channel_config {
|
||||||
|
crate::config::ChannelConfig::Tagged(TaggedChannelConfig::Feishu(feishu_config))
|
||||||
|
| crate::config::ChannelConfig::LegacyFeishu(feishu_config) => {
|
||||||
if feishu_config.enabled {
|
if feishu_config.enabled {
|
||||||
let channel =
|
let channel = FeishuChannel::new(
|
||||||
FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| {
|
name.clone(),
|
||||||
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
feishu_config.clone(),
|
||||||
|
provider_config.clone(),
|
||||||
|
)
|
||||||
|
.map_err(|e| {
|
||||||
|
ChannelError::Other(format!(
|
||||||
|
"Failed to create Feishu channel '{}': {}",
|
||||||
|
name, e
|
||||||
|
))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
self.channels
|
self.channels
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert("feishu".to_string(), Arc::new(channel));
|
.insert(name.clone(), Arc::new(channel));
|
||||||
tracing::info!("Feishu channel registered");
|
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel registered");
|
||||||
} else {
|
} else {
|
||||||
tracing::info!("Feishu channel disabled in config");
|
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel disabled in config");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
crate::config::ChannelConfig::Tagged(TaggedChannelConfig::Wechat(wechat_config)) => {
|
||||||
|
if wechat_config.enabled {
|
||||||
|
let channel = WechatChannel::new(
|
||||||
|
name.clone(),
|
||||||
|
wechat_config.clone(),
|
||||||
|
provider_config.clone(),
|
||||||
|
)
|
||||||
|
.map_err(|e| {
|
||||||
|
ChannelError::Other(format!(
|
||||||
|
"Failed to create WeChat channel '{}': {}",
|
||||||
|
name, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
self.channels
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(name.clone(), Arc::new(channel));
|
||||||
|
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel registered");
|
||||||
|
} else {
|
||||||
|
tracing::info!(channel = %name, kind = channel_config.kind(), "Channel disabled in config");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -101,3 +136,128 @@ impl ChannelManager {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn write_test_config() -> tempfile::NamedTempFile {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"channels": {
|
||||||
|
"primary": {
|
||||||
|
"type": "feishu",
|
||||||
|
"enabled": true,
|
||||||
|
"app_id": "app-id-1",
|
||||||
|
"app_secret": "secret-1"
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
"type": "feishu",
|
||||||
|
"enabled": true,
|
||||||
|
"app_id": "app-id-2",
|
||||||
|
"app_secret": "secret-2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
file
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn init_registers_all_configured_channels_by_instance_name() {
|
||||||
|
let file = write_test_config();
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
let manager = ChannelManager::new();
|
||||||
|
|
||||||
|
manager.init(&config, provider_config).await.unwrap();
|
||||||
|
|
||||||
|
let mut names = manager
|
||||||
|
.channels()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, _)| name)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
names.sort();
|
||||||
|
|
||||||
|
assert_eq!(names, vec!["backup", "cli", "primary"]);
|
||||||
|
assert_eq!(manager.get_channel("primary").await.unwrap().name(), "primary");
|
||||||
|
assert_eq!(manager.get_channel("backup").await.unwrap().name(), "backup");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn init_registers_wechat_channel_by_instance_name() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"channels": {
|
||||||
|
"wechat_main": {
|
||||||
|
"type": "wechat",
|
||||||
|
"enabled": true,
|
||||||
|
"cred_path": "/tmp/wechat-creds.json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
let manager = ChannelManager::new();
|
||||||
|
|
||||||
|
manager.init(&config, provider_config).await.unwrap();
|
||||||
|
|
||||||
|
let mut names = manager
|
||||||
|
.channels()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, _)| name)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
names.sort();
|
||||||
|
|
||||||
|
assert_eq!(names, vec!["cli", "wechat_main"]);
|
||||||
|
assert_eq!(manager.get_channel("wechat_main").await.unwrap().name(), "wechat_main");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -2,8 +2,10 @@ pub mod base;
|
|||||||
pub mod cli;
|
pub mod cli;
|
||||||
pub mod feishu;
|
pub mod feishu;
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
|
pub mod wechat;
|
||||||
|
|
||||||
pub use base::{Channel, ChannelError};
|
pub use base::{Channel, ChannelError};
|
||||||
pub use cli::CliChannel;
|
pub use cli::CliChannel;
|
||||||
pub use feishu::FeishuChannel;
|
pub use feishu::FeishuChannel;
|
||||||
pub use manager::ChannelManager;
|
pub use manager::ChannelManager;
|
||||||
|
pub use wechat::WechatChannel;
|
||||||
|
|||||||
383
src/channels/wechat.rs
Normal file
383
src/channels/wechat.rs
Normal file
@ -0,0 +1,383 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::{
|
||||||
|
Arc,
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
};
|
||||||
|
use std::time::UNIX_EPOCH;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use wechatbot::{BotOptions, SendContent, WeChatBot};
|
||||||
|
|
||||||
|
use crate::bus::{InboundMessage, MediaItem, MessageBus, OutboundMessage};
|
||||||
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
use crate::config::{LLMProviderConfig, WechatChannelConfig};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WechatChannel {
|
||||||
|
name: String,
|
||||||
|
config: WechatChannelConfig,
|
||||||
|
bot: Arc<WeChatBot>,
|
||||||
|
running: Arc<AtomicBool>,
|
||||||
|
task: Arc<RwLock<Option<JoinHandle<()>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WechatChannel {
|
||||||
|
pub fn new(
|
||||||
|
name: String,
|
||||||
|
config: WechatChannelConfig,
|
||||||
|
_provider_config: LLMProviderConfig,
|
||||||
|
) -> Result<Self, ChannelError> {
|
||||||
|
let channel_name = name.clone();
|
||||||
|
let bot = WeChatBot::new(BotOptions {
|
||||||
|
base_url: Some(config.base_url.clone()),
|
||||||
|
cred_path: Some(config.cred_path.clone()),
|
||||||
|
on_qr_url: Some(Box::new(move |url| {
|
||||||
|
tracing::info!(channel = %channel_name, qr_url = %url, "WeChat QR code ready");
|
||||||
|
})),
|
||||||
|
on_error: Some(Box::new(move |error| {
|
||||||
|
tracing::error!(error = %error, "WeChat SDK error");
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
config,
|
||||||
|
bot: Arc::new(bot),
|
||||||
|
running: Arc::new(AtomicBool::new(false)),
|
||||||
|
task: Arc::new(RwLock::new(None)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sender_allowed(&self, sender_id: &str) -> bool {
|
||||||
|
self.config.allow_from.iter().any(|pattern| pattern == "*" || pattern == sender_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn media_to_send_content(
|
||||||
|
media: &MediaItem,
|
||||||
|
caption: Option<String>,
|
||||||
|
) -> Result<SendContent, ChannelError> {
|
||||||
|
let data = std::fs::read(&media.path).map_err(|error| {
|
||||||
|
ChannelError::SendError(format!(
|
||||||
|
"WeChat media read failed for '{}': {}",
|
||||||
|
media.path, error
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if data.is_empty() {
|
||||||
|
return Err(ChannelError::SendError(format!(
|
||||||
|
"WeChat media file is empty: {}",
|
||||||
|
media.path
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let file_name = Path::new(&media.path)
|
||||||
|
.file_name()
|
||||||
|
.and_then(|name| name.to_str())
|
||||||
|
.unwrap_or("attachment.bin")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
match media.media_type.as_str() {
|
||||||
|
"image" => Ok(SendContent::Image { data, caption }),
|
||||||
|
"video" => Ok(SendContent::Video { data, caption }),
|
||||||
|
_ => Ok(SendContent::File {
|
||||||
|
data,
|
||||||
|
file_name,
|
||||||
|
caption,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_media_dir() -> PathBuf {
|
||||||
|
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||||
|
home.join(".picobot").join("media").join("wechat")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_download_filename(
|
||||||
|
media_type: &str,
|
||||||
|
file_name: Option<&str>,
|
||||||
|
format: Option<&str>,
|
||||||
|
) -> String {
|
||||||
|
if let Some(file_name) = file_name {
|
||||||
|
let sanitized: String = file_name
|
||||||
|
.chars()
|
||||||
|
.map(|ch| match ch {
|
||||||
|
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
|
||||||
|
_ => ch,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
if !sanitized.trim().is_empty() {
|
||||||
|
return format!("{}_{}", uuid::Uuid::new_v4(), sanitized);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ext = match (media_type, format) {
|
||||||
|
("image", _) => "jpg",
|
||||||
|
("video", _) => "mp4",
|
||||||
|
("voice", Some(fmt)) if !fmt.trim().is_empty() => fmt,
|
||||||
|
("voice", _) => "bin",
|
||||||
|
_ => "bin",
|
||||||
|
};
|
||||||
|
format!("{}_{}.{}", media_type, uuid::Uuid::new_v4(), ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_inbound_media(
|
||||||
|
bot: Arc<WeChatBot>,
|
||||||
|
msg: wechatbot::IncomingMessage,
|
||||||
|
) -> Result<Vec<MediaItem>, ChannelError> {
|
||||||
|
let Some(downloaded) = bot.download(&msg).await.map_err(|error| {
|
||||||
|
ChannelError::Other(format!("WeChat media download failed: {}", error))
|
||||||
|
})? else {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
};
|
||||||
|
|
||||||
|
let media_dir = Self::default_media_dir();
|
||||||
|
tokio::fs::create_dir_all(&media_dir)
|
||||||
|
.await
|
||||||
|
.map_err(|error| ChannelError::Other(format!("Failed to create WeChat media dir: {}", error)))?;
|
||||||
|
|
||||||
|
let filename = Self::build_download_filename(
|
||||||
|
&downloaded.media_type,
|
||||||
|
downloaded.file_name.as_deref(),
|
||||||
|
downloaded.format.as_deref(),
|
||||||
|
);
|
||||||
|
let file_path = media_dir.join(&filename);
|
||||||
|
tokio::fs::write(&file_path, downloaded.data)
|
||||||
|
.await
|
||||||
|
.map_err(|error| ChannelError::Other(format!("Failed to write WeChat media file: {}", error)))?;
|
||||||
|
|
||||||
|
tracing::info!(filename = %filename, media_type = %downloaded.media_type, "Downloaded WeChat media");
|
||||||
|
|
||||||
|
let mut media_item = MediaItem::new(
|
||||||
|
file_path.to_string_lossy().to_string(),
|
||||||
|
downloaded.media_type,
|
||||||
|
);
|
||||||
|
media_item.mime_type = mime_guess::from_path(&file_path)
|
||||||
|
.first_raw()
|
||||||
|
.map(ToOwned::to_owned);
|
||||||
|
Ok(vec![media_item])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Channel for WechatChannel {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_running(&self) -> bool {
|
||||||
|
self.running.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
||||||
|
if self.running.swap(true, Ordering::SeqCst) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let channel_name = self.name.clone();
|
||||||
|
let allow_from = self.config.allow_from.clone();
|
||||||
|
let bus_for_handler = bus.clone();
|
||||||
|
let bot_for_handler = self.bot.clone();
|
||||||
|
self.bot
|
||||||
|
.on_message(Box::new(move |msg| {
|
||||||
|
let sender_id = msg.user_id.clone();
|
||||||
|
let allowed = allow_from
|
||||||
|
.iter()
|
||||||
|
.any(|pattern| pattern == "*" || pattern == &sender_id);
|
||||||
|
if !allowed {
|
||||||
|
tracing::warn!(channel = %channel_name, sender = %sender_id, "Access denied");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = msg.clone();
|
||||||
|
let timestamp = msg
|
||||||
|
.timestamp
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as i64;
|
||||||
|
let bus = bus_for_handler.clone();
|
||||||
|
let bot = bot_for_handler.clone();
|
||||||
|
let channel_name_for_publish = channel_name.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let media = match Self::download_inbound_media(bot, msg.clone()).await {
|
||||||
|
Ok(media) => media,
|
||||||
|
Err(error) => {
|
||||||
|
tracing::error!(error = %error, "Failed to download WeChat inbound media");
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
metadata.insert("context_token".to_string(), msg.context_token().to_string());
|
||||||
|
|
||||||
|
let inbound = InboundMessage {
|
||||||
|
channel: channel_name_for_publish,
|
||||||
|
sender_id: sender_id.clone(),
|
||||||
|
chat_id: sender_id,
|
||||||
|
content: msg.text.clone(),
|
||||||
|
timestamp,
|
||||||
|
media,
|
||||||
|
metadata,
|
||||||
|
forwarded_metadata: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(error) = bus.publish_inbound(inbound).await {
|
||||||
|
tracing::error!(error = %error, "Failed to publish WeChat inbound message");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let bot = self.bot.clone();
|
||||||
|
let channel_name = self.name.clone();
|
||||||
|
let force_login = self.config.force_login;
|
||||||
|
let running = self.running.clone();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
match bot.login(force_login).await {
|
||||||
|
Ok(creds) => {
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
account_id = %creds.account_id,
|
||||||
|
user_id = %creds.user_id,
|
||||||
|
"WeChat login succeeded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
running.store(false, Ordering::SeqCst);
|
||||||
|
tracing::error!(channel = %channel_name, error = %error, "WeChat login failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(error) = bot.run().await {
|
||||||
|
tracing::error!(channel = %channel_name, error = %error, "WeChat channel stopped with error");
|
||||||
|
}
|
||||||
|
|
||||||
|
running.store(false, Ordering::SeqCst);
|
||||||
|
});
|
||||||
|
|
||||||
|
*self.task.write().await = Some(handle);
|
||||||
|
tracing::info!(channel = %self.name, "WeChat channel started");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop(&self) -> Result<(), ChannelError> {
|
||||||
|
self.running.store(false, Ordering::SeqCst);
|
||||||
|
self.bot.stop().await;
|
||||||
|
|
||||||
|
if let Some(handle) = self.task.write().await.take() {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(channel = %self.name, "WeChat channel stopped");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
|
let text = msg.content.trim().to_string();
|
||||||
|
let mut text_sent = false;
|
||||||
|
|
||||||
|
if !text.is_empty() {
|
||||||
|
self.bot.send(&msg.chat_id, &text).await.map_err(|error| {
|
||||||
|
ChannelError::SendError(format!("WeChat text send failed: {}", error))
|
||||||
|
})?;
|
||||||
|
tracing::info!(
|
||||||
|
channel = %self.name,
|
||||||
|
chat_id = %msg.chat_id,
|
||||||
|
content_len = text.len(),
|
||||||
|
"WeChat text message sent"
|
||||||
|
);
|
||||||
|
text_sent = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (index, media) in msg.media.iter().enumerate() {
|
||||||
|
let caption = if !text.is_empty() && !text_sent && index == 0 {
|
||||||
|
Some(text.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let content = Self::media_to_send_content(media, caption)?;
|
||||||
|
self.bot.send_media(&msg.chat_id, content).await.map_err(|error| {
|
||||||
|
ChannelError::SendError(format!("WeChat media send failed: {}", error))
|
||||||
|
})?;
|
||||||
|
tracing::info!(
|
||||||
|
channel = %self.name,
|
||||||
|
chat_id = %msg.chat_id,
|
||||||
|
media_type = %media.media_type,
|
||||||
|
media_path = %media.path,
|
||||||
|
"WeChat media message sent"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if text.is_empty() && msg.media.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_allowed(&self, sender_id: &str) -> bool {
|
||||||
|
self.sender_allowed(sender_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_download_filename_preserves_file_name() {
|
||||||
|
let filename = WechatChannel::build_download_filename("file", Some("README.md"), None);
|
||||||
|
|
||||||
|
assert!(filename.ends_with("_README.md"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_download_filename_adds_voice_extension_when_missing_name() {
|
||||||
|
let filename = WechatChannel::build_download_filename("voice", None, Some("silk"));
|
||||||
|
|
||||||
|
assert!(filename.starts_with("voice_"));
|
||||||
|
assert!(filename.ends_with(".silk"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn media_to_send_content_maps_image() {
|
||||||
|
let file = NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(file.path(), b"demo-image").unwrap();
|
||||||
|
let image_path = file.path().with_extension("png");
|
||||||
|
std::fs::rename(file.path(), &image_path).unwrap();
|
||||||
|
|
||||||
|
let media = MediaItem::new(image_path.to_string_lossy().to_string(), "image");
|
||||||
|
let content = WechatChannel::media_to_send_content(&media, None).unwrap();
|
||||||
|
|
||||||
|
assert!(matches!(content, SendContent::Image { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn media_to_send_content_maps_generic_file() {
|
||||||
|
let file = NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(file.path(), b"hello").unwrap();
|
||||||
|
let doc_path = file.path().with_extension("md");
|
||||||
|
std::fs::rename(file.path(), &doc_path).unwrap();
|
||||||
|
|
||||||
|
let media = MediaItem::new(doc_path.to_string_lossy().to_string(), "file");
|
||||||
|
let content = WechatChannel::media_to_send_content(&media, Some("note".to_string())).unwrap();
|
||||||
|
|
||||||
|
match content {
|
||||||
|
SendContent::File {
|
||||||
|
file_name,
|
||||||
|
caption,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(file_name, doc_path.file_name().unwrap().to_string_lossy());
|
||||||
|
assert_eq!(caption.as_deref(), Some("note"));
|
||||||
|
}
|
||||||
|
_ => panic!("expected file send content"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -22,7 +22,7 @@ pub struct Config {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub client: ClientConfig,
|
pub client: ClientConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub channels: HashMap<String, FeishuChannelConfig>,
|
pub channels: HashMap<String, ChannelConfig>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub skills: SkillsConfig,
|
pub skills: SkillsConfig,
|
||||||
}
|
}
|
||||||
@ -96,6 +96,54 @@ impl Default for SkillsConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ChannelConfig {
|
||||||
|
Tagged(TaggedChannelConfig),
|
||||||
|
LegacyFeishu(FeishuChannelConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum TaggedChannelConfig {
|
||||||
|
Feishu(FeishuChannelConfig),
|
||||||
|
Wechat(WechatChannelConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelConfig {
|
||||||
|
pub fn kind(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => "feishu",
|
||||||
|
Self::Tagged(TaggedChannelConfig::Wechat(_)) => "wechat",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enabled(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
|
||||||
|
config.enabled
|
||||||
|
}
|
||||||
|
Self::Tagged(TaggedChannelConfig::Wechat(config)) => config.enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_feishu(&self) -> Option<&FeishuChannelConfig> {
|
||||||
|
match self {
|
||||||
|
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
|
||||||
|
Some(config)
|
||||||
|
}
|
||||||
|
Self::Tagged(TaggedChannelConfig::Wechat(_)) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_wechat(&self) -> Option<&WechatChannelConfig> {
|
||||||
|
match self {
|
||||||
|
Self::Tagged(TaggedChannelConfig::Wechat(config)) => Some(config),
|
||||||
|
Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct FeishuChannelConfig {
|
pub struct FeishuChannelConfig {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -117,6 +165,22 @@ pub struct FeishuChannelConfig {
|
|||||||
pub reply_context_max_chars: usize,
|
pub reply_context_max_chars: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct WechatChannelConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub enabled: bool,
|
||||||
|
#[serde(default = "default_allow_from")]
|
||||||
|
pub allow_from: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub agent: String,
|
||||||
|
#[serde(default = "default_wechat_base_url")]
|
||||||
|
pub base_url: String,
|
||||||
|
#[serde(default = "default_wechat_cred_path")]
|
||||||
|
pub cred_path: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub force_login: bool,
|
||||||
|
}
|
||||||
|
|
||||||
fn default_allow_from() -> Vec<String> {
|
fn default_allow_from() -> Vec<String> {
|
||||||
vec!["*".to_string()]
|
vec!["*".to_string()]
|
||||||
}
|
}
|
||||||
@ -128,6 +192,17 @@ fn default_media_dir() -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_wechat_base_url() -> String {
|
||||||
|
"https://ilinkai.weixin.qq.com".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_wechat_cred_path() -> String {
|
||||||
|
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||||
|
home.join(".picobot/wechat/credentials.json")
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
fn default_reaction_emoji() -> String {
|
fn default_reaction_emoji() -> String {
|
||||||
"Typing".to_string()
|
"Typing".to_string()
|
||||||
}
|
}
|
||||||
@ -1171,11 +1246,105 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
let feishu = &config.channels["feishu"];
|
let feishu = config.channels["feishu"].as_feishu().unwrap();
|
||||||
assert_eq!(feishu.max_message_chars, 20_000);
|
assert_eq!(feishu.max_message_chars, 20_000);
|
||||||
assert_eq!(feishu.reply_context_max_chars, 20_000);
|
assert_eq!(feishu.reply_context_max_chars, 20_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tagged_feishu_channel_config_loads() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"channels": {
|
||||||
|
"primary": {
|
||||||
|
"type": "feishu",
|
||||||
|
"enabled": true,
|
||||||
|
"app_id": "app-id",
|
||||||
|
"app_secret": "secret"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let feishu = config.channels["primary"].as_feishu().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.channels["primary"].kind(), "feishu");
|
||||||
|
assert!(config.channels["primary"].enabled());
|
||||||
|
assert_eq!(feishu.app_id, "app-id");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tagged_wechat_channel_config_loads() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"channels": {
|
||||||
|
"wechat_main": {
|
||||||
|
"type": "wechat",
|
||||||
|
"enabled": true,
|
||||||
|
"base_url": "https://ilinkai.weixin.qq.com",
|
||||||
|
"cred_path": "/tmp/wechat-creds.json",
|
||||||
|
"force_login": true,
|
||||||
|
"allow_from": ["wxid_1"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let wechat = config.channels["wechat_main"].as_wechat().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
|
||||||
|
assert!(config.channels["wechat_main"].enabled());
|
||||||
|
assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json");
|
||||||
|
assert!(wechat.force_login);
|
||||||
|
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_feishu_channel_config_loads_custom_truncation_limits() {
|
fn test_feishu_channel_config_loads_custom_truncation_limits() {
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
@ -1215,7 +1384,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
let feishu = &config.channels["feishu"];
|
let feishu = config.channels["feishu"].as_feishu().unwrap();
|
||||||
assert_eq!(feishu.max_message_chars, 3456);
|
assert_eq!(feishu.max_message_chars, 3456);
|
||||||
assert_eq!(feishu.reply_context_max_chars, 4567);
|
assert_eq!(feishu.reply_context_max_chars, 4567);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,6 +40,8 @@ mod tests {
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
fn test_provider_config() -> LLMProviderConfig {
|
fn test_provider_config() -> LLMProviderConfig {
|
||||||
LLMProviderConfig {
|
LLMProviderConfig {
|
||||||
provider_type: "openai".to_string(),
|
provider_type: "openai".to_string(),
|
||||||
@ -80,7 +82,7 @@ mod tests {
|
|||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(ToolRegistry::new());
|
let tools = Arc::new(ToolRegistry::new());
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
TEST_CHANNEL.to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
@ -130,7 +132,7 @@ mod tests {
|
|||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(ToolRegistry::new());
|
let tools = Arc::new(ToolRegistry::new());
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
TEST_CHANNEL.to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
|
|||||||
@ -542,7 +542,7 @@ mod tests {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"test-channel".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
@ -587,7 +587,7 @@ mod tests {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"test-channel".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
@ -791,7 +791,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let outbound = session_manager
|
let outbound = session_manager
|
||||||
.handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None)
|
.handle_message("test-channel", "user-1", "chat-1", "hello", Vec::new(), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -840,7 +840,7 @@ mod tests {
|
|||||||
|
|
||||||
let planner_outbound = session_manager
|
let planner_outbound = session_manager
|
||||||
.run_scheduled_agent_task(
|
.run_scheduled_agent_task(
|
||||||
"feishu",
|
"test-channel",
|
||||||
"chat-planner",
|
"chat-planner",
|
||||||
"请规划今天工作",
|
"请规划今天工作",
|
||||||
ScheduledAgentTaskOptions {
|
ScheduledAgentTaskOptions {
|
||||||
@ -856,7 +856,7 @@ mod tests {
|
|||||||
|
|
||||||
let default_outbound = session_manager
|
let default_outbound = session_manager
|
||||||
.run_scheduled_agent_task(
|
.run_scheduled_agent_task(
|
||||||
"feishu",
|
"test-channel",
|
||||||
"chat-default",
|
"chat-default",
|
||||||
"请规划今天工作",
|
"请规划今天工作",
|
||||||
ScheduledAgentTaskOptions {
|
ScheduledAgentTaskOptions {
|
||||||
@ -904,7 +904,7 @@ mod tests {
|
|||||||
|
|
||||||
session_manager
|
session_manager
|
||||||
.run_scheduled_agent_task(
|
.run_scheduled_agent_task(
|
||||||
"feishu",
|
"test-channel",
|
||||||
"chat-guard",
|
"chat-guard",
|
||||||
"每小时执行以下流程:检查邮箱并同步待办",
|
"每小时执行以下流程:检查邮箱并同步待办",
|
||||||
ScheduledAgentTaskOptions {
|
ScheduledAgentTaskOptions {
|
||||||
@ -916,7 +916,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let session = session_manager.get("feishu").await.unwrap();
|
let session = session_manager.get("test-channel").await.unwrap();
|
||||||
let session_guard = session.lock().await;
|
let session_guard = session.lock().await;
|
||||||
let persisted_messages = session_guard
|
let persisted_messages = session_guard
|
||||||
.store()
|
.store()
|
||||||
@ -1477,7 +1477,13 @@ mod tests {
|
|||||||
async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() {
|
async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() {
|
||||||
let bus = MessageBus::new(4);
|
let bus = MessageBus::new(4);
|
||||||
let emitter =
|
let emitter =
|
||||||
BusToolCallEmitter::new(bus.clone(), "feishu", "chat-1", HashMap::new(), false);
|
BusToolCallEmitter::new(
|
||||||
|
bus.clone(),
|
||||||
|
"test-channel",
|
||||||
|
"chat-1",
|
||||||
|
HashMap::new(),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
emitter
|
emitter
|
||||||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
||||||
@ -1508,7 +1514,7 @@ mod tests {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"test-channel".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
@ -1546,7 +1552,7 @@ mod tests {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"test-channel".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
@ -1612,7 +1618,7 @@ mod tests {
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"test-channel".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
user_tx,
|
user_tx,
|
||||||
tools,
|
tools,
|
||||||
|
|||||||
@ -42,6 +42,7 @@ impl SessionMessageSender for BusSessionMessageSender {
|
|||||||
.is_some();
|
.is_some();
|
||||||
|
|
||||||
if let Some(text) = request.text.filter(|value| !value.trim().is_empty()) {
|
if let Some(text) = request.text.filter(|value| !value.trim().is_empty()) {
|
||||||
|
let content_len = text.len();
|
||||||
self.bus
|
self.bus
|
||||||
.publish_outbound(OutboundMessage::assistant(
|
.publish_outbound(OutboundMessage::assistant(
|
||||||
channel_name.to_string(),
|
channel_name.to_string(),
|
||||||
@ -52,10 +53,18 @@ impl SessionMessageSender for BusSessionMessageSender {
|
|||||||
))
|
))
|
||||||
.await?;
|
.await?;
|
||||||
published_messages += 1;
|
published_messages += 1;
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
content_len = content_len,
|
||||||
|
"Published session text message to outbound bus"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let attachment_count = request.attachments.len();
|
let attachment_count = request.attachments.len();
|
||||||
for attachment in request.attachments {
|
for attachment in request.attachments {
|
||||||
|
let media_path = attachment.path.clone();
|
||||||
|
let media_type = attachment.media_type.clone();
|
||||||
let mut outbound = OutboundMessage::assistant(
|
let mut outbound = OutboundMessage::assistant(
|
||||||
channel_name.to_string(),
|
channel_name.to_string(),
|
||||||
chat_id.to_string(),
|
chat_id.to_string(),
|
||||||
@ -66,6 +75,13 @@ impl SessionMessageSender for BusSessionMessageSender {
|
|||||||
outbound.media = vec![attachment];
|
outbound.media = vec![attachment];
|
||||||
self.bus.publish_outbound(outbound).await?;
|
self.bus.publish_outbound(outbound).await?;
|
||||||
published_messages += 1;
|
published_messages += 1;
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
media_type = %media_type,
|
||||||
|
media_path = %media_path,
|
||||||
|
"Published session attachment to outbound bus"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(SessionSendOutcome {
|
Ok(SessionSendOutcome {
|
||||||
@ -81,12 +97,14 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::bus::MediaItem;
|
use crate::bus::MediaItem;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn bus_sender_publishes_text_then_attachment() {
|
async fn bus_sender_publishes_text_then_attachment() {
|
||||||
let bus = MessageBus::new(8);
|
let bus = MessageBus::new(8);
|
||||||
let sender = BusSessionMessageSender::new(bus.clone());
|
let sender = BusSessionMessageSender::new(bus.clone());
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
chat_id: Some("chat-1".to_string()),
|
chat_id: Some("chat-1".to_string()),
|
||||||
..ToolContext::default()
|
..ToolContext::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
pub mod agent;
|
pub mod agent;
|
||||||
|
pub mod bootstrap;
|
||||||
pub mod bus;
|
pub mod bus;
|
||||||
pub mod channels;
|
pub mod channels;
|
||||||
pub mod cli;
|
pub mod cli;
|
||||||
|
|||||||
@ -23,6 +23,8 @@ enum Command {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
picobot::bootstrap::initialize_process_runtime();
|
||||||
|
|
||||||
let mut cmd = Command::command();
|
let mut cmd = Command::command();
|
||||||
|
|
||||||
// If no arguments, print help
|
// If no arguments, print help
|
||||||
|
|||||||
@ -812,7 +812,7 @@ mod agent_task_tests {
|
|||||||
interval_secs: 0,
|
interval_secs: 0,
|
||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
target: serde_json::json!({
|
target: serde_json::json!({
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
}),
|
}),
|
||||||
payload: serde_json::json!({
|
payload: serde_json::json!({
|
||||||
@ -859,7 +859,7 @@ mod agent_task_tests {
|
|||||||
interval_secs: 0,
|
interval_secs: 0,
|
||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
target: serde_json::json!({
|
target: serde_json::json!({
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo",
|
"chat_id": "oc_demo",
|
||||||
"session_chat_id": "scheduler/agent.daily_summary.background"
|
"session_chat_id": "scheduler/agent.daily_summary.background"
|
||||||
}),
|
}),
|
||||||
@ -905,7 +905,7 @@ mod agent_task_tests {
|
|||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
},
|
},
|
||||||
target: SchedulerJobTarget {
|
target: SchedulerJobTarget {
|
||||||
channel: Some("feishu".to_string()),
|
channel: Some("test-channel".to_string()),
|
||||||
chat_id: Some("oc_demo".to_string()),
|
chat_id: Some("oc_demo".to_string()),
|
||||||
session_chat_id: None,
|
session_chat_id: None,
|
||||||
reply_to: None,
|
reply_to: None,
|
||||||
@ -965,7 +965,7 @@ mod agent_task_tests {
|
|||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
},
|
},
|
||||||
target: SchedulerJobTarget {
|
target: SchedulerJobTarget {
|
||||||
channel: Some("feishu".to_string()),
|
channel: Some("test-channel".to_string()),
|
||||||
chat_id: Some("oc_demo".to_string()),
|
chat_id: Some("oc_demo".to_string()),
|
||||||
session_chat_id: None,
|
session_chat_id: None,
|
||||||
reply_to: None,
|
reply_to: None,
|
||||||
@ -1101,7 +1101,7 @@ mod tests {
|
|||||||
interval_secs: 0,
|
interval_secs: 0,
|
||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
target: serde_json::json!({
|
target: serde_json::json!({
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
}),
|
}),
|
||||||
payload: serde_json::json!({"content": "hello"}),
|
payload: serde_json::json!({"content": "hello"}),
|
||||||
@ -1151,7 +1151,7 @@ mod tests {
|
|||||||
interval_secs: 60,
|
interval_secs: 60,
|
||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
target: serde_json::json!({
|
target: serde_json::json!({
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
}),
|
}),
|
||||||
payload: serde_json::json!({
|
payload: serde_json::json!({
|
||||||
@ -1271,7 +1271,7 @@ mod tests {
|
|||||||
startup_delay_secs: 0,
|
startup_delay_secs: 0,
|
||||||
},
|
},
|
||||||
target: SchedulerJobTarget {
|
target: SchedulerJobTarget {
|
||||||
channel: Some("feishu".to_string()),
|
channel: Some("test-channel".to_string()),
|
||||||
chat_id: Some("oc_demo".to_string()),
|
chat_id: Some("oc_demo".to_string()),
|
||||||
session_chat_id: Some("scheduler/agent.daily_summary.background".to_string()),
|
session_chat_id: Some("scheduler/agent.daily_summary.background".to_string()),
|
||||||
reply_to: None,
|
reply_to: None,
|
||||||
@ -1300,7 +1300,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(outbound.channel, "feishu");
|
assert_eq!(outbound.channel, "test-channel");
|
||||||
assert_eq!(outbound.chat_id, "oc_demo");
|
assert_eq!(outbound.chat_id, "oc_demo");
|
||||||
assert!(outbound.content.contains("定时任务执行失败"));
|
assert!(outbound.content.contains("定时任务执行失败"));
|
||||||
assert!(outbound.content.contains("agent.daily_summary.background"));
|
assert!(outbound.content.contains("agent.daily_summary.background"));
|
||||||
|
|||||||
@ -1,10 +1,18 @@
|
|||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
static SKILL_TEST_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn acquire_skill_test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
SKILL_TEST_ENV_LOCK.lock().unwrap_or_else(|err| err.into_inner())
|
||||||
|
}
|
||||||
|
|
||||||
use crate::config::SkillsConfig;
|
use crate::config::SkillsConfig;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -24,7 +32,7 @@ pub enum SkillSource {
|
|||||||
ProjectAgent,
|
ProjectAgent,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum SkillScope {
|
pub enum SkillScope {
|
||||||
User,
|
User,
|
||||||
Project,
|
Project,
|
||||||
@ -62,6 +70,16 @@ pub struct SkillRuntime {
|
|||||||
catalog: RwLock<SkillCatalog>,
|
catalog: RwLock<SkillCatalog>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SkillAvailabilityChange {
|
||||||
|
pub name: String,
|
||||||
|
pub scope: SkillScope,
|
||||||
|
pub state_path: PathBuf,
|
||||||
|
pub changed: bool,
|
||||||
|
pub disabled_in_scopes: Vec<SkillScope>,
|
||||||
|
pub available: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for SkillRuntime {
|
impl Default for SkillRuntime {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -220,6 +238,78 @@ impl SkillRuntime {
|
|||||||
}
|
}
|
||||||
Ok(dir)
|
Ok(dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn disable_skill(
|
||||||
|
&self,
|
||||||
|
scope: SkillScope,
|
||||||
|
name: &str,
|
||||||
|
reload: bool,
|
||||||
|
) -> Result<SkillAvailabilityChange, String> {
|
||||||
|
self.set_skill_enabled(scope, name, false, reload)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enable_skill(
|
||||||
|
&self,
|
||||||
|
scope: SkillScope,
|
||||||
|
name: &str,
|
||||||
|
reload: bool,
|
||||||
|
) -> Result<SkillAvailabilityChange, String> {
|
||||||
|
self.set_skill_enabled(scope, name, true, reload)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_skill_definition(&self, name: &str) -> Result<bool, String> {
|
||||||
|
validate_skill_name(name)?;
|
||||||
|
let cwd = std::env::current_dir()
|
||||||
|
.map_err(|err| format!("failed to get current dir: {}", err))?;
|
||||||
|
Ok(SkillCatalog::discover_without_state(&self.config, &cwd)
|
||||||
|
.find_skill(name)
|
||||||
|
.is_some())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_skill_enabled(
|
||||||
|
&self,
|
||||||
|
scope: SkillScope,
|
||||||
|
name: &str,
|
||||||
|
enabled: bool,
|
||||||
|
reload: bool,
|
||||||
|
) -> Result<SkillAvailabilityChange, String> {
|
||||||
|
validate_skill_name(name)?;
|
||||||
|
if !self.has_skill_definition(name)? {
|
||||||
|
return Err(format!("skill '{}' not found", name));
|
||||||
|
}
|
||||||
|
|
||||||
|
let state_path = skill_state_path(scope)?;
|
||||||
|
let mut state = load_skill_state_file(&state_path)?;
|
||||||
|
let mut disabled: HashSet<String> = state.disabled_skills.into_iter().collect();
|
||||||
|
let changed = if enabled {
|
||||||
|
disabled.remove(name)
|
||||||
|
} else {
|
||||||
|
disabled.insert(name.to_string())
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut disabled_skills: Vec<String> = disabled.into_iter().collect();
|
||||||
|
disabled_skills.sort();
|
||||||
|
state.disabled_skills = disabled_skills;
|
||||||
|
save_skill_state_file(&state_path, &state)?;
|
||||||
|
|
||||||
|
if reload {
|
||||||
|
let _ = self.reload()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cwd = std::env::current_dir()
|
||||||
|
.map_err(|err| format!("failed to get current dir: {}", err))?;
|
||||||
|
let effective_state = load_skill_disable_state(&cwd);
|
||||||
|
let disabled_in_scopes = effective_state.disabled_scopes_for(name);
|
||||||
|
|
||||||
|
Ok(SkillAvailabilityChange {
|
||||||
|
name: name.to_string(),
|
||||||
|
scope,
|
||||||
|
state_path,
|
||||||
|
changed,
|
||||||
|
available: disabled_in_scopes.is_empty(),
|
||||||
|
disabled_in_scopes,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::agent::SkillProvider for SkillRuntime {
|
impl crate::agent::SkillProvider for SkillRuntime {
|
||||||
@ -262,6 +352,20 @@ impl Default for SkillCatalog {
|
|||||||
|
|
||||||
impl SkillCatalog {
|
impl SkillCatalog {
|
||||||
pub fn discover(config: &SkillsConfig) -> Self {
|
pub fn discover(config: &SkillsConfig) -> Self {
|
||||||
|
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
|
||||||
|
let disable_state = load_skill_disable_state(&cwd);
|
||||||
|
Self::discover_with_state(config, &cwd, Some(&disable_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discover_without_state(config: &SkillsConfig, cwd: &Path) -> Self {
|
||||||
|
Self::discover_with_state(config, cwd, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discover_with_state(
|
||||||
|
config: &SkillsConfig,
|
||||||
|
cwd: &Path,
|
||||||
|
disable_state: Option<&SkillDisableState>,
|
||||||
|
) -> Self {
|
||||||
if !config.enabled {
|
if !config.enabled {
|
||||||
return Self {
|
return Self {
|
||||||
max_index_chars: config.max_index_chars,
|
max_index_chars: config.max_index_chars,
|
||||||
@ -270,7 +374,6 @@ impl SkillCatalog {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
|
|
||||||
let mut merged: HashMap<String, Skill> = HashMap::new();
|
let mut merged: HashMap<String, Skill> = HashMap::new();
|
||||||
let mut sources_seen = 0usize;
|
let mut sources_seen = 0usize;
|
||||||
|
|
||||||
@ -294,6 +397,9 @@ impl SkillCatalog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut skills: Vec<Skill> = merged.into_values().collect();
|
let mut skills: Vec<Skill> = merged.into_values().collect();
|
||||||
|
if let Some(disable_state) = disable_state {
|
||||||
|
skills.retain(|skill| !disable_state.is_disabled(&skill.name));
|
||||||
|
}
|
||||||
skills.sort_by(|a, b| a.name.cmp(&b.name));
|
skills.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
@ -399,6 +505,35 @@ impl SkillCatalog {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
struct SkillStateFile {
|
||||||
|
#[serde(default)]
|
||||||
|
disabled_skills: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
struct SkillDisableState {
|
||||||
|
user_disabled: HashSet<String>,
|
||||||
|
project_disabled: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SkillDisableState {
|
||||||
|
fn is_disabled(&self, name: &str) -> bool {
|
||||||
|
self.user_disabled.contains(name) || self.project_disabled.contains(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn disabled_scopes_for(&self, name: &str) -> Vec<SkillScope> {
|
||||||
|
let mut scopes = Vec::new();
|
||||||
|
if self.user_disabled.contains(name) {
|
||||||
|
scopes.push(SkillScope::User);
|
||||||
|
}
|
||||||
|
if self.project_disabled.contains(name) {
|
||||||
|
scopes.push(SkillScope::Project);
|
||||||
|
}
|
||||||
|
scopes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn source_order(sources: &[String]) -> Vec<SkillSource> {
|
fn source_order(sources: &[String]) -> Vec<SkillSource> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for source in sources {
|
for source in sources {
|
||||||
@ -461,10 +596,18 @@ fn project_agent_skills_root(cwd: &Path) -> PathBuf {
|
|||||||
cwd.join(".agents").join("skills")
|
cwd.join(".agents").join("skills")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn project_skill_state_path(cwd: &Path) -> PathBuf {
|
||||||
|
cwd.join(".picobot").join("skill-state.json")
|
||||||
|
}
|
||||||
|
|
||||||
fn user_skills_root() -> Option<PathBuf> {
|
fn user_skills_root() -> Option<PathBuf> {
|
||||||
dirs::home_dir().map(|p| p.join(".picobot").join("skills"))
|
dirs::home_dir().map(|p| p.join(".picobot").join("skills"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn user_skill_state_path() -> Option<PathBuf> {
|
||||||
|
dirs::home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
|
||||||
|
}
|
||||||
|
|
||||||
fn user_agent_skills_root() -> Option<PathBuf> {
|
fn user_agent_skills_root() -> Option<PathBuf> {
|
||||||
dirs::home_dir().map(|p| p.join(".agents").join("skills"))
|
dirs::home_dir().map(|p| p.join(".agents").join("skills"))
|
||||||
}
|
}
|
||||||
@ -495,6 +638,78 @@ fn skill_file_path(scope: SkillScope, name: &str) -> Result<PathBuf, String> {
|
|||||||
Ok(skill_dir_path(scope, name)?.join("SKILL.md"))
|
Ok(skill_dir_path(scope, name)?.join("SKILL.md"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn skill_state_path(scope: SkillScope) -> Result<PathBuf, String> {
|
||||||
|
match scope {
|
||||||
|
SkillScope::User => user_skill_state_path()
|
||||||
|
.ok_or_else(|| "failed to resolve home directory".to_string()),
|
||||||
|
SkillScope::Project => {
|
||||||
|
let cwd = std::env::current_dir()
|
||||||
|
.map_err(|err| format!("failed to get current dir: {}", err))?;
|
||||||
|
Ok(project_skill_state_path(&cwd))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_skill_disable_state(cwd: &Path) -> SkillDisableState {
|
||||||
|
SkillDisableState {
|
||||||
|
user_disabled: user_skill_state_path()
|
||||||
|
.map(|path| load_disabled_skill_names(&path))
|
||||||
|
.unwrap_or_default(),
|
||||||
|
project_disabled: load_disabled_skill_names(&project_skill_state_path(cwd)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_disabled_skill_names(path: &Path) -> HashSet<String> {
|
||||||
|
match load_skill_state_file(path) {
|
||||||
|
Ok(state) => state
|
||||||
|
.disabled_skills
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|name| normalize_skill_name(name, path))
|
||||||
|
.collect(),
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!(path = %path.display(), error = %err, "Failed to load skill state file");
|
||||||
|
HashSet::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_skill_name(name: String, path: &Path) -> Option<String> {
|
||||||
|
let trimmed = name.trim();
|
||||||
|
match validate_skill_name(trimmed) {
|
||||||
|
Ok(()) => Some(trimmed.to_string()),
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!(path = %path.display(), skill = %name, error = %err, "Ignoring invalid disabled skill entry");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_skill_state_file(path: &Path) -> Result<SkillStateFile, String> {
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(SkillStateFile::default());
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = fs::read_to_string(path)
|
||||||
|
.map_err(|err| format!("failed to read skill state file: {}", err))?;
|
||||||
|
serde_json::from_str(&content)
|
||||||
|
.map_err(|err| format!("failed to parse skill state file: {}", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_skill_state_file(path: &Path, state: &SkillStateFile) -> Result<(), String> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)
|
||||||
|
.map_err(|err| format!("failed to create skill state directory: {}", err))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = serde_json::to_string_pretty(state)
|
||||||
|
.map_err(|err| format!("failed to render skill state file: {}", err))?;
|
||||||
|
let tmp_path = path.with_extension("json.tmp");
|
||||||
|
fs::write(&tmp_path, format!("{}\n", content))
|
||||||
|
.map_err(|err| format!("failed to write temporary skill state file: {}", err))?;
|
||||||
|
fs::rename(&tmp_path, path)
|
||||||
|
.map_err(|err| format!("failed to persist skill state file: {}", err))
|
||||||
|
}
|
||||||
|
|
||||||
fn render_skill_file(name: &str, description: &str, body: &str) -> Result<String, String> {
|
fn render_skill_file(name: &str, description: &str, body: &str) -> Result<String, String> {
|
||||||
if description.trim().is_empty() {
|
if description.trim().is_empty() {
|
||||||
return Err("description is required and cannot be empty".to_string());
|
return Err("description is required and cannot be empty".to_string());
|
||||||
@ -615,14 +830,16 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::sync::Mutex;
|
use std::ffi::OsString;
|
||||||
|
|
||||||
static CWD_TEST_LOCK: Mutex<()> = Mutex::new(());
|
|
||||||
|
|
||||||
struct CurrentDirGuard {
|
struct CurrentDirGuard {
|
||||||
previous: PathBuf,
|
previous: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct HomeDirGuard {
|
||||||
|
previous: Option<OsString>,
|
||||||
|
}
|
||||||
|
|
||||||
impl CurrentDirGuard {
|
impl CurrentDirGuard {
|
||||||
fn enter(path: &Path) -> Self {
|
fn enter(path: &Path) -> Self {
|
||||||
let previous = std::env::current_dir().unwrap();
|
let previous = std::env::current_dir().unwrap();
|
||||||
@ -637,6 +854,33 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl HomeDirGuard {
|
||||||
|
fn enter(path: &Path) -> Self {
|
||||||
|
let previous = std::env::var_os("HOME");
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("HOME", path);
|
||||||
|
}
|
||||||
|
Self { previous }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for HomeDirGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
match &self.previous {
|
||||||
|
Some(value) => unsafe {
|
||||||
|
std::env::set_var("HOME", value);
|
||||||
|
},
|
||||||
|
None => unsafe {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
acquire_skill_test_env_lock()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_split_frontmatter() {
|
fn test_split_frontmatter() {
|
||||||
let input = "---\ndescription: demo\n---\nhello";
|
let input = "---\ndescription: demo\n---\nhello";
|
||||||
@ -683,9 +927,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_runtime_create_update_delete_reload() {
|
fn test_runtime_create_update_delete_reload() {
|
||||||
let _lock = CWD_TEST_LOCK.lock().unwrap();
|
let _lock = acquire_test_lock();
|
||||||
let temp_dir = tempfile::tempdir().unwrap();
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
let _guard = CurrentDirGuard::enter(temp_dir.path());
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
let runtime = SkillRuntime::from_config(SkillsConfig {
|
let runtime = SkillRuntime::from_config(SkillsConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
@ -756,12 +1005,16 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_discover_loads_project_agent_skills() {
|
fn test_discover_loads_project_agent_skills() {
|
||||||
let _lock = CWD_TEST_LOCK.lock().unwrap();
|
let _lock = acquire_test_lock();
|
||||||
let temp_dir = tempfile::tempdir().unwrap();
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
let _guard = CurrentDirGuard::enter(temp_dir.path());
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
let agent_skill_dir = temp_dir
|
let agent_skill_dir = project_dir
|
||||||
.path()
|
|
||||||
.join(".agents")
|
.join(".agents")
|
||||||
.join("skills")
|
.join("skills")
|
||||||
.join("demo-agent");
|
.join("demo-agent");
|
||||||
@ -786,11 +1039,16 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_discover_prefers_project_agent_on_conflict() {
|
fn test_discover_prefers_project_agent_on_conflict() {
|
||||||
let _lock = CWD_TEST_LOCK.lock().unwrap();
|
let _lock = acquire_test_lock();
|
||||||
let temp_dir = tempfile::tempdir().unwrap();
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
let _guard = CurrentDirGuard::enter(temp_dir.path());
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
let project_skill_dir = temp_dir.path().join(".picobot").join("skills").join("demo");
|
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
|
||||||
fs::create_dir_all(&project_skill_dir).unwrap();
|
fs::create_dir_all(&project_skill_dir).unwrap();
|
||||||
fs::write(
|
fs::write(
|
||||||
project_skill_dir.join("SKILL.md"),
|
project_skill_dir.join("SKILL.md"),
|
||||||
@ -798,7 +1056,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let agent_skill_dir = temp_dir.path().join(".agents").join("skills").join("demo");
|
let agent_skill_dir = project_dir.join(".agents").join("skills").join("demo");
|
||||||
fs::create_dir_all(&agent_skill_dir).unwrap();
|
fs::create_dir_all(&agent_skill_dir).unwrap();
|
||||||
fs::write(
|
fs::write(
|
||||||
agent_skill_dir.join("SKILL.md"),
|
agent_skill_dir.join("SKILL.md"),
|
||||||
@ -817,4 +1075,136 @@ mod tests {
|
|||||||
assert!(payload.contains("Source: project_agent"));
|
assert!(payload.contains("Source: project_agent"));
|
||||||
assert!(payload.contains("Agent body"));
|
assert!(payload.contains("Agent body"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skill_state_file_roundtrip() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let path = dir.path().join("skill-state.json");
|
||||||
|
let state = SkillStateFile {
|
||||||
|
disabled_skills: vec!["demo".to_string(), "other".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
save_skill_state_file(&path, &state).unwrap();
|
||||||
|
|
||||||
|
let loaded = load_skill_state_file(&path).unwrap();
|
||||||
|
assert_eq!(loaded, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_discover_filters_disabled_skills_from_sidecar() {
|
||||||
|
let _lock = acquire_test_lock();
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
|
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
|
||||||
|
fs::create_dir_all(&project_skill_dir).unwrap();
|
||||||
|
fs::write(
|
||||||
|
project_skill_dir.join("SKILL.md"),
|
||||||
|
"---\ndescription: project version\n---\nProject body",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
save_skill_state_file(
|
||||||
|
&project_dir.join(".picobot").join("skill-state.json"),
|
||||||
|
&SkillStateFile {
|
||||||
|
disabled_skills: vec!["demo".to_string()],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let catalog = SkillCatalog::discover(&SkillsConfig {
|
||||||
|
enabled: true,
|
||||||
|
sources: vec!["project".to_string()],
|
||||||
|
max_index_chars: 4000,
|
||||||
|
max_listed_skills: 32,
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(catalog.len(), 0);
|
||||||
|
assert!(catalog.activation_payload("demo").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_runtime_disable_and_enable_skill_updates_visibility() {
|
||||||
|
let _lock = acquire_test_lock();
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
|
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
|
||||||
|
fs::create_dir_all(&project_skill_dir).unwrap();
|
||||||
|
fs::write(
|
||||||
|
project_skill_dir.join("SKILL.md"),
|
||||||
|
"---\ndescription: project version\n---\nProject body",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let runtime = SkillRuntime::from_config(SkillsConfig {
|
||||||
|
enabled: true,
|
||||||
|
sources: vec!["project".to_string()],
|
||||||
|
max_index_chars: 4000,
|
||||||
|
max_listed_skills: 32,
|
||||||
|
});
|
||||||
|
|
||||||
|
let disabled = runtime.disable_skill(SkillScope::Project, "demo", true).unwrap();
|
||||||
|
assert!(disabled.changed);
|
||||||
|
assert_eq!(disabled.disabled_in_scopes, vec![SkillScope::Project]);
|
||||||
|
assert!(!disabled.available);
|
||||||
|
assert!(runtime.get_skill("demo").is_none());
|
||||||
|
|
||||||
|
let enabled = runtime.enable_skill(SkillScope::Project, "demo", true).unwrap();
|
||||||
|
assert!(enabled.changed);
|
||||||
|
assert!(enabled.disabled_in_scopes.is_empty());
|
||||||
|
assert!(enabled.available);
|
||||||
|
assert!(runtime.get_skill("demo").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_user_scope_disable_overrides_project_scope_enable() {
|
||||||
|
let _lock = acquire_test_lock();
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
|
let project_skill_dir = project_dir.join(".picobot").join("skills").join("demo");
|
||||||
|
fs::create_dir_all(&project_skill_dir).unwrap();
|
||||||
|
fs::write(
|
||||||
|
project_skill_dir.join("SKILL.md"),
|
||||||
|
"---\ndescription: project version\n---\nProject body",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let runtime = SkillRuntime::from_config(SkillsConfig {
|
||||||
|
enabled: true,
|
||||||
|
sources: vec!["project".to_string()],
|
||||||
|
max_index_chars: 4000,
|
||||||
|
max_listed_skills: 32,
|
||||||
|
});
|
||||||
|
|
||||||
|
let user_disabled = runtime.disable_skill(SkillScope::User, "demo", true).unwrap();
|
||||||
|
assert_eq!(user_disabled.disabled_in_scopes, vec![SkillScope::User]);
|
||||||
|
assert!(runtime.get_skill("demo").is_none());
|
||||||
|
|
||||||
|
let project_enabled = runtime.enable_skill(SkillScope::Project, "demo", true).unwrap();
|
||||||
|
assert!(!project_enabled.available);
|
||||||
|
assert_eq!(project_enabled.disabled_in_scopes, vec![SkillScope::User]);
|
||||||
|
assert!(runtime.get_skill("demo").is_none());
|
||||||
|
|
||||||
|
let user_enabled = runtime.enable_skill(SkillScope::User, "demo", true).unwrap();
|
||||||
|
assert!(user_enabled.available);
|
||||||
|
assert!(user_enabled.disabled_in_scopes.is_empty());
|
||||||
|
assert!(runtime.get_skill("demo").is_some());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1614,10 +1614,12 @@ mod tests {
|
|||||||
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
|
use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT;
|
||||||
use crate::domain::messages::ToolCall;
|
use crate::domain::messages::ToolCall;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_persistent_session_id_for_cli_and_channel() {
|
fn test_persistent_session_id_for_cli_and_channel() {
|
||||||
assert_eq!(persistent_session_id("cli", "abc"), "abc");
|
assert_eq!(persistent_session_id("cli", "abc"), "abc");
|
||||||
assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc");
|
assert_eq!(persistent_session_id(TEST_CHANNEL, "abc"), "test-channel:abc");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1682,12 +1684,12 @@ mod tests {
|
|||||||
fn test_ensure_channel_session_is_stable() {
|
fn test_ensure_channel_session_is_stable() {
|
||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
|
||||||
let first = store.ensure_channel_session("feishu", "chat-1").unwrap();
|
let first = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
|
||||||
let second = store.ensure_channel_session("feishu", "chat-1").unwrap();
|
let second = store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
|
||||||
|
|
||||||
assert_eq!(first.id, second.id);
|
assert_eq!(first.id, second.id);
|
||||||
assert_eq!(first.chat_id, "chat-1");
|
assert_eq!(first.chat_id, "chat-1");
|
||||||
assert_eq!(second.channel_name, "feishu");
|
assert_eq!(second.channel_name, TEST_CHANNEL);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -2040,27 +2042,27 @@ mod tests {
|
|||||||
let saved = store
|
let saved = store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "profile".to_string(),
|
namespace: "profile".to_string(),
|
||||||
memory_key: "language".to_string(),
|
memory_key: "language".to_string(),
|
||||||
content: "Rust".to_string(),
|
content: "Rust".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-1".to_string()),
|
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-1".to_string()),
|
source_message_id: Some("msg-1".to_string()),
|
||||||
source_message_seq: Some(7),
|
source_message_seq: Some(7),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-1".to_string()),
|
source_chat_id: Some("chat-1".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(saved.content, "Rust");
|
assert_eq!(saved.content, "Rust");
|
||||||
assert_eq!(saved.source_type, "message");
|
assert_eq!(saved.source_type, "message");
|
||||||
assert_eq!(saved.source_session_id.as_deref(), Some("feishu:chat-1"));
|
assert_eq!(saved.source_session_id.as_deref(), Some("test-channel:chat-1"));
|
||||||
assert_eq!(saved.source_message_id.as_deref(), Some("msg-1"));
|
assert_eq!(saved.source_message_id.as_deref(), Some("msg-1"));
|
||||||
assert_eq!(saved.source_message_seq, Some(7));
|
assert_eq!(saved.source_message_seq, Some(7));
|
||||||
|
|
||||||
let fetched = store
|
let fetched = store
|
||||||
.get_memory("user", "feishu:user-1", "profile", "language")
|
.get_memory("user", "test-channel:user-1", "profile", "language")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(fetched.id, saved.id);
|
assert_eq!(fetched.id, saved.id);
|
||||||
@ -2074,21 +2076,21 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "preferences".to_string(),
|
namespace: "preferences".to_string(),
|
||||||
memory_key: "editor".to_string(),
|
memory_key: "editor".to_string(),
|
||||||
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-2".to_string()),
|
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-2".to_string()),
|
source_message_id: Some("msg-2".to_string()),
|
||||||
source_message_seq: Some(3),
|
source_message_seq: Some(3),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-2".to_string()),
|
source_chat_id: Some("chat-2".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let hits = store
|
let hits = store
|
||||||
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
|
.search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(hits.len(), 1);
|
assert_eq!(hits.len(), 1);
|
||||||
assert_eq!(hits[0].memory_key, "editor");
|
assert_eq!(hits[0].memory_key, "editor");
|
||||||
@ -2096,36 +2098,36 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "preferences".to_string(),
|
namespace: "preferences".to_string(),
|
||||||
memory_key: "editor".to_string(),
|
memory_key: "editor".to_string(),
|
||||||
content: "Prefers clippy diagnostics".to_string(),
|
content: "Prefers clippy diagnostics".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-3".to_string()),
|
source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-3".to_string()),
|
source_message_id: Some("msg-3".to_string()),
|
||||||
source_message_seq: Some(4),
|
source_message_seq: Some(4),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-3".to_string()),
|
source_chat_id: Some("chat-3".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let old_hits = store
|
let old_hits = store
|
||||||
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
|
.search_memories("user", "test-channel:user-1", "rust-analyzer", None, 10)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(old_hits.is_empty());
|
assert!(old_hits.is_empty());
|
||||||
|
|
||||||
let new_hits = store
|
let new_hits = store
|
||||||
.search_memories("user", "feishu:user-1", "clippy", None, 10)
|
.search_memories("user", "test-channel:user-1", "clippy", None, 10)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(new_hits.len(), 1);
|
assert_eq!(new_hits.len(), 1);
|
||||||
|
|
||||||
let deleted = store
|
let deleted = store
|
||||||
.delete_memory("user", "feishu:user-1", "preferences", "editor")
|
.delete_memory("user", "test-channel:user-1", "preferences", "editor")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(deleted);
|
assert!(deleted);
|
||||||
|
|
||||||
let hits_after_delete = store
|
let hits_after_delete = store
|
||||||
.search_memories("user", "feishu:user-1", "clippy", None, 10)
|
.search_memories("user", "test-channel:user-1", "clippy", None, 10)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(hits_after_delete.is_empty());
|
assert!(hits_after_delete.is_empty());
|
||||||
}
|
}
|
||||||
@ -2137,21 +2139,21 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "preferences".to_string(),
|
namespace: "preferences".to_string(),
|
||||||
memory_key: "email_folder_preference".to_string(),
|
memory_key: "email_folder_preference".to_string(),
|
||||||
content: "用户提到邮件时默认查看代收邮箱。".to_string(),
|
content: "用户提到邮件时默认查看代收邮箱。".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-8".to_string()),
|
source_session_id: Some(format!("{}:chat-8", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-8".to_string()),
|
source_message_id: Some("msg-8".to_string()),
|
||||||
source_message_seq: Some(8),
|
source_message_seq: Some(8),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-8".to_string()),
|
source_chat_id: Some("chat-8".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let hits = store
|
let hits = store
|
||||||
.search_memories("user", "feishu:user-1", "email_folder_preference", None, 10)
|
.search_memories("user", "test-channel:user-1", "email_folder_preference", None, 10)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(hits.len(), 1);
|
assert_eq!(hits.len(), 1);
|
||||||
@ -2165,15 +2167,15 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "preferences".to_string(),
|
namespace: "preferences".to_string(),
|
||||||
memory_key: "editor".to_string(),
|
memory_key: "editor".to_string(),
|
||||||
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-2".to_string()),
|
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-2".to_string()),
|
source_message_id: Some("msg-2".to_string()),
|
||||||
source_message_seq: Some(3),
|
source_message_seq: Some(3),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-2".to_string()),
|
source_chat_id: Some("chat-2".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -2181,15 +2183,15 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "tasks".to_string(),
|
namespace: "tasks".to_string(),
|
||||||
memory_key: "quality".to_string(),
|
memory_key: "quality".to_string(),
|
||||||
content: "Tracks clippy warnings before release".to_string(),
|
content: "Tracks clippy warnings before release".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-3".to_string()),
|
source_session_id: Some(format!("{}:chat-3", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-3".to_string()),
|
source_message_id: Some("msg-3".to_string()),
|
||||||
source_message_seq: Some(4),
|
source_message_seq: Some(4),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-3".to_string()),
|
source_chat_id: Some("chat-3".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -2197,7 +2199,7 @@ mod tests {
|
|||||||
let hits = store
|
let hits = store
|
||||||
.search_memories_any(
|
.search_memories_any(
|
||||||
"user",
|
"user",
|
||||||
"feishu:user-1",
|
"test-channel:user-1",
|
||||||
&["rust-analyzer".to_string(), "clippy".to_string()],
|
&["rust-analyzer".to_string(), "clippy".to_string()],
|
||||||
None,
|
None,
|
||||||
10,
|
10,
|
||||||
@ -2216,45 +2218,45 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-2".to_string(),
|
scope_key: format!("{}:user-2", TEST_CHANNEL),
|
||||||
namespace: "preferences".to_string(),
|
namespace: "preferences".to_string(),
|
||||||
memory_key: "style".to_string(),
|
memory_key: "style".to_string(),
|
||||||
content: "偏好简洁表达".to_string(),
|
content: "偏好简洁表达".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-2".to_string()),
|
source_session_id: Some(format!("{}:chat-2", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-2".to_string()),
|
source_message_id: Some("msg-2".to_string()),
|
||||||
source_message_seq: Some(2),
|
source_message_seq: Some(2),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-2".to_string()),
|
source_chat_id: Some("chat-2".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "profile".to_string(),
|
namespace: "profile".to_string(),
|
||||||
memory_key: "work".to_string(),
|
memory_key: "work".to_string(),
|
||||||
content: "用户在做AI产品".to_string(),
|
content: "用户在做AI产品".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-1".to_string()),
|
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-1".to_string()),
|
source_message_id: Some("msg-1".to_string()),
|
||||||
source_message_seq: Some(1),
|
source_message_seq: Some(1),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-1".to_string()),
|
source_chat_id: Some("chat-1".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
store
|
store
|
||||||
.put_memory(&MemoryUpsert {
|
.put_memory(&MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "patterns".to_string(),
|
namespace: "patterns".to_string(),
|
||||||
memory_key: "workflow".to_string(),
|
memory_key: "workflow".to_string(),
|
||||||
content: "习惯先问方案再要代码".to_string(),
|
content: "习惯先问方案再要代码".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-1".to_string()),
|
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-3".to_string()),
|
source_message_id: Some("msg-3".to_string()),
|
||||||
source_message_seq: Some(3),
|
source_message_seq: Some(3),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-1".to_string()),
|
source_chat_id: Some("chat-1".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -2262,17 +2264,17 @@ mod tests {
|
|||||||
let scope_keys = store.list_memory_scope_keys("user").unwrap();
|
let scope_keys = store.list_memory_scope_keys("user").unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
scope_keys,
|
scope_keys,
|
||||||
vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]
|
vec!["test-channel:user-1".to_string(), "test-channel:user-2".to_string()]
|
||||||
);
|
);
|
||||||
|
|
||||||
let full_scope = store
|
let full_scope = store
|
||||||
.list_memories_for_scope("user", "feishu:user-1")
|
.list_memories_for_scope("user", "test-channel:user-1")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(full_scope.len(), 2);
|
assert_eq!(full_scope.len(), 2);
|
||||||
assert!(
|
assert!(
|
||||||
full_scope
|
full_scope
|
||||||
.iter()
|
.iter()
|
||||||
.all(|memory| memory.scope_key == "feishu:user-1")
|
.all(|memory| memory.scope_key == "test-channel:user-1")
|
||||||
);
|
);
|
||||||
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
|
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
|
||||||
assert!(
|
assert!(
|
||||||
@ -2298,7 +2300,7 @@ mod tests {
|
|||||||
interval_secs: 300,
|
interval_secs: 300,
|
||||||
startup_delay_secs: 10,
|
startup_delay_secs: 10,
|
||||||
target: serde_json::json!({
|
target: serde_json::json!({
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo",
|
"chat_id": "oc_demo",
|
||||||
}),
|
}),
|
||||||
payload: serde_json::json!({
|
payload: serde_json::json!({
|
||||||
|
|||||||
@ -221,15 +221,17 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_manage_put_returns_saved_memory() {
|
async fn test_memory_manage_put_returns_saved_memory() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemoryManageTool::new(store);
|
let tool = MemoryManageTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
sender_id: Some("user-1".to_string()),
|
sender_id: Some("user-1".to_string()),
|
||||||
chat_id: Some("chat-1".to_string()),
|
chat_id: Some("chat-1".to_string()),
|
||||||
session_id: Some("feishu:chat-1".to_string()),
|
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
message_id: Some("msg-1".to_string()),
|
message_id: Some("msg-1".to_string()),
|
||||||
message_seq: Some(1),
|
message_seq: Some(1),
|
||||||
};
|
};
|
||||||
@ -275,7 +277,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemoryManageTool::new(store);
|
let tool = MemoryManageTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
sender_id: Some("user-1".to_string()),
|
sender_id: Some("user-1".to_string()),
|
||||||
..ToolContext::default()
|
..ToolContext::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -207,31 +207,33 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_search_search_and_get() {
|
async fn test_memory_search_search_and_get() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
store
|
store
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: "feishu:user-1".to_string(),
|
scope_key: format!("{}:user-1", TEST_CHANNEL),
|
||||||
namespace: "preferences".to_string(),
|
namespace: "preferences".to_string(),
|
||||||
memory_key: "language".to_string(),
|
memory_key: "language".to_string(),
|
||||||
content: "User prefers Chinese responses".to_string(),
|
content: "User prefers Chinese responses".to_string(),
|
||||||
source_type: "message".to_string(),
|
source_type: "message".to_string(),
|
||||||
source_session_id: Some("feishu:chat-1".to_string()),
|
source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
source_message_id: Some("msg-1".to_string()),
|
source_message_id: Some("msg-1".to_string()),
|
||||||
source_message_seq: Some(1),
|
source_message_seq: Some(1),
|
||||||
source_channel_name: Some("feishu".to_string()),
|
source_channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
source_chat_id: Some("chat-1".to_string()),
|
source_chat_id: Some("chat-1".to_string()),
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let tool = MemorySearchTool::new(store);
|
let tool = MemorySearchTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
sender_id: Some("user-1".to_string()),
|
sender_id: Some("user-1".to_string()),
|
||||||
chat_id: Some("chat-1".to_string()),
|
chat_id: Some("chat-1".to_string()),
|
||||||
session_id: Some("feishu:chat-1".to_string()),
|
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
message_id: Some("msg-2".to_string()),
|
message_id: Some("msg-2".to_string()),
|
||||||
message_seq: Some(2),
|
message_seq: Some(2),
|
||||||
};
|
};
|
||||||
@ -285,7 +287,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemorySearchTool::new(store);
|
let tool = MemorySearchTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
sender_id: Some("user-1".to_string()),
|
sender_id: Some("user-1".to_string()),
|
||||||
..ToolContext::default()
|
..ToolContext::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -435,6 +435,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_scheduler_manage_put_and_get() {
|
async fn test_scheduler_manage_put_and_get() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
@ -450,7 +452,7 @@ mod tests {
|
|||||||
"seconds": 60
|
"seconds": 60
|
||||||
},
|
},
|
||||||
"target": {
|
"target": {
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
@ -488,7 +490,7 @@ mod tests {
|
|||||||
"expression": "0 9 * * *"
|
"expression": "0 9 * * *"
|
||||||
},
|
},
|
||||||
"target": {
|
"target": {
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
@ -518,7 +520,7 @@ mod tests {
|
|||||||
"expression": "0 9 * * *"
|
"expression": "0 9 * * *"
|
||||||
},
|
},
|
||||||
"target": {
|
"target": {
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo",
|
"chat_id": "oc_demo",
|
||||||
"session_chat_id": "scheduler/agent.daily_summary.background"
|
"session_chat_id": "scheduler/agent.daily_summary.background"
|
||||||
},
|
},
|
||||||
@ -576,10 +578,10 @@ mod tests {
|
|||||||
let put_result = tool
|
let put_result = tool
|
||||||
.execute_with_context(
|
.execute_with_context(
|
||||||
&crate::tools::ToolContext {
|
&crate::tools::ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
sender_id: Some("user-1".to_string()),
|
sender_id: Some("user-1".to_string()),
|
||||||
chat_id: Some("oc_demo".to_string()),
|
chat_id: Some("oc_demo".to_string()),
|
||||||
session_id: Some("feishu:oc_demo".to_string()),
|
session_id: Some(format!("{}:oc_demo", TEST_CHANNEL)),
|
||||||
message_id: Some("msg-1".to_string()),
|
message_id: Some("msg-1".to_string()),
|
||||||
message_seq: Some(1),
|
message_seq: Some(1),
|
||||||
},
|
},
|
||||||
@ -602,7 +604,7 @@ mod tests {
|
|||||||
assert!(put_result.success);
|
assert!(put_result.success);
|
||||||
|
|
||||||
let saved = store.get_scheduler_job("work_reminder").unwrap().unwrap();
|
let saved = store.get_scheduler_job("work_reminder").unwrap().unwrap();
|
||||||
assert_eq!(saved.target["channel"], "feishu");
|
assert_eq!(saved.target["channel"], "test-channel");
|
||||||
assert_eq!(saved.target["chat_id"], "oc_demo");
|
assert_eq!(saved.target["chat_id"], "oc_demo");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -621,7 +623,7 @@ mod tests {
|
|||||||
"expression": "0 9 * * *"
|
"expression": "0 9 * * *"
|
||||||
},
|
},
|
||||||
"target": {
|
"target": {
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
@ -653,7 +655,7 @@ mod tests {
|
|||||||
"expression": "0 9 * * *"
|
"expression": "0 9 * * *"
|
||||||
},
|
},
|
||||||
"target": {
|
"target": {
|
||||||
"channel": "feishu",
|
"channel": "test-channel",
|
||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
|
|||||||
@ -240,6 +240,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
struct MockSender {
|
struct MockSender {
|
||||||
outcome: SessionSendOutcome,
|
outcome: SessionSendOutcome,
|
||||||
}
|
}
|
||||||
@ -257,7 +259,7 @@ mod tests {
|
|||||||
|
|
||||||
fn context() -> ToolContext {
|
fn context() -> ToolContext {
|
||||||
ToolContext {
|
ToolContext {
|
||||||
channel_name: Some("feishu".to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
chat_id: Some("chat-1".to_string()),
|
chat_id: Some("chat-1".to_string()),
|
||||||
..ToolContext::default()
|
..ToolContext::default()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -124,14 +124,16 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
|
|
||||||
|
const TEST_CHANNEL: &str = "test-channel";
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_skill_activate_records_failed_activation_event() {
|
async fn test_skill_activate_records_failed_activation_event() {
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
store.ensure_channel_session("feishu", "chat-1").unwrap();
|
store.ensure_channel_session(TEST_CHANNEL, "chat-1").unwrap();
|
||||||
let tool = SkillActivateTool::new(skills, store.clone());
|
let tool = SkillActivateTool::new(skills, store.clone());
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
session_id: Some("feishu:chat-1".to_string()),
|
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||||
..ToolContext::default()
|
..ToolContext::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -143,7 +145,9 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("not found"));
|
assert!(result.error.unwrap().contains("not found"));
|
||||||
|
|
||||||
let events = store.list_skill_events(Some("feishu:chat-1")).unwrap();
|
let events = store
|
||||||
|
.list_skill_events(Some(&format!("{}:chat-1", TEST_CHANNEL)))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(events.len(), 1);
|
assert_eq!(events.len(), 1);
|
||||||
assert_eq!(events[0].event_type, "activation_failed");
|
assert_eq!(events[0].event_type, "activation_failed");
|
||||||
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));
|
assert_eq!(events[0].skill_name.as_deref(), Some("demo"));
|
||||||
|
|||||||
@ -32,7 +32,7 @@ impl Tool for SkillManageTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Manage PicoBot skills stored under .picobot/skills or ~/.picobot/skills, while discovery also reads .agents/skills and ~/.agents/skills. Supports actions: list, get, create, update, delete, reload."
|
"Manage PicoBot skills stored under .picobot/skills or ~/.picobot/skills, while discovery also reads .agents/skills and ~/.agents/skills. Supports actions: list, get, create, update, delete, disable, reload."
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
@ -41,18 +41,25 @@ impl Tool for SkillManageTool {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["list", "get", "create", "update", "delete", "reload"],
|
"enum": ["list", "get", "create", "update", "delete", "disable", "reload"],
|
||||||
"description": "Management action to perform"
|
"description": "Management action to perform"
|
||||||
},
|
},
|
||||||
"scope": {
|
"scope": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["project", "user"],
|
"enum": ["project", "user"],
|
||||||
"description": "Writable skill scope for create/update/delete. Defaults to project. .agents discovery sources are read-only here."
|
"description": "Writable skill scope for create/update/delete/disable. Defaults to project. .agents discovery sources are read-only here, but can still be disabled via sidecar state."
|
||||||
},
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Skill name"
|
"description": "Skill name"
|
||||||
},
|
},
|
||||||
|
"names": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Skill names for batch disable; pass a single-item array to disable one skill"
|
||||||
|
},
|
||||||
"description": {
|
"description": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Skill description used for discovery"
|
"description": "Skill description used for discovery"
|
||||||
@ -93,6 +100,10 @@ impl Tool for SkillManageTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let name = args.get("name").and_then(|v| v.as_str());
|
let name = args.get("name").and_then(|v| v.as_str());
|
||||||
|
let names = match parse_disable_names(&args) {
|
||||||
|
Ok(names) => names,
|
||||||
|
Err(err) => return Ok(error_result(&err)),
|
||||||
|
};
|
||||||
|
|
||||||
let result = match action {
|
let result = match action {
|
||||||
"list" => list_skills_payload(&self.skills),
|
"list" => list_skills_payload(&self.skills),
|
||||||
@ -192,6 +203,30 @@ impl Tool for SkillManageTool {
|
|||||||
}),
|
}),
|
||||||
Err(err) => return Ok(error_result(&err)),
|
Err(err) => return Ok(error_result(&err)),
|
||||||
},
|
},
|
||||||
|
"disable" => {
|
||||||
|
let targets = &names;
|
||||||
|
|
||||||
|
let mut changes = Vec::new();
|
||||||
|
for target in targets {
|
||||||
|
match self.skills.disable_skill(scope, target, false) {
|
||||||
|
Ok(change) => changes.push(change),
|
||||||
|
Err(err) => return Ok(error_result(&err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if reload {
|
||||||
|
if let Err(err) = self.skills.reload() {
|
||||||
|
return Ok(error_result(&err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
json!({
|
||||||
|
"status": "disabled",
|
||||||
|
"scope": scope.as_str(),
|
||||||
|
"count": changes.len(),
|
||||||
|
"reloaded": reload,
|
||||||
|
"changes": changes.into_iter().map(skill_change_payload).collect::<Vec<_>>(),
|
||||||
|
})
|
||||||
|
}
|
||||||
_ => return Ok(error_result("Unsupported action")),
|
_ => return Ok(error_result("Unsupported action")),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -242,6 +277,42 @@ fn error_result(message: &str) -> ToolResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_disable_names(args: &serde_json::Value) -> Result<Vec<String>, String> {
|
||||||
|
let names = args
|
||||||
|
.get("names")
|
||||||
|
.ok_or_else(|| "disable requires names".to_string())?
|
||||||
|
.as_array()
|
||||||
|
.ok_or_else(|| "names must be an array of strings".to_string())?;
|
||||||
|
|
||||||
|
let mut parsed = Vec::new();
|
||||||
|
for item in names {
|
||||||
|
let name = item
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| "names must be an array of strings".to_string())?
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
if name.is_empty() {
|
||||||
|
return Err("names must not contain empty values".to_string());
|
||||||
|
}
|
||||||
|
parsed.push(name);
|
||||||
|
}
|
||||||
|
if parsed.is_empty() {
|
||||||
|
return Err("names must not be empty".to_string());
|
||||||
|
}
|
||||||
|
Ok(parsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn skill_change_payload(change: crate::skills::SkillAvailabilityChange) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"name": change.name,
|
||||||
|
"scope": change.scope.as_str(),
|
||||||
|
"path": change.state_path.display().to_string(),
|
||||||
|
"changed": change.changed,
|
||||||
|
"available": change.available,
|
||||||
|
"disabled_in_scopes": change.disabled_in_scopes.into_iter().map(|scope| scope.as_str()).collect::<Vec<_>>(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn list_skills_payload(skills: &Arc<SkillRuntime>) -> serde_json::Value {
|
fn list_skills_payload(skills: &Arc<SkillRuntime>) -> serde_json::Value {
|
||||||
let skills = skills.list_skills();
|
let skills = skills.list_skills();
|
||||||
json!({
|
json!({
|
||||||
@ -259,3 +330,170 @@ fn list_skills_payload(skills: &Arc<SkillRuntime>) -> serde_json::Value {
|
|||||||
})).collect::<Vec<_>>()
|
})).collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::SkillsConfig;
|
||||||
|
use crate::skills::acquire_skill_test_env_lock;
|
||||||
|
use std::ffi::OsString;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
struct CurrentDirGuard {
|
||||||
|
previous: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HomeDirGuard {
|
||||||
|
previous: Option<OsString>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CurrentDirGuard {
|
||||||
|
fn enter(path: &Path) -> Self {
|
||||||
|
let previous = std::env::current_dir().unwrap();
|
||||||
|
std::env::set_current_dir(path).unwrap();
|
||||||
|
Self { previous }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for CurrentDirGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = std::env::set_current_dir(&self.previous);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HomeDirGuard {
|
||||||
|
fn enter(path: &Path) -> Self {
|
||||||
|
let previous = std::env::var_os("HOME");
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("HOME", path);
|
||||||
|
}
|
||||||
|
Self { previous }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for HomeDirGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
match &self.previous {
|
||||||
|
Some(value) => unsafe {
|
||||||
|
std::env::set_var("HOME", value);
|
||||||
|
},
|
||||||
|
None => unsafe {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
acquire_skill_test_env_lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_skill_manage_disable_updates_runtime() {
|
||||||
|
let _lock = acquire_test_lock();
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
std::fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
std::fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
|
let runtime = Arc::new(SkillRuntime::from_config(SkillsConfig {
|
||||||
|
enabled: true,
|
||||||
|
sources: vec!["project".to_string()],
|
||||||
|
max_index_chars: 4000,
|
||||||
|
max_listed_skills: 32,
|
||||||
|
}));
|
||||||
|
runtime
|
||||||
|
.create_skill(SkillScope::Project, "demo", "demo skill", "body", true)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tool = SkillManageTool::new(runtime.clone());
|
||||||
|
|
||||||
|
let disabled = tool
|
||||||
|
.execute(json!({
|
||||||
|
"action": "disable",
|
||||||
|
"names": ["demo"],
|
||||||
|
"scope": "project"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(disabled.success);
|
||||||
|
assert!(disabled.output.contains("disabled"));
|
||||||
|
assert!(runtime.get_skill("demo").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_skill_manage_batch_disable_uses_names_array() {
|
||||||
|
let _lock = acquire_test_lock();
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
std::fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
std::fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
|
let runtime = Arc::new(SkillRuntime::from_config(SkillsConfig {
|
||||||
|
enabled: true,
|
||||||
|
sources: vec!["project".to_string()],
|
||||||
|
max_index_chars: 4000,
|
||||||
|
max_listed_skills: 32,
|
||||||
|
}));
|
||||||
|
runtime
|
||||||
|
.create_skill(SkillScope::Project, "demo-a", "demo skill a", "body", true)
|
||||||
|
.unwrap();
|
||||||
|
runtime
|
||||||
|
.create_skill(SkillScope::Project, "demo-b", "demo skill b", "body", true)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tool = SkillManageTool::new(runtime.clone());
|
||||||
|
|
||||||
|
let disabled = tool
|
||||||
|
.execute(json!({
|
||||||
|
"action": "disable",
|
||||||
|
"names": ["demo-a", "demo-b"],
|
||||||
|
"scope": "project"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(disabled.success);
|
||||||
|
assert!(disabled.output.contains("\"count\": 2"));
|
||||||
|
assert!(runtime.get_skill("demo-a").is_none());
|
||||||
|
assert!(runtime.get_skill("demo-b").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_skill_manage_disable_requires_names_array() {
|
||||||
|
let _lock = acquire_test_lock();
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let home_dir = temp_dir.path().join("home");
|
||||||
|
let project_dir = temp_dir.path().join("project");
|
||||||
|
std::fs::create_dir_all(&home_dir).unwrap();
|
||||||
|
std::fs::create_dir_all(&project_dir).unwrap();
|
||||||
|
let _home = HomeDirGuard::enter(&home_dir);
|
||||||
|
let _guard = CurrentDirGuard::enter(&project_dir);
|
||||||
|
|
||||||
|
let runtime = Arc::new(SkillRuntime::from_config(SkillsConfig {
|
||||||
|
enabled: true,
|
||||||
|
sources: vec!["project".to_string()],
|
||||||
|
max_index_chars: 4000,
|
||||||
|
max_listed_skills: 32,
|
||||||
|
}));
|
||||||
|
let tool = SkillManageTool::new(runtime);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"action": "disable",
|
||||||
|
"name": "demo",
|
||||||
|
"scope": "project"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("disable requires names"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
1
vendor/wechatbot/.cargo-ok
vendored
Normal file
1
vendor/wechatbot/.cargo-ok
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"v":1}
|
||||||
6
vendor/wechatbot/.cargo_vcs_info.json
vendored
Normal file
6
vendor/wechatbot/.cargo_vcs_info.json
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"git": {
|
||||||
|
"sha1": "70bc64cc8035de4677bbe01265570e7f157bb31d"
|
||||||
|
},
|
||||||
|
"path_in_vcs": "rust"
|
||||||
|
}
|
||||||
91
vendor/wechatbot/Cargo.toml
vendored
Normal file
91
vendor/wechatbot/Cargo.toml
vendored
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||||
|
#
|
||||||
|
# When uploading crates to the registry Cargo will automatically
|
||||||
|
# "normalize" Cargo.toml files for maximal compatibility
|
||||||
|
# with all versions of Cargo and also rewrite `path` dependencies
|
||||||
|
# to registry (e.g., crates.io) dependencies.
|
||||||
|
#
|
||||||
|
# If you are reading this file be aware that the original Cargo.toml
|
||||||
|
# will likely look very different (and much more reasonable).
|
||||||
|
# See Cargo.toml.orig for the original contents.
|
||||||
|
|
||||||
|
[package]
|
||||||
|
edition = "2021"
|
||||||
|
name = "wechatbot"
|
||||||
|
version = "0.3.2"
|
||||||
|
build = false
|
||||||
|
autolib = false
|
||||||
|
autobins = false
|
||||||
|
autoexamples = false
|
||||||
|
autotests = false
|
||||||
|
autobenches = false
|
||||||
|
description = "WeChat iLink Bot SDK for Rust"
|
||||||
|
homepage = "https://github.com/corespeed-io/wechatbot"
|
||||||
|
documentation = "https://docs.rs/wechatbot"
|
||||||
|
readme = "README.md"
|
||||||
|
license = "MIT"
|
||||||
|
repository = "https://github.com/corespeed-io/wechatbot"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "wechatbot"
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "echo_bot"
|
||||||
|
path = "examples/echo_bot.rs"
|
||||||
|
|
||||||
|
[dependencies.aes]
|
||||||
|
version = "0.8"
|
||||||
|
|
||||||
|
[dependencies.base64]
|
||||||
|
version = "0.22"
|
||||||
|
|
||||||
|
[dependencies.dirs-next]
|
||||||
|
version = "2"
|
||||||
|
|
||||||
|
[dependencies.hex]
|
||||||
|
version = "0.4"
|
||||||
|
|
||||||
|
[dependencies.md-5]
|
||||||
|
version = "0.10"
|
||||||
|
|
||||||
|
[dependencies.rand]
|
||||||
|
version = "0.10"
|
||||||
|
|
||||||
|
[dependencies.reqwest]
|
||||||
|
version = "0.12"
|
||||||
|
default-features = false
|
||||||
|
features = ["json", "rustls-tls"]
|
||||||
|
|
||||||
|
[dependencies.serde]
|
||||||
|
version = "1"
|
||||||
|
features = ["derive"]
|
||||||
|
|
||||||
|
[dependencies.serde_json]
|
||||||
|
version = "1"
|
||||||
|
|
||||||
|
[dependencies.serde_repr]
|
||||||
|
version = "0.1"
|
||||||
|
|
||||||
|
[dependencies.thiserror]
|
||||||
|
version = "2"
|
||||||
|
|
||||||
|
[dependencies.tokio]
|
||||||
|
version = "1"
|
||||||
|
features = ["full"]
|
||||||
|
|
||||||
|
[dependencies.tracing]
|
||||||
|
version = "0.1"
|
||||||
|
|
||||||
|
[dependencies.urlencoding]
|
||||||
|
version = "2"
|
||||||
|
|
||||||
|
[dependencies.uuid]
|
||||||
|
version = "1"
|
||||||
|
features = ["v4"]
|
||||||
|
|
||||||
|
[dev-dependencies.tokio-test]
|
||||||
|
version = "0.4"
|
||||||
|
|
||||||
|
[dev-dependencies.tracing-subscriber]
|
||||||
|
version = "0.3"
|
||||||
35
vendor/wechatbot/Cargo.toml.orig
generated
vendored
Normal file
35
vendor/wechatbot/Cargo.toml.orig
generated
vendored
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
[package]
|
||||||
|
name = "wechatbot"
|
||||||
|
version = "0.3.2"
|
||||||
|
edition = "2021"
|
||||||
|
description = "WeChat iLink Bot SDK for Rust"
|
||||||
|
license = "MIT"
|
||||||
|
readme = "README.md"
|
||||||
|
repository = "https://github.com/corespeed-io/wechatbot"
|
||||||
|
homepage = "https://github.com/corespeed-io/wechatbot"
|
||||||
|
documentation = "https://docs.rs/wechatbot"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
aes = "0.8"
|
||||||
|
base64 = "0.22"
|
||||||
|
hex = "0.4"
|
||||||
|
md-5 = "0.10"
|
||||||
|
rand = "0.10"
|
||||||
|
reqwest = { version = "0.12", features = ["json"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
serde_repr = "0.1"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
uuid = { version = "1", features = ["v4"] }
|
||||||
|
thiserror = "2"
|
||||||
|
tracing = "0.1"
|
||||||
|
urlencoding = "2"
|
||||||
|
dirs-next = "2"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio-test = "0.4"
|
||||||
|
tracing-subscriber = "0.3"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "echo_bot"
|
||||||
|
path = "examples/echo_bot.rs"
|
||||||
226
vendor/wechatbot/README.md
vendored
Normal file
226
vendor/wechatbot/README.md
vendored
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
# wechatbot — Rust SDK
|
||||||
|
|
||||||
|
WeChat iLink Bot SDK for Rust — async, type-safe, zero-copy where possible.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[dependencies]
|
||||||
|
wechatbot = "0.1"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
Requires Rust 2021 edition. Built on `tokio` + `reqwest`.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use wechatbot::{WeChatBot, BotOptions};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let bot = WeChatBot::new(BotOptions::default());
|
||||||
|
let creds = bot.login(false).await.unwrap();
|
||||||
|
println!("Logged in: {}", creds.account_id);
|
||||||
|
|
||||||
|
bot.on_message(Box::new(|msg| {
|
||||||
|
println!("{}: {}", msg.user_id, msg.text);
|
||||||
|
})).await;
|
||||||
|
|
||||||
|
bot.run().await.unwrap();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── lib.rs ← Public re-exports
|
||||||
|
├── types.rs ← All protocol & public types (serde)
|
||||||
|
├── error.rs ← Error hierarchy (thiserror)
|
||||||
|
├── protocol.rs ← Raw iLink API calls (reqwest)
|
||||||
|
├── crypto.rs ← AES-128-ECB encrypt/decrypt + key encoding
|
||||||
|
└── bot.rs ← WeChatBot client (login, run, reply, send)
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Creating a Bot
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use wechatbot::{WeChatBot, BotOptions};
|
||||||
|
|
||||||
|
let bot = WeChatBot::new(BotOptions {
|
||||||
|
base_url: None, // default: ilinkai.weixin.qq.com
|
||||||
|
cred_path: None, // default: ~/.wechatbot/credentials.json
|
||||||
|
on_qr_url: Some(Box::new(|url| {
|
||||||
|
println!("Scan: {}", url);
|
||||||
|
})),
|
||||||
|
on_error: Some(Box::new(|err| {
|
||||||
|
eprintln!("Error: {}", err);
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Login (skips QR if credentials exist)
|
||||||
|
let creds = bot.login(false).await?;
|
||||||
|
|
||||||
|
// Force re-login
|
||||||
|
let creds = bot.login(true).await?;
|
||||||
|
|
||||||
|
// Credentials struct
|
||||||
|
println!("Token: {}", creds.token);
|
||||||
|
println!("Base URL: {}", creds.base_url);
|
||||||
|
println!("Account: {}", creds.account_id);
|
||||||
|
println!("User: {}", creds.user_id);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Message Handling
|
||||||
|
|
||||||
|
```rust
|
||||||
|
bot.on_message(Box::new(|msg| {
|
||||||
|
match msg.content_type {
|
||||||
|
ContentType::Text => println!("Text: {}", msg.text),
|
||||||
|
ContentType::Image => {
|
||||||
|
for img in &msg.images {
|
||||||
|
println!("Image URL: {:?}", img.url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ContentType::Voice => {
|
||||||
|
for voice in &msg.voices {
|
||||||
|
println!("Voice: {:?} ({}ms)", voice.text, voice.duration_ms.unwrap_or(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ContentType::File => {
|
||||||
|
for file in &msg.files {
|
||||||
|
println!("File: {:?}", file.file_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ContentType::Video => println!("Video received"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref quoted) = msg.quoted {
|
||||||
|
println!("Quoted: {:?}", quoted.title);
|
||||||
|
}
|
||||||
|
})).await;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sending Messages
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Reply to incoming message
|
||||||
|
bot.reply(&msg, "Echo: hello").await?;
|
||||||
|
|
||||||
|
// Send to user (needs prior context_token)
|
||||||
|
bot.send(user_id, "Hello").await?;
|
||||||
|
|
||||||
|
// Typing indicator
|
||||||
|
bot.send_typing(user_id).await?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Media Operations
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Reply with media content
|
||||||
|
bot.reply_media(&msg, SendContent::Image(png_bytes)).await?;
|
||||||
|
bot.reply_media(&msg, SendContent::File { data, file_name: "report.pdf".into() }).await?;
|
||||||
|
bot.reply_media(&msg, SendContent::Video(mp4_bytes)).await?;
|
||||||
|
```
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Download media from incoming message (priority: image > file > video > voice)
|
||||||
|
if let Some(media) = bot.download(&msg).await? {
|
||||||
|
println!("Type: {}, Size: {} bytes", media.media_type, media.data.len());
|
||||||
|
if let Some(name) = &media.file_name {
|
||||||
|
println!("Filename: {}", name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download a raw CDN reference directly
|
||||||
|
let raw = bot.download_raw(&msg.images[0].media.as_ref().unwrap(), None).await?;
|
||||||
|
```
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Upload to CDN without sending a message
|
||||||
|
let result = bot.upload(&file_bytes, user_id, 3).await?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lifecycle
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Start polling (blocks)
|
||||||
|
bot.run().await?;
|
||||||
|
|
||||||
|
// Stop
|
||||||
|
bot.stop().await;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use wechatbot::WeChatBotError;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(WeChatBotError::Api { message, errcode, .. }) => {
|
||||||
|
if errcode == -14 {
|
||||||
|
// session expired — handled automatically
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(WeChatBotError::NoContext(user_id)) => {
|
||||||
|
// no context_token for this user yet
|
||||||
|
}
|
||||||
|
Err(WeChatBotError::Transport(e)) => {
|
||||||
|
// network error
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## AES-128-ECB Crypto
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use wechatbot::{generate_aes_key, encrypt_aes_ecb, decrypt_aes_ecb, decode_aes_key};
|
||||||
|
|
||||||
|
// Generate key
|
||||||
|
let key = generate_aes_key();
|
||||||
|
|
||||||
|
// Encrypt/decrypt
|
||||||
|
let ciphertext = encrypt_aes_ecb(b"Hello", &key);
|
||||||
|
let plaintext = decrypt_aes_ecb(&ciphertext, &key)?;
|
||||||
|
|
||||||
|
// Decode protocol key (handles all 3 formats)
|
||||||
|
let key = decode_aes_key("ABEiM0RVZneImaq7zN3u/w==")?;
|
||||||
|
let key = decode_aes_key("00112233445566778899aabbccddeeff")?;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Types
|
||||||
|
|
||||||
|
All protocol types derive `Serialize` + `Deserialize` + `Clone` + `Debug`:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Wire-level (protocol)
|
||||||
|
WireMessage, WireMessageItem, CDNMedia, TextItem, ImageItem, ...
|
||||||
|
|
||||||
|
// Parsed (user-friendly)
|
||||||
|
IncomingMessage, ImageContent, VoiceContent, FileContent, VideoContent
|
||||||
|
|
||||||
|
// Auth
|
||||||
|
Credentials
|
||||||
|
|
||||||
|
// Enums
|
||||||
|
MessageType, MessageState, MessageItemType, ContentType, MediaType
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
43
vendor/wechatbot/examples/echo_bot.rs
vendored
Normal file
43
vendor/wechatbot/examples/echo_bot.rs
vendored
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
use wechatbot::{BotOptions, WeChatBot};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let bot = WeChatBot::new(BotOptions {
|
||||||
|
on_qr_url: Some(Box::new(|url| {
|
||||||
|
println!("\nScan this URL in WeChat:\n{}\n", url);
|
||||||
|
})),
|
||||||
|
on_error: Some(Box::new(|err| {
|
||||||
|
eprintln!("Error: {}", err);
|
||||||
|
})),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
let creds = bot.login(false).await.expect("login failed");
|
||||||
|
println!("Logged in: {} ({})", creds.account_id, creds.user_id);
|
||||||
|
|
||||||
|
bot.on_message(Box::new(|msg| {
|
||||||
|
println!("[{}] {}: {}", msg.content_type_str(), msg.user_id, msg.text);
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
println!("Listening for messages (Ctrl+C to stop)");
|
||||||
|
bot.run().await.expect("run failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
trait ContentTypeStr {
|
||||||
|
fn content_type_str(&self) -> &str;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContentTypeStr for wechatbot::IncomingMessage {
|
||||||
|
fn content_type_str(&self) -> &str {
|
||||||
|
match self.content_type {
|
||||||
|
wechatbot::ContentType::Text => "text",
|
||||||
|
wechatbot::ContentType::Image => "image",
|
||||||
|
wechatbot::ContentType::Voice => "voice",
|
||||||
|
wechatbot::ContentType::File => "file",
|
||||||
|
wechatbot::ContentType::Video => "video",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
741
vendor/wechatbot/src/bot.rs
vendored
Normal file
741
vendor/wechatbot/src/bot.rs
vendored
Normal file
@ -0,0 +1,741 @@
|
|||||||
|
//! Main WeChatBot client.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
use crate::cdn::CdnClient;
|
||||||
|
use crate::crypto;
|
||||||
|
use crate::error::{Result, WeChatBotError};
|
||||||
|
use crate::protocol::{self, ILinkClient};
|
||||||
|
use crate::types::*;
|
||||||
|
use md5::{Digest, Md5};
|
||||||
|
use rand::Rng;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
/// Message handler callback type.
|
||||||
|
pub type MessageHandler = Box<dyn Fn(&IncomingMessage) + Send + Sync>;
|
||||||
|
|
||||||
|
/// Bot configuration options.
|
||||||
|
pub struct BotOptions {
|
||||||
|
pub base_url: Option<String>,
|
||||||
|
pub cred_path: Option<String>,
|
||||||
|
pub on_qr_url: Option<Box<dyn Fn(&str) + Send + Sync>>,
|
||||||
|
pub on_error: Option<Box<dyn Fn(&WeChatBotError) + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BotOptions {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
base_url: None,
|
||||||
|
cred_path: None,
|
||||||
|
on_qr_url: None,
|
||||||
|
on_error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WeChatBot is the main entry point.
|
||||||
|
pub struct WeChatBot {
|
||||||
|
client: Arc<ILinkClient>,
|
||||||
|
cdn: CdnClient,
|
||||||
|
credentials: RwLock<Option<Credentials>>,
|
||||||
|
context_tokens: RwLock<HashMap<String, String>>,
|
||||||
|
handlers: Mutex<Vec<MessageHandler>>,
|
||||||
|
cursor: RwLock<String>,
|
||||||
|
base_url: RwLock<String>,
|
||||||
|
cred_path: Option<String>,
|
||||||
|
stopped: RwLock<bool>,
|
||||||
|
on_qr_url: Option<Box<dyn Fn(&str) + Send + Sync>>,
|
||||||
|
on_error: Option<Box<dyn Fn(&WeChatBotError) + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WeChatBot {
|
||||||
|
/// Create a new bot instance.
|
||||||
|
pub fn new(opts: BotOptions) -> Self {
|
||||||
|
Self {
|
||||||
|
client: Arc::new(ILinkClient::new()),
|
||||||
|
cdn: CdnClient::new(),
|
||||||
|
credentials: RwLock::new(None),
|
||||||
|
context_tokens: RwLock::new(HashMap::new()),
|
||||||
|
handlers: Mutex::new(Vec::new()),
|
||||||
|
cursor: RwLock::new(String::new()),
|
||||||
|
base_url: RwLock::new(
|
||||||
|
opts.base_url
|
||||||
|
.unwrap_or_else(|| protocol::DEFAULT_BASE_URL.to_string()),
|
||||||
|
),
|
||||||
|
cred_path: opts.cred_path,
|
||||||
|
stopped: RwLock::new(false),
|
||||||
|
on_qr_url: opts.on_qr_url,
|
||||||
|
on_error: opts.on_error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Maximum number of QR code refresh attempts before giving up.
|
||||||
|
const MAX_QR_REFRESH: u32 = 3;
|
||||||
|
/// Fixed API base URL for QR code requests.
|
||||||
|
const FIXED_QR_BASE_URL: &'static str = "https://ilinkai.weixin.qq.com";
|
||||||
|
|
||||||
|
/// Login via QR code. Returns credentials on success.
|
||||||
|
pub async fn login(&self, force: bool) -> Result<Credentials> {
|
||||||
|
let base_url = self.base_url.read().await.clone();
|
||||||
|
|
||||||
|
if !force {
|
||||||
|
if let Some(creds) = self.load_credentials().await? {
|
||||||
|
*self.credentials.write().await = Some(creds.clone());
|
||||||
|
*self.base_url.write().await = creds.base_url.clone();
|
||||||
|
info!("Loaded stored credentials for {}", creds.user_id);
|
||||||
|
return Ok(creds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QR code login flow
|
||||||
|
let mut qr_refresh_count = 0u32;
|
||||||
|
loop {
|
||||||
|
qr_refresh_count += 1;
|
||||||
|
if qr_refresh_count > Self::MAX_QR_REFRESH {
|
||||||
|
return Err(WeChatBotError::Auth(format!(
|
||||||
|
"QR code expired {} times — login aborted",
|
||||||
|
Self::MAX_QR_REFRESH
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let qr = self.client.get_qr_code(Self::FIXED_QR_BASE_URL).await?;
|
||||||
|
|
||||||
|
if let Some(ref cb) = self.on_qr_url {
|
||||||
|
cb(&qr.qrcode_img_content);
|
||||||
|
} else {
|
||||||
|
eprintln!("[wechatbot] Scan: {}", qr.qrcode_img_content);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut last_status = String::new();
|
||||||
|
let mut current_poll_base_url = Self::FIXED_QR_BASE_URL.to_string();
|
||||||
|
loop {
|
||||||
|
let status = self
|
||||||
|
.client
|
||||||
|
.poll_qr_status(¤t_poll_base_url, &qr.qrcode)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if status.status != last_status {
|
||||||
|
last_status = status.status.clone();
|
||||||
|
match status.status.as_str() {
|
||||||
|
"scaned" => info!("QR scanned — confirm in WeChat"),
|
||||||
|
"expired" => warn!("QR expired — requesting new one"),
|
||||||
|
"confirmed" => info!("Login confirmed"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.status == "confirmed" {
|
||||||
|
let token = status
|
||||||
|
.bot_token
|
||||||
|
.ok_or_else(|| WeChatBotError::Auth("missing bot_token".into()))?;
|
||||||
|
let creds = Credentials {
|
||||||
|
token,
|
||||||
|
base_url: status.baseurl.unwrap_or_else(|| base_url.clone()),
|
||||||
|
account_id: status.ilink_bot_id.unwrap_or_default(),
|
||||||
|
user_id: status.ilink_user_id.unwrap_or_default(),
|
||||||
|
saved_at: Some(chrono_now()),
|
||||||
|
};
|
||||||
|
self.save_credentials(&creds).await?;
|
||||||
|
*self.credentials.write().await = Some(creds.clone());
|
||||||
|
*self.base_url.write().await = creds.base_url.clone();
|
||||||
|
return Ok(creds);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle IDC redirect
|
||||||
|
if status.status == "scaned_but_redirect" {
|
||||||
|
if let Some(ref host) = status.redirect_host {
|
||||||
|
current_poll_base_url = format!("https://{}", host);
|
||||||
|
info!("IDC redirect, switching polling host to {}", host);
|
||||||
|
} else {
|
||||||
|
warn!("Received scaned_but_redirect but redirect_host is missing");
|
||||||
|
}
|
||||||
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.status == "expired" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a message handler.
|
||||||
|
pub async fn on_message(&self, handler: MessageHandler) {
|
||||||
|
self.handlers.lock().await.push(handler);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reply to an incoming message.
|
||||||
|
pub async fn reply(&self, msg: &IncomingMessage, text: &str) -> Result<()> {
|
||||||
|
self.context_tokens
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(msg.user_id.clone(), msg.context_token.clone());
|
||||||
|
self.send_text(&msg.user_id, text, &msg.context_token).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send text to a user (needs prior context_token).
|
||||||
|
pub async fn send(&self, user_id: &str, text: &str) -> Result<()> {
|
||||||
|
let ct = self.context_tokens.read().await.get(user_id).cloned();
|
||||||
|
let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?;
|
||||||
|
self.send_text(user_id, text, &ct).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Show "typing..." indicator.
|
||||||
|
pub async fn send_typing(&self, user_id: &str) -> Result<()> {
|
||||||
|
let ct = self.context_tokens.read().await.get(user_id).cloned();
|
||||||
|
let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?;
|
||||||
|
let (base_url, token) = self.get_auth().await?;
|
||||||
|
let config = self
|
||||||
|
.client
|
||||||
|
.get_config(&base_url, &token, user_id, &ct)
|
||||||
|
.await?;
|
||||||
|
if let Some(ticket) = config.typing_ticket {
|
||||||
|
self.client
|
||||||
|
.send_typing(&base_url, &token, user_id, &ticket, 1)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reply with media content (image, video, or file).
|
||||||
|
pub async fn reply_media(&self, msg: &IncomingMessage, content: SendContent) -> Result<()> {
|
||||||
|
self.context_tokens
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(msg.user_id.clone(), msg.context_token.clone());
|
||||||
|
self.send_content(&msg.user_id, &msg.context_token, content)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send any content type to a user (needs prior context_token).
|
||||||
|
pub async fn send_media(&self, user_id: &str, content: SendContent) -> Result<()> {
|
||||||
|
let ct = self.context_tokens.read().await.get(user_id).cloned();
|
||||||
|
let ct = ct.ok_or_else(|| WeChatBotError::NoContext(user_id.to_string()))?;
|
||||||
|
self.send_content(user_id, &ct, content).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download media from an incoming message.
|
||||||
|
/// Returns None if the message has no media. Priority: image > file > video > voice.
|
||||||
|
pub async fn download(&self, msg: &IncomingMessage) -> Result<Option<DownloadedMedia>> {
|
||||||
|
if let Some(img) = msg.images.first() {
|
||||||
|
if let Some(ref media) = img.media {
|
||||||
|
let data = self.cdn.download(media, img.aes_key.as_deref()).await?;
|
||||||
|
return Ok(Some(DownloadedMedia {
|
||||||
|
data,
|
||||||
|
media_type: "image".into(),
|
||||||
|
file_name: None,
|
||||||
|
format: None,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(file) = msg.files.first() {
|
||||||
|
if let Some(ref media) = file.media {
|
||||||
|
let data = self.cdn.download(media, None).await?;
|
||||||
|
return Ok(Some(DownloadedMedia {
|
||||||
|
data,
|
||||||
|
media_type: "file".into(),
|
||||||
|
file_name: Some(file.file_name.clone().unwrap_or_else(|| "file.bin".into())),
|
||||||
|
format: None,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(video) = msg.videos.first() {
|
||||||
|
if let Some(ref media) = video.media {
|
||||||
|
let data = self.cdn.download(media, None).await?;
|
||||||
|
return Ok(Some(DownloadedMedia {
|
||||||
|
data,
|
||||||
|
media_type: "video".into(),
|
||||||
|
file_name: None,
|
||||||
|
format: None,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(voice) = msg.voices.first() {
|
||||||
|
if let Some(ref media) = voice.media {
|
||||||
|
let data = self.cdn.download(media, None).await?;
|
||||||
|
return Ok(Some(DownloadedMedia {
|
||||||
|
data,
|
||||||
|
media_type: "voice".into(),
|
||||||
|
file_name: None,
|
||||||
|
format: Some("silk".into()),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download and decrypt a raw CDN media reference.
|
||||||
|
pub async fn download_raw(
|
||||||
|
&self,
|
||||||
|
media: &CDNMedia,
|
||||||
|
aeskey_override: Option<&str>,
|
||||||
|
) -> Result<Vec<u8>> {
|
||||||
|
self.cdn.download(media, aeskey_override).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upload data to WeChat CDN without sending a message.
|
||||||
|
pub async fn upload(
|
||||||
|
&self,
|
||||||
|
data: &[u8],
|
||||||
|
user_id: &str,
|
||||||
|
media_type: i32,
|
||||||
|
) -> Result<UploadResult> {
|
||||||
|
let (base_url, token) = self.get_auth().await?;
|
||||||
|
self.cdn_upload(&base_url, &token, data, user_id, media_type)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the long-poll loop. Blocks until stopped.
|
||||||
|
pub async fn run(&self) -> Result<()> {
|
||||||
|
*self.stopped.write().await = false;
|
||||||
|
info!("Long-poll loop started");
|
||||||
|
let mut retry_delay = Duration::from_secs(1);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if *self.stopped.read().await {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (base_url, token) = self.get_auth().await?;
|
||||||
|
let cursor = self.cursor.read().await.clone();
|
||||||
|
|
||||||
|
match self.client.get_updates(&base_url, &token, &cursor).await {
|
||||||
|
Ok(updates) => {
|
||||||
|
if !updates.get_updates_buf.is_empty() {
|
||||||
|
*self.cursor.write().await = updates.get_updates_buf;
|
||||||
|
}
|
||||||
|
retry_delay = Duration::from_secs(1);
|
||||||
|
|
||||||
|
for wire in &updates.msgs {
|
||||||
|
self.remember_context(wire).await;
|
||||||
|
if let Some(incoming) = IncomingMessage::from_wire(wire) {
|
||||||
|
let handlers = self.handlers.lock().await;
|
||||||
|
for handler in handlers.iter() {
|
||||||
|
handler(&incoming);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) if e.is_session_expired() => {
|
||||||
|
warn!("Session expired — re-login required");
|
||||||
|
*self.context_tokens.write().await = HashMap::new();
|
||||||
|
*self.cursor.write().await = String::new();
|
||||||
|
if let Err(e) = self.login(true).await {
|
||||||
|
self.report_error(&e);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
self.report_error(&e);
|
||||||
|
sleep(retry_delay).await;
|
||||||
|
retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(10));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Long-poll loop stopped");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop the bot.
|
||||||
|
pub async fn stop(&self) {
|
||||||
|
*self.stopped.write().await = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- internal media ---
|
||||||
|
|
||||||
|
fn send_content<'a>(
|
||||||
|
&'a self,
|
||||||
|
user_id: &'a str,
|
||||||
|
context_token: &'a str,
|
||||||
|
content: SendContent,
|
||||||
|
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let (base_url, token) = self.get_auth().await?;
|
||||||
|
match content {
|
||||||
|
SendContent::Text(text) => self.send_text(user_id, &text, context_token).await,
|
||||||
|
SendContent::Image { data, caption } => {
|
||||||
|
let result = self
|
||||||
|
.cdn_upload(&base_url, &token, &data, user_id, 1)
|
||||||
|
.await?;
|
||||||
|
let mut items = Vec::new();
|
||||||
|
if let Some(cap) = caption {
|
||||||
|
items.push(json!({"type": 1, "text_item": {"text": cap}}));
|
||||||
|
}
|
||||||
|
items.push(json!({"type": 2, "image_item": {
|
||||||
|
"media": cdn_media_json(&result.media),
|
||||||
|
"mid_size": result.encrypted_file_size,
|
||||||
|
}}));
|
||||||
|
let msg = protocol::build_media_message(user_id, context_token, items);
|
||||||
|
self.client.send_message(&base_url, &token, &msg).await
|
||||||
|
}
|
||||||
|
SendContent::Video { data, caption } => {
|
||||||
|
let result = self
|
||||||
|
.cdn_upload(&base_url, &token, &data, user_id, 2)
|
||||||
|
.await?;
|
||||||
|
let mut items = Vec::new();
|
||||||
|
if let Some(cap) = caption {
|
||||||
|
items.push(json!({"type": 1, "text_item": {"text": cap}}));
|
||||||
|
}
|
||||||
|
items.push(json!({"type": 5, "video_item": {
|
||||||
|
"media": cdn_media_json(&result.media),
|
||||||
|
"video_size": result.encrypted_file_size,
|
||||||
|
}}));
|
||||||
|
let msg = protocol::build_media_message(user_id, context_token, items);
|
||||||
|
self.client.send_message(&base_url, &token, &msg).await
|
||||||
|
}
|
||||||
|
SendContent::File {
|
||||||
|
data,
|
||||||
|
file_name,
|
||||||
|
caption,
|
||||||
|
} => {
|
||||||
|
let cat = categorize_by_extension(&file_name);
|
||||||
|
match cat {
|
||||||
|
"image" => {
|
||||||
|
self.send_content(
|
||||||
|
user_id,
|
||||||
|
context_token,
|
||||||
|
SendContent::Image { data, caption },
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
"video" => {
|
||||||
|
self.send_content(
|
||||||
|
user_id,
|
||||||
|
context_token,
|
||||||
|
SendContent::Video { data, caption },
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Some(cap) = caption {
|
||||||
|
self.send_text(user_id, &cap, context_token).await?;
|
||||||
|
}
|
||||||
|
let data_len = data.len();
|
||||||
|
let result = self
|
||||||
|
.cdn_upload(&base_url, &token, &data, user_id, 3)
|
||||||
|
.await?;
|
||||||
|
let items = vec![json!({"type": 4, "file_item": {
|
||||||
|
"media": cdn_media_json(&result.media),
|
||||||
|
"file_name": file_name,
|
||||||
|
"len": data_len.to_string(),
|
||||||
|
}})];
|
||||||
|
let msg = protocol::build_media_message(user_id, context_token, items);
|
||||||
|
self.client.send_message(&base_url, &token, &msg).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cdn_upload(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
token: &str,
|
||||||
|
data: &[u8],
|
||||||
|
user_id: &str,
|
||||||
|
media_type: i32,
|
||||||
|
) -> Result<UploadResult> {
|
||||||
|
let aes_key = crypto::generate_aes_key();
|
||||||
|
let ciphertext = crypto::encrypt_aes_ecb(data, &aes_key);
|
||||||
|
|
||||||
|
let mut filekey_buf = [0u8; 16];
|
||||||
|
rand::rng().fill_bytes(&mut filekey_buf);
|
||||||
|
let filekey = hex::encode(filekey_buf);
|
||||||
|
|
||||||
|
let raw_md5 = hex::encode(Md5::digest(data));
|
||||||
|
|
||||||
|
let params = protocol::GetUploadUrlParams {
|
||||||
|
filekey: filekey.clone(),
|
||||||
|
media_type,
|
||||||
|
to_user_id: user_id.to_string(),
|
||||||
|
rawsize: data.len(),
|
||||||
|
rawfilemd5: raw_md5,
|
||||||
|
filesize: ciphertext.len(),
|
||||||
|
no_need_thumb: true,
|
||||||
|
aeskey: crypto::encode_aes_key_hex(&aes_key),
|
||||||
|
};
|
||||||
|
|
||||||
|
let upload_resp = self.client.get_upload_url(base_url, token, ¶ms).await?;
|
||||||
|
let upload_param = upload_resp.upload_param.ok_or_else(|| {
|
||||||
|
WeChatBotError::Media("getuploadurl did not return upload_param".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let upload_url =
|
||||||
|
protocol::build_cdn_upload_url(protocol::CDN_BASE_URL, &upload_param, &filekey);
|
||||||
|
|
||||||
|
let encrypted_file_size = ciphertext.len();
|
||||||
|
|
||||||
|
let encrypt_query_param = self.client.upload_to_cdn(&upload_url, &ciphertext).await?;
|
||||||
|
|
||||||
|
Ok(UploadResult {
|
||||||
|
media: CDNMedia {
|
||||||
|
encrypt_query_param,
|
||||||
|
aes_key: crypto::encode_aes_key_base64(&aes_key),
|
||||||
|
encrypt_type: Some(1),
|
||||||
|
full_url: None,
|
||||||
|
},
|
||||||
|
aes_key,
|
||||||
|
encrypted_file_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- internal text ---
|
||||||
|
|
||||||
|
async fn send_text(&self, user_id: &str, text: &str, context_token: &str) -> Result<()> {
|
||||||
|
let (base_url, token) = self.get_auth().await?;
|
||||||
|
for chunk in chunk_text(text, 4000) {
|
||||||
|
let msg = protocol::build_text_message(user_id, context_token, &chunk);
|
||||||
|
self.client.send_message(&base_url, &token, &msg).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remember_context(&self, wire: &WireMessage) {
|
||||||
|
let user_id = if wire.message_type == MessageType::User {
|
||||||
|
&wire.from_user_id
|
||||||
|
} else {
|
||||||
|
&wire.to_user_id
|
||||||
|
};
|
||||||
|
if !user_id.is_empty() && !wire.context_token.is_empty() {
|
||||||
|
self.context_tokens
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(user_id.clone(), wire.context_token.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_auth(&self) -> Result<(String, String)> {
|
||||||
|
let creds = self.credentials.read().await;
|
||||||
|
let creds = creds
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| WeChatBotError::Auth("not logged in".into()))?;
|
||||||
|
Ok((creds.base_url.clone(), creds.token.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_credentials(&self) -> Result<Option<Credentials>> {
|
||||||
|
let path = self.cred_path.clone().unwrap_or_else(default_cred_path);
|
||||||
|
match tokio::fs::read_to_string(&path).await {
|
||||||
|
Ok(data) => Ok(Some(serde_json::from_str(&data)?)),
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn save_credentials(&self, creds: &Credentials) -> Result<()> {
|
||||||
|
let path = self.cred_path.clone().unwrap_or_else(default_cred_path);
|
||||||
|
let dir = std::path::Path::new(&path).parent().unwrap();
|
||||||
|
tokio::fs::create_dir_all(dir).await?;
|
||||||
|
let data = serde_json::to_string_pretty(creds)?;
|
||||||
|
tokio::fs::write(&path, format!("{}\n", data)).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn report_error(&self, err: &WeChatBotError) {
|
||||||
|
error!("{}", err);
|
||||||
|
if let Some(ref cb) = self.on_error {
|
||||||
|
cb(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Content to send via reply_media / send_media.
|
||||||
|
pub enum SendContent {
|
||||||
|
Text(String),
|
||||||
|
Image {
|
||||||
|
data: Vec<u8>,
|
||||||
|
caption: Option<String>,
|
||||||
|
},
|
||||||
|
Video {
|
||||||
|
data: Vec<u8>,
|
||||||
|
caption: Option<String>,
|
||||||
|
},
|
||||||
|
File {
|
||||||
|
data: Vec<u8>,
|
||||||
|
file_name: String,
|
||||||
|
caption: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cdn_media_json(media: &CDNMedia) -> serde_json::Value {
|
||||||
|
let mut v = json!({
|
||||||
|
"encrypt_query_param": media.encrypt_query_param,
|
||||||
|
"aes_key": media.aes_key,
|
||||||
|
});
|
||||||
|
if let Some(et) = media.encrypt_type {
|
||||||
|
v["encrypt_type"] = json!(et);
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
|
|
||||||
|
fn categorize_by_extension(filename: &str) -> &'static str {
|
||||||
|
let ext = Path::new(filename)
|
||||||
|
.extension()
|
||||||
|
.and_then(|e| e.to_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_lowercase();
|
||||||
|
match ext.as_str() {
|
||||||
|
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" | "svg" => "image",
|
||||||
|
"mp4" | "mov" | "webm" | "mkv" | "avi" => "video",
|
||||||
|
_ => "file",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chunk_text(text: &str, limit: usize) -> Vec<String> {
|
||||||
|
if text.len() <= limit {
|
||||||
|
return vec![text.to_string()];
|
||||||
|
}
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
let mut remaining = text;
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
if remaining.len() <= limit {
|
||||||
|
chunks.push(remaining.to_string());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let window = &remaining[..limit];
|
||||||
|
let cut = window
|
||||||
|
.rfind("\n\n")
|
||||||
|
.filter(|&i| i > limit * 3 / 10)
|
||||||
|
.map(|i| i + 2)
|
||||||
|
.or_else(|| {
|
||||||
|
window
|
||||||
|
.rfind('\n')
|
||||||
|
.filter(|&i| i > limit * 3 / 10)
|
||||||
|
.map(|i| i + 1)
|
||||||
|
})
|
||||||
|
.or_else(|| {
|
||||||
|
window
|
||||||
|
.rfind(' ')
|
||||||
|
.filter(|&i| i > limit * 3 / 10)
|
||||||
|
.map(|i| i + 1)
|
||||||
|
})
|
||||||
|
.unwrap_or(limit);
|
||||||
|
chunks.push(remaining[..cut].to_string());
|
||||||
|
remaining = &remaining[cut..];
|
||||||
|
}
|
||||||
|
if chunks.is_empty() {
|
||||||
|
vec![String::new()]
|
||||||
|
} else {
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_cred_path() -> String {
|
||||||
|
let home = dirs_next::home_dir().unwrap_or_else(|| ".".into());
|
||||||
|
home.join(".wechatbot")
|
||||||
|
.join("credentials.json")
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chrono_now() -> String {
|
||||||
|
// Simple ISO 8601 without chrono dependency
|
||||||
|
let dur = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap();
|
||||||
|
format!("{}Z", dur.as_secs())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chunk_text_short() {
|
||||||
|
let chunks = chunk_text("hello", 100);
|
||||||
|
assert_eq!(chunks, vec!["hello"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chunk_text_empty() {
|
||||||
|
let chunks = chunk_text("", 100);
|
||||||
|
assert_eq!(chunks, vec![""]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chunk_text_splits_on_paragraph() {
|
||||||
|
let text = "aaaa\n\nbbbb";
|
||||||
|
let chunks = chunk_text(text, 7);
|
||||||
|
assert_eq!(chunks, vec!["aaaa\n\n", "bbbb"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chunk_text_splits_on_newline() {
|
||||||
|
let text = "aaaa\nbbbb";
|
||||||
|
let chunks = chunk_text(text, 7);
|
||||||
|
assert_eq!(chunks, vec!["aaaa\n", "bbbb"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chunk_text_exact_limit() {
|
||||||
|
let text = "abcdef";
|
||||||
|
let chunks = chunk_text(text, 6);
|
||||||
|
assert_eq!(chunks, vec!["abcdef"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_cred_path_not_empty() {
|
||||||
|
let path = default_cred_path();
|
||||||
|
assert!(!path.is_empty());
|
||||||
|
assert!(path.contains(".wechatbot"));
|
||||||
|
assert!(path.contains("credentials.json"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn categorize_image_extensions() {
|
||||||
|
assert_eq!(categorize_by_extension("photo.png"), "image");
|
||||||
|
assert_eq!(categorize_by_extension("photo.JPG"), "image");
|
||||||
|
assert_eq!(categorize_by_extension("anim.gif"), "image");
|
||||||
|
assert_eq!(categorize_by_extension("pic.webp"), "image");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn categorize_video_extensions() {
|
||||||
|
assert_eq!(categorize_by_extension("clip.mp4"), "video");
|
||||||
|
assert_eq!(categorize_by_extension("clip.MOV"), "video");
|
||||||
|
assert_eq!(categorize_by_extension("movie.webm"), "video");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn categorize_file_extensions() {
|
||||||
|
assert_eq!(categorize_by_extension("report.pdf"), "file");
|
||||||
|
assert_eq!(categorize_by_extension("data.csv"), "file");
|
||||||
|
assert_eq!(categorize_by_extension("noext"), "file");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cdn_media_json_with_encrypt_type() {
|
||||||
|
let media = CDNMedia {
|
||||||
|
encrypt_query_param: "param=1".to_string(),
|
||||||
|
aes_key: "key123".to_string(),
|
||||||
|
encrypt_type: Some(1),
|
||||||
|
full_url: None,
|
||||||
|
};
|
||||||
|
let j = cdn_media_json(&media);
|
||||||
|
assert_eq!(j["encrypt_query_param"], "param=1");
|
||||||
|
assert_eq!(j["aes_key"], "key123");
|
||||||
|
assert_eq!(j["encrypt_type"], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cdn_media_json_without_encrypt_type() {
|
||||||
|
let media = CDNMedia {
|
||||||
|
encrypt_query_param: "p".to_string(),
|
||||||
|
aes_key: "k".to_string(),
|
||||||
|
encrypt_type: None,
|
||||||
|
full_url: None,
|
||||||
|
};
|
||||||
|
let j = cdn_media_json(&media);
|
||||||
|
assert!(j.get("encrypt_type").is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
138
vendor/wechatbot/src/cdn.rs
vendored
Normal file
138
vendor/wechatbot/src/cdn.rs
vendored
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
//! Low-level CDN client for direct media download.
|
||||||
|
//!
|
||||||
|
//! [`CdnClient`] is a primitive layer that can be used independently of
|
||||||
|
//! [`WeChatBot`](crate::WeChatBot), e.g. when you drive `get_updates` yourself
|
||||||
|
//! via [`ILinkClient`](crate::protocol::ILinkClient) and only need decryption
|
||||||
|
//! for a specific attachment.
|
||||||
|
//!
|
||||||
|
//! Modeled after [`teloxide_core::Bot`]: wraps a [`reqwest::Client`] so
|
||||||
|
//! connection pool / TLS session / DNS cache are reused across calls, and is
|
||||||
|
//! cheap to [`Clone`].
|
||||||
|
|
||||||
|
use reqwest::Client;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::crypto;
|
||||||
|
use crate::error::{Result, WeChatBotError};
|
||||||
|
use crate::protocol::CDN_BASE_URL;
|
||||||
|
use crate::types::CDNMedia;
|
||||||
|
|
||||||
|
/// HTTP client for WeChat CDN media endpoints.
|
||||||
|
///
|
||||||
|
/// Cheap to [`Clone`] — shares the underlying [`reqwest::Client`], which uses
|
||||||
|
/// an `Arc` internally.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use wechatbot::{CdnClient, CDNMedia};
|
||||||
|
///
|
||||||
|
/// # async fn demo(media: CDNMedia) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
/// let cdn = CdnClient::new();
|
||||||
|
/// let bytes = cdn.download(&media, None).await?;
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CdnClient {
|
||||||
|
http: Client,
|
||||||
|
base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CdnClient {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CdnClient {
|
||||||
|
/// Create a [`CdnClient`] with a fresh internal [`reqwest::Client`].
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_client(Client::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a [`CdnClient`] that reuses an existing [`reqwest::Client`].
|
||||||
|
///
|
||||||
|
/// Useful when the caller already maintains a shared HTTP client with
|
||||||
|
/// custom proxy / TLS / timeout configuration.
|
||||||
|
pub fn with_client(http: Client) -> Self {
|
||||||
|
Self {
|
||||||
|
http,
|
||||||
|
base_url: CDN_BASE_URL.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Override the CDN base URL (defaults to [`CDN_BASE_URL`]).
|
||||||
|
///
|
||||||
|
/// Primarily intended for tests and regional endpoints.
|
||||||
|
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
||||||
|
self.base_url = base_url.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download and AES-decrypt a CDN media object.
|
||||||
|
///
|
||||||
|
/// `aes_key_override` is used when the decryption key is attached to the
|
||||||
|
/// message metadata (e.g. [`ImageContent::aes_key`](crate::ImageContent::aes_key))
|
||||||
|
/// rather than embedded in the media's own `aes_key` field.
|
||||||
|
pub async fn download(
|
||||||
|
&self,
|
||||||
|
media: &CDNMedia,
|
||||||
|
aes_key_override: Option<&str>,
|
||||||
|
) -> Result<Vec<u8>> {
|
||||||
|
let download_url = format!(
|
||||||
|
"{}/download?encrypted_query_param={}",
|
||||||
|
self.base_url,
|
||||||
|
urlencoding::encode(&media.encrypt_query_param)
|
||||||
|
);
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.http
|
||||||
|
.get(&download_url)
|
||||||
|
.timeout(Duration::from_secs(60))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(WeChatBotError::Media(format!(
|
||||||
|
"CDN download failed: HTTP {}",
|
||||||
|
resp.status()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ciphertext = resp.bytes().await?.to_vec();
|
||||||
|
|
||||||
|
let key_source = aes_key_override.unwrap_or(&media.aes_key);
|
||||||
|
if key_source.is_empty() {
|
||||||
|
return Err(WeChatBotError::Media("no AES key available".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let aes_key = crypto::decode_aes_key(key_source)?;
|
||||||
|
crypto::decrypt_aes_ecb(&ciphertext, &aes_key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_and_new_equivalent() {
|
||||||
|
let a = CdnClient::default();
|
||||||
|
let b = CdnClient::new();
|
||||||
|
assert_eq!(a.base_url, b.base_url);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn with_base_url_overrides() {
|
||||||
|
let c = CdnClient::new().with_base_url("https://example.test/cdn");
|
||||||
|
assert_eq!(c.base_url, "https://example.test/cdn");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn clone_is_cheap_and_preserves_config() {
|
||||||
|
let c = CdnClient::new().with_base_url("https://x.y/z");
|
||||||
|
let cloned = c.clone();
|
||||||
|
assert_eq!(c.base_url, cloned.base_url);
|
||||||
|
}
|
||||||
|
}
|
||||||
148
vendor/wechatbot/src/crypto.rs
vendored
Normal file
148
vendor/wechatbot/src/crypto.rs
vendored
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
//! AES-128-ECB encryption for WeChat CDN media files.
|
||||||
|
|
||||||
|
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit};
|
||||||
|
use aes::Aes128;
|
||||||
|
use base64::Engine;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::error::{Result, WeChatBotError};
|
||||||
|
|
||||||
|
/// Encrypt plaintext with AES-128-ECB and PKCS7 padding.
|
||||||
|
pub fn encrypt_aes_ecb(plaintext: &[u8], key: &[u8; 16]) -> Vec<u8> {
|
||||||
|
let cipher = Aes128::new(key.into());
|
||||||
|
let padded = pkcs7_pad(plaintext, 16);
|
||||||
|
let mut ciphertext = padded;
|
||||||
|
for chunk in ciphertext.chunks_exact_mut(16) {
|
||||||
|
cipher.encrypt_block(chunk.into());
|
||||||
|
}
|
||||||
|
ciphertext
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt AES-128-ECB ciphertext and remove PKCS7 padding.
|
||||||
|
pub fn decrypt_aes_ecb(ciphertext: &[u8], key: &[u8; 16]) -> Result<Vec<u8>> {
|
||||||
|
if ciphertext.len() % 16 != 0 {
|
||||||
|
return Err(WeChatBotError::Media(
|
||||||
|
"ciphertext length not a multiple of 16".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let cipher = Aes128::new(key.into());
|
||||||
|
let mut plaintext = ciphertext.to_vec();
|
||||||
|
for chunk in plaintext.chunks_exact_mut(16) {
|
||||||
|
cipher.decrypt_block(chunk.into());
|
||||||
|
}
|
||||||
|
pkcs7_unpad(&plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a random 16-byte AES key.
|
||||||
|
pub fn generate_aes_key() -> [u8; 16] {
|
||||||
|
let mut key = [0u8; 16];
|
||||||
|
rand::rng().fill_bytes(&mut key);
|
||||||
|
key
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate encrypted size with PKCS7 padding.
|
||||||
|
pub fn encrypted_size(raw_size: usize) -> usize {
|
||||||
|
((raw_size + 1 + 15) / 16) * 16
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode an aes_key from the protocol (handles all three formats).
|
||||||
|
pub fn decode_aes_key(encoded: &str) -> Result<[u8; 16]> {
|
||||||
|
// Direct hex (32 chars)
|
||||||
|
if encoded.len() == 32 && encoded.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||||
|
let bytes =
|
||||||
|
hex::decode(encoded).map_err(|e| WeChatBotError::Media(format!("hex decode: {e}")))?;
|
||||||
|
return bytes_to_key(&bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64 decode
|
||||||
|
let decoded = base64::engine::general_purpose::STANDARD
|
||||||
|
.decode(encoded)
|
||||||
|
.or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(encoded))
|
||||||
|
.map_err(|e| WeChatBotError::Media(format!("base64 decode: {e}")))?;
|
||||||
|
|
||||||
|
if decoded.len() == 16 {
|
||||||
|
return bytes_to_key(&decoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded.len() == 32 {
|
||||||
|
let hex_str = std::str::from_utf8(&decoded)
|
||||||
|
.map_err(|_| WeChatBotError::Media("decoded key is not UTF-8".into()))?;
|
||||||
|
if hex_str.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||||
|
let bytes = hex::decode(hex_str)
|
||||||
|
.map_err(|e| WeChatBotError::Media(format!("hex decode: {e}")))?;
|
||||||
|
return bytes_to_key(&bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(WeChatBotError::Media(format!(
|
||||||
|
"unexpected decoded key length: {}",
|
||||||
|
decoded.len()
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode an AES key as hex (for getuploadurl).
|
||||||
|
pub fn encode_aes_key_hex(key: &[u8; 16]) -> String {
|
||||||
|
hex::encode(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode an AES key as base64(hex) (for CDNMedia.aes_key).
|
||||||
|
pub fn encode_aes_key_base64(key: &[u8; 16]) -> String {
|
||||||
|
base64::engine::general_purpose::STANDARD.encode(hex::encode(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bytes_to_key(bytes: &[u8]) -> Result<[u8; 16]> {
|
||||||
|
bytes
|
||||||
|
.try_into()
|
||||||
|
.map_err(|_| WeChatBotError::Media(format!("key length {} != 16", bytes.len())))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pkcs7_pad(data: &[u8], block_size: usize) -> Vec<u8> {
|
||||||
|
let padding = block_size - (data.len() % block_size);
|
||||||
|
let mut result = data.to_vec();
|
||||||
|
result.extend(std::iter::repeat(padding as u8).take(padding));
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pkcs7_unpad(data: &[u8]) -> Result<Vec<u8>> {
|
||||||
|
if data.is_empty() {
|
||||||
|
return Err(WeChatBotError::Media("empty data".into()));
|
||||||
|
}
|
||||||
|
let padding = *data.last().unwrap() as usize;
|
||||||
|
if padding == 0 || padding > data.len() || padding > 16 {
|
||||||
|
return Err(WeChatBotError::Media("invalid PKCS7 padding".into()));
|
||||||
|
}
|
||||||
|
Ok(data[..data.len() - padding].to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip() {
|
||||||
|
let key = generate_aes_key();
|
||||||
|
let plaintext = b"Hello, WeChat!";
|
||||||
|
let ct = encrypt_aes_ecb(plaintext, &key);
|
||||||
|
let pt = decrypt_aes_ecb(&ct, &key).unwrap();
|
||||||
|
assert_eq!(pt, plaintext);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encrypted_size_calc() {
|
||||||
|
assert_eq!(encrypted_size(14), 16);
|
||||||
|
assert_eq!(encrypted_size(16), 32);
|
||||||
|
assert_eq!(encrypted_size(100), 112);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decode_direct_hex() {
|
||||||
|
let key = decode_aes_key("00112233445566778899aabbccddeeff").unwrap();
|
||||||
|
assert_eq!(key.len(), 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decode_base64_raw() {
|
||||||
|
let key = decode_aes_key("ABEiM0RVZneImaq7zN3u/w==").unwrap();
|
||||||
|
assert_eq!(key.len(), 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
93
vendor/wechatbot/src/error.rs
vendored
Normal file
93
vendor/wechatbot/src/error.rs
vendored
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// Errors that can occur in the SDK.
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum WeChatBotError {
|
||||||
|
#[error("API error: {message} (http={http_status}, errcode={errcode})")]
|
||||||
|
Api {
|
||||||
|
message: String,
|
||||||
|
http_status: u16,
|
||||||
|
errcode: i32,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Auth error: {0}")]
|
||||||
|
Auth(String),
|
||||||
|
|
||||||
|
#[error("No context_token for user {0}")]
|
||||||
|
NoContext(String),
|
||||||
|
|
||||||
|
#[error("Media error: {0}")]
|
||||||
|
Media(String),
|
||||||
|
|
||||||
|
#[error("Transport error: {0}")]
|
||||||
|
Transport(#[from] reqwest::Error),
|
||||||
|
|
||||||
|
#[error("JSON error: {0}")]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("IO error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WeChatBotError {
|
||||||
|
/// Returns true if this is a session-expired error (errcode -14).
|
||||||
|
pub fn is_session_expired(&self) -> bool {
|
||||||
|
matches!(self, WeChatBotError::Api { errcode: -14, .. })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, WeChatBotError>;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn session_expired_true() {
|
||||||
|
let err = WeChatBotError::Api {
|
||||||
|
message: "session expired".to_string(),
|
||||||
|
http_status: 200,
|
||||||
|
errcode: -14,
|
||||||
|
};
|
||||||
|
assert!(err.is_session_expired());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn session_expired_false() {
|
||||||
|
let err = WeChatBotError::Api {
|
||||||
|
message: "other error".to_string(),
|
||||||
|
http_status: 400,
|
||||||
|
errcode: -1,
|
||||||
|
};
|
||||||
|
assert!(!err.is_session_expired());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_api_not_session_expired() {
|
||||||
|
let err = WeChatBotError::Auth("test".to_string());
|
||||||
|
assert!(!err.is_session_expired());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_display() {
|
||||||
|
let err = WeChatBotError::Api {
|
||||||
|
message: "bad request".to_string(),
|
||||||
|
http_status: 400,
|
||||||
|
errcode: -1,
|
||||||
|
};
|
||||||
|
let msg = format!("{}", err);
|
||||||
|
assert!(msg.contains("bad request"));
|
||||||
|
assert!(msg.contains("400"));
|
||||||
|
assert!(msg.contains("-1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_context_error() {
|
||||||
|
let err = WeChatBotError::NoContext("user123".to_string());
|
||||||
|
let msg = format!("{}", err);
|
||||||
|
assert!(msg.contains("user123"));
|
||||||
|
}
|
||||||
|
}
|
||||||
38
vendor/wechatbot/src/lib.rs
vendored
Normal file
38
vendor/wechatbot/src/lib.rs
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
//! # wechatbot
|
||||||
|
//!
|
||||||
|
//! WeChat iLink Bot SDK for Rust.
|
||||||
|
//!
|
||||||
|
//! ## Quick Start
|
||||||
|
//!
|
||||||
|
//! ```rust,no_run
|
||||||
|
//! use wechatbot::{WeChatBot, BotOptions};
|
||||||
|
//!
|
||||||
|
//! #[tokio::main]
|
||||||
|
//! async fn main() {
|
||||||
|
//! let bot = WeChatBot::new(BotOptions::default());
|
||||||
|
//! bot.login(false).await.unwrap();
|
||||||
|
//!
|
||||||
|
//! bot.on_message(Box::new(|msg| {
|
||||||
|
//! println!("{}: {}", msg.user_id, msg.text);
|
||||||
|
//! })).await;
|
||||||
|
//!
|
||||||
|
//! bot.run().await.unwrap();
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
pub mod bot;
|
||||||
|
pub mod cdn;
|
||||||
|
pub mod crypto;
|
||||||
|
pub mod error;
|
||||||
|
pub mod protocol;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
pub use bot::{BotOptions, MessageHandler, SendContent, WeChatBot};
|
||||||
|
pub use cdn::CdnClient;
|
||||||
|
pub use crypto::{
|
||||||
|
decode_aes_key, decrypt_aes_ecb, decrypt_aes_ecb as download_decrypt, encode_aes_key_base64,
|
||||||
|
encode_aes_key_hex, encrypt_aes_ecb, generate_aes_key,
|
||||||
|
};
|
||||||
|
pub use error::{Result, WeChatBotError};
|
||||||
|
pub use protocol::{build_cdn_upload_url, GetUploadUrlParams, GetUploadUrlResponse};
|
||||||
|
pub use types::*;
|
||||||
407
vendor/wechatbot/src/protocol.rs
vendored
Normal file
407
vendor/wechatbot/src/protocol.rs
vendored
Normal file
@ -0,0 +1,407 @@
|
|||||||
|
//! Raw iLink Bot API HTTP calls.
|
||||||
|
|
||||||
|
use base64::Engine;
|
||||||
|
use rand::Rng;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::time::Duration;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::error::{Result, WeChatBotError};
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
use crate::types::*;
|
||||||
|
|
||||||
|
pub const DEFAULT_BASE_URL: &str = "https://ilinkai.weixin.qq.com";
|
||||||
|
pub const CDN_BASE_URL: &str = "https://novac2c.cdn.weixin.qq.com/c2c";
|
||||||
|
pub const CHANNEL_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
|
||||||
|
/// iLink-App-Id header value.
|
||||||
|
const ILINK_APP_ID: &str = "bot";
|
||||||
|
|
||||||
|
/// Build iLink-App-ClientVersion from the crate version (0x00MMNNPP).
|
||||||
|
fn build_client_version() -> String {
|
||||||
|
let version = env!("CARGO_PKG_VERSION");
|
||||||
|
let parts: Vec<u32> = version.split('.').filter_map(|p| p.parse().ok()).collect();
|
||||||
|
let major = parts.first().copied().unwrap_or(0) & 0xff;
|
||||||
|
let minor = parts.get(1).copied().unwrap_or(0) & 0xff;
|
||||||
|
let patch = parts.get(2).copied().unwrap_or(0) & 0xff;
|
||||||
|
let num = (major << 16) | (minor << 8) | patch;
|
||||||
|
num.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate the X-WECHAT-UIN header value.
|
||||||
|
pub fn random_wechat_uin() -> String {
|
||||||
|
let mut buf = [0u8; 4];
|
||||||
|
rand::rng().fill_bytes(&mut buf);
|
||||||
|
let val = u32::from_be_bytes(buf);
|
||||||
|
base64::engine::general_purpose::STANDARD.encode(val.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// QR code response.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct QrCodeResponse {
|
||||||
|
pub qrcode: String,
|
||||||
|
pub qrcode_img_content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// QR status response.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct QrStatusResponse {
|
||||||
|
pub status: String,
|
||||||
|
pub bot_token: Option<String>,
|
||||||
|
pub ilink_bot_id: Option<String>,
|
||||||
|
pub ilink_user_id: Option<String>,
|
||||||
|
pub baseurl: Option<String>,
|
||||||
|
/// New host to redirect polling to when status is "scaned_but_redirect".
|
||||||
|
pub redirect_host: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get updates response.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct GetUpdatesResponse {
|
||||||
|
#[serde(default)]
|
||||||
|
pub ret: i32,
|
||||||
|
#[serde(default)]
|
||||||
|
pub msgs: Vec<WireMessage>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub get_updates_buf: String,
|
||||||
|
pub errcode: Option<i32>,
|
||||||
|
pub errmsg: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get config response.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct GetConfigResponse {
|
||||||
|
pub typing_ticket: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Low-level iLink API client.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ILinkClient {
|
||||||
|
http: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ILinkClient {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
http: Client::builder()
|
||||||
|
.timeout(Duration::from_secs(45))
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_qr_code(&self, base_url: &str) -> Result<QrCodeResponse> {
|
||||||
|
let url = format!("{}/ilink/bot/get_bot_qrcode?bot_type=3", base_url);
|
||||||
|
let resp = self
|
||||||
|
.http
|
||||||
|
.get(&url)
|
||||||
|
.header("iLink-App-Id", ILINK_APP_ID)
|
||||||
|
.header("iLink-App-ClientVersion", build_client_version())
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
Ok(resp.json().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn poll_qr_status(&self, base_url: &str, qrcode: &str) -> Result<QrStatusResponse> {
|
||||||
|
let url = format!(
|
||||||
|
"{}/ilink/bot/get_qrcode_status?qrcode={}",
|
||||||
|
base_url,
|
||||||
|
urlencoding::encode(qrcode)
|
||||||
|
);
|
||||||
|
let resp = self
|
||||||
|
.http
|
||||||
|
.get(&url)
|
||||||
|
.header("iLink-App-Id", ILINK_APP_ID)
|
||||||
|
.header("iLink-App-ClientVersion", build_client_version())
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
Ok(resp.json().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_updates(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
token: &str,
|
||||||
|
cursor: &str,
|
||||||
|
) -> Result<GetUpdatesResponse> {
|
||||||
|
let body = json!({
|
||||||
|
"get_updates_buf": cursor,
|
||||||
|
"base_info": { "channel_version": CHANNEL_VERSION }
|
||||||
|
});
|
||||||
|
let resp = self
|
||||||
|
.api_post(base_url, "/ilink/bot/getupdates", token, &body, 45)
|
||||||
|
.await?;
|
||||||
|
let result: GetUpdatesResponse = serde_json::from_value(resp)?;
|
||||||
|
if result.ret != 0 || result.errcode.is_some_and(|c| c != 0) {
|
||||||
|
let code = result.errcode.unwrap_or(result.ret);
|
||||||
|
let msg = result
|
||||||
|
.errmsg
|
||||||
|
.unwrap_or_else(|| format!("ret={}", result.ret));
|
||||||
|
return Err(WeChatBotError::Api {
|
||||||
|
message: msg,
|
||||||
|
http_status: 200,
|
||||||
|
errcode: code,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_message(&self, base_url: &str, token: &str, msg: &Value) -> Result<()> {
|
||||||
|
let body = json!({
|
||||||
|
"msg": msg,
|
||||||
|
"base_info": { "channel_version": CHANNEL_VERSION }
|
||||||
|
});
|
||||||
|
self.api_post(base_url, "/ilink/bot/sendmessage", token, &body, 15)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_config(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
token: &str,
|
||||||
|
user_id: &str,
|
||||||
|
context_token: &str,
|
||||||
|
) -> Result<GetConfigResponse> {
|
||||||
|
let body = json!({
|
||||||
|
"ilink_user_id": user_id,
|
||||||
|
"context_token": context_token,
|
||||||
|
"base_info": { "channel_version": CHANNEL_VERSION }
|
||||||
|
});
|
||||||
|
let resp = self
|
||||||
|
.api_post(base_url, "/ilink/bot/getconfig", token, &body, 15)
|
||||||
|
.await?;
|
||||||
|
Ok(serde_json::from_value(resp)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_typing(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
token: &str,
|
||||||
|
user_id: &str,
|
||||||
|
ticket: &str,
|
||||||
|
status: i32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let body = json!({
|
||||||
|
"ilink_user_id": user_id,
|
||||||
|
"typing_ticket": ticket,
|
||||||
|
"status": status,
|
||||||
|
"base_info": { "channel_version": CHANNEL_VERSION }
|
||||||
|
});
|
||||||
|
self.api_post(base_url, "/ilink/bot/sendtyping", token, &body, 15)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn api_post(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
endpoint: &str,
|
||||||
|
token: &str,
|
||||||
|
body: &Value,
|
||||||
|
timeout_secs: u64,
|
||||||
|
) -> Result<Value> {
|
||||||
|
let url = format!("{}{}", base_url, endpoint);
|
||||||
|
let resp = self
|
||||||
|
.http
|
||||||
|
.post(&url)
|
||||||
|
.timeout(Duration::from_secs(timeout_secs))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("AuthorizationType", "ilink_bot_token")
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.header("X-WECHAT-UIN", random_wechat_uin())
|
||||||
|
.header("iLink-App-Id", ILINK_APP_ID)
|
||||||
|
.header("iLink-App-ClientVersion", build_client_version())
|
||||||
|
.json(body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = resp.status().as_u16();
|
||||||
|
let text = resp.text().await?;
|
||||||
|
let value: Value = serde_json::from_str(&text).unwrap_or(json!({}));
|
||||||
|
|
||||||
|
if status >= 400 {
|
||||||
|
return Err(WeChatBotError::Api {
|
||||||
|
message: value["errmsg"]
|
||||||
|
.as_str()
|
||||||
|
.or_else(|| value["message"].as_str())
|
||||||
|
.unwrap_or(&text)
|
||||||
|
.to_string(),
|
||||||
|
http_status: status,
|
||||||
|
errcode: value["errcode"].as_i64().unwrap_or(0) as i32,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(errcode) = value["errcode"].as_i64() {
|
||||||
|
if errcode != 0 {
|
||||||
|
return Err(WeChatBotError::Api {
|
||||||
|
message: value["errmsg"]
|
||||||
|
.as_str()
|
||||||
|
.or_else(|| value["message"].as_str())
|
||||||
|
.unwrap_or(&text)
|
||||||
|
.to_string(),
|
||||||
|
http_status: status,
|
||||||
|
errcode: errcode as i32,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a media message payload.
|
||||||
|
pub fn build_media_message(user_id: &str, context_token: &str, item_list: Vec<Value>) -> Value {
|
||||||
|
json!({
|
||||||
|
"from_user_id": "",
|
||||||
|
"to_user_id": user_id,
|
||||||
|
"client_id": Uuid::new_v4().to_string(),
|
||||||
|
"message_type": 2,
|
||||||
|
"message_state": 2,
|
||||||
|
"context_token": context_token,
|
||||||
|
"item_list": item_list
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GetUploadUrl request parameters.
|
||||||
|
pub struct GetUploadUrlParams {
|
||||||
|
pub filekey: String,
|
||||||
|
pub media_type: i32,
|
||||||
|
pub to_user_id: String,
|
||||||
|
pub rawsize: usize,
|
||||||
|
pub rawfilemd5: String,
|
||||||
|
pub filesize: usize,
|
||||||
|
pub no_need_thumb: bool,
|
||||||
|
pub aeskey: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GetUploadUrl response.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct GetUploadUrlResponse {
|
||||||
|
pub upload_param: Option<String>,
|
||||||
|
pub thumb_upload_param: Option<String>,
|
||||||
|
pub upload_full_url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ILinkClient {
|
||||||
|
/// Get a pre-signed CDN upload URL.
|
||||||
|
pub async fn get_upload_url(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
token: &str,
|
||||||
|
params: &GetUploadUrlParams,
|
||||||
|
) -> Result<GetUploadUrlResponse> {
|
||||||
|
let body = json!({
|
||||||
|
"filekey": params.filekey,
|
||||||
|
"media_type": params.media_type,
|
||||||
|
"to_user_id": params.to_user_id,
|
||||||
|
"rawsize": params.rawsize,
|
||||||
|
"rawfilemd5": params.rawfilemd5,
|
||||||
|
"filesize": params.filesize,
|
||||||
|
"no_need_thumb": params.no_need_thumb,
|
||||||
|
"aeskey": params.aeskey,
|
||||||
|
"base_info": { "channel_version": CHANNEL_VERSION }
|
||||||
|
});
|
||||||
|
let resp = self
|
||||||
|
.api_post(base_url, "/ilink/bot/getuploadurl", token, &body, 15)
|
||||||
|
.await?;
|
||||||
|
Ok(serde_json::from_value(resp)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upload encrypted bytes to CDN with retry (up to 3 attempts).
|
||||||
|
/// Returns the download encrypted_query_param from the x-encrypted-param header.
|
||||||
|
pub async fn upload_to_cdn(&self, cdn_url: &str, ciphertext: &[u8]) -> Result<String> {
|
||||||
|
const MAX_RETRIES: u32 = 3;
|
||||||
|
let mut last_err = None;
|
||||||
|
|
||||||
|
for attempt in 1..=MAX_RETRIES {
|
||||||
|
match self
|
||||||
|
.http
|
||||||
|
.post(cdn_url)
|
||||||
|
.header("Content-Type", "application/octet-stream")
|
||||||
|
.body(ciphertext.to_vec())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
let status = resp.status().as_u16();
|
||||||
|
if status >= 400 && status < 500 {
|
||||||
|
let err_msg = resp
|
||||||
|
.headers()
|
||||||
|
.get("x-error-message")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("client error")
|
||||||
|
.to_string();
|
||||||
|
return Err(WeChatBotError::Media(format!(
|
||||||
|
"CDN upload client error {}: {}",
|
||||||
|
status, err_msg
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if status != 200 {
|
||||||
|
let err_msg = resp
|
||||||
|
.headers()
|
||||||
|
.get("x-error-message")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("server error")
|
||||||
|
.to_string();
|
||||||
|
last_err = Some(WeChatBotError::Media(format!(
|
||||||
|
"CDN upload server error {}: {}",
|
||||||
|
status, err_msg
|
||||||
|
)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match resp
|
||||||
|
.headers()
|
||||||
|
.get("x-encrypted-param")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
{
|
||||||
|
Some(param) => return Ok(param.to_string()),
|
||||||
|
None => {
|
||||||
|
last_err = Some(WeChatBotError::Media(
|
||||||
|
"CDN upload response missing x-encrypted-param header".into(),
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
last_err = Some(WeChatBotError::Other(format!(
|
||||||
|
"CDN upload network error: {}",
|
||||||
|
e
|
||||||
|
)));
|
||||||
|
if attempt < MAX_RETRIES {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(last_err.unwrap_or_else(|| {
|
||||||
|
WeChatBotError::Media(format!("CDN upload failed after {} attempts", MAX_RETRIES))
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a CDN upload URL from params.
|
||||||
|
pub fn build_cdn_upload_url(cdn_base_url: &str, upload_param: &str, filekey: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"{}/upload?encrypted_query_param={}&filekey={}",
|
||||||
|
cdn_base_url,
|
||||||
|
urlencoding::encode(upload_param),
|
||||||
|
urlencoding::encode(filekey)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a text message payload.
|
||||||
|
pub fn build_text_message(user_id: &str, context_token: &str, text: &str) -> Value {
|
||||||
|
json!({
|
||||||
|
"from_user_id": "",
|
||||||
|
"to_user_id": user_id,
|
||||||
|
"client_id": Uuid::new_v4().to_string(),
|
||||||
|
"message_type": 2,
|
||||||
|
"message_state": 2,
|
||||||
|
"context_token": context_token,
|
||||||
|
"item_list": [{ "type": 1, "text_item": { "text": text } }]
|
||||||
|
})
|
||||||
|
}
|
||||||
858
vendor/wechatbot/src/types.rs
vendored
Normal file
858
vendor/wechatbot/src/types.rs
vendored
Normal file
@ -0,0 +1,858 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
/// Message sender type.
|
||||||
|
/// Uses serde_repr for integer (de)serialization: JSON `1` ↔ `MessageType::User`.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum MessageType {
|
||||||
|
User = 1,
|
||||||
|
Bot = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message delivery state.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum MessageState {
|
||||||
|
New = 0,
|
||||||
|
Generating = 1,
|
||||||
|
Finish = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Content type of a message item.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum MessageItemType {
|
||||||
|
Text = 1,
|
||||||
|
Image = 2,
|
||||||
|
Voice = 3,
|
||||||
|
File = 4,
|
||||||
|
Video = 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Media type for upload requests.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum MediaType {
|
||||||
|
Image = 1,
|
||||||
|
Video = 2,
|
||||||
|
File = 3,
|
||||||
|
Voice = 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CDN media reference.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CDNMedia {
|
||||||
|
pub encrypt_query_param: String,
|
||||||
|
pub aes_key: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub encrypt_type: Option<i32>,
|
||||||
|
/// Complete download URL returned by server; when set, use directly.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub full_url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Text content.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TextItem {
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Image content.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ImageItem {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thumb_media: Option<CDNMedia>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub aeskey: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub url: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mid_size: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thumb_width: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thumb_height: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Voice content.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct VoiceItem {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub encode_type: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub text: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub playtime: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// File content.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct FileItem {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file_name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub md5: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub len: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Video content.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct VideoItem {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub video_size: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub play_length: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thumb_media: Option<CDNMedia>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Referenced/quoted message.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RefMessage {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub title: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub message_item: Option<Box<WireMessageItem>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single content item in a message.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct WireMessageItem {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub item_type: MessageItemType,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub text_item: Option<TextItem>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_item: Option<ImageItem>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub voice_item: Option<VoiceItem>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file_item: Option<FileItem>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub video_item: Option<VideoItem>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ref_msg: Option<RefMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw wire message from the iLink API.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct WireMessage {
|
||||||
|
pub from_user_id: String,
|
||||||
|
pub to_user_id: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub create_time_ms: i64,
|
||||||
|
pub message_type: MessageType,
|
||||||
|
pub message_state: MessageState,
|
||||||
|
pub context_token: String,
|
||||||
|
pub item_list: Vec<WireMessageItem>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parsed incoming message — user-friendly.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IncomingMessage {
|
||||||
|
pub user_id: String,
|
||||||
|
pub text: String,
|
||||||
|
pub content_type: ContentType,
|
||||||
|
pub timestamp: SystemTime,
|
||||||
|
pub images: Vec<ImageContent>,
|
||||||
|
pub voices: Vec<VoiceContent>,
|
||||||
|
pub files: Vec<FileContent>,
|
||||||
|
pub videos: Vec<VideoContent>,
|
||||||
|
pub quoted: Option<QuotedMessage>,
|
||||||
|
pub raw: WireMessage,
|
||||||
|
pub(crate) context_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IncomingMessage {
|
||||||
|
/// Opaque reply token bound to this message.
|
||||||
|
///
|
||||||
|
/// Pass it back via [`WeChatBot::reply`](crate::WeChatBot::reply) (which
|
||||||
|
/// does this automatically) or when constructing a message payload with
|
||||||
|
/// [`protocol::build_text_message`](crate::protocol::build_text_message) /
|
||||||
|
/// [`protocol::build_media_message`](crate::protocol::build_media_message)
|
||||||
|
/// for use with [`ILinkClient::send_message`](crate::protocol::ILinkClient::send_message).
|
||||||
|
pub fn context_token(&self) -> &str {
|
||||||
|
&self.context_token
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a raw [`WireMessage`] into a user-friendly [`IncomingMessage`].
|
||||||
|
///
|
||||||
|
/// Returns `None` if the wire message is not a user-originated message
|
||||||
|
/// (e.g. it was sent by the bot itself).
|
||||||
|
///
|
||||||
|
/// This is the stable entry point for consumers who drive
|
||||||
|
/// [`ILinkClient::get_updates`](crate::protocol::ILinkClient::get_updates)
|
||||||
|
/// themselves instead of using [`WeChatBot`](crate::WeChatBot)'s
|
||||||
|
/// dispatcher.
|
||||||
|
pub fn from_wire(wire: &WireMessage) -> Option<Self> {
|
||||||
|
if wire.message_type != MessageType::User {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut msg = IncomingMessage {
|
||||||
|
user_id: wire.from_user_id.clone(),
|
||||||
|
text: extract_text(&wire.item_list),
|
||||||
|
content_type: detect_type(&wire.item_list),
|
||||||
|
timestamp: std::time::UNIX_EPOCH
|
||||||
|
+ std::time::Duration::from_millis(wire.create_time_ms as u64),
|
||||||
|
images: Vec::new(),
|
||||||
|
voices: Vec::new(),
|
||||||
|
files: Vec::new(),
|
||||||
|
videos: Vec::new(),
|
||||||
|
quoted: None,
|
||||||
|
raw: wire.clone(),
|
||||||
|
context_token: wire.context_token.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for item in &wire.item_list {
|
||||||
|
if let Some(ref img) = item.image_item {
|
||||||
|
msg.images.push(ImageContent {
|
||||||
|
media: img.media.clone(),
|
||||||
|
thumb_media: img.thumb_media.clone(),
|
||||||
|
aes_key: img.aeskey.clone(),
|
||||||
|
url: img.url.clone(),
|
||||||
|
width: img.thumb_width,
|
||||||
|
height: img.thumb_height,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(ref voice) = item.voice_item {
|
||||||
|
msg.voices.push(VoiceContent {
|
||||||
|
media: voice.media.clone(),
|
||||||
|
text: voice.text.clone(),
|
||||||
|
duration_ms: voice.playtime,
|
||||||
|
encode_type: voice.encode_type,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(ref file) = item.file_item {
|
||||||
|
msg.files.push(FileContent {
|
||||||
|
media: file.media.clone(),
|
||||||
|
file_name: file.file_name.clone(),
|
||||||
|
md5: file.md5.clone(),
|
||||||
|
size: file.len.as_ref().and_then(|s| s.parse().ok()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(ref video) = item.video_item {
|
||||||
|
msg.videos.push(VideoContent {
|
||||||
|
media: video.media.clone(),
|
||||||
|
thumb_media: video.thumb_media.clone(),
|
||||||
|
duration_ms: video.play_length,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(ref refm) = item.ref_msg {
|
||||||
|
msg.quoted = Some(QuotedMessage {
|
||||||
|
title: refm.title.clone(),
|
||||||
|
text: refm
|
||||||
|
.message_item
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|i| i.text_item.as_ref())
|
||||||
|
.map(|t| t.text.clone()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detect_type(items: &[WireMessageItem]) -> ContentType {
|
||||||
|
items
|
||||||
|
.first()
|
||||||
|
.map_or(ContentType::Text, |item| match item.item_type {
|
||||||
|
MessageItemType::Image => ContentType::Image,
|
||||||
|
MessageItemType::Voice => ContentType::Voice,
|
||||||
|
MessageItemType::File => ContentType::File,
|
||||||
|
MessageItemType::Video => ContentType::Video,
|
||||||
|
_ => ContentType::Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text(items: &[WireMessageItem]) -> String {
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.filter_map(|item| match item.item_type {
|
||||||
|
MessageItemType::Text => item.text_item.as_ref().map(|t| t.text.clone()),
|
||||||
|
MessageItemType::Image => Some(
|
||||||
|
item.image_item
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|i| i.url.clone())
|
||||||
|
.unwrap_or_else(|| "[image]".to_string()),
|
||||||
|
),
|
||||||
|
MessageItemType::Voice => Some(
|
||||||
|
item.voice_item
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.text.clone())
|
||||||
|
.unwrap_or_else(|| "[voice]".to_string()),
|
||||||
|
),
|
||||||
|
MessageItemType::File => Some(
|
||||||
|
item.file_item
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|f| f.file_name.clone())
|
||||||
|
.unwrap_or_else(|| "[file]".to_string()),
|
||||||
|
),
|
||||||
|
MessageItemType::Video => Some("[video]".to_string()),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Content type of an incoming message.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ContentType {
|
||||||
|
Text,
|
||||||
|
Image,
|
||||||
|
Voice,
|
||||||
|
File,
|
||||||
|
Video,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ImageContent {
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
pub thumb_media: Option<CDNMedia>,
|
||||||
|
pub aes_key: Option<String>,
|
||||||
|
pub url: Option<String>,
|
||||||
|
pub width: Option<i32>,
|
||||||
|
pub height: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct VoiceContent {
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
pub text: Option<String>,
|
||||||
|
pub duration_ms: Option<i32>,
|
||||||
|
pub encode_type: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct FileContent {
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
pub file_name: Option<String>,
|
||||||
|
pub md5: Option<String>,
|
||||||
|
pub size: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct VideoContent {
|
||||||
|
pub media: Option<CDNMedia>,
|
||||||
|
pub thumb_media: Option<CDNMedia>,
|
||||||
|
pub duration_ms: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct QuotedMessage {
|
||||||
|
pub title: Option<String>,
|
||||||
|
pub text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of downloading media from a message.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DownloadedMedia {
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
/// "image", "file", "video", "voice"
|
||||||
|
pub media_type: String,
|
||||||
|
pub file_name: Option<String>,
|
||||||
|
pub format: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of uploading media to CDN.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct UploadResult {
|
||||||
|
pub media: CDNMedia,
|
||||||
|
pub aes_key: [u8; 16],
|
||||||
|
pub encrypted_file_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stored login credentials.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Credentials {
|
||||||
|
pub token: String,
|
||||||
|
#[serde(rename = "baseUrl")]
|
||||||
|
pub base_url: String,
|
||||||
|
#[serde(rename = "accountId")]
|
||||||
|
pub account_id: String,
|
||||||
|
#[serde(rename = "userId")]
|
||||||
|
pub user_id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub saved_at: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn message_type_values() {
|
||||||
|
assert_eq!(MessageType::User as i32, 1);
|
||||||
|
assert_eq!(MessageType::Bot as i32, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn message_state_values() {
|
||||||
|
assert_eq!(MessageState::New as i32, 0);
|
||||||
|
assert_eq!(MessageState::Generating as i32, 1);
|
||||||
|
assert_eq!(MessageState::Finish as i32, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn message_item_type_values() {
|
||||||
|
assert_eq!(MessageItemType::Text as i32, 1);
|
||||||
|
assert_eq!(MessageItemType::Image as i32, 2);
|
||||||
|
assert_eq!(MessageItemType::Voice as i32, 3);
|
||||||
|
assert_eq!(MessageItemType::File as i32, 4);
|
||||||
|
assert_eq!(MessageItemType::Video as i32, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wire_message_json_round_trip() {
|
||||||
|
let wire = WireMessage {
|
||||||
|
from_user_id: "user1".to_string(),
|
||||||
|
to_user_id: "bot1".to_string(),
|
||||||
|
client_id: "c1".to_string(),
|
||||||
|
create_time_ms: 1700000000000,
|
||||||
|
message_type: MessageType::User,
|
||||||
|
message_state: MessageState::Finish,
|
||||||
|
context_token: "ctx".to_string(),
|
||||||
|
item_list: vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "hello".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&wire).unwrap();
|
||||||
|
let decoded: WireMessage = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(decoded.from_user_id, "user1");
|
||||||
|
assert_eq!(decoded.message_type, MessageType::User);
|
||||||
|
assert_eq!(decoded.item_list.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
decoded.item_list[0].text_item.as_ref().unwrap().text,
|
||||||
|
"hello"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_json_camel_case() {
|
||||||
|
let creds = Credentials {
|
||||||
|
token: "tok".to_string(),
|
||||||
|
base_url: "https://api.example.com".to_string(),
|
||||||
|
account_id: "acc1".to_string(),
|
||||||
|
user_id: "uid1".to_string(),
|
||||||
|
saved_at: Some("2024-01-01T00:00:00Z".to_string()),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&creds).unwrap();
|
||||||
|
assert!(json.contains("\"baseUrl\""), "expected camelCase baseUrl");
|
||||||
|
assert!(
|
||||||
|
json.contains("\"accountId\""),
|
||||||
|
"expected camelCase accountId"
|
||||||
|
);
|
||||||
|
assert!(json.contains("\"userId\""), "expected camelCase userId");
|
||||||
|
|
||||||
|
let decoded: Credentials = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(decoded.token, "tok");
|
||||||
|
assert_eq!(decoded.base_url, "https://api.example.com");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_omits_none_saved_at() {
|
||||||
|
let creds = Credentials {
|
||||||
|
token: "tok".to_string(),
|
||||||
|
base_url: "https://api.example.com".to_string(),
|
||||||
|
account_id: "acc1".to_string(),
|
||||||
|
user_id: "uid1".to_string(),
|
||||||
|
saved_at: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&creds).unwrap();
|
||||||
|
assert!(!json.contains("saved_at"), "should omit None saved_at");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cdn_media_json() {
|
||||||
|
let media = CDNMedia {
|
||||||
|
encrypt_query_param: "param=abc".to_string(),
|
||||||
|
aes_key: "key123".to_string(),
|
||||||
|
encrypt_type: Some(1),
|
||||||
|
full_url: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&media).unwrap();
|
||||||
|
let decoded: CDNMedia = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(decoded.encrypt_query_param, "param=abc");
|
||||||
|
assert_eq!(decoded.aes_key, "key123");
|
||||||
|
assert_eq!(decoded.encrypt_type, Some(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wire_message_with_image() {
|
||||||
|
let wire = WireMessage {
|
||||||
|
from_user_id: "user1".to_string(),
|
||||||
|
to_user_id: "bot1".to_string(),
|
||||||
|
client_id: "c1".to_string(),
|
||||||
|
create_time_ms: 1700000000000,
|
||||||
|
message_type: MessageType::User,
|
||||||
|
message_state: MessageState::Finish,
|
||||||
|
context_token: "ctx".to_string(),
|
||||||
|
item_list: vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Image,
|
||||||
|
text_item: None,
|
||||||
|
image_item: Some(ImageItem {
|
||||||
|
media: None,
|
||||||
|
thumb_media: None,
|
||||||
|
aeskey: Some("key".to_string()),
|
||||||
|
url: Some("http://img.jpg".to_string()),
|
||||||
|
mid_size: Some(1024),
|
||||||
|
thumb_width: Some(100),
|
||||||
|
thumb_height: Some(200),
|
||||||
|
}),
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&wire).unwrap();
|
||||||
|
let decoded: WireMessage = serde_json::from_str(&json).unwrap();
|
||||||
|
let img = decoded.item_list[0].image_item.as_ref().unwrap();
|
||||||
|
assert_eq!(img.url, Some("http://img.jpg".to_string()));
|
||||||
|
assert_eq!(img.thumb_width, Some(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn content_type_equality() {
|
||||||
|
assert_eq!(ContentType::Text, ContentType::Text);
|
||||||
|
assert_ne!(ContentType::Text, ContentType::Image);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detect_type_text() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "hi".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(detect_type(&items), ContentType::Text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detect_type_image() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Image,
|
||||||
|
text_item: None,
|
||||||
|
image_item: Some(ImageItem {
|
||||||
|
media: None,
|
||||||
|
thumb_media: None,
|
||||||
|
aeskey: None,
|
||||||
|
url: Some("http://img".to_string()),
|
||||||
|
mid_size: None,
|
||||||
|
thumb_width: None,
|
||||||
|
thumb_height: None,
|
||||||
|
}),
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(detect_type(&items), ContentType::Image);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detect_type_empty() {
|
||||||
|
assert_eq!(detect_type(&[]), ContentType::Text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_single() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "hello world".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(extract_text(&items), "hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_multi() {
|
||||||
|
let items = vec![
|
||||||
|
WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "line1".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
},
|
||||||
|
WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "line2".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
assert_eq!(extract_text(&items), "line1\nline2");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_image_url() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Image,
|
||||||
|
text_item: None,
|
||||||
|
image_item: Some(ImageItem {
|
||||||
|
media: None,
|
||||||
|
thumb_media: None,
|
||||||
|
aeskey: None,
|
||||||
|
url: Some("http://img.jpg".to_string()),
|
||||||
|
mid_size: None,
|
||||||
|
thumb_width: None,
|
||||||
|
thumb_height: None,
|
||||||
|
}),
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(extract_text(&items), "http://img.jpg");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_image_placeholder() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Image,
|
||||||
|
text_item: None,
|
||||||
|
image_item: Some(ImageItem {
|
||||||
|
media: None,
|
||||||
|
thumb_media: None,
|
||||||
|
aeskey: None,
|
||||||
|
url: None,
|
||||||
|
mid_size: None,
|
||||||
|
thumb_width: None,
|
||||||
|
thumb_height: None,
|
||||||
|
}),
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(extract_text(&items), "[image]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_voice_with_text() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Voice,
|
||||||
|
text_item: None,
|
||||||
|
image_item: None,
|
||||||
|
voice_item: Some(VoiceItem {
|
||||||
|
media: None,
|
||||||
|
encode_type: None,
|
||||||
|
text: Some("hello".to_string()),
|
||||||
|
playtime: None,
|
||||||
|
}),
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(extract_text(&items), "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_file_name() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::File,
|
||||||
|
text_item: None,
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: Some(FileItem {
|
||||||
|
media: None,
|
||||||
|
file_name: Some("doc.pdf".to_string()),
|
||||||
|
md5: None,
|
||||||
|
len: None,
|
||||||
|
}),
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(extract_text(&items), "doc.pdf");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_text_video() {
|
||||||
|
let items = vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Video,
|
||||||
|
text_item: None,
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: Some(VideoItem {
|
||||||
|
media: None,
|
||||||
|
video_size: None,
|
||||||
|
play_length: None,
|
||||||
|
thumb_media: None,
|
||||||
|
}),
|
||||||
|
ref_msg: None,
|
||||||
|
}];
|
||||||
|
assert_eq!(extract_text(&items), "[video]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_wire_user_text() {
|
||||||
|
let wire = WireMessage {
|
||||||
|
from_user_id: "user123".to_string(),
|
||||||
|
to_user_id: "bot456".to_string(),
|
||||||
|
client_id: "c1".to_string(),
|
||||||
|
create_time_ms: 1700000000000,
|
||||||
|
message_type: MessageType::User,
|
||||||
|
message_state: MessageState::Finish,
|
||||||
|
context_token: "ctx-abc".to_string(),
|
||||||
|
item_list: vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "hello".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let msg = IncomingMessage::from_wire(&wire).unwrap();
|
||||||
|
assert_eq!(msg.user_id, "user123");
|
||||||
|
assert_eq!(msg.text, "hello");
|
||||||
|
assert_eq!(msg.content_type, ContentType::Text);
|
||||||
|
assert_eq!(msg.context_token(), "ctx-abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_wire_skips_bot() {
|
||||||
|
let wire = WireMessage {
|
||||||
|
from_user_id: "bot456".to_string(),
|
||||||
|
to_user_id: "user123".to_string(),
|
||||||
|
client_id: "c1".to_string(),
|
||||||
|
create_time_ms: 1700000000000,
|
||||||
|
message_type: MessageType::Bot,
|
||||||
|
message_state: MessageState::Finish,
|
||||||
|
context_token: "ctx".to_string(),
|
||||||
|
item_list: vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "reply".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
assert!(IncomingMessage::from_wire(&wire).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_wire_with_image() {
|
||||||
|
let wire = WireMessage {
|
||||||
|
from_user_id: "user123".to_string(),
|
||||||
|
to_user_id: "bot456".to_string(),
|
||||||
|
client_id: "c1".to_string(),
|
||||||
|
create_time_ms: 1700000000000,
|
||||||
|
message_type: MessageType::User,
|
||||||
|
message_state: MessageState::Finish,
|
||||||
|
context_token: "ctx".to_string(),
|
||||||
|
item_list: vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Image,
|
||||||
|
text_item: None,
|
||||||
|
image_item: Some(ImageItem {
|
||||||
|
media: None,
|
||||||
|
thumb_media: None,
|
||||||
|
aeskey: Some("key".to_string()),
|
||||||
|
url: Some("http://img.jpg".to_string()),
|
||||||
|
mid_size: None,
|
||||||
|
thumb_width: Some(100),
|
||||||
|
thumb_height: Some(200),
|
||||||
|
}),
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let msg = IncomingMessage::from_wire(&wire).unwrap();
|
||||||
|
assert_eq!(msg.images.len(), 1);
|
||||||
|
assert_eq!(msg.images[0].url, Some("http://img.jpg".to_string()));
|
||||||
|
assert_eq!(msg.images[0].width, Some(100));
|
||||||
|
assert_eq!(msg.images[0].height, Some(200));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_wire_with_quoted() {
|
||||||
|
let wire = WireMessage {
|
||||||
|
from_user_id: "user123".to_string(),
|
||||||
|
to_user_id: "bot456".to_string(),
|
||||||
|
client_id: "c1".to_string(),
|
||||||
|
create_time_ms: 1700000000000,
|
||||||
|
message_type: MessageType::User,
|
||||||
|
message_state: MessageState::Finish,
|
||||||
|
context_token: "ctx".to_string(),
|
||||||
|
item_list: vec![WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "replying".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: Some(RefMessage {
|
||||||
|
title: Some("Original".to_string()),
|
||||||
|
message_item: Some(Box::new(WireMessageItem {
|
||||||
|
item_type: MessageItemType::Text,
|
||||||
|
text_item: Some(TextItem {
|
||||||
|
text: "original text".to_string(),
|
||||||
|
}),
|
||||||
|
image_item: None,
|
||||||
|
voice_item: None,
|
||||||
|
file_item: None,
|
||||||
|
video_item: None,
|
||||||
|
ref_msg: None,
|
||||||
|
})),
|
||||||
|
}),
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let msg = IncomingMessage::from_wire(&wire).unwrap();
|
||||||
|
let quoted = msg.quoted.as_ref().unwrap();
|
||||||
|
assert_eq!(quoted.title, Some("Original".to_string()));
|
||||||
|
assert_eq!(quoted.text, Some("original text".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user