diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..91f8c4d --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "PicoBot" +version = "0.1.0" +edition = "2024" + +[dependencies] +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +dotenv = "0.15" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +async-trait = "0.1" +thiserror = "1.0" +tokio = { version = "1.0", features = ["full"] } diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..798d014 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,132 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Config { + pub providers: HashMap, + pub models: HashMap, + pub agents: HashMap, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ProviderConfig { + #[serde(rename = "type")] + pub provider_type: String, + pub base_url: String, + pub api_key: String, + #[serde(default)] + pub extra_headers: HashMap, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ModelConfig { + pub model_id: String, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AgentConfig { + pub provider: String, + pub model: String, +} + +#[derive(Debug, Clone)] +pub struct LLMProviderConfig { + pub provider_type: String, + pub name: String, + pub base_url: String, + pub api_key: String, + pub extra_headers: HashMap, + pub model_id: String, + pub temperature: Option, + pub max_tokens: Option, + pub model_extra: HashMap, +} + +impl Config { + pub fn load(path: &str) -> Result> { + let content = fs::read_to_string(path)?; + let config: Config = serde_json::from_str(&content)?; + Ok(config) + } + + pub fn get_provider_config(&self, agent_name: &str) -> Result { + let agent = self.agents.get(agent_name) + .ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?; + + let provider = self.providers.get(&agent.provider) + .ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?; + + let model = self.models.get(&agent.model) + .ok_or(ConfigError::ModelNotFound(agent.model.clone()))?; + + Ok(LLMProviderConfig { + provider_type: provider.provider_type.clone(), + name: agent.provider.clone(), + base_url: provider.base_url.clone(), + api_key: provider.api_key.clone(), + extra_headers: provider.extra_headers.clone(), + model_id: model.model_id.clone(), + temperature: model.temperature, + max_tokens: model.max_tokens, + model_extra: model.extra.clone(), + }) + } +} + +#[derive(Debug)] +pub enum ConfigError { + AgentNotFound(String), + ProviderNotFound(String), + ModelNotFound(String), +} + +impl std::fmt::Display for ConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), + ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), + ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), + } + } +} + +impl std::error::Error for ConfigError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_load() { + let config = Config::load("config.json").unwrap(); + + // Check providers + assert!(config.providers.contains_key("volcengine")); + assert!(config.providers.contains_key("aliyun")); + + // Check models + assert!(config.models.contains_key("doubao-seed-2-0-lite-260215")); + assert!(config.models.contains_key("qwen-plus")); + + // Check agents + assert!(config.agents.contains_key("default")); + } + + #[test] + fn test_get_provider_config() { + let config = Config::load("config.json").unwrap(); + let provider_config = config.get_provider_config("default").unwrap(); + + assert_eq!(provider_config.provider_type, "openai"); + assert_eq!(provider_config.name, "aliyun"); + assert_eq!(provider_config.model_id, "qwen-plus"); + assert_eq!(provider_config.temperature, Some(0.0)); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..4e2fd8e --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod config; +pub mod providers; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..6f7cadc --- /dev/null +++ b/src/main.rs @@ -0,0 +1,38 @@ +mod config; +mod providers; + +use config::Config; +use providers::{create_provider, ChatCompletionRequest, Message}; + +#[tokio::main] +async fn main() { + // Load config + let config = Config::load("config.json").expect("Failed to load config.json"); + + // Get provider config for "default" agent + let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); + + // Create provider + let provider = create_provider(provider_config).expect("Failed to create provider"); + + println!("Provider type: {}", provider.ptype()); + println!("Provider name: {}", provider.name()); + println!("Model ID: {}", provider.model_id()); + + // Create request (no model ID needed - it's baked into the provider) + let request = ChatCompletionRequest { + messages: vec![Message { + role: "user".to_string(), + content: "Hello!".to_string(), + }], + temperature: None, // Will use config default if not provided + max_tokens: None, // Will use config default if not provided + tools: None, + }; + + // Example usage: + // match provider.chat(request).await { + // Ok(resp) => println!("Response: {}", resp.content), + // Err(e) => eprintln!("Error: {}", e), + // } +} diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs new file mode 100644 index 0000000..9b06156 --- /dev/null +++ b/src/providers/anthropic.rs @@ -0,0 +1,198 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; +use super::traits::Usage; + +pub struct AnthropicProvider { + client: Client, + name: String, + api_key: String, + base_url: String, + extra_headers: HashMap, + model_id: String, + temperature: Option, + max_tokens: Option, + model_extra: HashMap, +} + +impl AnthropicProvider { + pub fn new( + name: String, + api_key: String, + base_url: String, + extra_headers: HashMap, + model_id: String, + temperature: Option, + max_tokens: Option, + model_extra: HashMap, + ) -> Self { + Self { + client: Client::new(), + name, + api_key, + base_url, + extra_headers, + model_id, + temperature, + max_tokens, + model_extra, + } + } +} + +#[derive(Serialize)] +struct AnthropicRequest { + model: String, + messages: Vec, + max_tokens: u32, + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(flatten)] + extra: HashMap, +} + +#[derive(Serialize)] +struct AnthropicMessage { + role: String, + content: String, +} + +#[derive(Serialize)] +struct AnthropicTool { + name: String, + description: String, + input_schema: serde_json::Value, +} + +#[derive(Deserialize)] +struct AnthropicResponse { + id: String, + model: String, + content: Vec, + usage: AnthropicUsage, +} + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum AnthropicContent { + Text { text: String }, + Thinking { thinking: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, +} + +#[derive(Deserialize)] +struct AnthropicUsage { + input_tokens: u32, + output_tokens: u32, +} + +#[async_trait] +impl LLMProvider for AnthropicProvider { + async fn chat( + &self, + request: ChatCompletionRequest, + ) -> Result> { + let url = format!("{}/v1/messages", self.base_url); + let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024); + + let tools = request.tools.map(|tools| { + tools + .iter() + .map(|t: &Tool| AnthropicTool { + name: t.function.name.clone(), + description: t.function.description.clone(), + input_schema: t.function.parameters.clone(), + }) + .collect() + }); + + let body = AnthropicRequest { + model: self.model_id.clone(), + messages: request + .messages + .iter() + .map(|m| AnthropicMessage { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(), + max_tokens, + temperature: request.temperature.or(self.temperature), + tools, + extra: self.model_extra.clone(), + }; + + let mut req_builder = self + .client + .post(&url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json"); + + for (key, value) in &self.extra_headers { + req_builder = req_builder.header(key.as_str(), value.as_str()); + } + + let resp = req_builder.json(&body).send().await?; + + let anthropic_resp: AnthropicResponse = resp.json().await?; + + let mut content = String::new(); + let mut tool_calls = Vec::new(); + + for c in &anthropic_resp.content { + match c { + AnthropicContent::Text { text } => { + if !text.is_empty() { + if !content.is_empty() { + content.push('\n'); + } + content.push_str(text); + } + } + AnthropicContent::Thinking { .. } => {} + AnthropicContent::ToolUse { id, name, input } => { + tool_calls.push(ToolCall { + id: id.clone(), + name: name.clone(), + arguments: input.clone(), + }); + } + } + } + + Ok(ChatCompletionResponse { + id: anthropic_resp.id, + model: anthropic_resp.model, + content, + tool_calls, + usage: Usage { + prompt_tokens: anthropic_resp.usage.input_tokens, + completion_tokens: anthropic_resp.usage.output_tokens, + total_tokens: anthropic_resp.usage.input_tokens + + anthropic_resp.usage.output_tokens, + }, + }) + } + + fn ptype(&self) -> &str { + "anthropic" + } + + fn name(&self) -> &str { + &self.name + } + + fn model_id(&self) -> &str { + &self.model_id + } +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs new file mode 100644 index 0000000..eedab44 --- /dev/null +++ b/src/providers/mod.rs @@ -0,0 +1,50 @@ +pub mod traits; +pub mod openai; +pub mod anthropic; + +pub use self::openai::OpenAIProvider; +pub use self::anthropic::AnthropicProvider; + +use crate::config::LLMProviderConfig; +pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage}; + +pub fn create_provider(config: LLMProviderConfig) -> Result, ProviderError> { + match config.provider_type.as_str() { + "openai" => Ok(Box::new(OpenAIProvider::new( + config.name, + config.api_key, + config.base_url, + config.extra_headers, + config.model_id, + config.temperature, + config.max_tokens, + config.model_extra, + ))), + "anthropic" => Ok(Box::new(AnthropicProvider::new( + config.name, + config.api_key, + config.base_url, + config.extra_headers, + config.model_id, + config.temperature, + config.max_tokens, + config.model_extra, + ))), + _ => Err(ProviderError::UnknownProviderType(config.provider_type)), + } +} + +#[derive(Debug)] +pub enum ProviderError { + UnknownProviderType(String), +} + +impl std::fmt::Display for ProviderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProviderError::UnknownProviderType(t) => write!(f, "Unknown provider type: {}", t), + } + } +} + +impl std::error::Error for ProviderError {} diff --git a/src/providers/openai.rs b/src/providers/openai.rs new file mode 100644 index 0000000..043b07b --- /dev/null +++ b/src/providers/openai.rs @@ -0,0 +1,181 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::Deserialize; +use serde_json::json; +use std::collections::HashMap; + +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; +use super::traits::Usage; + +pub struct OpenAIProvider { + client: Client, + name: String, + api_key: String, + base_url: String, + extra_headers: HashMap, + model_id: String, + temperature: Option, + max_tokens: Option, + model_extra: HashMap, +} + +impl OpenAIProvider { + pub fn new( + name: String, + api_key: String, + base_url: String, + extra_headers: HashMap, + model_id: String, + temperature: Option, + max_tokens: Option, + model_extra: HashMap, + ) -> Self { + Self { + client: Client::new(), + name, + api_key, + base_url, + extra_headers, + model_id, + temperature, + max_tokens, + model_extra, + } + } +} + +#[derive(Deserialize)] +struct OpenAIResponse { + id: String, + model: String, + choices: Vec, + #[serde(default)] + usage: OpenAIUsage, +} + +#[derive(Deserialize)] +struct OpenAIChoice { + message: OpenAIMessage, +} + +#[derive(Deserialize)] +struct OpenAIMessage { + #[serde(default)] + content: Option, + #[serde(default)] + name: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Deserialize)] +struct OpenAIToolCall { + id: String, + #[serde(rename = "function")] + function: OAIFunction, + #[serde(default)] + index: Option, +} + +#[derive(Deserialize)] +struct OAIFunction { + name: String, + arguments: String, +} + +#[derive(Deserialize, Default)] +struct OpenAIUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, + #[serde(default)] + total_tokens: u32, +} + +#[async_trait] +impl LLMProvider for OpenAIProvider { + async fn chat( + &self, + request: ChatCompletionRequest, + ) -> Result> { + let url = format!("{}/chat/completions", self.base_url); + + let mut body = json!({ + "model": self.model_id, + "messages": request.messages.iter().map(|m| { + json!({ + "role": m.role, + "content": m.content + }) + }).collect::>(), + "temperature": request.temperature.or(self.temperature).unwrap_or(0.7), + "max_tokens": request.max_tokens.or(self.max_tokens), + }); + + // Add model extra fields + for (key, value) in &self.model_extra { + body[key] = value.clone(); + } + + if let Some(tools) = &request.tools { + body["tools"] = json!(tools); + } + + let mut req_builder = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json"); + + for (key, value) in &self.extra_headers { + req_builder = req_builder.header(key.as_str(), value.as_str()); + } + + let resp = req_builder.json(&body).send().await?; + + let openai_resp: OpenAIResponse = resp.json().await?; + + let content = openai_resp.choices[0] + .message + .content + .as_ref() + .unwrap_or(&String::new()) + .clone(); + + let tool_calls: Vec = openai_resp.choices[0] + .message + .tool_calls + .iter() + .map(|tc| ToolCall { + id: tc.id.clone(), + name: tc.function.name.clone(), + arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), + }) + .collect(); + + Ok(ChatCompletionResponse { + id: openai_resp.id, + model: openai_resp.model, + content, + tool_calls, + usage: Usage { + prompt_tokens: openai_resp.usage.prompt_tokens, + completion_tokens: openai_resp.usage.completion_tokens, + total_tokens: openai_resp.usage.total_tokens, + }, + }) + } + + fn ptype(&self) -> &str { + "openai" + } + + fn name(&self) -> &str { + &self.name + } + + fn model_id(&self) -> &str { + &self.model_id + } +} diff --git a/src/providers/traits.rs b/src/providers/traits.rs new file mode 100644 index 0000000..843aabb --- /dev/null +++ b/src/providers/traits.rs @@ -0,0 +1,67 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, + pub function: ToolFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolFunction { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionRequest { + pub messages: Vec, + pub temperature: Option, + pub max_tokens: Option, + pub tools: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub model: String, + pub content: String, + pub tool_calls: Vec, + pub usage: Usage, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[async_trait] +pub trait LLMProvider: Send + Sync { + async fn chat( + &self, + request: ChatCompletionRequest, + ) -> Result>; + + fn ptype(&self) -> &str; + + fn name(&self) -> &str; + + fn model_id(&self) -> &str; +} diff --git a/tests/test.env.example b/tests/test.env.example new file mode 100644 index 0000000..c132b97 --- /dev/null +++ b/tests/test.env.example @@ -0,0 +1,12 @@ +# Copy this file to test.env and fill in your API keys +# cp tests/test.env.example tests/test.env + +# Anthropic Configuration +ANTHROPIIC_BASE_URL=https://api.anthropic.com/v1 +ANTHROPIIC_API_KEY=your_anthropic_api_key_here +ANTHROPIIC_MODEL_NAME=claude-3-5-sonnet-20241022 + +# OpenAI Configuration +OPENAI_BASE_URL=https://api.openai.com/v1 +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_MODEL_NAME=gpt-4 diff --git a/tests/test_integration.rs b/tests/test_integration.rs new file mode 100644 index 0000000..aeb9e95 --- /dev/null +++ b/tests/test_integration.rs @@ -0,0 +1,94 @@ +use std::collections::HashMap; +use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message}; +use PicoBot::config::{Config, LLMProviderConfig}; + +fn load_config() -> Option { + dotenv::from_filename("tests/test.env").ok()?; + + let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?; + let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?; + let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string()); + + if openai_api_key.contains("your_") { + return None; + } + + Some(LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test_openai".to_string(), + base_url: openai_base_url, + api_key: openai_api_key, + extra_headers: HashMap::new(), + model_id: openai_model, + temperature: Some(0.0), + max_tokens: Some(100), + model_extra: HashMap::new(), + }) +} + +fn create_request(content: &str) -> ChatCompletionRequest { + ChatCompletionRequest { + messages: vec![Message { + role: "user".to_string(), + content: content.to_string(), + }], + temperature: Some(0.0), + max_tokens: Some(100), + tools: None, + } +} + +#[tokio::test] +#[ignore] +async fn test_openai_simple_completion() { + let config = load_config() + .expect("Please configure tests/test.env with valid API keys"); + + let provider = create_provider(config).expect("Failed to create provider"); + let response = provider.chat(create_request("Say 'ok'")).await.unwrap(); + + assert!(!response.id.is_empty()); + assert!(!response.content.is_empty()); + assert!(response.usage.total_tokens > 0); + assert!(response.content.to_lowercase().contains("ok")); +} + +#[tokio::test] +#[ignore] +async fn test_openai_conversation() { + let config = load_config() + .expect("Please configure tests/test.env with valid API keys"); + + let provider = create_provider(config).expect("Failed to create provider"); + + let request = ChatCompletionRequest { + messages: vec![ + Message { role: "user".to_string(), content: "My name is Alice".to_string() }, + Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() }, + Message { role: "user".to_string(), content: "What is my name?".to_string() }, + ], + temperature: Some(0.0), + max_tokens: Some(50), + tools: None, + }; + + let response = provider.chat(request).await.unwrap(); + assert!(response.content.to_lowercase().contains("alice")); +} + +#[tokio::test] +#[ignore] +async fn test_config_load() { + // Test that config.json can be loaded and provider config created + let config = Config::load("config.json").expect("Failed to load config.json"); + let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); + + assert_eq!(provider_config.provider_type, "openai"); + assert_eq!(provider_config.name, "aliyun"); + assert_eq!(provider_config.model_id, "qwen-plus"); + + let provider = create_provider(provider_config).expect("Failed to create provider"); + assert_eq!(provider.ptype(), "openai"); + assert_eq!(provider.name(), "aliyun"); + assert_eq!(provider.model_id(), "qwen-plus"); +} diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs new file mode 100644 index 0000000..58f34f4 --- /dev/null +++ b/tests/test_request_format.rs @@ -0,0 +1,65 @@ +use PicoBot::providers::{ChatCompletionRequest, Message}; + +/// Test that message with special characters is properly escaped +#[test] +fn test_message_special_characters() { + let msg = Message { + role: "user".to_string(), + content: "Hello \"world\"\nNew line\tTab".to_string(), + }; + + let json = serde_json::to_string(&msg).unwrap(); + let deserialized: Message = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.content, "Hello \"world\"\nNew line\tTab"); +} + +/// Test that multi-line system prompt is preserved +#[test] +fn test_multiline_system_prompt() { + let messages = vec![ + Message { + role: "system".to_string(), + content: "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate".to_string(), + }, + Message { + role: "user".to_string(), + content: "Hi".to_string(), + }, + ]; + + let json = serde_json::to_string(&messages[0]).unwrap(); + assert!(json.contains("helpful assistant")); + assert!(json.contains("rules")); + assert!(json.contains("1. Be kind")); +} + +/// Test ChatCompletionRequest serialization (without model field) +#[test] +fn test_chat_request_serialization() { + let request = ChatCompletionRequest { + messages: vec![ + Message { + role: "system".to_string(), + content: "You are helpful".to_string(), + }, + Message { + role: "user".to_string(), + content: "Hello".to_string(), + }, + ], + temperature: Some(0.7), + max_tokens: Some(100), + tools: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + + // Verify structure + assert!(json.contains(r#""role":"system""#)); + assert!(json.contains(r#""role":"user""#)); + assert!(json.contains(r#""content":"You are helpful""#)); + assert!(json.contains(r#""content":"Hello""#)); + assert!(json.contains(r#""temperature":0.7"#)); + assert!(json.contains(r#""max_tokens":100"#)); +} diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs new file mode 100644 index 0000000..c96ba60 --- /dev/null +++ b/tests/test_tool_calling.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction}; +use PicoBot::config::LLMProviderConfig; + +fn load_openai_config() -> Option { + dotenv::from_filename("tests/test.env").ok()?; + + let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?; + let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?; + let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string()); + + if openai_api_key.contains("your_") { + return None; + } + + Some(LLMProviderConfig { + provider_type: "openai".to_string(), + name: "test_openai".to_string(), + base_url: openai_base_url, + api_key: openai_api_key, + extra_headers: HashMap::new(), + model_id: openai_model, + temperature: Some(0.0), + max_tokens: Some(100), + model_extra: HashMap::new(), + }) +} + +fn make_weather_tool() -> Tool { + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "get_weather".to_string(), + description: "Get current weather for a city".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"] + }), + }, + } +} + +#[tokio::test] +#[ignore] +async fn test_openai_tool_call() { + let config = load_openai_config() + .expect("Please configure tests/test.env with valid API keys"); + + let provider = create_provider(config).expect("Failed to create provider"); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "user".to_string(), + content: "What is the weather in Tokyo?".to_string(), + }], + temperature: Some(0.0), + max_tokens: Some(200), + tools: Some(vec![make_weather_tool()]), + }; + + let response = provider.chat(request).await.unwrap(); + + // Should have tool calls + assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content); + + let tool_call = &response.tool_calls[0]; + assert_eq!(tool_call.name, "get_weather"); + assert!(tool_call.arguments.get("city").is_some()); +} + +#[tokio::test] +#[ignore] +async fn test_openai_tool_call_with_manual_execution() { + let config = load_openai_config() + .expect("Please configure tests/test.env with valid API keys"); + + let provider = create_provider(config).expect("Failed to create provider"); + + // First request with tool + let request1 = ChatCompletionRequest { + messages: vec![Message { + role: "user".to_string(), + content: "What is the weather in Tokyo?".to_string(), + }], + temperature: Some(0.0), + max_tokens: Some(200), + tools: Some(vec![make_weather_tool()]), + }; + + let response1 = provider.chat(request1).await.unwrap(); + let tool_call = response1.tool_calls.first() + .expect("Expected tool call"); + assert_eq!(tool_call.name, "get_weather"); + + // Second request with tool result + let request2 = ChatCompletionRequest { + messages: vec![ + Message { + role: "user".to_string(), + content: "What is the weather in Tokyo?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: r#"I'll check the weather for you using the get_weather tool."#.to_string(), + }, + ], + temperature: Some(0.0), + max_tokens: Some(200), + tools: Some(vec![make_weather_tool()]), + }; + + let response2 = provider.chat(request2).await.unwrap(); + + // Should have a response + assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_openai_no_tool_when_not_provided() { + let config = load_openai_config() + .expect("Please configure tests/test.env with valid API keys"); + + let provider = create_provider(config).expect("Failed to create provider"); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "user".to_string(), + content: "Say hello in one word.".to_string(), + }], + temperature: Some(0.0), + max_tokens: Some(10), + tools: None, + }; + + let response = provider.chat(request).await.unwrap(); + + // Should NOT have tool calls + assert!(response.tool_calls.is_empty()); + assert!(!response.content.is_empty()); +}