PicoBot/tests/test_tool_calling.rs

148 lines
4.6 KiB
Rust

use std::collections::HashMap;
use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use PicoBot::config::LLMProviderConfig;
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
if openai_api_key.contains("your_") {
return None;
}
Some(LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test_openai".to_string(),
base_url: openai_base_url,
api_key: openai_api_key,
extra_headers: HashMap::new(),
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),
model_extra: HashMap::new(),
})
}
fn make_weather_tool() -> Tool {
Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "get_weather".to_string(),
description: "Get current weather for a city".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
}),
},
}
}
#[tokio::test]
#[ignore]
async fn test_openai_tool_call() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message {
role: "user".to_string(),
content: "What is the weather in Tokyo?".to_string(),
}],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response = provider.chat(request).await.unwrap();
// Should have tool calls
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather");
assert!(tool_call.arguments.get("city").is_some());
}
#[tokio::test]
#[ignore]
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
// First request with tool
let request1 = ChatCompletionRequest {
messages: vec![Message {
role: "user".to_string(),
content: "What is the weather in Tokyo?".to_string(),
}],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first()
.expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather");
// Second request with tool result
let request2 = ChatCompletionRequest {
messages: vec![
Message {
role: "user".to_string(),
content: "What is the weather in Tokyo?".to_string(),
},
Message {
role: "assistant".to_string(),
content: r#"I'll check the weather for you using the get_weather tool."#.to_string(),
},
],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response2 = provider.chat(request2).await.unwrap();
// Should have a response
assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty());
}
#[tokio::test]
#[ignore]
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message {
role: "user".to_string(),
content: "Say hello in one word.".to_string(),
}],
temperature: Some(0.0),
max_tokens: Some(10),
tools: None,
};
let response = provider.chat(request).await.unwrap();
// Should NOT have tool calls
assert!(response.tool_calls.is_empty());
assert!(!response.content.is_empty());
}