PicoBot/src/providers/openai.rs
xiaoxixi 9834bd75cf feat: add calculator tool and integrate with agent loop
- Introduced a new CalculatorTool for performing various arithmetic and statistical calculations.
- Enhanced the AgentLoop to support tool execution, including handling tool calls in the process flow.
- Updated ChatMessage structure to include optional fields for tool call identification and names.
- Modified the Session and SessionManager to manage tool registrations and pass them to agents.
- Updated the OpenAIProvider to serialize tool-related message fields.
- Added a ToolRegistry for managing multiple tools and their definitions.
- Implemented tests for the CalculatorTool to ensure functionality and correctness.
2026-04-06 23:43:45 +08:00

199 lines
5.1 KiB
Rust

use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use std::collections::HashMap;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage;
pub struct OpenAIProvider {
client: Client,
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
}
impl OpenAIProvider {
pub fn new(
name: String,
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
Self {
client: Client::new(),
name,
api_key,
base_url,
extra_headers,
model_id,
temperature,
max_tokens,
model_extra,
}
}
}
#[derive(Deserialize)]
struct OpenAIResponse {
id: String,
model: String,
choices: Vec<OpenAIChoice>,
#[serde(default)]
usage: OpenAIUsage,
}
#[derive(Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[derive(Deserialize)]
struct OpenAIMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
tool_calls: Vec<OpenAIToolCall>,
}
#[derive(Deserialize)]
struct OpenAIToolCall {
id: String,
#[serde(rename = "function")]
function: OAIFunction,
#[serde(default)]
index: Option<u32>,
}
#[derive(Deserialize)]
struct OAIFunction {
name: String,
arguments: String,
}
#[derive(Deserialize, Default)]
struct OpenAIUsage {
#[serde(default)]
prompt_tokens: u32,
#[serde(default)]
completion_tokens: u32,
#[serde(default)]
total_tokens: u32,
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let url = format!("{}/chat/completions", self.base_url);
let mut body = json!({
"model": self.model_id,
"messages": request.messages.iter().map(|m| {
if m.role == "tool" {
json!({
"role": m.role,
"content": m.content,
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else {
json!({
"role": m.role,
"content": m.content
})
}
}).collect::<Vec<_>>(),
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
"max_tokens": request.max_tokens.or(self.max_tokens),
});
// Add model extra fields
for (key, value) in &self.model_extra {
body[key] = value.clone();
}
if let Some(tools) = &request.tools {
body["tools"] = json!(tools);
}
let mut req_builder = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(format!("API error {}: {}", status, text).into());
}
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
let content = openai_resp.choices[0]
.message
.content
.as_ref()
.unwrap_or(&String::new())
.clone();
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
.message
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
})
.collect();
Ok(ChatCompletionResponse {
id: openai_resp.id,
model: openai_resp.model,
content,
tool_calls,
usage: Usage {
prompt_tokens: openai_resp.usage.prompt_tokens,
completion_tokens: openai_resp.usage.completion_tokens,
total_tokens: openai_resp.usage.total_tokens,
},
})
}
fn ptype(&self) -> &str {
"openai"
}
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.model_id
}
}