PicoBot/src/gateway/provider_config_service.rs

106 lines
3.5 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
#[derive(Clone)]
pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
maintenance_config: MemoryMaintenanceConfig,
}
impl ProviderConfigService {
pub(crate) fn new(
default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
maintenance_config: MemoryMaintenanceConfig,
) -> Self {
Self {
default_provider_config,
provider_configs: Arc::new(provider_configs),
maintenance_config,
}
}
pub(crate) fn select(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(self.default_provider_config.clone()),
Some(agent_name) => self
.provider_configs
.get(agent_name)
.cloned()
.ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone()
}
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
self.maintenance_config.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 100_000,
context_tool_result_trim_chars: 20_000,
max_images_in_context: 1,
max_image_age_rounds: 10,
}
}
#[test]
fn test_select_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(
default_provider,
HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]),
MemoryMaintenanceConfig::default(),
);
let selected = service.select(Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(
default_provider,
HashMap::new(),
MemoryMaintenanceConfig::default(),
);
let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
}