- Introduced a new CalculatorTool for performing various arithmetic and statistical calculations. - Enhanced the AgentLoop to support tool execution, including handling tool calls in the process flow. - Updated ChatMessage structure to include optional fields for tool call identification and names. - Modified the Session and SessionManager to manage tool registrations and pass them to agents. - Updated the OpenAIProvider to serialize tool-related message fields. - Added a ToolRegistry for managing multiple tools and their definitions. - Implemented tests for the CalculatorTool to ensure functionality and correctness.
199 lines
5.1 KiB
Rust
199 lines
5.1 KiB
Rust
use async_trait::async_trait;
|
|
use reqwest::Client;
|
|
use serde::Deserialize;
|
|
use serde_json::json;
|
|
use std::collections::HashMap;
|
|
|
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
use super::traits::Usage;
|
|
|
|
pub struct OpenAIProvider {
|
|
client: Client,
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
impl OpenAIProvider {
|
|
pub fn new(
|
|
name: String,
|
|
api_key: String,
|
|
base_url: String,
|
|
extra_headers: HashMap<String, String>,
|
|
model_id: String,
|
|
temperature: Option<f32>,
|
|
max_tokens: Option<u32>,
|
|
model_extra: HashMap<String, serde_json::Value>,
|
|
) -> Self {
|
|
Self {
|
|
client: Client::new(),
|
|
name,
|
|
api_key,
|
|
base_url,
|
|
extra_headers,
|
|
model_id,
|
|
temperature,
|
|
max_tokens,
|
|
model_extra,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIResponse {
|
|
id: String,
|
|
model: String,
|
|
choices: Vec<OpenAIChoice>,
|
|
#[serde(default)]
|
|
usage: OpenAIUsage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIChoice {
|
|
message: OpenAIMessage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIMessage {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Vec<OpenAIToolCall>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAIToolCall {
|
|
id: String,
|
|
#[serde(rename = "function")]
|
|
function: OAIFunction,
|
|
#[serde(default)]
|
|
index: Option<u32>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OAIFunction {
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAIUsage {
|
|
#[serde(default)]
|
|
prompt_tokens: u32,
|
|
#[serde(default)]
|
|
completion_tokens: u32,
|
|
#[serde(default)]
|
|
total_tokens: u32,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for OpenAIProvider {
|
|
async fn chat(
|
|
&self,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
|
|
let mut body = json!({
|
|
"model": self.model_id,
|
|
"messages": request.messages.iter().map(|m| {
|
|
if m.role == "tool" {
|
|
json!({
|
|
"role": m.role,
|
|
"content": m.content,
|
|
"tool_call_id": m.tool_call_id,
|
|
"name": m.name,
|
|
})
|
|
} else {
|
|
json!({
|
|
"role": m.role,
|
|
"content": m.content
|
|
})
|
|
}
|
|
}).collect::<Vec<_>>(),
|
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
|
});
|
|
|
|
// Add model extra fields
|
|
for (key, value) in &self.model_extra {
|
|
body[key] = value.clone();
|
|
}
|
|
|
|
if let Some(tools) = &request.tools {
|
|
body["tools"] = json!(tools);
|
|
}
|
|
|
|
let mut req_builder = self
|
|
.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("Content-Type", "application/json");
|
|
|
|
for (key, value) in &self.extra_headers {
|
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
}
|
|
|
|
let resp = req_builder.json(&body).send().await?;
|
|
|
|
let status = resp.status();
|
|
let text = resp.text().await?;
|
|
|
|
if !status.is_success() {
|
|
return Err(format!("API error {}: {}", status, text).into());
|
|
}
|
|
|
|
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
|
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
|
|
|
let content = openai_resp.choices[0]
|
|
.message
|
|
.content
|
|
.as_ref()
|
|
.unwrap_or(&String::new())
|
|
.clone();
|
|
|
|
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
|
.message
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| ToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
|
})
|
|
.collect();
|
|
|
|
Ok(ChatCompletionResponse {
|
|
id: openai_resp.id,
|
|
model: openai_resp.model,
|
|
content,
|
|
tool_calls,
|
|
usage: Usage {
|
|
prompt_tokens: openai_resp.usage.prompt_tokens,
|
|
completion_tokens: openai_resp.usage.completion_tokens,
|
|
total_tokens: openai_resp.usage.total_tokens,
|
|
},
|
|
})
|
|
}
|
|
|
|
fn ptype(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn model_id(&self) -> &str {
|
|
&self.model_id
|
|
}
|
|
}
|