444 lines
13 KiB
Rust
444 lines
13 KiB
Rust
use regex::Regex;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::env;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
/// Get the user configuration directory (~/.picobot)
|
|
pub fn get_user_config_dir() -> PathBuf {
|
|
dirs::home_dir()
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join(".picobot")
|
|
}
|
|
|
|
/// Get the default workspace directory (~/.picobot/workspace)
|
|
pub fn get_default_workspace_dir() -> PathBuf {
|
|
get_user_config_dir().join("workspace")
|
|
}
|
|
|
|
/// Expand ~ in path to user home directory
|
|
pub fn expand_path(path: &str) -> PathBuf {
|
|
if path.starts_with("~/") {
|
|
dirs::home_dir()
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join(&path[2..])
|
|
} else {
|
|
PathBuf::from(path)
|
|
}
|
|
}
|
|
|
|
/// Ensure workspace directory exists, create if needed
|
|
pub fn ensure_workspace_dir(path: &Path) -> Result<PathBuf, std::io::Error> {
|
|
if !path.exists() {
|
|
tracing::info!("Creating workspace directory: {}", path.display());
|
|
fs::create_dir_all(path)?;
|
|
}
|
|
// Return canonical path
|
|
path.canonicalize().or_else(|_| Ok(path.to_path_buf()))
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct Config {
|
|
pub providers: HashMap<String, ProviderConfig>,
|
|
pub models: HashMap<String, ModelConfig>,
|
|
pub agents: HashMap<String, AgentConfig>,
|
|
#[serde(default)]
|
|
pub gateway: GatewayConfig,
|
|
#[serde(default)]
|
|
pub client: ClientConfig,
|
|
#[serde(default)]
|
|
pub channels: HashMap<String, FeishuChannelConfig>,
|
|
#[serde(default = "default_workspace_dir")]
|
|
pub workspace_dir: String,
|
|
}
|
|
|
|
fn default_workspace_dir() -> String {
|
|
get_default_workspace_dir().to_string_lossy().to_string()
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct FeishuChannelConfig {
|
|
#[serde(default)]
|
|
pub enabled: bool,
|
|
pub app_id: String,
|
|
pub app_secret: String,
|
|
#[serde(default = "default_allow_from")]
|
|
pub allow_from: Vec<String>,
|
|
#[serde(default)]
|
|
pub agent: String,
|
|
#[serde(default = "default_media_dir")]
|
|
pub media_dir: String,
|
|
/// Emoji type for message reactions (e.g. "THUMBSUP", "OK", "EYES").
|
|
#[serde(default = "default_reaction_emoji")]
|
|
pub reaction_emoji: String,
|
|
}
|
|
|
|
fn default_allow_from() -> Vec<String> {
|
|
vec!["*".to_string()]
|
|
}
|
|
|
|
fn default_media_dir() -> String {
|
|
get_user_config_dir()
|
|
.join("media/feishu")
|
|
.to_string_lossy()
|
|
.to_string()
|
|
}
|
|
|
|
fn default_reaction_emoji() -> String {
|
|
"Typing".to_string()
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ProviderConfig {
|
|
#[serde(rename = "type")]
|
|
pub provider_type: String,
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
#[serde(default)]
|
|
pub extra_headers: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ModelConfig {
|
|
pub model_id: String,
|
|
#[serde(default)]
|
|
pub temperature: Option<f32>,
|
|
#[serde(default)]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(flatten)]
|
|
pub extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct AgentConfig {
|
|
pub provider: String,
|
|
pub model: String,
|
|
#[serde(default = "default_max_tool_iterations")]
|
|
pub max_tool_iterations: usize,
|
|
#[serde(default = "default_token_limit")]
|
|
pub token_limit: usize,
|
|
}
|
|
|
|
fn default_max_tool_iterations() -> usize {
|
|
20
|
|
}
|
|
|
|
fn default_token_limit() -> usize {
|
|
128_000
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct GatewayConfig {
|
|
#[serde(default = "default_gateway_host")]
|
|
pub host: String,
|
|
#[serde(default = "default_gateway_port")]
|
|
pub port: u16,
|
|
#[serde(default, rename = "session_ttl_hours")]
|
|
pub session_ttl_hours: Option<u64>,
|
|
#[serde(default, rename = "cleanup_interval_minutes")]
|
|
pub cleanup_interval_minutes: Option<u64>,
|
|
#[serde(default, rename = "session_db_path")]
|
|
pub session_db_path: Option<String>,
|
|
#[serde(default)]
|
|
pub scheduler: Option<SchedulerConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SchedulerConfig {
|
|
/// Whether the scheduler is enabled
|
|
#[serde(default = "default_scheduler_enabled")]
|
|
pub enabled: bool,
|
|
/// Poll interval in seconds (how often to check for due jobs)
|
|
#[serde(default = "default_poll_interval_secs")]
|
|
pub poll_interval_secs: u64,
|
|
/// Maximum concurrent job executions (currently sequential, reserved for future)
|
|
#[serde(default = "default_max_concurrent")]
|
|
pub max_concurrent: usize,
|
|
}
|
|
|
|
fn default_scheduler_enabled() -> bool {
|
|
true
|
|
}
|
|
|
|
fn default_poll_interval_secs() -> u64 {
|
|
60
|
|
}
|
|
|
|
fn default_max_concurrent() -> usize {
|
|
1
|
|
}
|
|
|
|
impl Default for SchedulerConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
enabled: true,
|
|
poll_interval_secs: 60,
|
|
max_concurrent: 1,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ClientConfig {
|
|
#[serde(default = "default_gateway_url")]
|
|
pub gateway_url: String,
|
|
}
|
|
|
|
fn default_gateway_host() -> String {
|
|
"127.0.0.1".to_string()
|
|
}
|
|
|
|
fn default_gateway_port() -> u16 {
|
|
19876
|
|
}
|
|
|
|
fn default_gateway_url() -> String {
|
|
"ws://127.0.0.1:19876/ws".to_string()
|
|
}
|
|
|
|
impl Default for GatewayConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
host: default_gateway_host(),
|
|
port: default_gateway_port(),
|
|
session_ttl_hours: None,
|
|
cleanup_interval_minutes: None,
|
|
session_db_path: None,
|
|
scheduler: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ClientConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
gateway_url: default_gateway_url(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct LLMProviderConfig {
|
|
pub provider_type: String,
|
|
pub name: String,
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
pub extra_headers: HashMap<String, String>,
|
|
pub model_id: String,
|
|
pub temperature: Option<f32>,
|
|
pub max_tokens: Option<u32>,
|
|
pub model_extra: HashMap<String, serde_json::Value>,
|
|
pub max_tool_iterations: usize,
|
|
pub token_limit: usize,
|
|
pub workspace_dir: PathBuf,
|
|
}
|
|
|
|
fn get_default_config_path() -> PathBuf {
|
|
get_user_config_dir().join("config.json")
|
|
}
|
|
|
|
impl Config {
|
|
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
|
Self::load_from(Path::new(path))
|
|
}
|
|
|
|
pub fn load_default() -> Result<Self, Box<dyn std::error::Error>> {
|
|
let path = get_default_config_path();
|
|
Self::load_from(&path)
|
|
}
|
|
|
|
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
|
load_env_file()?;
|
|
let content = if path.exists() {
|
|
tracing::info!(path = %path.display(), "Config loaded");
|
|
fs::read_to_string(path)?
|
|
} else {
|
|
// Fallback to current directory
|
|
let fallback = Path::new("config.json");
|
|
if fallback.exists() {
|
|
tracing::info!(path = %fallback.display(), "Config loaded from fallback path");
|
|
fs::read_to_string(fallback)?
|
|
} else {
|
|
return Err(Box::new(ConfigError::ConfigNotFound(
|
|
path.to_string_lossy().to_string(),
|
|
)));
|
|
}
|
|
};
|
|
let content = resolve_env_placeholders(&content);
|
|
let config: Config = serde_json::from_str(&content)?;
|
|
Ok(config)
|
|
}
|
|
|
|
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
|
let agent = self
|
|
.agents
|
|
.get(agent_name)
|
|
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
|
|
|
let provider = self
|
|
.providers
|
|
.get(&agent.provider)
|
|
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
|
|
|
let model = self
|
|
.models
|
|
.get(&agent.model)
|
|
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
|
|
|
Ok(LLMProviderConfig {
|
|
provider_type: provider.provider_type.clone(),
|
|
name: agent.provider.clone(),
|
|
base_url: provider.base_url.clone(),
|
|
api_key: provider.api_key.clone(),
|
|
extra_headers: provider.extra_headers.clone(),
|
|
model_id: model.model_id.clone(),
|
|
temperature: model.temperature,
|
|
max_tokens: model.max_tokens,
|
|
model_extra: model.extra.clone(),
|
|
max_tool_iterations: agent.max_tool_iterations,
|
|
token_limit: agent.token_limit,
|
|
workspace_dir: expand_path(&self.workspace_dir),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum ConfigError {
|
|
ConfigNotFound(String),
|
|
AgentNotFound(String),
|
|
ProviderNotFound(String),
|
|
ModelNotFound(String),
|
|
}
|
|
|
|
impl std::fmt::Display for ConfigError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ConfigError {}
|
|
|
|
fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
|
|
let env_path = Path::new(".env");
|
|
if env_path.exists() {
|
|
let content = fs::read_to_string(env_path)?;
|
|
for line in content.lines() {
|
|
let line = line.trim();
|
|
if line.is_empty() || line.starts_with('#') {
|
|
continue;
|
|
}
|
|
if let Some((key, value)) = line.split_once('=') {
|
|
let key = key.trim();
|
|
let value = value.trim().trim_matches('"').trim_matches('\'');
|
|
if !value.is_empty() {
|
|
// SAFETY: Setting environment variables for the current process
|
|
// is safe as we're only modifying our own process state
|
|
unsafe { env::set_var(key, value) };
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn resolve_env_placeholders(content: &str) -> String {
|
|
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
|
|
re.replace_all(content, |caps: ®ex::Captures| {
|
|
let var_name = &caps[1];
|
|
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
|
})
|
|
.to_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn write_test_config() -> tempfile::NamedTempFile {
|
|
let file = tempfile::NamedTempFile::new().unwrap();
|
|
std::fs::write(
|
|
file.path(),
|
|
r#"{
|
|
"providers": {
|
|
"aliyun": {
|
|
"type": "openai",
|
|
"base_url": "https://example.invalid/v1",
|
|
"api_key": "test-key",
|
|
"extra_headers": {}
|
|
},
|
|
"volcengine": {
|
|
"type": "openai",
|
|
"base_url": "https://example.invalid/volc",
|
|
"api_key": "test-key-2",
|
|
"extra_headers": {}
|
|
}
|
|
},
|
|
"models": {
|
|
"qwen-plus": {
|
|
"model_id": "qwen-plus",
|
|
"temperature": 0.0
|
|
},
|
|
"doubao-seed-2-0-lite-260215": {
|
|
"model_id": "doubao-seed-2-0-lite-260215"
|
|
}
|
|
},
|
|
"agents": {
|
|
"default": {
|
|
"provider": "aliyun",
|
|
"model": "qwen-plus"
|
|
}
|
|
},
|
|
"gateway": {
|
|
"host": "0.0.0.0",
|
|
"port": 19876
|
|
}
|
|
}"#,
|
|
)
|
|
.unwrap();
|
|
file
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_load() {
|
|
let file = write_test_config();
|
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
|
|
|
// Check providers
|
|
assert!(config.providers.contains_key("volcengine"));
|
|
assert!(config.providers.contains_key("aliyun"));
|
|
|
|
// Check models
|
|
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
|
|
assert!(config.models.contains_key("qwen-plus"));
|
|
|
|
// Check agents
|
|
assert!(config.agents.contains_key("default"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_provider_config() {
|
|
let file = write_test_config();
|
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
|
let provider_config = config.get_provider_config("default").unwrap();
|
|
|
|
assert_eq!(provider_config.provider_type, "openai");
|
|
assert_eq!(provider_config.name, "aliyun");
|
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
|
assert_eq!(provider_config.temperature, Some(0.0));
|
|
}
|
|
|
|
#[test]
|
|
fn test_default_gateway_config() {
|
|
let file = write_test_config();
|
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
|
assert_eq!(config.gateway.host, "0.0.0.0");
|
|
assert_eq!(config.gateway.port, 19876);
|
|
}
|
|
}
|