149 lines
4.6 KiB
Rust
149 lines
4.6 KiB
Rust
use std::collections::HashMap;
|
|
use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
|
use PicoBot::config::LLMProviderConfig;
|
|
|
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
|
dotenv::from_filename("tests/test.env").ok()?;
|
|
|
|
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
|
|
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
|
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
|
|
|
|
if openai_api_key.contains("your_") {
|
|
return None;
|
|
}
|
|
|
|
Some(LLMProviderConfig {
|
|
provider_type: "openai".to_string(),
|
|
name: "test_openai".to_string(),
|
|
base_url: openai_base_url,
|
|
api_key: openai_api_key,
|
|
extra_headers: HashMap::new(),
|
|
model_id: openai_model,
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(100),
|
|
model_extra: HashMap::new(),
|
|
max_tool_iterations: 20,
|
|
})
|
|
}
|
|
|
|
fn make_weather_tool() -> Tool {
|
|
Tool {
|
|
tool_type: "function".to_string(),
|
|
function: ToolFunction {
|
|
name: "get_weather".to_string(),
|
|
description: "Get current weather for a city".to_string(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city name"
|
|
}
|
|
},
|
|
"required": ["city"]
|
|
}),
|
|
},
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_openai_tool_call() {
|
|
let config = load_openai_config()
|
|
.expect("Please configure tests/test.env with valid API keys");
|
|
|
|
let provider = create_provider(config).expect("Failed to create provider");
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "user".to_string(),
|
|
content: "What is the weather in Tokyo?".to_string(),
|
|
}],
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(200),
|
|
tools: Some(vec![make_weather_tool()]),
|
|
};
|
|
|
|
let response = provider.chat(request).await.unwrap();
|
|
|
|
// Should have tool calls
|
|
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
|
|
|
let tool_call = &response.tool_calls[0];
|
|
assert_eq!(tool_call.name, "get_weather");
|
|
assert!(tool_call.arguments.get("city").is_some());
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_openai_tool_call_with_manual_execution() {
|
|
let config = load_openai_config()
|
|
.expect("Please configure tests/test.env with valid API keys");
|
|
|
|
let provider = create_provider(config).expect("Failed to create provider");
|
|
|
|
// First request with tool
|
|
let request1 = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "user".to_string(),
|
|
content: "What is the weather in Tokyo?".to_string(),
|
|
}],
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(200),
|
|
tools: Some(vec![make_weather_tool()]),
|
|
};
|
|
|
|
let response1 = provider.chat(request1).await.unwrap();
|
|
let tool_call = response1.tool_calls.first()
|
|
.expect("Expected tool call");
|
|
assert_eq!(tool_call.name, "get_weather");
|
|
|
|
// Second request with tool result
|
|
let request2 = ChatCompletionRequest {
|
|
messages: vec![
|
|
Message {
|
|
role: "user".to_string(),
|
|
content: "What is the weather in Tokyo?".to_string(),
|
|
},
|
|
Message {
|
|
role: "assistant".to_string(),
|
|
content: r#"I'll check the weather for you using the get_weather tool."#.to_string(),
|
|
},
|
|
],
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(200),
|
|
tools: Some(vec![make_weather_tool()]),
|
|
};
|
|
|
|
let response2 = provider.chat(request2).await.unwrap();
|
|
|
|
// Should have a response
|
|
assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_openai_no_tool_when_not_provided() {
|
|
let config = load_openai_config()
|
|
.expect("Please configure tests/test.env with valid API keys");
|
|
|
|
let provider = create_provider(config).expect("Failed to create provider");
|
|
|
|
let request = ChatCompletionRequest {
|
|
messages: vec![Message {
|
|
role: "user".to_string(),
|
|
content: "Say hello in one word.".to_string(),
|
|
}],
|
|
temperature: Some(0.0),
|
|
max_tokens: Some(10),
|
|
tools: None,
|
|
};
|
|
|
|
let response = provider.chat(request).await.unwrap();
|
|
|
|
// Should NOT have tool calls
|
|
assert!(response.tool_calls.is_empty());
|
|
assert!(!response.content.is_empty());
|
|
}
|