271 lines
8.0 KiB
Rust
271 lines
8.0 KiB
Rust
use regex::Regex;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::env;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct Config {
|
|
pub providers: HashMap<String, ProviderConfig>,
|
|
pub models: HashMap<String, ModelConfig>,
|
|
pub agents: HashMap<String, AgentConfig>,
|
|
#[serde(default)]
|
|
pub gateway: GatewayConfig,
|
|
#[serde(default)]
|
|
pub client: ClientConfig,
|
|
#[serde(default)]
|
|
pub channels: HashMap<String, FeishuChannelConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct FeishuChannelConfig {
|
|
#[serde(default)]
|
|
pub enabled: bool,
|
|
pub app_id: String,
|
|
pub app_secret: String,
|
|
#[serde(default = "default_allow_from")]
|
|
pub allow_from: Vec<String>,
|
|
#[serde(default)]
|
|
pub agent: String,
|
|
}
|
|
|
|
fn default_allow_from() -> Vec<String> {
|
|
vec!["*".to_string()]
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ProviderConfig {
|
|
#[serde(rename = "type")]
|
|
pub provider_type: String,
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
#[serde(default)]
|
|
pub extra_headers: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ModelConfig {
|
|
pub model_id: String,
|
|
#[serde(default)]
|
|
pub temperature: Option<f32>,
|
|
#[serde(default)]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(flatten)]
|
|
pub extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct AgentConfig {
|
|
pub provider: String,
|
|
pub model: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct GatewayConfig {
|
|
#[serde(default = "default_gateway_host")]
|
|
pub host: String,
|
|
#[serde(default = "default_gateway_port")]
|
|
pub port: u16,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ClientConfig {
|
|
#[serde(default = "default_gateway_url")]
|
|
pub gateway_url: String,
|
|
}
|
|
|
|
fn default_gateway_host() -> String {
|
|
"127.0.0.1".to_string()
|
|
}
|
|
|
|
fn default_gateway_port() -> u16 {
|
|
19876
|
|
}
|
|
|
|
fn default_gateway_url() -> String {
|
|
"ws://127.0.0.1:19876/ws".to_string()
|
|
}
|
|
|
|
impl Default for GatewayConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
host: default_gateway_host(),
|
|
port: default_gateway_port(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ClientConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
gateway_url: default_gateway_url(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct LLMProviderConfig {
|
|
pub provider_type: String,
|
|
pub name: String,
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
pub extra_headers: HashMap<String, String>,
|
|
pub model_id: String,
|
|
pub temperature: Option<f32>,
|
|
pub max_tokens: Option<u32>,
|
|
pub model_extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
fn get_default_config_path() -> PathBuf {
|
|
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
home.join(".config").join("picobot").join("config.json")
|
|
}
|
|
|
|
impl Config {
|
|
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
|
Self::load_from(Path::new(path))
|
|
}
|
|
|
|
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
|
|
let path = get_default_config_path();
|
|
Self::load_from(&path)
|
|
}
|
|
|
|
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
|
load_env_file()?;
|
|
let content = if path.exists() {
|
|
println!("Config loaded from: {}", path.display());
|
|
fs::read_to_string(path)?
|
|
} else {
|
|
// Fallback to current directory
|
|
let fallback = Path::new("config.json");
|
|
if fallback.exists() {
|
|
println!("Config loaded from: {}", fallback.display());
|
|
fs::read_to_string(fallback)?
|
|
} else {
|
|
return Err(Box::new(ConfigError::ConfigNotFound(
|
|
path.to_string_lossy().to_string(),
|
|
)));
|
|
}
|
|
};
|
|
let content = resolve_env_placeholders(&content);
|
|
let config: Config = serde_json::from_str(&content)?;
|
|
Ok(config)
|
|
}
|
|
|
|
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
|
let agent = self.agents.get(agent_name)
|
|
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
|
|
|
let provider = self.providers.get(&agent.provider)
|
|
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
|
|
|
let model = self.models.get(&agent.model)
|
|
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
|
|
|
Ok(LLMProviderConfig {
|
|
provider_type: provider.provider_type.clone(),
|
|
name: agent.provider.clone(),
|
|
base_url: provider.base_url.clone(),
|
|
api_key: provider.api_key.clone(),
|
|
extra_headers: provider.extra_headers.clone(),
|
|
model_id: model.model_id.clone(),
|
|
temperature: model.temperature,
|
|
max_tokens: model.max_tokens,
|
|
model_extra: model.extra.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum ConfigError {
|
|
ConfigNotFound(String),
|
|
AgentNotFound(String),
|
|
ProviderNotFound(String),
|
|
ModelNotFound(String),
|
|
}
|
|
|
|
impl std::fmt::Display for ConfigError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path),
|
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ConfigError {}
|
|
|
|
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
|
|
let env_path = Path::new(".env");
|
|
if env_path.exists() {
|
|
let content = fs::read_to_string(env_path)?;
|
|
for line in content.lines() {
|
|
let line = line.trim();
|
|
if line.is_empty() || line.starts_with('#') {
|
|
continue;
|
|
}
|
|
if let Some((key, value)) = line.split_once('=') {
|
|
let key = key.trim();
|
|
let value = value.trim().trim_matches('"').trim_matches('\'');
|
|
if !value.is_empty() {
|
|
// SAFETY: Setting environment variables for the current process
|
|
// is safe as we're only modifying our own process state
|
|
unsafe { env::set_var(key, value) };
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn resolve_env_placeholders(content: &str) -> String {
|
|
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
|
|
re.replace_all(content, |caps: ®ex::Captures| {
|
|
let var_name = &caps[1];
|
|
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
|
}).to_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_config_load() {
|
|
let config = Config::load("config.json").unwrap();
|
|
|
|
// Check providers
|
|
assert!(config.providers.contains_key("volcengine"));
|
|
assert!(config.providers.contains_key("aliyun"));
|
|
|
|
// Check models
|
|
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
|
|
assert!(config.models.contains_key("qwen-plus"));
|
|
|
|
// Check agents
|
|
assert!(config.agents.contains_key("default"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_provider_config() {
|
|
let config = Config::load("config.json").unwrap();
|
|
let provider_config = config.get_provider_config("default").unwrap();
|
|
|
|
assert_eq!(provider_config.provider_type, "openai");
|
|
assert_eq!(provider_config.name, "aliyun");
|
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
|
assert_eq!(provider_config.temperature, Some(0.0));
|
|
}
|
|
|
|
#[test]
|
|
fn test_default_gateway_config() {
|
|
let config = Config::load("config.json").unwrap();
|
|
assert_eq!(config.gateway.host, "0.0.0.0");
|
|
assert_eq!(config.gateway.port, 19876);
|
|
}
|
|
}
|