实现基本的模型调用
This commit is contained in:
parent
5dc13ea7ce
commit
8b1e6e7e06
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
||||||
13
Cargo.toml
Normal file
13
Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[package]
|
||||||
|
name = "PicoBot"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||||
|
dotenv = "0.15"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
async-trait = "0.1"
|
||||||
|
thiserror = "1.0"
|
||||||
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
132
src/config/mod.rs
Normal file
132
src/config/mod.rs
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub providers: HashMap<String, ProviderConfig>,
|
||||||
|
pub models: HashMap<String, ModelConfig>,
|
||||||
|
pub agents: HashMap<String, AgentConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub provider_type: String,
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_key: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub extra_headers: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ModelConfig {
|
||||||
|
pub model_id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: HashMap<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct AgentConfig {
|
||||||
|
pub provider: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LLMProviderConfig {
|
||||||
|
pub provider_type: String,
|
||||||
|
pub name: String,
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_key: String,
|
||||||
|
pub extra_headers: HashMap<String, String>,
|
||||||
|
pub model_id: String,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn load(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let content = fs::read_to_string(path)?;
|
||||||
|
let config: Config = serde_json::from_str(&content)?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
||||||
|
let agent = self.agents.get(agent_name)
|
||||||
|
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
||||||
|
|
||||||
|
let provider = self.providers.get(&agent.provider)
|
||||||
|
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
||||||
|
|
||||||
|
let model = self.models.get(&agent.model)
|
||||||
|
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
||||||
|
|
||||||
|
Ok(LLMProviderConfig {
|
||||||
|
provider_type: provider.provider_type.clone(),
|
||||||
|
name: agent.provider.clone(),
|
||||||
|
base_url: provider.base_url.clone(),
|
||||||
|
api_key: provider.api_key.clone(),
|
||||||
|
extra_headers: provider.extra_headers.clone(),
|
||||||
|
model_id: model.model_id.clone(),
|
||||||
|
temperature: model.temperature,
|
||||||
|
max_tokens: model.max_tokens,
|
||||||
|
model_extra: model.extra.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ConfigError {
|
||||||
|
AgentNotFound(String),
|
||||||
|
ProviderNotFound(String),
|
||||||
|
ModelNotFound(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ConfigError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||||
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||||
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ConfigError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_load() {
|
||||||
|
let config = Config::load("config.json").unwrap();
|
||||||
|
|
||||||
|
// Check providers
|
||||||
|
assert!(config.providers.contains_key("volcengine"));
|
||||||
|
assert!(config.providers.contains_key("aliyun"));
|
||||||
|
|
||||||
|
// Check models
|
||||||
|
assert!(config.models.contains_key("doubao-seed-2-0-lite-260215"));
|
||||||
|
assert!(config.models.contains_key("qwen-plus"));
|
||||||
|
|
||||||
|
// Check agents
|
||||||
|
assert!(config.agents.contains_key("default"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_provider_config() {
|
||||||
|
let config = Config::load("config.json").unwrap();
|
||||||
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider_config.provider_type, "openai");
|
||||||
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||||
|
assert_eq!(provider_config.temperature, Some(0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
2
src/lib.rs
Normal file
2
src/lib.rs
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
pub mod config;
|
||||||
|
pub mod providers;
|
||||||
38
src/main.rs
Normal file
38
src/main.rs
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
mod config;
|
||||||
|
mod providers;
|
||||||
|
|
||||||
|
use config::Config;
|
||||||
|
use providers::{create_provider, ChatCompletionRequest, Message};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
// Load config
|
||||||
|
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||||
|
|
||||||
|
// Get provider config for "default" agent
|
||||||
|
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
let provider = create_provider(provider_config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
println!("Provider type: {}", provider.ptype());
|
||||||
|
println!("Provider name: {}", provider.name());
|
||||||
|
println!("Model ID: {}", provider.model_id());
|
||||||
|
|
||||||
|
// Create request (no model ID needed - it's baked into the provider)
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hello!".to_string(),
|
||||||
|
}],
|
||||||
|
temperature: None, // Will use config default if not provided
|
||||||
|
max_tokens: None, // Will use config default if not provided
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Example usage:
|
||||||
|
// match provider.chat(request).await {
|
||||||
|
// Ok(resp) => println!("Response: {}", resp.content),
|
||||||
|
// Err(e) => eprintln!("Error: {}", e),
|
||||||
|
// }
|
||||||
|
}
|
||||||
198
src/providers/anthropic.rs
Normal file
198
src/providers/anthropic.rs
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
|
use super::traits::Usage;
|
||||||
|
|
||||||
|
pub struct AnthropicProvider {
|
||||||
|
client: Client,
|
||||||
|
name: String,
|
||||||
|
api_key: String,
|
||||||
|
base_url: String,
|
||||||
|
extra_headers: HashMap<String, String>,
|
||||||
|
model_id: String,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicProvider {
|
||||||
|
pub fn new(
|
||||||
|
name: String,
|
||||||
|
api_key: String,
|
||||||
|
base_url: String,
|
||||||
|
extra_headers: HashMap<String, String>,
|
||||||
|
model_id: String,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
client: Client::new(),
|
||||||
|
name,
|
||||||
|
api_key,
|
||||||
|
base_url,
|
||||||
|
extra_headers,
|
||||||
|
model_id,
|
||||||
|
temperature,
|
||||||
|
max_tokens,
|
||||||
|
model_extra,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct AnthropicRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<AnthropicMessage>,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<AnthropicTool>>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: HashMap<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct AnthropicMessage {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct AnthropicTool {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
input_schema: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct AnthropicResponse {
|
||||||
|
id: String,
|
||||||
|
model: String,
|
||||||
|
content: Vec<AnthropicContent>,
|
||||||
|
usage: AnthropicUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
enum AnthropicContent {
|
||||||
|
Text { text: String },
|
||||||
|
Thinking { thinking: String },
|
||||||
|
#[serde(rename = "tool_use")]
|
||||||
|
ToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct AnthropicUsage {
|
||||||
|
input_tokens: u32,
|
||||||
|
output_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for AnthropicProvider {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let url = format!("{}/v1/messages", self.base_url);
|
||||||
|
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
||||||
|
|
||||||
|
let tools = request.tools.map(|tools| {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|t: &Tool| AnthropicTool {
|
||||||
|
name: t.function.name.clone(),
|
||||||
|
description: t.function.description.clone(),
|
||||||
|
input_schema: t.function.parameters.clone(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
let body = AnthropicRequest {
|
||||||
|
model: self.model_id.clone(),
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| AnthropicMessage {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
max_tokens,
|
||||||
|
temperature: request.temperature.or(self.temperature),
|
||||||
|
tools,
|
||||||
|
extra: self.model_extra.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut req_builder = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("x-api-key", &self.api_key)
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
|
for (key, value) in &self.extra_headers {
|
||||||
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = req_builder.json(&body).send().await?;
|
||||||
|
|
||||||
|
let anthropic_resp: AnthropicResponse = resp.json().await?;
|
||||||
|
|
||||||
|
let mut content = String::new();
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
|
||||||
|
for c in &anthropic_resp.content {
|
||||||
|
match c {
|
||||||
|
AnthropicContent::Text { text } => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
if !content.is_empty() {
|
||||||
|
content.push('\n');
|
||||||
|
}
|
||||||
|
content.push_str(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AnthropicContent::Thinking { .. } => {}
|
||||||
|
AnthropicContent::ToolUse { id, name, input } => {
|
||||||
|
tool_calls.push(ToolCall {
|
||||||
|
id: id.clone(),
|
||||||
|
name: name.clone(),
|
||||||
|
arguments: input.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: anthropic_resp.id,
|
||||||
|
model: anthropic_resp.model,
|
||||||
|
content,
|
||||||
|
tool_calls,
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: anthropic_resp.usage.input_tokens,
|
||||||
|
completion_tokens: anthropic_resp.usage.output_tokens,
|
||||||
|
total_tokens: anthropic_resp.usage.input_tokens
|
||||||
|
+ anthropic_resp.usage.output_tokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptype(&self) -> &str {
|
||||||
|
"anthropic"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_id(&self) -> &str {
|
||||||
|
&self.model_id
|
||||||
|
}
|
||||||
|
}
|
||||||
50
src/providers/mod.rs
Normal file
50
src/providers/mod.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
pub mod traits;
|
||||||
|
pub mod openai;
|
||||||
|
pub mod anthropic;
|
||||||
|
|
||||||
|
pub use self::openai::OpenAIProvider;
|
||||||
|
pub use self::anthropic::AnthropicProvider;
|
||||||
|
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
||||||
|
|
||||||
|
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||||
|
match config.provider_type.as_str() {
|
||||||
|
"openai" => Ok(Box::new(OpenAIProvider::new(
|
||||||
|
config.name,
|
||||||
|
config.api_key,
|
||||||
|
config.base_url,
|
||||||
|
config.extra_headers,
|
||||||
|
config.model_id,
|
||||||
|
config.temperature,
|
||||||
|
config.max_tokens,
|
||||||
|
config.model_extra,
|
||||||
|
))),
|
||||||
|
"anthropic" => Ok(Box::new(AnthropicProvider::new(
|
||||||
|
config.name,
|
||||||
|
config.api_key,
|
||||||
|
config.base_url,
|
||||||
|
config.extra_headers,
|
||||||
|
config.model_id,
|
||||||
|
config.temperature,
|
||||||
|
config.max_tokens,
|
||||||
|
config.model_extra,
|
||||||
|
))),
|
||||||
|
_ => Err(ProviderError::UnknownProviderType(config.provider_type)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ProviderError {
|
||||||
|
UnknownProviderType(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ProviderError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ProviderError::UnknownProviderType(t) => write!(f, "Unknown provider type: {}", t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ProviderError {}
|
||||||
181
src/providers/openai.rs
Normal file
181
src/providers/openai.rs
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
|
use super::traits::Usage;
|
||||||
|
|
||||||
|
pub struct OpenAIProvider {
|
||||||
|
client: Client,
|
||||||
|
name: String,
|
||||||
|
api_key: String,
|
||||||
|
base_url: String,
|
||||||
|
extra_headers: HashMap<String, String>,
|
||||||
|
model_id: String,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIProvider {
|
||||||
|
pub fn new(
|
||||||
|
name: String,
|
||||||
|
api_key: String,
|
||||||
|
base_url: String,
|
||||||
|
extra_headers: HashMap<String, String>,
|
||||||
|
model_id: String,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
client: Client::new(),
|
||||||
|
name,
|
||||||
|
api_key,
|
||||||
|
base_url,
|
||||||
|
extra_headers,
|
||||||
|
model_id,
|
||||||
|
temperature,
|
||||||
|
max_tokens,
|
||||||
|
model_extra,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
id: String,
|
||||||
|
model: String,
|
||||||
|
choices: Vec<OpenAIChoice>,
|
||||||
|
#[serde(default)]
|
||||||
|
usage: OpenAIUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIChoice {
|
||||||
|
message: OpenAIMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
name: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Vec<OpenAIToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIToolCall {
|
||||||
|
id: String,
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
function: OAIFunction,
|
||||||
|
#[serde(default)]
|
||||||
|
index: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OAIFunction {
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Default)]
|
||||||
|
struct OpenAIUsage {
|
||||||
|
#[serde(default)]
|
||||||
|
prompt_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
completion_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for OpenAIProvider {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
|
let mut body = json!({
|
||||||
|
"model": self.model_id,
|
||||||
|
"messages": request.messages.iter().map(|m| {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": m.content
|
||||||
|
})
|
||||||
|
}).collect::<Vec<_>>(),
|
||||||
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
||||||
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add model extra fields
|
||||||
|
for (key, value) in &self.model_extra {
|
||||||
|
body[key] = value.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tools) = &request.tools {
|
||||||
|
body["tools"] = json!(tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut req_builder = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
|
for (key, value) in &self.extra_headers {
|
||||||
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = req_builder.json(&body).send().await?;
|
||||||
|
|
||||||
|
let openai_resp: OpenAIResponse = resp.json().await?;
|
||||||
|
|
||||||
|
let content = openai_resp.choices[0]
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&String::new())
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
||||||
|
.message
|
||||||
|
.tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.id.clone(),
|
||||||
|
name: tc.function.name.clone(),
|
||||||
|
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: openai_resp.id,
|
||||||
|
model: openai_resp.model,
|
||||||
|
content,
|
||||||
|
tool_calls,
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: openai_resp.usage.prompt_tokens,
|
||||||
|
completion_tokens: openai_resp.usage.completion_tokens,
|
||||||
|
total_tokens: openai_resp.usage.total_tokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptype(&self) -> &str {
|
||||||
|
"openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_id(&self) -> &str {
|
||||||
|
&self.model_id
|
||||||
|
}
|
||||||
|
}
|
||||||
67
src/providers/traits.rs
Normal file
67
src/providers/traits.rs
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Message {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Tool {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String,
|
||||||
|
pub function: ToolFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolFunction {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: String,
|
||||||
|
pub content: String,
|
||||||
|
pub tool_calls: Vec<ToolCall>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait LLMProvider: Send + Sync {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
fn ptype(&self) -> &str;
|
||||||
|
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
fn model_id(&self) -> &str;
|
||||||
|
}
|
||||||
12
tests/test.env.example
Normal file
12
tests/test.env.example
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Copy this file to test.env and fill in your API keys
|
||||||
|
# cp tests/test.env.example tests/test.env
|
||||||
|
|
||||||
|
# Anthropic Configuration
|
||||||
|
ANTHROPIIC_BASE_URL=https://api.anthropic.com/v1
|
||||||
|
ANTHROPIIC_API_KEY=your_anthropic_api_key_here
|
||||||
|
ANTHROPIIC_MODEL_NAME=claude-3-5-sonnet-20241022
|
||||||
|
|
||||||
|
# OpenAI Configuration
|
||||||
|
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
OPENAI_MODEL_NAME=gpt-4
|
||||||
94
tests/test_integration.rs
Normal file
94
tests/test_integration.rs
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
||||||
|
use PicoBot::config::{Config, LLMProviderConfig};
|
||||||
|
|
||||||
|
fn load_config() -> Option<LLMProviderConfig> {
|
||||||
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
|
|
||||||
|
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
|
||||||
|
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
||||||
|
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
|
||||||
|
|
||||||
|
if openai_api_key.contains("your_") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "test_openai".to_string(),
|
||||||
|
base_url: openai_base_url,
|
||||||
|
api_key: openai_api_key,
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
model_id: openai_model,
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(100),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_request(content: &str) -> ChatCompletionRequest {
|
||||||
|
ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: content.to_string(),
|
||||||
|
}],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(100),
|
||||||
|
tools: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_openai_simple_completion() {
|
||||||
|
let config = load_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||||
|
|
||||||
|
assert!(!response.id.is_empty());
|
||||||
|
assert!(!response.content.is_empty());
|
||||||
|
assert!(response.usage.total_tokens > 0);
|
||||||
|
assert!(response.content.to_lowercase().contains("ok"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_openai_conversation() {
|
||||||
|
let config = load_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message { role: "user".to_string(), content: "My name is Alice".to_string() },
|
||||||
|
Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() },
|
||||||
|
Message { role: "user".to_string(), content: "What is my name?".to_string() },
|
||||||
|
],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(50),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request).await.unwrap();
|
||||||
|
assert!(response.content.to_lowercase().contains("alice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_config_load() {
|
||||||
|
// Test that config.json can be loaded and provider config created
|
||||||
|
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||||
|
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||||
|
|
||||||
|
assert_eq!(provider_config.provider_type, "openai");
|
||||||
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||||
|
|
||||||
|
let provider = create_provider(provider_config).expect("Failed to create provider");
|
||||||
|
assert_eq!(provider.ptype(), "openai");
|
||||||
|
assert_eq!(provider.name(), "aliyun");
|
||||||
|
assert_eq!(provider.model_id(), "qwen-plus");
|
||||||
|
}
|
||||||
65
tests/test_request_format.rs
Normal file
65
tests/test_request_format.rs
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
use PicoBot::providers::{ChatCompletionRequest, Message};
|
||||||
|
|
||||||
|
/// Test that message with special characters is properly escaped
|
||||||
|
#[test]
|
||||||
|
fn test_message_special_characters() {
|
||||||
|
let msg = Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hello \"world\"\nNew line\tTab".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
let deserialized: Message = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(deserialized.content, "Hello \"world\"\nNew line\tTab");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test that multi-line system prompt is preserved
|
||||||
|
#[test]
|
||||||
|
fn test_multiline_system_prompt() {
|
||||||
|
let messages = vec![
|
||||||
|
Message {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&messages[0]).unwrap();
|
||||||
|
assert!(json.contains("helpful assistant"));
|
||||||
|
assert!(json.contains("rules"));
|
||||||
|
assert!(json.contains("1. Be kind"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test ChatCompletionRequest serialization (without model field)
|
||||||
|
#[test]
|
||||||
|
fn test_chat_request_serialization() {
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: "You are helpful".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hello".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: Some(0.7),
|
||||||
|
max_tokens: Some(100),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
|
|
||||||
|
// Verify structure
|
||||||
|
assert!(json.contains(r#""role":"system""#));
|
||||||
|
assert!(json.contains(r#""role":"user""#));
|
||||||
|
assert!(json.contains(r#""content":"You are helpful""#));
|
||||||
|
assert!(json.contains(r#""content":"Hello""#));
|
||||||
|
assert!(json.contains(r#""temperature":0.7"#));
|
||||||
|
assert!(json.contains(r#""max_tokens":100"#));
|
||||||
|
}
|
||||||
147
tests/test_tool_calling.rs
Normal file
147
tests/test_tool_calling.rs
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
||||||
|
use PicoBot::config::LLMProviderConfig;
|
||||||
|
|
||||||
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||||
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
|
|
||||||
|
let openai_base_url = std::env::var("OPENAI_BASE_URL").ok()?;
|
||||||
|
let openai_api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
||||||
|
let openai_model = std::env::var("OPENAI_MODEL_NAME").unwrap_or_else(|_| "gpt-4".to_string());
|
||||||
|
|
||||||
|
if openai_api_key.contains("your_") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "test_openai".to_string(),
|
||||||
|
base_url: openai_base_url,
|
||||||
|
api_key: openai_api_key,
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
model_id: openai_model,
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(100),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_weather_tool() -> Tool {
|
||||||
|
Tool {
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: ToolFunction {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get current weather for a city".to_string(),
|
||||||
|
parameters: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_openai_tool_call() {
|
||||||
|
let config = load_openai_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is the weather in Tokyo?".to_string(),
|
||||||
|
}],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(200),
|
||||||
|
tools: Some(vec![make_weather_tool()]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request).await.unwrap();
|
||||||
|
|
||||||
|
// Should have tool calls
|
||||||
|
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
||||||
|
|
||||||
|
let tool_call = &response.tool_calls[0];
|
||||||
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
|
assert!(tool_call.arguments.get("city").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_openai_tool_call_with_manual_execution() {
|
||||||
|
let config = load_openai_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
// First request with tool
|
||||||
|
let request1 = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is the weather in Tokyo?".to_string(),
|
||||||
|
}],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(200),
|
||||||
|
tools: Some(vec![make_weather_tool()]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response1 = provider.chat(request1).await.unwrap();
|
||||||
|
let tool_call = response1.tool_calls.first()
|
||||||
|
.expect("Expected tool call");
|
||||||
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
|
|
||||||
|
// Second request with tool result
|
||||||
|
let request2 = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is the weather in Tokyo?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: r#"I'll check the weather for you using the get_weather tool."#.to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(200),
|
||||||
|
tools: Some(vec![make_weather_tool()]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response2 = provider.chat(request2).await.unwrap();
|
||||||
|
|
||||||
|
// Should have a response
|
||||||
|
assert!(!response2.content.is_empty() || !response2.tool_calls.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_openai_no_tool_when_not_provided() {
|
||||||
|
let config = load_openai_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Say hello in one word.".to_string(),
|
||||||
|
}],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(10),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request).await.unwrap();
|
||||||
|
|
||||||
|
// Should NOT have tool calls
|
||||||
|
assert!(response.tool_calls.is_empty());
|
||||||
|
assert!(!response.content.is_empty());
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user