PicoBot/src/providers/traits.rs

127 lines
3.2 KiB
Rust

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Vec<ContentBlock>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
Self {
role: "user".to_string(),
content,
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: Some(tool_call_id.into()),
name: Some(tool_name.into()),
tool_calls: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub model: String,
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
fn ptype(&self) -> &str;
fn name(&self) -> &str;
fn model_id(&self) -> &str;
}