PicoBot/src/config/mod.rs

2204 lines
63 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use chrono::{DateTime, Utc};
use chrono_tz::Tz;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::str::FromStr;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub providers: HashMap<String, ProviderConfig>,
pub models: HashMap<String, ModelConfig>,
pub agents: HashMap<String, AgentConfig>,
#[serde(default)]
pub time: TimeConfig,
#[serde(default)]
pub gateway: GatewayConfig,
#[serde(default)]
pub scheduler: SchedulerConfig,
#[serde(default)]
pub client: ClientConfig,
#[serde(default)]
pub channels: HashMap<String, ChannelConfig>,
#[serde(default)]
pub skills: SkillsConfig,
#[serde(default)]
pub tools: ToolsConfig,
#[serde(default)]
pub memory_maintenance: MemoryMaintenanceConfig,
#[serde(default, rename = "mcpServers")]
pub mcp_servers: HashMap<String, crate::mcp::McpServerConfig>,
#[serde(default)]
pub image_context: ImageContextConfig,
#[serde(default)]
pub subagents: SubagentsConfig,
}
/// 图片上下文限制配置
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageContextConfig {
/// topic 上下文历史中最多发送给模型的图片数量 (默认 1)
#[serde(default = "default_max_images_in_context")]
pub max_images_in_context: usize,
/// 图片超过多少消息轮次后就不再提交给模型 (默认 10)
/// "轮次"定义为:消息在历史中的位置(距离最新消息的消息数)
/// 包括所有 role 类型user、assistant、tool、system 等)
#[serde(default = "default_max_image_age_rounds")]
pub max_image_age_rounds: usize,
}
fn default_max_images_in_context() -> usize {
1
}
fn default_max_image_age_rounds() -> usize {
10
}
impl Default for ImageContextConfig {
fn default() -> Self {
Self {
max_images_in_context: default_max_images_in_context(),
max_image_age_rounds: default_max_image_age_rounds(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TimeConfig {
#[serde(default = "default_timezone")]
pub timezone: String,
}
impl TimeConfig {
pub fn parse_timezone(&self) -> Result<Tz, ConfigError> {
self.timezone.parse::<Tz>().map_err(|_| {
ConfigError::InvalidTimezone(format!(
"unsupported timezone '{}', expected an IANA timezone like 'Asia/Shanghai'",
self.timezone
))
})
}
}
impl Default for TimeConfig {
fn default() -> Self {
Self {
timezone: default_timezone(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SkillsConfig {
#[serde(default = "default_skills_enabled")]
pub enabled: bool,
#[serde(default = "default_skills_sources")]
pub sources: Vec<String>,
#[serde(default = "default_skills_max_index_chars")]
pub max_index_chars: usize,
#[serde(default = "default_skills_max_listed")]
pub max_listed_skills: usize,
}
fn default_skills_enabled() -> bool {
true
}
fn default_skills_sources() -> Vec<String> {
vec![
"user".to_string(),
"user_agent".to_string(),
"user_openclaw".to_string(),
"project".to_string(),
"project_agent".to_string(),
"project_openclaw".to_string(),
]
}
fn default_skills_max_index_chars() -> usize {
4_000
}
fn default_skills_max_listed() -> usize {
32
}
impl Default for SkillsConfig {
fn default() -> Self {
Self {
enabled: default_skills_enabled(),
sources: default_skills_sources(),
max_index_chars: default_skills_max_index_chars(),
max_listed_skills: default_skills_max_listed(),
}
}
}
/// 自定义子代理配置
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SubagentsConfig {
/// 是否启用自定义子代理发现
#[serde(default = "default_subagents_enabled")]
pub enabled: bool,
/// 定义来源优先级
#[serde(default = "default_subagents_sources")]
pub sources: Vec<String>,
}
fn default_subagents_enabled() -> bool {
true
}
fn default_subagents_sources() -> Vec<String> {
vec!["user".to_string(), "project".to_string()]
}
impl Default for SubagentsConfig {
fn default() -> Self {
Self {
enabled: default_subagents_enabled(),
sources: default_subagents_sources(),
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct ToolsConfig {
#[serde(default)]
pub disabled: Vec<String>,
#[serde(default)]
pub task: TaskConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MemoryMaintenanceConfig {
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
#[serde(default = "default_max_merge_ratio")]
pub max_merge_ratio: f32,
/// 最小保留记忆数量,默认 5
#[serde(default = "default_min_memories_to_keep")]
pub min_memories_to_keep: usize,
/// 单次合并最大源记忆数,默认 3
#[serde(default = "default_max_merge_per_group")]
pub max_merge_per_group: usize,
}
fn default_max_merge_ratio() -> f32 {
0.3
}
fn default_min_memories_to_keep() -> usize {
5
}
fn default_max_merge_per_group() -> usize {
3
}
impl Default for MemoryMaintenanceConfig {
fn default() -> Self {
Self {
max_merge_ratio: default_max_merge_ratio(),
min_memories_to_keep: default_min_memories_to_keep(),
max_merge_per_group: default_max_merge_per_group(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TaskConfig {
#[serde(default = "default_task_enabled")]
pub enabled: bool,
#[serde(default = "default_task_max_execution_secs")]
pub max_execution_secs: u64,
#[serde(default = "default_task_explore_max_execution_secs")]
pub explore_max_execution_secs: u64,
#[serde(default = "default_task_ttl_hours")]
pub ttl_hours: u64,
#[serde(default = "default_task_allowed_tools")]
pub allowed_tools: Vec<String>,
#[serde(default = "default_task_max_nesting_depth")]
pub max_nesting_depth: u32,
}
fn default_task_enabled() -> bool {
true
}
fn default_task_max_execution_secs() -> u64 {
3600 // 60分钟
}
fn default_task_explore_max_execution_secs() -> u64 {
3600 // 60分钟
}
fn default_task_ttl_hours() -> u64 {
24
}
fn default_task_max_nesting_depth() -> u32 {
2
}
fn default_task_allowed_tools() -> Vec<String> {
vec![
"read".to_string(),
"edit".to_string(),
"write".to_string(),
"bash".to_string(),
"http_request".to_string(),
"web_fetch".to_string(),
"memory_search".to_string(),
"get_time".to_string(),
"calculator".to_string(),
"skill_activate".to_string(),
"skill_list".to_string(),
"send_session_message".to_string(),
]
}
impl Default for TaskConfig {
fn default() -> Self {
Self {
enabled: default_task_enabled(),
max_execution_secs: default_task_max_execution_secs(),
explore_max_execution_secs: default_task_explore_max_execution_secs(),
ttl_hours: default_task_ttl_hours(),
allowed_tools: default_task_allowed_tools(),
max_nesting_depth: default_task_max_nesting_depth(),
}
}
}
impl ToolsConfig {
/// Check if a tool is disabled
pub fn is_disabled(&self, tool_name: &str) -> bool {
self.disabled.iter().any(|name| name == tool_name)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChannelConfig {
Tagged(TaggedChannelConfig),
LegacyFeishu(FeishuChannelConfig),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TaggedChannelConfig {
Feishu(FeishuChannelConfig),
Wechat(WechatChannelConfig),
}
impl ChannelConfig {
pub fn kind(&self) -> &'static str {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => "feishu",
Self::Tagged(TaggedChannelConfig::Wechat(_)) => "wechat",
}
}
pub fn enabled(&self) -> bool {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
config.enabled
}
Self::Tagged(TaggedChannelConfig::Wechat(config)) => config.enabled,
}
}
pub fn as_feishu(&self) -> Option<&FeishuChannelConfig> {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
Some(config)
}
Self::Tagged(TaggedChannelConfig::Wechat(_)) => None,
}
}
pub fn as_feishu_mut(&mut self) -> Option<&mut FeishuChannelConfig> {
match self {
Self::Tagged(TaggedChannelConfig::Feishu(config)) | Self::LegacyFeishu(config) => {
Some(config)
}
Self::Tagged(TaggedChannelConfig::Wechat(_)) => None,
}
}
pub fn as_wechat(&self) -> Option<&WechatChannelConfig> {
match self {
Self::Tagged(TaggedChannelConfig::Wechat(config)) => Some(config),
Self::Tagged(TaggedChannelConfig::Feishu(_)) | Self::LegacyFeishu(_) => None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FeishuChannelConfig {
#[serde(default)]
pub enabled: bool,
pub app_id: String,
pub app_secret: String,
#[serde(default = "default_allow_from")]
pub allow_from: Vec<String>,
#[serde(default)]
pub agent: String,
#[serde(default = "default_media_dir")]
pub media_dir: String,
/// Emoji type for message reactions (e.g. "THUMBSUP", "OK", "EYES").
#[serde(default = "default_reaction_emoji")]
pub reaction_emoji: String,
#[serde(default = "default_channel_max_message_chars")]
pub max_message_chars: usize,
#[serde(default = "default_channel_reply_context_max_chars")]
pub reply_context_max_chars: usize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WechatChannelConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_allow_from")]
pub allow_from: Vec<String>,
#[serde(default)]
pub agent: String,
#[serde(default = "default_wechat_base_url")]
pub base_url: String,
#[serde(default = "default_wechat_cred_path")]
pub cred_path: String,
#[serde(default)]
pub force_login: bool,
}
fn default_allow_from() -> Vec<String> {
vec!["*".to_string()]
}
fn default_media_dir() -> String {
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
home.join(".picobot/media/feishu")
.to_string_lossy()
.to_string()
}
fn default_wechat_base_url() -> String {
"https://ilinkai.weixin.qq.com".to_string()
}
fn default_wechat_cred_path() -> String {
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
home.join(".picobot/wechat/credentials.json")
.to_string_lossy()
.to_string()
}
fn default_reaction_emoji() -> String {
"Typing".to_string()
}
fn default_channel_max_message_chars() -> usize {
20_000
}
fn default_channel_reply_context_max_chars() -> usize {
20_000
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderConfig {
#[serde(rename = "type")]
pub provider_type: String,
pub base_url: String,
pub api_key: String,
#[serde(default)]
pub extra_headers: HashMap<String, String>,
#[serde(default = "default_llm_timeout_secs")]
pub llm_timeout_secs: u64,
#[serde(default = "default_memory_maintenance_timeout_secs")]
pub memory_maintenance_timeout_secs: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelConfig {
pub model_id: String,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub context_window_tokens: Option<u32>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentConfig {
pub provider: String,
pub model: String,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
#[serde(default = "default_tool_result_max_chars")]
pub tool_result_max_chars: usize,
#[serde(default = "default_context_tool_result_trim_chars")]
pub context_tool_result_trim_chars: usize,
}
fn default_max_tool_iterations() -> usize {
100
}
fn default_tool_result_max_chars() -> usize {
100_000
}
fn default_context_tool_result_trim_chars() -> usize {
2_000
}
fn default_llm_timeout_secs() -> u64 {
120
}
fn default_memory_maintenance_timeout_secs() -> u64 {
600
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig {
#[serde(default = "default_gateway_host")]
pub host: String,
#[serde(default = "default_gateway_port")]
pub port: u16,
#[serde(default)]
pub show_tool_results: bool,
#[serde(
default = "default_agent_prompt_reinject_every",
rename = "agent_prompt_reinject_every"
)]
pub agent_prompt_reinject_every: u64,
#[serde(default = "default_max_concurrent_requests")]
pub max_concurrent_requests: usize,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ClientConfig {
#[serde(default = "default_gateway_url")]
pub gateway_url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SchedulerConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_scheduler_tick_resolution_ms")]
pub tick_resolution_ms: u64,
#[serde(default = "default_scheduler_worker_queue_capacity")]
pub worker_queue_capacity: usize,
#[serde(default)]
pub misfire_policy: SchedulerMisfirePolicy,
#[serde(default)]
pub jobs: Vec<SchedulerJobConfig>,
}
pub const BUILTIN_MEMORY_MAINTENANCE_JOB_ID: &str = "builtin.memory_maintenance_daily";
pub const BUILTIN_SESSION_CLEANUP_JOB_ID: &str = "builtin.session_cleanup_hourly";
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerMisfirePolicy {
CatchUp,
#[default]
Skip,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SchedulerJobConfig {
pub id: String,
#[serde(default = "default_scheduler_job_enabled")]
pub enabled: bool,
pub kind: SchedulerJobKind,
#[serde(default)]
pub schedule: Option<SchedulerSchedule>,
#[serde(default)]
pub startup_delay_secs: u64,
#[serde(default)]
pub interval_secs: u64,
#[serde(default)]
pub target: SchedulerJobTarget,
#[serde(default)]
pub payload: serde_json::Value,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SchedulerJobKind {
InternalEvent,
OutboundMessage,
AgentTask,
SilentAgentTask,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct SchedulerJobTarget {
#[serde(default)]
pub channel: Option<String>,
#[serde(default)]
pub chat_id: Option<String>,
#[serde(default)]
pub session_chat_id: Option<String>,
#[serde(default)]
pub reply_to: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SchedulerSchedule {
Delay {
seconds: u64,
},
Interval {
seconds: u64,
#[serde(default)]
startup_delay_secs: u64,
},
At {
timestamp: String,
},
Cron {
expression: String,
},
}
impl SchedulerJobConfig {
pub fn resolved_schedule(&self) -> Result<SchedulerSchedule, ConfigError> {
if let Some(schedule) = &self.schedule {
schedule.validate(&self.id)?;
return Ok(schedule.normalized_for_storage());
}
if self.interval_secs > 0 {
return Ok(SchedulerSchedule::Interval {
seconds: self.interval_secs,
startup_delay_secs: self.startup_delay_secs,
});
}
Err(ConfigError::InvalidSchedulerJob(format!(
"scheduler job '{}' requires schedule or interval_secs",
self.id
)))
}
}
impl SchedulerConfig {
pub fn builtin_jobs(time: &TimeConfig) -> Vec<SchedulerJobConfig> {
vec![
SchedulerJobConfig {
id: BUILTIN_MEMORY_MAINTENANCE_JOB_ID.to_string(),
enabled: true,
kind: SchedulerJobKind::InternalEvent,
schedule: Some(SchedulerSchedule::Cron {
expression: "0 */4 * * *".to_string(),
}),
startup_delay_secs: 0,
interval_secs: 0,
target: SchedulerJobTarget::default(),
payload: serde_json::json!({
"event": "memory_maintenance",
"time_zone": time.timezone,
"local_time": "every_4_hours"
}),
},
SchedulerJobConfig {
id: BUILTIN_SESSION_CLEANUP_JOB_ID.to_string(),
enabled: true,
kind: SchedulerJobKind::InternalEvent,
schedule: Some(SchedulerSchedule::Cron {
expression: "0 * * * *".to_string(),
}),
startup_delay_secs: 0,
interval_secs: 0,
target: SchedulerJobTarget::default(),
payload: serde_json::json!({
"event": "session_cleanup",
"time_zone": time.timezone,
"local_time": "every_hour"
}),
},
]
}
pub fn effective_jobs(&self, time: &TimeConfig) -> Vec<SchedulerJobConfig> {
let mut jobs = Self::builtin_jobs(time);
for configured in &self.jobs {
if let Some(existing) = jobs.iter_mut().find(|job| job.id == configured.id) {
*existing = configured.clone();
} else {
jobs.push(configured.clone());
}
}
jobs
}
}
impl SchedulerSchedule {
pub fn validate(&self, job_id: &str) -> Result<(), ConfigError> {
match self {
SchedulerSchedule::Delay { seconds } => {
if *seconds == 0 {
return Err(ConfigError::InvalidSchedulerJob(format!(
"scheduler job '{}' delay.seconds must be greater than 0",
job_id
)));
}
}
SchedulerSchedule::Interval { seconds, .. } => {
if *seconds == 0 {
return Err(ConfigError::InvalidSchedulerJob(format!(
"scheduler job '{}' interval.seconds must be greater than 0",
job_id
)));
}
}
SchedulerSchedule::At { timestamp } => {
DateTime::parse_from_rfc3339(timestamp).map_err(|err| {
ConfigError::InvalidSchedulerJob(format!(
"scheduler job '{}' invalid at.timestamp '{}': {}",
job_id, timestamp, err
))
})?;
}
SchedulerSchedule::Cron { expression } => {
parse_scheduler_cron(expression).map_err(|err| {
ConfigError::InvalidSchedulerJob(format!(
"scheduler job '{}' invalid cron.expression '{}': {}",
job_id, expression, err
))
})?;
}
}
Ok(())
}
pub fn is_one_shot(&self) -> bool {
matches!(
self,
SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. }
)
}
pub fn normalized_for_storage(&self) -> Self {
match self {
SchedulerSchedule::At { timestamp } => {
let parsed = DateTime::parse_from_rfc3339(timestamp)
.map(|value| value.with_timezone(&Utc).to_rfc3339())
.unwrap_or_else(|_| timestamp.clone());
SchedulerSchedule::At { timestamp: parsed }
}
other => other.clone(),
}
}
pub fn display(&self) -> String {
match self {
SchedulerSchedule::Delay { seconds } => format!("delay:{}s", seconds),
SchedulerSchedule::Interval {
seconds,
startup_delay_secs,
} => format!("interval:{}s:start_delay:{}s", seconds, startup_delay_secs),
SchedulerSchedule::At { timestamp } => format!("at:{}", timestamp),
SchedulerSchedule::Cron { expression } => format!("cron:{}", expression),
}
}
}
fn parse_scheduler_cron(expression: &str) -> Result<cron::Schedule, cron::error::Error> {
let normalized = normalize_cron_expression(expression);
cron::Schedule::from_str(&normalized)
}
fn normalize_cron_expression(expression: &str) -> String {
let parts: Vec<&str> = expression.split_whitespace().collect();
if parts.len() == 5 {
format!("0 {}", expression.trim())
} else {
expression.trim().to_string()
}
}
fn default_scheduler_tick_resolution_ms() -> u64 {
1_000
}
fn default_scheduler_worker_queue_capacity() -> usize {
64
}
fn default_scheduler_job_enabled() -> bool {
true
}
fn default_gateway_host() -> String {
"127.0.0.1".to_string()
}
fn default_gateway_port() -> u16 {
19876
}
fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string()
}
fn default_timezone() -> String {
detect_system_timezone().unwrap_or_else(default_beijing_timezone)
}
fn detect_system_timezone() -> Option<String> {
let detected = iana_time_zone::get_timezone().ok()?;
if detected.parse::<Tz>().is_ok() {
Some(detected)
} else {
None
}
}
fn default_beijing_timezone() -> String {
"Asia/Shanghai".to_string()
}
fn default_agent_prompt_reinject_every() -> u64 {
100
}
fn default_max_concurrent_requests() -> usize {
10
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
show_tool_results: false,
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
max_concurrent_requests: default_max_concurrent_requests(),
session_ttl_hours: Some(24),
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
gateway_url: default_gateway_url(),
}
}
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
enabled: true,
tick_resolution_ms: default_scheduler_tick_resolution_ms(),
worker_queue_capacity: default_scheduler_worker_queue_capacity(),
misfire_policy: SchedulerMisfirePolicy::default(),
jobs: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct LLMProviderConfig {
pub provider_type: String,
pub name: String,
pub base_url: String,
pub api_key: String,
pub extra_headers: HashMap<String, String>,
pub llm_timeout_secs: u64,
pub memory_maintenance_timeout_secs: u64,
pub model_id: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub context_window_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
pub max_tool_iterations: usize,
pub tool_result_max_chars: usize,
pub context_tool_result_trim_chars: usize,
/// 图片上下文限制配置
pub max_images_in_context: usize,
pub max_image_age_rounds: usize,
}
impl LLMProviderConfig {
pub fn context_window_tokens(&self) -> usize {
self.context_window_tokens
.map(|value| value as usize)
.unwrap_or(128_000)
}
pub fn context_summary_char_budget(&self) -> usize {
const SUMMARY_RATIO: f64 = 0.1;
const CHARS_PER_TOKEN: f64 = 2.5;
const MIN_SUMMARY_CHARS: usize = 1_500;
const MAX_SUMMARY_CHARS: usize = 50_000;
((self.context_window_tokens() as f64 * SUMMARY_RATIO * CHARS_PER_TOKEN) as usize)
.clamp(MIN_SUMMARY_CHARS, MAX_SUMMARY_CHARS)
}
}
pub(crate) fn get_default_config_path() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("config.json")
}
impl Config {
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
Self::load_from(Path::new(path))
}
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
let path = get_default_config_path();
Self::load_from(&path)
}
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
load_env_file()?;
let content = if path.exists() {
tracing::info!(path = %path.display(), "Config loaded");
fs::read_to_string(path)?
} else {
// Fallback to current directory
let fallback = Path::new("config.json");
if fallback.exists() {
tracing::info!(path = %fallback.display(), "Config loaded from fallback path");
fs::read_to_string(fallback)?
} else {
return Err(Box::new(ConfigError::ConfigNotFound(
path.to_string_lossy().to_string(),
)));
}
};
let content = resolve_env_placeholders(&content);
let config: Config = serde_json::from_str(&content)?;
config.time.parse_timezone()?;
// Log MCP servers count if any
if !config.mcp_servers.is_empty() {
tracing::info!(
mcp_servers = config.mcp_servers.len(),
"MCP servers loaded from config"
);
}
Ok(config)
}
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
let agent = self
.agents
.get(agent_name)
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
let provider = self
.providers
.get(&agent.provider)
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
let model = self
.models
.get(&agent.model)
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
Ok(LLMProviderConfig {
provider_type: provider.provider_type.clone(),
name: agent.provider.clone(),
base_url: provider.base_url.clone(),
api_key: provider.api_key.clone(),
extra_headers: provider.extra_headers.clone(),
llm_timeout_secs: provider.llm_timeout_secs,
memory_maintenance_timeout_secs: provider.memory_maintenance_timeout_secs,
model_id: model.model_id.clone(),
temperature: model.temperature,
max_tokens: model.max_tokens,
context_window_tokens: model.context_window_tokens,
model_extra: model.extra.clone(),
max_tool_iterations: agent.max_tool_iterations,
tool_result_max_chars: agent.tool_result_max_chars,
context_tool_result_trim_chars: agent.context_tool_result_trim_chars,
max_images_in_context: self.image_context.max_images_in_context,
max_image_age_rounds: self.image_context.max_image_age_rounds,
})
}
}
#[derive(Debug)]
pub enum ConfigError {
ConfigNotFound(String),
AgentNotFound(String),
ProviderNotFound(String),
ModelNotFound(String),
InvalidSchedulerJob(String),
InvalidTimezone(String),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(
f,
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
path
),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
ConfigError::InvalidSchedulerJob(message) => {
write!(f, "Invalid scheduler job: {}", message)
}
ConfigError::InvalidTimezone(message) => write!(f, "Invalid timezone: {}", message),
}
}
}
impl std::error::Error for ConfigError {}
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
let env_path = Path::new(".env");
if env_path.exists() {
let content = fs::read_to_string(env_path)?;
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = value.trim().trim_matches('"').trim_matches('\'');
if !value.is_empty() {
// SAFETY: Setting environment variables for the current process
// is safe as we're only modifying our own process state
unsafe { env::set_var(key, value) };
}
}
}
}
Ok(())
}
fn resolve_env_placeholders(content: &str) -> String {
// Support both ${ENV_VAR} (Claude Desktop style) and <ENV_VAR> (legacy style)
let re_braces = Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("invalid regex");
let re_angle = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
let content = re_braces.replace_all(content, |caps: &regex::Captures| {
let var_name = &caps[1];
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
});
re_angle.replace_all(&content, |caps: &regex::Captures| {
let var_name = &caps[1];
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
})
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
fn write_test_config() -> tempfile::NamedTempFile {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
},
"volcengine": {
"type": "openai",
"base_url": "https://example.invalid/volc",
"api_key": "test-key-2",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"temperature": 0.0
},
"doubao-seed-2-0-lite-260215": {
"model_id": "doubao-seed-2-0-lite-260215"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"gateway": {
"host": "127.0.0.1",
"port": 19876,
"agent_prompt_reinject_every": 120
}
}"#,
)
.unwrap();
file
}
#[test]
fn test_config_load() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
// Check providers
assert!(config.providers.contains_key("volcengine"));
assert!(config.providers.contains_key("aliyun"));
// Check models
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
assert!(config.models.contains_key("qwen-plus"));
// Check agents
assert!(config.agents.contains_key("default"));
}
#[test]
fn test_get_provider_config() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");
assert_eq!(provider_config.model_id, "qwen-plus");
assert_eq!(provider_config.temperature, Some(0.0));
assert_eq!(provider_config.max_tokens, None);
assert_eq!(provider_config.llm_timeout_secs, 120);
assert_eq!(provider_config.tool_result_max_chars, 100_000);
assert_eq!(provider_config.context_tool_result_trim_chars, 2_000);
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
}
#[test]
fn test_provider_config_loads_custom_llm_timeout() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {},
"llm_timeout_secs": 400
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.llm_timeout_secs, 400);
}
#[test]
fn test_default_skills_sources_include_agent_directories() {
let config = SkillsConfig::default();
assert_eq!(
config.sources,
vec![
"user".to_string(),
"user_agent".to_string(),
"user_openclaw".to_string(),
"project".to_string(),
"project_agent".to_string(),
"project_openclaw".to_string(),
]
);
}
#[test]
fn test_default_gateway_config() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.host, "127.0.0.1");
assert_eq!(config.gateway.port, 19876);
assert!(!config.gateway.show_tool_results);
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
}
#[test]
fn test_gateway_config_defaults_agent_prompt_reinject_every() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert!(!config.gateway.show_tool_results);
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
}
#[test]
fn test_config_loads_configured_timezone() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"time": {
"timezone": "Asia/Shanghai"
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.time.timezone, "Asia/Shanghai");
assert_eq!(
config.time.parse_timezone().unwrap(),
chrono_tz::Asia::Shanghai
);
}
#[test]
fn test_config_rejects_invalid_timezone() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"time": {
"timezone": "Mars/Base"
}
}"#,
)
.unwrap();
let error = Config::load(file.path().to_str().unwrap()).unwrap_err();
assert!(error.to_string().contains("Invalid timezone"));
}
#[test]
fn test_gateway_config_can_enable_tool_results() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"gateway": {
"show_tool_results": true
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert!(config.gateway.show_tool_results);
}
#[test]
fn test_agent_config_defaults_max_tool_iterations_to_100() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.agents["default"].max_tool_iterations, 100);
assert_eq!(config.agents["default"].tool_result_max_chars, 100_000);
assert_eq!(
config.agents["default"].context_tool_result_trim_chars,
2_000
);
}
#[test]
fn test_agent_config_loads_custom_truncation_limits() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus",
"tool_result_max_chars": 1234,
"context_tool_result_trim_chars": 3456
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let agent = &config.agents["default"];
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(agent.tool_result_max_chars, 1234);
assert_eq!(agent.context_tool_result_trim_chars, 3456);
assert_eq!(provider_config.tool_result_max_chars, 1234);
assert_eq!(provider_config.context_tool_result_trim_chars, 3456);
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
}
#[test]
fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"context_window_tokens": 4096
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.context_window_tokens(), 4096);
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
}
#[test]
fn test_provider_config_max_tokens_does_not_change_context_window() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"max_tokens": 4096
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.max_tokens, Some(4096));
assert_eq!(provider_config.context_window_tokens(), 128_000);
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
}
#[test]
fn test_feishu_channel_config_defaults_truncation_limits() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"feishu": {
"enabled": true,
"app_id": "app-id",
"app_secret": "secret"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let feishu = config.channels["feishu"].as_feishu().unwrap();
assert_eq!(feishu.max_message_chars, 20_000);
assert_eq!(feishu.reply_context_max_chars, 20_000);
}
#[test]
fn test_tagged_feishu_channel_config_loads() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"primary": {
"type": "feishu",
"enabled": true,
"app_id": "app-id",
"app_secret": "secret"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let feishu = config.channels["primary"].as_feishu().unwrap();
assert_eq!(config.channels["primary"].kind(), "feishu");
assert!(config.channels["primary"].enabled());
assert_eq!(feishu.app_id, "app-id");
}
#[test]
fn test_tagged_wechat_channel_config_loads() {
let file = tempfile::NamedTempFile::new().unwrap();
// 使用临时文件路径确保跨平台兼容
let temp_dir = tempfile::tempdir().unwrap();
let cred_path = temp_dir.path().join("wechat-creds.json");
// JSON 中的路径需要转义反斜杠
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"wechat_main": {
"type": "wechat",
"enabled": true,
"base_url": "https://ilinkai.weixin.qq.com",
"cred_path": "<CRED_PATH>",
"force_login": true,
"allow_from": ["wxid_1"]
}
}
}"#.replace("<CRED_PATH>", &cred_path_json),
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let wechat = config.channels["wechat_main"].as_wechat().unwrap();
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
assert!(config.channels["wechat_main"].enabled());
assert_eq!(wechat.cred_path, cred_path.display().to_string());
assert!(wechat.force_login);
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
}
#[test]
fn test_feishu_channel_config_loads_custom_truncation_limits() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"channels": {
"feishu": {
"enabled": true,
"app_id": "app-id",
"app_secret": "secret",
"max_message_chars": 3456,
"reply_context_max_chars": 4567
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let feishu = config.channels["feishu"].as_feishu().unwrap();
assert_eq!(feishu.max_message_chars, 3456);
assert_eq!(feishu.reply_context_max_chars, 4567);
}
#[test]
fn test_scheduler_config_defaults() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert!(config.scheduler.enabled);
assert_eq!(config.scheduler.tick_resolution_ms, 1_000);
assert_eq!(config.scheduler.worker_queue_capacity, 64);
assert_eq!(
config.scheduler.misfire_policy,
SchedulerMisfirePolicy::Skip
);
assert!(config.scheduler.jobs.is_empty());
let effective_jobs = config.scheduler.effective_jobs(&config.time);
assert_eq!(effective_jobs.len(), 2);
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent);
assert_eq!(
effective_jobs[0].resolved_schedule().unwrap(),
SchedulerSchedule::Cron {
expression: "0 */4 * * *".to_string(),
}
);
// 第二个内置作业是会话清理
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
}
#[test]
fn test_scheduler_effective_jobs_allows_builtin_override() {
let mut scheduler = SchedulerConfig::default();
scheduler.jobs.push(SchedulerJobConfig {
id: BUILTIN_MEMORY_MAINTENANCE_JOB_ID.to_string(),
enabled: false,
kind: SchedulerJobKind::InternalEvent,
schedule: Some(SchedulerSchedule::Cron {
expression: "15 2 * * *".to_string(),
}),
startup_delay_secs: 0,
interval_secs: 0,
target: SchedulerJobTarget::default(),
payload: serde_json::json!({
"event": "memory_maintenance",
"time_zone": "UTC",
"local_time": "02:15"
}),
});
scheduler.jobs.push(SchedulerJobConfig {
id: "custom.reminder".to_string(),
enabled: true,
kind: SchedulerJobKind::InternalEvent,
schedule: Some(SchedulerSchedule::Delay { seconds: 30 }),
startup_delay_secs: 0,
interval_secs: 0,
target: SchedulerJobTarget::default(),
payload: serde_json::json!({"event": "custom"}),
});
let effective_jobs = scheduler.effective_jobs(&TimeConfig {
timezone: "Asia/Shanghai".to_string(),
});
assert_eq!(effective_jobs.len(), 3); // 2个内置 + 1个自定义
// 第一个作业:内存维护(被覆盖为禁用)
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
assert!(!effective_jobs[0].enabled);
assert_eq!(
effective_jobs[0].resolved_schedule().unwrap(),
SchedulerSchedule::Cron {
expression: "15 2 * * *".to_string(),
}
);
// 第二个作业:会话清理(保持默认)
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
assert!(effective_jobs[1].enabled);
// 第三个作业:自定义提醒
assert_eq!(effective_jobs[2].id, "custom.reminder");
}
#[test]
fn test_scheduler_config_loads_interval_compat_jobs() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"scheduler": {
"enabled": true,
"tick_resolution_ms": 500,
"worker_queue_capacity": 8,
"misfire_policy": "catch_up",
"jobs": [
{
"id": "heartbeat.reminder",
"kind": "outbound_message",
"interval_secs": 60,
"startup_delay_secs": 5,
"target": {
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
"content": "heartbeat"
}
}
]
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert!(config.scheduler.enabled);
assert_eq!(config.scheduler.tick_resolution_ms, 500);
assert_eq!(config.scheduler.worker_queue_capacity, 8);
assert_eq!(
config.scheduler.misfire_policy,
SchedulerMisfirePolicy::CatchUp
);
assert_eq!(config.scheduler.jobs.len(), 1);
let job = &config.scheduler.jobs[0];
assert_eq!(job.id, "heartbeat.reminder");
assert!(job.enabled);
assert_eq!(job.kind, SchedulerJobKind::OutboundMessage);
assert_eq!(job.interval_secs, 60);
assert_eq!(job.startup_delay_secs, 5);
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
assert_eq!(
job.payload.get("content").and_then(|value| value.as_str()),
Some("heartbeat")
);
assert_eq!(
job.resolved_schedule().unwrap(),
SchedulerSchedule::Interval {
seconds: 60,
startup_delay_secs: 5,
}
);
}
#[test]
fn test_scheduler_config_loads_schedule_variants() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"scheduler": {
"enabled": true,
"jobs": [
{
"id": "delay.job",
"kind": "internal_event",
"schedule": {
"type": "delay",
"seconds": 30
}
},
{
"id": "at.job",
"kind": "outbound_message",
"schedule": {
"type": "at",
"timestamp": "2026-04-23T09:00:00Z"
},
"target": {
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
"content": "at run"
}
},
{
"id": "cron.job",
"kind": "internal_event",
"schedule": {
"type": "cron",
"expression": "0 9 * * *"
}
}
]
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.scheduler.jobs.len(), 3);
assert_eq!(
config.scheduler.jobs[0].resolved_schedule().unwrap(),
SchedulerSchedule::Delay { seconds: 30 }
);
assert_eq!(
config.scheduler.jobs[0].kind,
SchedulerJobKind::InternalEvent
);
assert_eq!(
config.scheduler.jobs[1].resolved_schedule().unwrap(),
SchedulerSchedule::At {
timestamp: "2026-04-23T09:00:00+00:00".to_string(),
}
);
assert_eq!(
config.scheduler.jobs[1].kind,
SchedulerJobKind::OutboundMessage
);
assert_eq!(
config.scheduler.jobs[2].resolved_schedule().unwrap(),
SchedulerSchedule::Cron {
expression: "0 9 * * *".to_string(),
}
);
assert_eq!(
config.scheduler.jobs[2].kind,
SchedulerJobKind::InternalEvent
);
}
#[test]
fn test_scheduler_config_loads_agent_task_job() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"scheduler": {
"enabled": true,
"jobs": [
{
"id": "agent.daily_summary",
"kind": "agent_task",
"schedule": {
"type": "cron",
"expression": "0 9 * * *"
},
"target": {
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
"prompt": "请总结今天待办"
}
}
]
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let job = &config.scheduler.jobs[0];
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
assert_eq!(
job.payload.get("prompt").and_then(|value| value.as_str()),
Some("请总结今天待办")
);
}
#[test]
fn test_scheduler_config_loads_silent_agent_task_job() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"scheduler": {
"enabled": true,
"jobs": [
{
"id": "agent.daily_summary.background",
"kind": "silent_agent_task",
"schedule": {
"type": "cron",
"expression": "0 9 * * *"
},
"target": {
"channel": "feishu",
"chat_id": "oc_demo",
"session_chat_id": "scheduler/agent.daily_summary.background"
},
"payload": {
"prompt": "请后台总结今天待办"
}
}
]
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let job = &config.scheduler.jobs[0];
assert_eq!(job.kind, SchedulerJobKind::SilentAgentTask);
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
assert_eq!(
job.target.session_chat_id.as_deref(),
Some("scheduler/agent.daily_summary.background")
);
assert_eq!(
job.payload.get("prompt").and_then(|value| value.as_str()),
Some("请后台总结今天待办")
);
}
#[test]
fn test_scheduler_schedule_validation_rejects_invalid_values() {
assert!(
SchedulerSchedule::Delay { seconds: 0 }
.validate("delay.job")
.is_err()
);
assert!(
SchedulerSchedule::Interval {
seconds: 0,
startup_delay_secs: 0,
}
.validate("interval.job")
.is_err()
);
assert!(
SchedulerSchedule::At {
timestamp: "bad timestamp".to_string(),
}
.validate("at.job")
.is_err()
);
assert!(
SchedulerSchedule::Cron {
expression: "bad cron".to_string(),
}
.validate("cron.job")
.is_err()
);
}
#[test]
fn test_resolve_env_placeholders_brace_syntax() {
// Test ${ENV_VAR} syntax (Claude Desktop style)
unsafe { env::set_var("TEST_API_KEY", "my-secret-key") };
let content = r#"{"api_key": "${TEST_API_KEY}", "other": "${MISSING_VAR}"}"#;
let resolved = resolve_env_placeholders(content);
assert!(resolved.contains("my-secret-key"));
assert!(resolved.contains("${MISSING_VAR}")); // Unresolved stays as-is
// Clean up
unsafe { env::remove_var("TEST_API_KEY") };
}
#[test]
fn test_resolve_env_placeholders_angle_syntax() {
// Test <ENV_VAR> syntax (legacy style)
unsafe { env::set_var("LEGACY_KEY", "legacy-value") };
let content = r#"{"api_key": "<LEGACY_KEY>", "other": "<MISSING>"}"#;
let resolved = resolve_env_placeholders(content);
assert!(resolved.contains("legacy-value"));
assert!(resolved.contains("<MISSING>")); // Unresolved stays as-is
// Clean up
unsafe { env::remove_var("LEGACY_KEY") };
}
#[test]
fn test_resolve_env_placeholders_mixed_syntax() {
// Test both syntaxes in the same content
unsafe { env::set_var("BRACE_VAR", "brace-value") };
unsafe { env::set_var("ANGLE_VAR", "angle-value") };
let content = r#"{"brace": "${BRACE_VAR}", "angle": "<ANGLE_VAR>"}"#;
let resolved = resolve_env_placeholders(content);
assert!(resolved.contains("brace-value"));
assert!(resolved.contains("angle-value"));
// Clean up
unsafe { env::remove_var("BRACE_VAR") };
unsafe { env::remove_var("ANGLE_VAR") };
}
#[test]
fn test_root_level_mcp_servers_merging() {
// Test that mcpServers at root level is loaded correctly
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"mcpServers": {
"WebSearch": {
"type": "streamableHttp",
"baseUrl": "https://api.example.com/mcp",
"isActive": true
},
"filesystem": {
"type": "stdio",
"command": "npx",
"isActive": true
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
// Should have 2 servers
assert_eq!(config.mcp_servers.len(), 2);
assert!(config.mcp_servers.contains_key("WebSearch"));
assert!(config.mcp_servers.contains_key("filesystem"));
}
#[test]
fn test_root_level_mcp_servers_only() {
// Test that mcpServers at root level works
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"mcpServers": {
"WebSearch": {
"type": "streamableHttp",
"baseUrl": "https://api.example.com/mcp",
"isActive": true
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
// Should have 1 server from root level
assert_eq!(config.mcp_servers.len(), 1);
assert!(config.mcp_servers.contains_key("WebSearch"));
}
}