PicoBot/tests/test_integration.rs
2026-05-01 21:22:07 +08:00

113 lines
3.7 KiB
Rust

use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
use std::collections::HashMap;
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
ProviderRuntimeConfig {
provider_type: config.provider_type,
name: config.name,
base_url: config.base_url,
api_key: config.api_key,
extra_headers: config.extra_headers,
llm_timeout_secs: config.llm_timeout_secs,
model_id: config.model_id,
temperature: config.temperature,
max_tokens: config.max_tokens,
model_extra: config.model_extra,
}
}
fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
if openai_api_key.contains("your_") {
return None;
}
Some(LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test_openai".to_string(),
base_url: openai_base_url,
api_key: openai_api_key,
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 20,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
})
}
fn create_request(content: &str) -> ChatCompletionRequest {
ChatCompletionRequest {
messages: vec![Message::user(content)],
temperature: Some(0.0),
max_tokens: Some(100),
tools: None,
}
}
#[tokio::test]
#[ignore]
async fn test_openai_simple_completion() {
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
assert!(!response.id.is_empty());
assert!(!response.content.is_empty());
assert!(response.usage.total_tokens > 0);
assert!(response.content.to_lowercase().contains("ok"));
}
#[tokio::test]
#[ignore]
async fn test_openai_conversation() {
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![
Message::user("My name is Alice"),
Message::assistant("Hello Alice!"),
Message::user("What is my name?"),
],
temperature: Some(0.0),
max_tokens: Some(50),
tools: None,
};
let response = provider.chat(request).await.unwrap();
assert!(response.content.to_lowercase().contains("alice"));
}
#[tokio::test]
#[ignore]
async fn test_config_load() {
// Test that config.json can be loaded and provider config created
let config = Config::load("config.json").expect("Failed to load config.json");
let provider_config = config
.get_provider_config("default")
.expect("Failed to get provider config");
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");
assert_eq!(provider_config.model_id, "qwen-plus");
let provider =
create_provider(to_runtime_config(provider_config)).expect("Failed to create provider");
assert_eq!(provider.ptype(), "openai");
assert_eq!(provider.name(), "aliyun");
assert_eq!(provider.model_id(), "qwen-plus");
}