PicoBot/src/agent/agent_loop.rs

73 lines
2.0 KiB
Rust

use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
history: Vec<ChatMessage>,
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
history: Vec::new(),
})
}
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
self.history.push(user_message.clone());
let messages: Vec<Message> = self.history
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatCompletionRequest {
messages,
temperature: None,
max_tokens: None,
tools: None,
};
let response = (*self.provider).chat(request).await
.map_err(|e| AgentError::LlmError(e.to_string()))?;
let assistant_message = ChatMessage::assistant(response.content);
self.history.push(assistant_message.clone());
Ok(assistant_message)
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn history(&self) -> &[ChatMessage] {
&self.history
}
}
#[derive(Debug)]
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
}
}
}
impl std::error::Error for AgentError {}