133 lines
4.0 KiB
Rust
133 lines
4.0 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::fs;
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct Config {
|
|
pub providers: HashMap<String, ProviderConfig>,
|
|
pub models: HashMap<String, ModelConfig>,
|
|
pub agents: HashMap<String, AgentConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ProviderConfig {
|
|
#[serde(rename = "type")]
|
|
pub provider_type: String,
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
#[serde(default)]
|
|
pub extra_headers: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ModelConfig {
|
|
pub model_id: String,
|
|
#[serde(default)]
|
|
pub temperature: Option<f32>,
|
|
#[serde(default)]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(flatten)]
|
|
pub extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct AgentConfig {
|
|
pub provider: String,
|
|
pub model: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct LLMProviderConfig {
|
|
pub provider_type: String,
|
|
pub name: String,
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
pub extra_headers: HashMap<String, String>,
|
|
pub model_id: String,
|
|
pub temperature: Option<f32>,
|
|
pub max_tokens: Option<u32>,
|
|
pub model_extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
impl Config {
|
|
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
|
let content = fs::read_to_string(path)?;
|
|
let config: Config = serde_json::from_str(&content)?;
|
|
Ok(config)
|
|
}
|
|
|
|
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
|
let agent = self.agents.get(agent_name)
|
|
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
|
|
|
let provider = self.providers.get(&agent.provider)
|
|
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
|
|
|
let model = self.models.get(&agent.model)
|
|
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
|
|
|
Ok(LLMProviderConfig {
|
|
provider_type: provider.provider_type.clone(),
|
|
name: agent.provider.clone(),
|
|
base_url: provider.base_url.clone(),
|
|
api_key: provider.api_key.clone(),
|
|
extra_headers: provider.extra_headers.clone(),
|
|
model_id: model.model_id.clone(),
|
|
temperature: model.temperature,
|
|
max_tokens: model.max_tokens,
|
|
model_extra: model.extra.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum ConfigError {
|
|
AgentNotFound(String),
|
|
ProviderNotFound(String),
|
|
ModelNotFound(String),
|
|
}
|
|
|
|
impl std::fmt::Display for ConfigError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ConfigError {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_config_load() {
|
|
let config = Config::load("config.json").unwrap();
|
|
|
|
// Check providers
|
|
assert!(config.providers.contains_key("volcengine"));
|
|
assert!(config.providers.contains_key("aliyun"));
|
|
|
|
// Check models
|
|
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
|
|
assert!(config.models.contains_key("qwen-plus"));
|
|
|
|
// Check agents
|
|
assert!(config.agents.contains_key("default"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_provider_config() {
|
|
let config = Config::load("config.json").unwrap();
|
|
let provider_config = config.get_provider_config("default").unwrap();
|
|
|
|
assert_eq!(provider_config.provider_type, "openai");
|
|
assert_eq!(provider_config.name, "aliyun");
|
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
|
assert_eq!(provider_config.temperature, Some(0.0));
|
|
}
|
|
}
|