PicoBot/src/config/mod.rs

444 lines
13 KiB
Rust

use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
/// Get the user configuration directory (~/.picobot)
pub fn get_user_config_dir() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".picobot")
}
/// Get the default workspace directory (~/.picobot/workspace)
pub fn get_default_workspace_dir() -> PathBuf {
get_user_config_dir().join("workspace")
}
/// Expand ~ in path to user home directory
pub fn expand_path(path: &str) -> PathBuf {
if path.starts_with("~/") {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(&path[2..])
} else {
PathBuf::from(path)
}
}
/// Ensure workspace directory exists, create if needed
pub fn ensure_workspace_dir(path: &Path) -> Result<PathBuf, std::io::Error> {
if !path.exists() {
tracing::info!("Creating workspace directory: {}", path.display());
fs::create_dir_all(path)?;
}
// Return canonical path
path.canonicalize().or_else(|_| Ok(path.to_path_buf()))
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub providers: HashMap<String, ProviderConfig>,
pub models: HashMap<String, ModelConfig>,
pub agents: HashMap<String, AgentConfig>,
#[serde(default)]
pub gateway: GatewayConfig,
#[serde(default)]
pub client: ClientConfig,
#[serde(default)]
pub channels: HashMap<String, FeishuChannelConfig>,
#[serde(default = "default_workspace_dir")]
pub workspace_dir: String,
}
fn default_workspace_dir() -> String {
get_default_workspace_dir().to_string_lossy().to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FeishuChannelConfig {
#[serde(default)]
pub enabled: bool,
pub app_id: String,
pub app_secret: String,
#[serde(default = "default_allow_from")]
pub allow_from: Vec<String>,
#[serde(default)]
pub agent: String,
#[serde(default = "default_media_dir")]
pub media_dir: String,
/// Emoji type for message reactions (e.g. "THUMBSUP", "OK", "EYES").
#[serde(default = "default_reaction_emoji")]
pub reaction_emoji: String,
}
fn default_allow_from() -> Vec<String> {
vec!["*".to_string()]
}
fn default_media_dir() -> String {
get_user_config_dir()
.join("media/feishu")
.to_string_lossy()
.to_string()
}
fn default_reaction_emoji() -> String {
"Typing".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderConfig {
#[serde(rename = "type")]
pub provider_type: String,
pub base_url: String,
pub api_key: String,
#[serde(default)]
pub extra_headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelConfig {
pub model_id: String,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentConfig {
pub provider: String,
pub model: String,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
#[serde(default = "default_token_limit")]
pub token_limit: usize,
}
fn default_max_tool_iterations() -> usize {
20
}
fn default_token_limit() -> usize {
128_000
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig {
#[serde(default = "default_gateway_host")]
pub host: String,
#[serde(default = "default_gateway_port")]
pub port: u16,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
#[serde(default, rename = "cleanup_interval_minutes")]
pub cleanup_interval_minutes: Option<u64>,
#[serde(default, rename = "session_db_path")]
pub session_db_path: Option<String>,
#[serde(default)]
pub scheduler: Option<SchedulerConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig {
/// Whether the scheduler is enabled
#[serde(default = "default_scheduler_enabled")]
pub enabled: bool,
/// Poll interval in seconds (how often to check for due jobs)
#[serde(default = "default_poll_interval_secs")]
pub poll_interval_secs: u64,
/// Maximum concurrent job executions (currently sequential, reserved for future)
#[serde(default = "default_max_concurrent")]
pub max_concurrent: usize,
}
fn default_scheduler_enabled() -> bool {
true
}
fn default_poll_interval_secs() -> u64 {
60
}
fn default_max_concurrent() -> usize {
1
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
enabled: true,
poll_interval_secs: 60,
max_concurrent: 1,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ClientConfig {
#[serde(default = "default_gateway_url")]
pub gateway_url: String,
}
fn default_gateway_host() -> String {
"127.0.0.1".to_string()
}
fn default_gateway_port() -> u16 {
19876
}
fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string()
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
scheduler: None,
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
gateway_url: default_gateway_url(),
}
}
}
#[derive(Debug, Clone)]
pub struct LLMProviderConfig {
pub provider_type: String,
pub name: String,
pub base_url: String,
pub api_key: String,
pub extra_headers: HashMap<String, String>,
pub model_id: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
pub max_tool_iterations: usize,
pub token_limit: usize,
pub workspace_dir: PathBuf,
}
fn get_default_config_path() -> PathBuf {
get_user_config_dir().join("config.json")
}
impl Config {
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
Self::load_from(Path::new(path))
}
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
let path = get_default_config_path();
Self::load_from(&path)
}
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
load_env_file()?;
let content = if path.exists() {
tracing::info!(path = %path.display(), "Config loaded");
fs::read_to_string(path)?
} else {
// Fallback to current directory
let fallback = Path::new("config.json");
if fallback.exists() {
tracing::info!(path = %fallback.display(), "Config loaded from fallback path");
fs::read_to_string(fallback)?
} else {
return Err(Box::new(ConfigError::ConfigNotFound(
path.to_string_lossy().to_string(),
)));
}
};
let content = resolve_env_placeholders(&content);
let config: Config = serde_json::from_str(&content)?;
Ok(config)
}
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
let agent = self
.agents
.get(agent_name)
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
let provider = self
.providers
.get(&agent.provider)
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
let model = self
.models
.get(&agent.model)
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
Ok(LLMProviderConfig {
provider_type: provider.provider_type.clone(),
name: agent.provider.clone(),
base_url: provider.base_url.clone(),
api_key: provider.api_key.clone(),
extra_headers: provider.extra_headers.clone(),
model_id: model.model_id.clone(),
temperature: model.temperature,
max_tokens: model.max_tokens,
model_extra: model.extra.clone(),
max_tool_iterations: agent.max_tool_iterations,
token_limit: agent.token_limit,
workspace_dir: expand_path(&self.workspace_dir),
})
}
}
#[derive(Debug)]
pub enum ConfigError {
ConfigNotFound(String),
AgentNotFound(String),
ProviderNotFound(String),
ModelNotFound(String),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
}
}
}
impl std::error::Error for ConfigError {}
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
let env_path = Path::new(".env");
if env_path.exists() {
let content = fs::read_to_string(env_path)?;
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = value.trim().trim_matches('"').trim_matches('\'');
if !value.is_empty() {
// SAFETY: Setting environment variables for the current process
// is safe as we're only modifying our own process state
unsafe { env::set_var(key, value) };
}
}
}
}
Ok(())
}
fn resolve_env_placeholders(content: &str) -> String {
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
re.replace_all(content, |caps: &regex::Captures| {
let var_name = &caps[1];
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
})
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
fn write_test_config() -> tempfile::NamedTempFile {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
},
"volcengine": {
"type": "openai",
"base_url": "https://example.invalid/volc",
"api_key": "test-key-2",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"temperature": 0.0
},
"doubao-seed-2-0-lite-260215": {
"model_id": "doubao-seed-2-0-lite-260215"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
},
"gateway": {
"host": "0.0.0.0",
"port": 19876
}
}"#,
)
.unwrap();
file
}
#[test]
fn test_config_load() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
// Check providers
assert!(config.providers.contains_key("volcengine"));
assert!(config.providers.contains_key("aliyun"));
// Check models
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
assert!(config.models.contains_key("qwen-plus"));
// Check agents
assert!(config.agents.contains_key("default"));
}
#[test]
fn test_get_provider_config() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");
assert_eq!(provider_config.model_id, "qwen-plus");
assert_eq!(provider_config.temperature, Some(0.0));
}
#[test]
fn test_default_gateway_config() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.host, "0.0.0.0");
assert_eq!(config.gateway.port, 19876);
}
}