95 lines
3.2 KiB
Rust
95 lines
3.2 KiB
Rust
use std::collections::HashMap;
|
|
use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
|
use PicoBot::config::{Config, LLMProviderConfig};
|
|
|
|
fn load_config() -> Option<LLMProviderConfig> {
|
|
dotenv::from_filename("tests/test.env").ok()?;
|
|
|
|
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
|
|
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
|
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
|
|
|
|
if openai_api_key.contains("your_") {
|
|
return None;
|
|
}
|
|
|
|
Some(LLMProviderConfig {
|
|
provider_type: "openai".to_string(),
|
|
name: "test_openai".to_string(),
|
|
base_url: openai_base_url,
|
|
api_key: openai_api_key,
|
|
extra_headers: HashMap::new(),
|
|
model_id: openai_model,
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(100),
|
|
model_extra: HashMap::new(),
|
|
})
|
|
}
|
|
|
|
fn create_request(content: &str) -> ChatCompletionRequest {
|
|
ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "user".to_string(),
|
|
content: content.to_string(),
|
|
}],
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(100),
|
|
tools: None,
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_openai_simple_completion() {
|
|
let config = load_config()
|
|
.expect("Please configure tests/test.env with valid API keys");
|
|
|
|
let provider = create_provider(config).expect("Failed to create provider");
|
|
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
|
|
|
assert!(!response.id.is_empty());
|
|
assert!(!response.content.is_empty());
|
|
assert!(response.usage.total_tokens > 0);
|
|
assert!(response.content.to_lowercase().contains("ok"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_openai_conversation() {
|
|
let config = load_config()
|
|
.expect("Please configure tests/test.env with valid API keys");
|
|
|
|
let provider = create_provider(config).expect("Failed to create provider");
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![
|
|
Message { role: "user".to_string(), content: "My name is Alice".to_string() },
|
|
Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() },
|
|
Message { role: "user".to_string(), content: "What is my name?".to_string() },
|
|
],
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(50),
|
|
tools: None,
|
|
};
|
|
|
|
let response = provider.chat(request).await.unwrap();
|
|
assert!(response.content.to_lowercase().contains("alice"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_config_load() {
|
|
// Test that config.json can be loaded and provider config created
|
|
let config = Config::load("config.json").expect("Failed to load config.json");
|
|
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
|
|
|
assert_eq!(provider_config.provider_type, "openai");
|
|
assert_eq!(provider_config.name, "aliyun");
|
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
|
|
|
let provider = create_provider(provider_config).expect("Failed to create provider");
|
|
assert_eq!(provider.ptype(), "openai");
|
|
assert_eq!(provider.name(), "aliyun");
|
|
assert_eq!(provider.model_id(), "qwen-plus");
|
|
}
|