use regex::Regex; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::env; use std::fs; use std::path::{Path, PathBuf}; /// Get the user configuration directory (~/.picobot) pub fn get_user_config_dir() -> PathBuf { dirs::home_dir() .unwrap_or_else(|| PathBuf::from(".")) .join(".picobot") } /// Get the default workspace directory (~/.picobot/workspace) pub fn get_default_workspace_dir() -> PathBuf { get_user_config_dir().join("workspace") } /// Expand ~ in path to user home directory pub fn expand_path(path: &str) -> PathBuf { if path.starts_with("~/") { dirs::home_dir() .unwrap_or_else(|| PathBuf::from(".")) .join(&path[2..]) } else { PathBuf::from(path) } } /// Ensure workspace directory exists, create if needed pub fn ensure_workspace_dir(path: &Path) -> Result { if !path.exists() { tracing::info!("Creating workspace directory: {}", path.display()); fs::create_dir_all(path)?; } // Return canonical path path.canonicalize().or_else(|_| Ok(path.to_path_buf())) } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { pub providers: HashMap, pub models: HashMap, pub agents: HashMap, #[serde(default)] pub gateway: GatewayConfig, #[serde(default)] pub client: ClientConfig, #[serde(default)] pub channels: HashMap, #[serde(default = "default_workspace_dir")] pub workspace_dir: String, } fn default_workspace_dir() -> String { get_default_workspace_dir().to_string_lossy().to_string() } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeishuChannelConfig { #[serde(default)] pub enabled: bool, pub app_id: String, pub app_secret: String, #[serde(default = "default_allow_from")] pub allow_from: Vec, #[serde(default)] pub agent: String, #[serde(default = "default_media_dir")] pub media_dir: String, /// Emoji type for message reactions (e.g. "THUMBSUP", "OK", "EYES"). #[serde(default = "default_reaction_emoji")] pub reaction_emoji: String, } fn default_allow_from() -> Vec { vec!["*".to_string()] } fn default_media_dir() -> String { get_user_config_dir() .join("media/feishu") .to_string_lossy() .to_string() } fn default_reaction_emoji() -> String { "Typing".to_string() } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ProviderConfig { #[serde(rename = "type")] pub provider_type: String, pub base_url: String, pub api_key: String, #[serde(default)] pub extra_headers: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ModelConfig { pub model_id: String, #[serde(default)] pub temperature: Option, #[serde(default)] pub max_tokens: Option, #[serde(flatten)] pub extra: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AgentConfig { pub provider: String, pub model: String, #[serde(default = "default_max_tool_iterations")] pub max_tool_iterations: usize, #[serde(default = "default_token_limit")] pub token_limit: usize, } fn default_max_tool_iterations() -> usize { 20 } fn default_token_limit() -> usize { 128_000 } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GatewayConfig { #[serde(default = "default_gateway_host")] pub host: String, #[serde(default = "default_gateway_port")] pub port: u16, #[serde(default, rename = "session_ttl_hours")] pub session_ttl_hours: Option, #[serde(default, rename = "cleanup_interval_minutes")] pub cleanup_interval_minutes: Option, #[serde(default, rename = "session_db_path")] pub session_db_path: Option, #[serde(default)] pub scheduler: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SchedulerConfig { /// Whether the scheduler is enabled #[serde(default = "default_scheduler_enabled")] pub enabled: bool, /// Poll interval in seconds (how often to check for due jobs) #[serde(default = "default_poll_interval_secs")] pub poll_interval_secs: u64, /// Maximum concurrent job executions (currently sequential, reserved for future) #[serde(default = "default_max_concurrent")] pub max_concurrent: usize, } fn default_scheduler_enabled() -> bool { true } fn default_poll_interval_secs() -> u64 { 60 } fn default_max_concurrent() -> usize { 1 } impl Default for SchedulerConfig { fn default() -> Self { Self { enabled: true, poll_interval_secs: 60, max_concurrent: 1, } } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ClientConfig { #[serde(default = "default_gateway_url")] pub gateway_url: String, } fn default_gateway_host() -> String { "127.0.0.1".to_string() } fn default_gateway_port() -> u16 { 19876 } fn default_gateway_url() -> String { "ws://127.0.0.1:19876/ws".to_string() } impl Default for GatewayConfig { fn default() -> Self { Self { host: default_gateway_host(), port: default_gateway_port(), session_ttl_hours: None, cleanup_interval_minutes: None, session_db_path: None, scheduler: None, } } } impl Default for ClientConfig { fn default() -> Self { Self { gateway_url: default_gateway_url(), } } } #[derive(Debug, Clone)] pub struct LLMProviderConfig { pub provider_type: String, pub name: String, pub base_url: String, pub api_key: String, pub extra_headers: HashMap, pub model_id: String, pub temperature: Option, pub max_tokens: Option, pub model_extra: HashMap, pub max_tool_iterations: usize, pub token_limit: usize, pub workspace_dir: PathBuf, } fn get_default_config_path() -> PathBuf { get_user_config_dir().join("config.json") } impl Config { pub fn load(path: &str) -> Result> { Self::load_from(Path::new(path)) } pub fn load_default() -> Result> { let path = get_default_config_path(); Self::load_from(&path) } fn load_from(path: &Path) -> Result> { load_env_file()?; let content = if path.exists() { tracing::info!(path = %path.display(), "Config loaded"); fs::read_to_string(path)? } else { // Fallback to current directory let fallback = Path::new("config.json"); if fallback.exists() { tracing::info!(path = %fallback.display(), "Config loaded from fallback path"); fs::read_to_string(fallback)? } else { return Err(Box::new(ConfigError::ConfigNotFound( path.to_string_lossy().to_string(), ))); } }; let content = resolve_env_placeholders(&content); let config: Config = serde_json::from_str(&content)?; Ok(config) } pub fn get_provider_config(&self, agent_name: &str) -> Result { let agent = self .agents .get(agent_name) .ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?; let provider = self .providers .get(&agent.provider) .ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?; let model = self .models .get(&agent.model) .ok_or(ConfigError::ModelNotFound(agent.model.clone()))?; Ok(LLMProviderConfig { provider_type: provider.provider_type.clone(), name: agent.provider.clone(), base_url: provider.base_url.clone(), api_key: provider.api_key.clone(), extra_headers: provider.extra_headers.clone(), model_id: model.model_id.clone(), temperature: model.temperature, max_tokens: model.max_tokens, model_extra: model.extra.clone(), max_tool_iterations: agent.max_tool_iterations, token_limit: agent.token_limit, workspace_dir: expand_path(&self.workspace_dir), }) } } #[derive(Debug)] pub enum ConfigError { ConfigNotFound(String), AgentNotFound(String), ProviderNotFound(String), ModelNotFound(String), } impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), } } } impl std::error::Error for ConfigError {} fn load_env_file() -> Result<(), Box> { let env_path = Path::new(".env"); if env_path.exists() { let content = fs::read_to_string(env_path)?; for line in content.lines() { let line = line.trim(); if line.is_empty() || line.starts_with('#') { continue; } if let Some((key, value)) = line.split_once('=') { let key = key.trim(); let value = value.trim().trim_matches('"').trim_matches('\''); if !value.is_empty() { // SAFETY: Setting environment variables for the current process // is safe as we're only modifying our own process state unsafe { env::set_var(key, value) }; } } } } Ok(()) } fn resolve_env_placeholders(content: &str) -> String { let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex"); re.replace_all(content, |caps: ®ex::Captures| { let var_name = &caps[1]; env::var(var_name).unwrap_or_else(|_| caps[0].to_string()) }) .to_string() } #[cfg(test)] mod tests { use super::*; fn write_test_config() -> tempfile::NamedTempFile { let file = tempfile::NamedTempFile::new().unwrap(); std::fs::write( file.path(), r#"{ "providers": { "aliyun": { "type": "openai", "base_url": "https://example.invalid/v1", "api_key": "test-key", "extra_headers": {} }, "volcengine": { "type": "openai", "base_url": "https://example.invalid/volc", "api_key": "test-key-2", "extra_headers": {} } }, "models": { "qwen-plus": { "model_id": "qwen-plus", "temperature": 0.0 }, "doubao-seed-2-0-lite-260215": { "model_id": "doubao-seed-2-0-lite-260215" } }, "agents": { "default": { "provider": "aliyun", "model": "qwen-plus" } }, "gateway": { "host": "0.0.0.0", "port": 19876 } }"#, ) .unwrap(); file } #[test] fn test_config_load() { let file = write_test_config(); let config = Config::load(file.path().to_str().unwrap()).unwrap(); // Check providers assert!(config.providers.contains_key("volcengine")); assert!(config.providers.contains_key("aliyun")); // Check models assert!(config.models.contains_key("doubao-seed-2-0-lite-260215")); assert!(config.models.contains_key("qwen-plus")); // Check agents assert!(config.agents.contains_key("default")); } #[test] fn test_get_provider_config() { let file = write_test_config(); let config = Config::load(file.path().to_str().unwrap()).unwrap(); let provider_config = config.get_provider_config("default").unwrap(); assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.model_id, "qwen-plus"); assert_eq!(provider_config.temperature, Some(0.0)); } #[test] fn test_default_gateway_config() { let file = write_test_config(); let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.port, 19876); } }