use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { pub providers: HashMap, pub models: HashMap, pub agents: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ProviderConfig { #[serde(rename = "type")] pub provider_type: String, pub base_url: String, pub api_key: String, #[serde(default)] pub extra_headers: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ModelConfig { pub model_id: String, #[serde(default)] pub temperature: Option, #[serde(default)] pub max_tokens: Option, #[serde(flatten)] pub extra: HashMap, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AgentConfig { pub provider: String, pub model: String, } #[derive(Debug, Clone)] pub struct LLMProviderConfig { pub provider_type: String, pub name: String, pub base_url: String, pub api_key: String, pub extra_headers: HashMap, pub model_id: String, pub temperature: Option, pub max_tokens: Option, pub model_extra: HashMap, } impl Config { pub fn load(path: &str) -> Result> { let content = fs::read_to_string(path)?; let config: Config = serde_json::from_str(&content)?; Ok(config) } pub fn get_provider_config(&self, agent_name: &str) -> Result { let agent = self.agents.get(agent_name) .ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?; let provider = self.providers.get(&agent.provider) .ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?; let model = self.models.get(&agent.model) .ok_or(ConfigError::ModelNotFound(agent.model.clone()))?; Ok(LLMProviderConfig { provider_type: provider.provider_type.clone(), name: agent.provider.clone(), base_url: provider.base_url.clone(), api_key: provider.api_key.clone(), extra_headers: provider.extra_headers.clone(), model_id: model.model_id.clone(), temperature: model.temperature, max_tokens: model.max_tokens, model_extra: model.extra.clone(), }) } } #[derive(Debug)] pub enum ConfigError { AgentNotFound(String), ProviderNotFound(String), ModelNotFound(String), } impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), } } } impl std::error::Error for ConfigError {} #[cfg(test)] mod tests { use super::*; #[test] fn test_config_load() { let config = Config::load("config.json").unwrap(); // Check providers assert!(config.providers.contains_key("volcengine")); assert!(config.providers.contains_key("aliyun")); // Check models assert!(config.models.contains_key("doubao-seed-2-0-lite-260215")); assert!(config.models.contains_key("qwen-plus")); // Check agents assert!(config.agents.contains_key("default")); } #[test] fn test_get_provider_config() { let config = Config::load("config.json").unwrap(); let provider_config = config.get_provider_config("default").unwrap(); assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.model_id, "qwen-plus"); assert_eq!(provider_config.temperature, Some(0.0)); } }