106 lines
3.5 KiB
Rust
106 lines
3.5 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
use crate::agent::AgentError;
|
|
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
|
|
|
#[derive(Clone)]
|
|
pub(crate) struct ProviderConfigService {
|
|
default_provider_config: LLMProviderConfig,
|
|
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
|
maintenance_config: MemoryMaintenanceConfig,
|
|
}
|
|
|
|
impl ProviderConfigService {
|
|
pub(crate) fn new(
|
|
default_provider_config: LLMProviderConfig,
|
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
|
maintenance_config: MemoryMaintenanceConfig,
|
|
) -> Self {
|
|
Self {
|
|
default_provider_config,
|
|
provider_configs: Arc::new(provider_configs),
|
|
maintenance_config,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn select(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
|
|
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
|
None | Some("default") => Ok(self.default_provider_config.clone()),
|
|
Some(agent_name) => self
|
|
.provider_configs
|
|
.get(agent_name)
|
|
.cloned()
|
|
.ok_or_else(|| {
|
|
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
|
}),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
|
self.default_provider_config.clone()
|
|
}
|
|
|
|
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
|
|
self.maintenance_config.clone()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
|
LLMProviderConfig {
|
|
provider_type: "openai".to_string(),
|
|
name: name.to_string(),
|
|
base_url: "http://localhost".to_string(),
|
|
api_key: "test-key".to_string(),
|
|
extra_headers: HashMap::new(),
|
|
llm_timeout_secs: 120,
|
|
memory_maintenance_timeout_secs: 600,
|
|
model_id: model_id.to_string(),
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(32),
|
|
context_window_tokens: None,
|
|
model_extra: HashMap::new(),
|
|
max_tool_iterations: 1,
|
|
tool_result_max_chars: 100_000,
|
|
context_tool_result_trim_chars: 20_000,
|
|
max_images_in_context: 1,
|
|
max_image_age_rounds: 10,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_select_uses_named_agent_override() {
|
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
|
let service = ProviderConfigService::new(
|
|
default_provider,
|
|
HashMap::from([(
|
|
"planner".to_string(),
|
|
test_provider_config_named("planner-provider", "planner-model"),
|
|
)]),
|
|
MemoryMaintenanceConfig::default(),
|
|
);
|
|
|
|
let selected = service.select(Some("planner")).unwrap();
|
|
assert_eq!(selected.name, "planner-provider");
|
|
assert_eq!(selected.model_id, "planner-model");
|
|
}
|
|
|
|
#[test]
|
|
fn test_select_falls_back_to_default() {
|
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
|
let service = ProviderConfigService::new(
|
|
default_provider,
|
|
HashMap::new(),
|
|
MemoryMaintenanceConfig::default(),
|
|
);
|
|
|
|
let selected = service.select(Some("default")).unwrap();
|
|
assert_eq!(selected.name, "default-provider");
|
|
assert_eq!(selected.model_id, "default-model");
|
|
}
|
|
}
|