PicoBot/src/config/mod.rs

133 lines
4.0 KiB
Rust

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub providers: HashMap<String, ProviderConfig>,
pub models: HashMap<String, ModelConfig>,
pub agents: HashMap<String, AgentConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderConfig {
#[serde(rename = "type")]
pub provider_type: String,
pub base_url: String,
pub api_key: String,
#[serde(default)]
pub extra_headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelConfig {
pub model_id: String,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentConfig {
pub provider: String,
pub model: String,
}
#[derive(Debug, Clone)]
pub struct LLMProviderConfig {
pub provider_type: String,
pub name: String,
pub base_url: String,
pub api_key: String,
pub extra_headers: HashMap<String, String>,
pub model_id: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
}
impl Config {
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = fs::read_to_string(path)?;
let config: Config = serde_json::from_str(&content)?;
Ok(config)
}
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
let agent = self.agents.get(agent_name)
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
let provider = self.providers.get(&agent.provider)
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
let model = self.models.get(&agent.model)
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
Ok(LLMProviderConfig {
provider_type: provider.provider_type.clone(),
name: agent.provider.clone(),
base_url: provider.base_url.clone(),
api_key: provider.api_key.clone(),
extra_headers: provider.extra_headers.clone(),
model_id: model.model_id.clone(),
temperature: model.temperature,
max_tokens: model.max_tokens,
model_extra: model.extra.clone(),
})
}
}
#[derive(Debug)]
pub enum ConfigError {
AgentNotFound(String),
ProviderNotFound(String),
ModelNotFound(String),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
}
}
}
impl std::error::Error for ConfigError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_load() {
let config = Config::load("config.json").unwrap();
// Check providers
assert!(config.providers.contains_key("volcengine"));
assert!(config.providers.contains_key("aliyun"));
// Check models
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
assert!(config.models.contains_key("qwen-plus"));
// Check agents
assert!(config.agents.contains_key("default"));
}
#[test]
fn test_get_provider_config() {
let config = Config::load("config.json").unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");
assert_eq!(provider_config.model_id, "qwen-plus");
assert_eq!(provider_config.temperature, Some(0.0));
}
}