use regex::Regex; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::env; use std::fs; use std::path::{Path, PathBuf}; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { pub providers: HashMap, pub models: HashMap, pub agents: HashMap, #[serde(default)] pub gateway: GatewayConfig, #[serde(default)] pub client: ClientConfig, #[serde(default)] pub channels: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeishuChannelConfig { #[serde(default)] pub enabled: bool, pub app_id: String, pub app_secret: String, #[serde(default = "default_allow_from")] pub allow_from: Vec, #[serde(default)] pub agent: String, } fn default_allow_from() -> Vec { vec!["*".to_string()] } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ProviderConfig { #[serde(rename = "type")] pub provider_type: String, pub base_url: String, pub api_key: String, #[serde(default)] pub extra_headers: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ModelConfig { pub model_id: String, #[serde(default)] pub temperature: Option, #[serde(default)] pub max_tokens: Option, #[serde(flatten)] pub extra: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AgentConfig { pub provider: String, pub model: String, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GatewayConfig { #[serde(default = "default_gateway_host")] pub host: String, #[serde(default = "default_gateway_port")] pub port: u16, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ClientConfig { #[serde(default = "default_gateway_url")] pub gateway_url: String, } fn default_gateway_host() -> String { "127.0.0.1".to_string() } fn default_gateway_port() -> u16 { 19876 } fn default_gateway_url() -> String { "ws://127.0.0.1:19876/ws".to_string() } impl Default for GatewayConfig { fn default() -> Self { Self { host: default_gateway_host(), port: default_gateway_port(), } } } impl Default for ClientConfig { fn default() -> Self { Self { gateway_url: default_gateway_url(), } } } #[derive(Debug, Clone)] pub struct LLMProviderConfig { pub provider_type: String, pub name: String, pub base_url: String, pub api_key: String, pub extra_headers: HashMap, pub model_id: String, pub temperature: Option, pub max_tokens: Option, pub model_extra: HashMap, } fn get_default_config_path() -> PathBuf { let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); home.join(".config").join("picobot").join("config.json") } impl Config { pub fn load(path: &str) -> Result> { Self::load_from(Path::new(path)) } pub fn load_default() -> Result> { let path = get_default_config_path(); Self::load_from(&path) } fn load_from(path: &Path) -> Result> { load_env_file()?; let content = if path.exists() { println!("Config loaded from: {}", path.display()); fs::read_to_string(path)? } else { // Fallback to current directory let fallback = Path::new("config.json"); if fallback.exists() { println!("Config loaded from: {}", fallback.display()); fs::read_to_string(fallback)? } else { return Err(Box::new(ConfigError::ConfigNotFound( path.to_string_lossy().to_string(), ))); } }; let content = resolve_env_placeholders(&content); let config: Config = serde_json::from_str(&content)?; Ok(config) } pub fn get_provider_config(&self, agent_name: &str) -> Result { let agent = self.agents.get(agent_name) .ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?; let provider = self.providers.get(&agent.provider) .ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?; let model = self.models.get(&agent.model) .ok_or(ConfigError::ModelNotFound(agent.model.clone()))?; Ok(LLMProviderConfig { provider_type: provider.provider_type.clone(), name: agent.provider.clone(), base_url: provider.base_url.clone(), api_key: provider.api_key.clone(), extra_headers: provider.extra_headers.clone(), model_id: model.model_id.clone(), temperature: model.temperature, max_tokens: model.max_tokens, model_extra: model.extra.clone(), }) } } #[derive(Debug)] pub enum ConfigError { ConfigNotFound(String), AgentNotFound(String), ProviderNotFound(String), ModelNotFound(String), } impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), } } } impl std::error::Error for ConfigError {} fn load_env_file() -> Result<(), Box> { let env_path = Path::new(".env"); if env_path.exists() { let content = fs::read_to_string(env_path)?; for line in content.lines() { let line = line.trim(); if line.is_empty() || line.starts_with('#') { continue; } if let Some((key, value)) = line.split_once('=') { let key = key.trim(); let value = value.trim().trim_matches('"').trim_matches('\''); if !value.is_empty() { // SAFETY: Setting environment variables for the current process // is safe as we're only modifying our own process state unsafe { env::set_var(key, value) }; } } } } Ok(()) } fn resolve_env_placeholders(content: &str) -> String { let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex"); re.replace_all(content, |caps: ®ex::Captures| { let var_name = &caps[1]; env::var(var_name).unwrap_or_else(|_| caps[0].to_string()) }).to_string() } #[cfg(test)] mod tests { use super::*; #[test] fn test_config_load() { let config = Config::load("config.json").unwrap(); // Check providers assert!(config.providers.contains_key("volcengine")); assert!(config.providers.contains_key("aliyun")); // Check models assert!(config.models.contains_key("doubao-seed-2-0-lite-260215")); assert!(config.models.contains_key("qwen-plus")); // Check agents assert!(config.agents.contains_key("default")); } #[test] fn test_get_provider_config() { let config = Config::load("config.json").unwrap(); let provider_config = config.get_provider_config("default").unwrap(); assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.model_id, "qwen-plus"); assert_eq!(provider_config.temperature, Some(0.0)); } #[test] fn test_default_gateway_config() { let config = Config::load("config.json").unwrap(); assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.port, 19876); } }