PicoBot/tests/test_tool_calling.rs
2026-05-01 21:22:07 +08:00

155 lines
5.0 KiB
Rust

use picobot::config::LLMProviderConfig;
use picobot::providers::{
ChatCompletionRequest, Message, ProviderRuntimeConfig, Tool, ToolFunction, create_provider,
};
use std::collections::HashMap;
fn to_runtime_config(config: LLMProviderConfig) -> ProviderRuntimeConfig {
ProviderRuntimeConfig {
provider_type: config.provider_type,
name: config.name,
base_url: config.base_url,
api_key: config.api_key,
extra_headers: config.extra_headers,
llm_timeout_secs: config.llm_timeout_secs,
model_id: config.model_id,
temperature: config.temperature,
max_tokens: config.max_tokens,
model_extra: config.model_extra,
}
}
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
if openai_api_key.contains("your_") {
return None;
}
Some(LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test_openai".to_string(),
base_url: openai_base_url,
api_key: openai_api_key,
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 20,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
})
}
fn make_weather_tool() -> Tool {
Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "get_weather".to_string(),
description: "Get current weather for a city".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
}),
},
}
}
#[tokio::test]
#[ignore]
async fn test_openai_tool_call() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message::user("What is the weather in Tokyo?")],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response = provider.chat(request).await.unwrap();
// Should have tool calls
assert!(
!response.tool_calls.is_empty(),
"Expected tool call, got: {}",
response.content
);
let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather");
assert!(tool_call.arguments.get("city").is_some());
}
#[tokio::test]
#[ignore]
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
// First request with tool
let request1 = ChatCompletionRequest {
messages: vec![Message::user("What is the weather in Tokyo?")],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first().expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather");
// Second request with tool result
let request2 = ChatCompletionRequest {
messages: vec![
Message::user("What is the weather in Tokyo?"),
Message::assistant(r#"I'll check the weather for you using the get_weather tool."#),
],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response2 = provider.chat(request2).await.unwrap();
// Should have a response
assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty());
}
#[tokio::test]
#[ignore]
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(to_runtime_config(config)).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message::user("Say hello in one word.")],
temperature: Some(0.0),
max_tokens: Some(10),
tools: None,
};
let response = provider.chat(request).await.unwrap();
// Should NOT have tool calls
assert!(response.tool_calls.is_empty());
assert!(!response.content.is_empty());
}