PicoBot/tests/test_tool_calling.rs
ooodc 8bb32fa066 feat: enhance WebSocket session management and storage
- Added SessionSummary struct for session metadata.
- Updated ws_handler to create and manage CLI sessions more robustly.
- Implemented session creation, loading, renaming, archiving, and deletion via WebSocket messages.
- Introduced SessionStore for persistent session storage using SQLite.
- Enhanced error handling and logging for session operations.
- Updated protocol definitions for new session-related WebSocket messages.
- Refactored tests to cover new session functionalities and ensure proper serialization.
2026-04-18 13:09:14 +08:00

135 lines
4.3 KiB
Rust

use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use picobot::config::LLMProviderConfig;
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
if openai_api_key.contains("your_") {
return None;
}
Some(LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test_openai".to_string(),
base_url: openai_base_url,
api_key: openai_api_key,
extra_headers: HashMap::new(),
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),
model_extra: HashMap::new(),
max_tool_iterations: 20,
token_limit: 128_000,
})
}
fn make_weather_tool() -> Tool {
Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "get_weather".to_string(),
description: "Get current weather for a city".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
}),
},
}
}
#[tokio::test]
#[ignore]
async fn test_openai_tool_call() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message::user("What is the weather in Tokyo?")],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response = provider.chat(request).await.unwrap();
// Should have tool calls
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather");
assert!(tool_call.arguments.get("city").is_some());
}
#[tokio::test]
#[ignore]
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
// First request with tool
let request1 = ChatCompletionRequest {
messages: vec![Message::user("What is the weather in Tokyo?")],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first()
.expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather");
// Second request with tool result
let request2 = ChatCompletionRequest {
messages: vec![
Message::user("What is the weather in Tokyo?"),
Message::assistant(r#"I'll check the weather for you using the get_weather tool."#),
],
temperature: Some(0.0),
max_tokens: Some(200),
tools: Some(vec![make_weather_tool()]),
};
let response2 = provider.chat(request2).await.unwrap();
// Should have a response
assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty());
}
#[tokio::test]
#[ignore]
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let request = ChatCompletionRequest {
messages: vec![Message::user("Say hello in one word.")],
temperature: Some(0.0),
max_tokens: Some(10),
tools: None,
};
let response = provider.chat(request).await.unwrap();
// Should NOT have tool calls
assert!(response.tool_calls.is_empty());
assert!(!response.content.is_empty());
}