Compare commits

..

No commits in common. "9834bd75cf2a1c3c06c65c08fcb1f83fdb484f3e" and "04736f9f46217e1ece03d5097e70a09cecc55448" have entirely different histories.

19 changed files with 171 additions and 1554 deletions

View File

@ -19,7 +19,3 @@ futures-util = "0.3"
clap = { version = "4", features = ["derive"] }
dirs = "6.0.0"
prost = "0.14"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
tracing-appender = "0.2"
anyhow = "1.0"

View File

@ -1,13 +1,10 @@
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::tools::ToolRegistry;
use std::sync::Arc;
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
history: Vec<ChatMessage>,
tools: Arc<ToolRegistry>,
}
impl AgentLoop {
@ -18,25 +15,9 @@ impl AgentLoop {
Ok(Self {
provider,
history: Vec::new(),
tools: Arc::new(ToolRegistry::new()),
})
}
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
history: Vec::new(),
tools,
})
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
self.history.push(user_message.clone());
@ -45,52 +26,18 @@ impl AgentLoop {
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
})
.collect();
tracing::debug!(history_len = self.history.len(), "Sending request to LLM");
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
let request = ChatCompletionRequest {
messages,
temperature: None,
max_tokens: None,
tools,
tools: None,
};
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
tracing::debug!(response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received");
if !response.tool_calls.is_empty() {
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
let assistant_message = ChatMessage::assistant(response.content.clone());
self.history.push(assistant_message.clone());
let tool_results = self.execute_tools(&response.tool_calls).await;
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
let tool_message = ChatMessage::tool(
tool_call.id.clone(),
tool_call.name.clone(),
result.clone(),
);
self.history.push(tool_message);
}
return self.continue_with_tool_results(response.content).await;
}
.map_err(|e| AgentError::LlmError(e.to_string()))?;
let assistant_message = ChatMessage::assistant(response.content);
self.history.push(assistant_message.clone());
@ -98,81 +45,8 @@ impl AgentLoop {
Ok(assistant_message)
}
async fn continue_with_tool_results(&mut self, _original_content: String) -> Result<ChatMessage, AgentError> {
let messages: Vec<Message> = self.history
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
})
.collect();
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
let request = ChatCompletionRequest {
messages,
temperature: None,
max_tokens: None,
tools,
};
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM continuation request failed");
AgentError::LlmError(e.to_string())
})?;
let assistant_message = ChatMessage::assistant(response.content);
self.history.push(assistant_message.clone());
Ok(assistant_message)
}
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
let mut results = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
let result = self.execute_tool(tool_call).await;
results.push(result);
}
results
}
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
None => {
tracing::warn!(tool = %tool_call.name, "Tool not found");
return format!("Error: Tool '{}' not found", tool_call.name);
}
};
match tool.execute(tool_call.arguments.clone()).await {
Ok(result) => {
if result.success {
result.output
} else {
format!("Error: {}", result.error.unwrap_or_default())
}
}
Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
format!("Error: {}", e)
}
}
}
pub fn clear_history(&mut self) {
let len = self.history.len();
self.history.clear();
tracing::debug!(previous_len = len, "Chat history cleared");
}
pub fn history(&self) -> &[ChatMessage] {
@ -184,7 +58,6 @@ impl AgentLoop {
pub enum AgentError {
ProviderCreation(String),
LlmError(String),
Other(String),
}
impl std::fmt::Display for AgentError {
@ -192,7 +65,6 @@ impl std::fmt::Display for AgentError {
match self {
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
AgentError::Other(e) => write!(f, "{}", e),
}
}
}

View File

@ -6,10 +6,6 @@ pub struct ChatMessage {
pub role: String,
pub content: String,
pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
}
impl ChatMessage {
@ -19,8 +15,6 @@ impl ChatMessage {
role: "user".to_string(),
content: content.into(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
}
}
@ -30,8 +24,6 @@ impl ChatMessage {
role: "assistant".to_string(),
content: content.into(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
}
}
@ -41,19 +33,6 @@ impl ChatMessage {
role: "system".to_string(),
content: content.into(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(),
content: content.into(),
timestamp: current_timestamp(),
tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()),
}
}
}

View File

@ -1,12 +1,15 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use tokio::sync::{broadcast, RwLock};
use tokio::sync::{broadcast, RwLock, Mutex};
use serde::Deserialize;
use futures_util::{SinkExt, StreamExt};
use prost::{Message as ProstMessage, bytes::Bytes};
use crate::agent::AgentLoop;
use crate::bus::ChatMessage;
use crate::channels::base::{Channel, ChannelError};
use crate::channels::manager::MessageHandler;
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
@ -131,31 +134,25 @@ pub struct FeishuChannel {
running: Arc<RwLock<bool>>,
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
connected: Arc<RwLock<bool>>,
/// Message handler for routing messages to Gateway
message_handler: Arc<dyn MessageHandler>,
}
/// Parsed message data from a Feishu frame
struct ParsedMessage {
message_id: String,
open_id: String,
chat_id: String,
content: String,
/// Dedup: message_id -> timestamp (cleaned after 30 min)
seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
/// Agent for processing messages
agent: Arc<Mutex<AgentLoop>>,
}
impl FeishuChannel {
pub fn new(
config: FeishuChannelConfig,
message_handler: Arc<dyn MessageHandler>,
_provider_config: LLMProviderConfig,
) -> Result<Self, ChannelError> {
pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result<Self, ChannelError> {
let agent = AgentLoop::new(provider_config)
.map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?;
Ok(Self {
config,
http_client: reqwest::Client::new(),
running: Arc::new(RwLock::new(false)),
shutdown_tx: Arc::new(RwLock::new(None)),
connected: Arc::new(RwLock::new(false)),
message_handler,
seen_ids: Arc::new(RwLock::new(HashMap::new())),
agent: Arc::new(Mutex::new(agent)),
})
}
@ -262,22 +259,23 @@ impl FeishuChannel {
Ok(())
}
/// Handle incoming message - delegate to message handler and send response
/// Handle incoming message - process through agent and send response
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
tracing::info!(open_id, chat_id, "Processing message from Feishu");
println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content);
// Delegate to message handler (Gateway)
let response = self.message_handler
.handle_message("feishu", open_id, chat_id, content)
.await?;
// Process through agent
let user_msg = ChatMessage::user(content);
let mut agent = self.agent.lock().await;
let response = agent.process(user_msg).await
.map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?;
// Send response to the chat
// Use open_id for p2p chats, chat_id for group chats
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
self.send_message(receive_id, receive_id_type, &response).await?;
tracing::info!(receive_id, "Sent response to Feishu");
self.send_message(receive_id, receive_id_type, &response.content).await?;
println!("Feishu: sent response to {}", receive_id);
Ok(())
}
@ -295,8 +293,8 @@ impl FeishuChannel {
.unwrap_or(0)
}
/// Handle incoming binary PbFrame - returns Some(ParsedMessage) if we need to ack
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<ParsedMessage>, ChannelError> {
/// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<String>, ChannelError> {
// method 0 = CONTROL (ping/pong)
if frame.method == 0 {
return Ok(None);
@ -327,7 +325,20 @@ impl FeishuChannel {
return Ok(None);
}
// Deduplication check with TTL cleanup
let message_id = payload_data.message.message_id.clone();
{
let mut seen = self.seen_ids.write().await;
let now = Instant::now();
// Clean expired entries (older than 30 min)
seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800);
if seen.contains_key(&message_id) {
return Ok(None);
}
seen.insert(message_id.clone(), now);
}
let open_id = payload_data.sender.sender_id.open_id
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
@ -337,12 +348,13 @@ impl FeishuChannel {
let msg_type = msg.message_type.as_str();
let content = parse_message_content(msg_type, &msg.content);
Ok(Some(ParsedMessage {
message_id,
open_id,
chat_id,
content,
}))
// Handle the message - process and send response
if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await {
eprintln!("Error handling message: {}", e);
}
// Return message_id for ack
Ok(Some(message_id))
}
/// Send acknowledgment for a message
@ -363,14 +375,14 @@ impl FeishuChannel {
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
let service_id = Self::extract_service_id(&wss_url);
tracing::info!(url = %wss_url, "Connecting to Feishu WebSocket");
println!("Feishu: connecting to {}", wss_url);
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url)
.await
.map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
*self.connected.write().await = true;
tracing::info!("Feishu WebSocket connected");
println!("Feishu channel connected");
let (mut write, mut read) = ws_stream.split();
@ -404,25 +416,17 @@ impl FeishuChannel {
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
let bytes: Bytes = data;
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
// Parse the frame first
// Handle the frame and get message_id for ack if needed
match self.handle_frame(&frame).await {
Ok(Some(parsed)) => {
Ok(Some(_message_id)) => {
// Send ACK immediately (Feishu requires within 3 s)
if let Err(e) = Self::send_ack(&frame, &mut write).await {
tracing::error!(error = %e, "Failed to send ACK to Feishu");
eprintln!("Error sending ack: {}", e);
}
// Then process message asynchronously (don't await)
let channel = self.clone();
tokio::spawn(async move {
if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await {
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to handle Feishu message");
}
});
}
Ok(None) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse Feishu frame");
eprintln!("Error handling frame: {}", e);
}
}
}
@ -442,11 +446,10 @@ impl FeishuChannel {
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
tracing::debug!("Feishu WebSocket closed");
break;
}
Some(Err(e)) => {
tracing::warn!(error = %e, "Feishu WebSocket error");
eprintln!("WS error: {}", e);
break;
}
_ => {}
@ -466,12 +469,12 @@ impl FeishuChannel {
payload: None,
};
if write.send(tokio_tungstenite::tungstenite::Message::Binary(ping.encode_to_vec().into())).await.is_err() {
tracing::warn!("Feishu ping failed, reconnecting");
eprintln!("Feishu: ping failed, reconnecting");
break;
}
}
_ = shutdown_rx.recv() => {
tracing::info!("Feishu channel shutdown signal received");
println!("Feishu channel shutdown signal received");
break;
}
}
@ -553,13 +556,13 @@ impl Channel for FeishuChannel {
let shutdown_rx = shutdown_tx.subscribe();
match channel.run_ws_loop(shutdown_rx).await {
Ok(_) => {
tracing::info!("Feishu WebSocket disconnected");
println!("Feishu WebSocket disconnected");
}
Err(e) => {
consecutive_failures += 1;
tracing::error!(attempt = consecutive_failures, error = %e, "Feishu WebSocket error");
eprintln!("Feishu WebSocket error (attempt {}): {}", consecutive_failures, e);
if consecutive_failures >= max_failures {
tracing::error!("Feishu channel: max failures reached, stopping");
eprintln!("Feishu channel: max failures reached, stopping");
break;
}
}
@ -569,15 +572,15 @@ impl Channel for FeishuChannel {
break;
}
tracing::info!("Feishu channel retrying in 5s...");
println!("Feishu channel retrying in 5s...");
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
*channel.running.write().await = false;
tracing::info!("Feishu channel stopped");
println!("Feishu channel stopped");
});
tracing::info!("Feishu channel started");
println!("Feishu channel started");
Ok(())
}
@ -592,7 +595,6 @@ impl Channel for FeishuChannel {
}
fn is_running(&self) -> bool {
// Note: blocking read, acceptable for this use case
self.running.try_read().map(|r| *r).unwrap_or(false)
false
}
}

View File

@ -1,60 +1,41 @@
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::channels::base::{Channel, ChannelError};
use crate::channels::feishu::FeishuChannel;
use crate::config::{Config, FeishuChannelConfig};
use crate::config::Config;
/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦
#[async_trait]
pub trait MessageHandler: Send + Sync {
async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
) -> Result<String, ChannelError>;
}
/// ChannelManager 管理所有 Channel
#[derive(Clone)]
pub struct ChannelManager {
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
message_handler: Arc<dyn MessageHandler>,
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>,
}
impl ChannelManager {
pub fn new(message_handler: Arc<dyn MessageHandler>) -> Self {
pub fn new() -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
message_handler,
}
}
/// 获取 MessageHandler 用于让 Channel 调用
pub fn get_handler(&self) -> Arc<dyn MessageHandler> {
self.message_handler.clone()
}
/// 初始化所有 Channel
pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
pub async fn init(&self, config: &Config) -> Result<(), ChannelError> {
// Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled {
let handler = self.get_handler();
let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config)
let agent_name = &feishu_config.agent;
let provider_config = config.get_provider_config(agent_name)
.map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?;
let channel = FeishuChannel::new(feishu_config.clone(), provider_config)
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
self.channels
.write()
.await
.insert("feishu".to_string(), Arc::new(channel));
tracing::info!("Feishu channel registered");
println!("Feishu channel registered");
} else {
tracing::info!("Feishu channel disabled in config");
println!("Feishu channel disabled in config");
}
}
Ok(())
@ -63,9 +44,11 @@ impl ChannelManager {
pub async fn start_all(&self) -> Result<(), ChannelError> {
let channels = self.channels.read().await;
for (name, channel) in channels.iter() {
tracing::info!(channel = %name, "Starting channel");
println!("Starting channel: {}", name);
if let Err(e) = channel.start().await {
tracing::error!(channel = %name, error = %e, "Failed to start channel");
eprintln!("Warning: Failed to start channel {}: {}", name, e);
// Channel failed to start - it should have logged why
// Continue starting other channels
}
}
Ok(())
@ -74,44 +57,22 @@ impl ChannelManager {
pub async fn stop_all(&self) -> Result<(), ChannelError> {
let mut channels = self.channels.write().await;
for (name, channel) in channels.iter() {
tracing::info!(channel = %name, "Stopping channel");
println!("Stopping channel: {}", name);
if let Err(e) = channel.stop().await {
tracing::error!(channel = %name, error = %e, "Error stopping channel");
eprintln!("Error stopping channel {}: {}", name, e);
}
}
channels.clear();
Ok(())
}
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel + Send + Sync>> {
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel>> {
self.channels.read().await.get(name).cloned()
}
}
/// Gateway 实现 MessageHandler trait
#[derive(Clone)]
pub struct GatewayMessageHandler {
session_manager: crate::gateway::session::SessionManager,
}
impl GatewayMessageHandler {
pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self {
Self { session_manager }
}
}
#[async_trait]
impl MessageHandler for GatewayMessageHandler {
async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
) -> Result<String, ChannelError> {
self.session_manager
.handle_message(channel_name, sender_id, chat_id, content)
.await
.map_err(|e| ChannelError::Other(e.to_string()))
impl Default for ChannelManager {
fn default() -> Self {
Self::new()
}
}

View File

@ -11,7 +11,7 @@ fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let (ws_stream, _) = connect_async(gateway_url).await?;
tracing::info!(url = %gateway_url, "Connected to gateway");
println!("Connected to gateway");
let (mut sender, mut receiver) = ws_stream.split();
@ -35,7 +35,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
input.write_output(&format!("Error: {}", message)).await?;
}
WsOutbound::SessionEstablished { session_id } => {
tracing::debug!(session_id = %session_id, "Session established");
input.write_output(&format!("Session: {}\n", session_id)).await?;
}
_ => {}
@ -43,7 +42,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
}
}
Some(Ok(Message::Close(_))) | None => {
tracing::info!("Gateway disconnected");
input.write_output("Gateway disconnected").await?;
break;
}
@ -60,7 +58,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
break;
}
"__CLEAR__" => {
let inbound = WsInbound::ClearHistory { chat_id: None };
let inbound = WsInbound::ClearHistory;
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
@ -69,22 +67,17 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
_ => {}
}
let inbound = WsInbound::UserInput {
content: msg.content,
channel: None,
chat_id: None,
sender_id: None,
};
let inbound = WsInbound::UserInput { content: msg.content };
if let Ok(text) = serialize_inbound(&inbound) {
if sender.send(Message::Text(text.into())).await.is_err() {
tracing::error!("Failed to send message to gateway");
eprintln!("Failed to send message");
break;
}
}
}
Ok(None) => break,
Err(e) => {
tracing::error!(error = %e, "Input error");
eprintln!("Input error: {}", e);
break;
}
}

View File

@ -67,8 +67,6 @@ pub struct GatewayConfig {
pub host: String,
#[serde(default = "default_gateway_port")]
pub port: u16,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -94,7 +92,6 @@ impl Default for GatewayConfig {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
}
}
}
@ -122,7 +119,7 @@ pub struct LLMProviderConfig {
fn get_default_config_path() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("config.json")
home.join(".config").join("picobot").join("config.json")
}
impl Config {
@ -138,13 +135,13 @@ impl Config {
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
load_env_file()?;
let content = if path.exists() {
tracing::info!(path = %path.display(), "Config loaded");
println!("Config loaded from: {}", path.display());
fs::read_to_string(path)?
} else {
// Fallback to current directory
let fallback = Path::new("config.json");
if fallback.exists() {
tracing::info!(path = %fallback.display(), "Config loaded from fallback path");
println!("Config loaded from: {}", fallback.display());
fs::read_to_string(fallback)?
} else {
return Err(Box::new(ConfigError::ConfigNotFound(
@ -192,7 +189,7 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),

View File

@ -6,9 +6,8 @@ use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::channels::{ChannelManager, manager::GatewayMessageHandler};
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::logging;
use session::SessionManager;
pub struct GatewayState {
@ -20,37 +19,20 @@ pub struct GatewayState {
impl GatewayState {
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
let config = Config::load_default()?;
// Get provider config for SessionManager
let provider_config = config.get_provider_config("default")?;
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let session_manager = SessionManager::new(session_ttl_hours, provider_config);
let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone()));
let channel_manager = ChannelManager::new(message_handler);
let channel_manager = ChannelManager::new();
Ok(Self {
config,
session_manager,
session_manager: SessionManager::new(),
channel_manager,
})
}
}
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
logging::init_logging();
tracing::info!("Starting PicoBot Gateway");
let state = Arc::new(GatewayState::new()?);
// Get provider config for channels
let provider_config = state.config.get_provider_config("default")?;
// Initialize and start channels
state.channel_manager.init(&state.config, provider_config).await?;
state.channel_manager.init(&state.config).await?;
state.channel_manager.start_all().await?;
// CLI args override config file values
@ -64,7 +46,7 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
let addr = format!("{}:{}", bind_host, bind_port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!(address = %addr, "Gateway listening");
println!("Gateway listening on {}", addr);
// Graceful shutdown using oneshot channel
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
@ -73,7 +55,7 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
// Spawn ctrl_c handler
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Shutdown signal received");
println!("Shutting down...");
let _ = channel_manager.stop_all().await;
let _ = shutdown_tx.send(());
});

View File

@ -1,204 +1,67 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError};
use crate::agent::AgentLoop;
use crate::protocol::WsOutbound;
use crate::tools::{CalculatorTool, ToolRegistry};
/// Session 按 channel 隔离,每个 channel 一个 Session
pub struct Session {
pub id: Uuid,
pub channel_name: String,
/// 按 chat_id 路由到不同 AgentLoop支持多用户多会话
chat_agents: HashMap<String, Arc<Mutex<AgentLoop>>>,
pub agent_loop: Arc<Mutex<AgentLoop>>,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
}
impl Session {
pub async fn new(
channel_name: String,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
) -> Result<Self, crate::agent::AgentError> {
let agent_loop = AgentLoop::new(provider_config)?;
Ok(Self {
id: Uuid::new_v4(),
channel_name,
chat_agents: HashMap::new(),
agent_loop: Arc::new(Mutex::new(agent_loop)),
user_tx,
provider_config,
tools,
})
}
/// 获取或创建指定 chat_id 的 AgentLoop
pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result<Arc<Mutex<AgentLoop>>, AgentError> {
if let Some(agent) = self.chat_agents.get(chat_id) {
tracing::trace!(chat_id = %chat_id, "Reusing existing agent");
return Ok(agent.clone());
}
tracing::debug!(chat_id = %chat_id, "Creating new agent for chat");
let agent = AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())?;
let arc = Arc::new(Mutex::new(agent));
self.chat_agents.insert(chat_id.to_string(), arc.clone());
Ok(arc)
}
/// 获取指定 chat_id 的 AgentLoop不创建
pub fn get_agent(&self, chat_id: &str) -> Option<Arc<Mutex<AgentLoop>>> {
self.chat_agents.get(chat_id).cloned()
}
/// 清除指定 chat_id 的历史
pub async fn clear_chat_history(&mut self, chat_id: &str) {
if let Some(agent) = self.chat_agents.get(chat_id) {
agent.lock().await.clear_history();
}
}
/// 清除所有历史
pub async fn clear_all_history(&mut self) {
for agent in self.chat_agents.values() {
agent.lock().await.clear_history();
}
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
}
/// SessionManager 管理所有 Session按 channel_name 路由
/// 使用 Arc<Mutex<SessionManager>> 以从 Arc 获取可变访问
#[derive(Clone)]
use std::collections::HashMap;
use std::sync::RwLock;
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
}
struct SessionManagerInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
fn default_tools() -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry
sessions: RwLock<HashMap<Uuid, Arc<Session>>>,
}
impl SessionManager {
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools: Arc::new(default_tools()),
sessions: RwLock::new(HashMap::new()),
}
}
pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone()
pub fn add(&self, session: Arc<Session>) {
self.sessions.write().unwrap().insert(session.id, session);
}
/// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) {
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
// 移除旧 session
inner.sessions.remove(channel_name);
// 创建新 session使用临时 user_tx因为 Feishu 不通过 WS
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone()).await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(channel_name.to_string(), arc.clone());
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
pub fn remove(&self, id: &Uuid) {
self.sessions.write().unwrap().remove(id);
}
Ok(())
pub fn get(&self, id: &Uuid) -> Option<Arc<Session>> {
self.sessions.read().unwrap().get(id).cloned()
}
/// 获取 session不检查超时
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
let inner = self.inner.lock().await;
inner.sessions.get(channel_name).cloned()
}
/// 更新最后活跃时间
pub async fn touch(&self, channel_name: &str) {
let mut inner = self.inner.lock().await;
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
}
/// 处理消息:路由到对应 session 的 agent
pub async fn handle_message(
&self,
channel_name: &str,
_sender_id: &str,
chat_id: &str,
content: &str,
) -> Result<String, AgentError> {
tracing::debug!(channel = %channel_name, chat_id = %chat_id, content_len = content.len(), "Routing message to agent");
// 确保 session 存在(可能需要重建)
self.ensure_session(channel_name).await?;
// 更新活跃时间
self.touch(channel_name).await;
// 获取 session
let session = self.get(channel_name).await
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
// 获取或创建 chat_id 对应的 agent
let mut session_guard = session.lock().await;
let agent = session_guard.get_or_create_agent(chat_id).await?;
drop(session_guard);
let mut agent = agent.lock().await;
// 处理消息
let user_msg = ChatMessage::user(content);
let response = agent.process(user_msg).await?;
tracing::debug!(channel = %channel_name, chat_id = %chat_id, response_len = response.content.len(), "Agent response received");
Ok(response.content)
}
/// 清除指定 session 的所有历史
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
if let Some(session) = self.get(channel_name).await {
let mut session_guard = session.lock().await;
session_guard.clear_all_history().await;
}
Ok(())
pub fn len(&self) -> usize {
self.sessions.read().unwrap().len()
}
}
impl Default for SessionManager {
fn default() -> Self {
Self::new()
}
}

View File

@ -3,7 +3,7 @@ use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use tokio::sync::mpsc;
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
use super::{GatewayState, session::Session};
@ -20,39 +20,33 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
eprintln!("Failed to get provider config: {}", e);
return;
}
};
// CLI 使用独立的 sessionchannel_name = "cli-{uuid}"
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
// 创建 CLI session
let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await {
Ok(s) => Arc::new(Mutex::new(s)),
let session = match Session::new(provider_config, sender).await {
Ok(s) => Arc::new(s),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
eprintln!("Failed to create session: {}", e);
return;
}
};
let session_id = session.lock().await.id;
tracing::info!(session_id = %session_id, "CLI session established");
let session_id = session.id;
state.session_manager.add(session.clone());
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
let _ = session.send(WsOutbound::SessionEstablished {
session_id: session_id.to_string(),
}).await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = session_id;
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
@ -68,8 +62,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
handle_inbound(&session, inbound).await;
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = session.lock().await.send(WsOutbound::Error {
let _ = session.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
}).await;
@ -77,62 +70,47 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
tracing::debug!(session_id = %session_id, "WebSocket closed");
break;
}
_ => {}
}
}
tracing::info!(session_id = %session_id, "CLI session ended");
state.session_manager.remove(&session_id);
}
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
let inbound_clone = inbound.clone();
// 提取 content 和 chat_idCLI 使用 session id 作为 chat_id
let (content, chat_id) = match inbound_clone {
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
// CLI 使用 session 中的 channel_name 作为标识
// chat_id 使用传入的或使用默认
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
(content, chat_id)
}
_ => return,
};
async fn handle_inbound(session: &Arc<Session>, inbound: WsInbound) {
match inbound {
WsInbound::UserInput { content } => {
let user_msg = ChatMessage::user(content);
let mut session_guard = session.lock().await;
let agent = match session_guard.get_or_create_agent(&chat_id).await {
Ok(a) => a,
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent");
let _ = session_guard.send(WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: e.to_string(),
}).await;
return;
}
};
drop(session_guard);
let mut agent = agent.lock().await;
let mut agent = session.agent_loop.lock().await;
match agent.process(user_msg).await {
Ok(response) => {
tracing::debug!(chat_id = %chat_id, "Agent response sent");
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
let _ = session.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
}).await;
}
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
let _ = session.lock().await.send(WsOutbound::Error {
let _ = session.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}
WsInbound::ClearHistory => {
let mut agent = session.agent_loop.lock().await;
agent.clear_history();
let _ = session.send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: "History cleared.".to_string(),
role: "system".to_string(),
}).await;
}
WsInbound::Ping => {
let _ = session.send(WsOutbound::Pong).await;
}
}
}

View File

@ -7,5 +7,3 @@ pub mod gateway;
pub mod client;
pub mod protocol;
pub mod channels;
pub mod logging;
pub mod tools;

View File

@ -1,80 +0,0 @@
use std::path::PathBuf;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{
fmt,
layer::SubscriberExt,
util::SubscriberInitExt,
EnvFilter,
};
/// Get the default log directory path: ~/.picobot/logs
pub fn get_default_log_dir() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("logs")
}
/// Get the default config file path: ~/.picobot/config.json
pub fn get_default_config_path() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot").join("config.json")
}
/// Initialize logging with file appender
/// Logs are written to ~/.picobot/logs/ with daily rotation
pub fn init_logging() {
let log_dir = get_default_log_dir();
// Create log directory if it doesn't exist
if !log_dir.exists() {
if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
}
}
// Create file appender with daily rotation
let file_appender = RollingFileAppender::new(
Rotation::DAILY,
&log_dir,
"picobot.log",
);
// Build subscriber with both console and file output
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let file_layer = fmt::layer()
.with_writer(file_appender)
.with_ansi(false)
.with_target(true)
.with_level(true)
.with_thread_ids(true);
let console_layer = fmt::layer()
.with_target(true)
.with_level(true);
tracing_subscriber::registry()
.with(env_filter)
.with(console_layer)
.with(file_layer)
.init();
tracing::info!("Logging initialized. Log directory: {}", log_dir.display());
}
/// Initialize logging without file output (console only)
pub fn init_logging_console_only() {
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer()
.with_target(true)
.with_level(true);
tracing_subscriber::registry()
.with(env_filter)
.with(console_layer)
.init();
tracing::info!("Logging initialized (console only)");
}

View File

@ -4,20 +4,9 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "type")]
pub enum WsInbound {
#[serde(rename = "user_input")]
UserInput {
content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
channel: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
chat_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
sender_id: Option<String>,
},
UserInput { content: String },
#[serde(rename = "clear_history")]
ClearHistory {
#[serde(default, skip_serializing_if = "Option::is_none")]
chat_id: Option<String>,
},
ClearHistory,
#[serde(rename = "ping")]
Ping,
}

View File

@ -104,19 +104,10 @@ impl LLMProvider for OpenAIProvider {
let mut body = json!({
"model": self.model_id,
"messages": request.messages.iter().map(|m| {
if m.role == "tool" {
json!({
"role": m.role,
"content": m.content,
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else {
json!({
"role": m.role,
"content": m.content
})
}
}).collect::<Vec<_>>(),
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
"max_tokens": request.max_tokens.or(self.max_tokens),

View File

@ -5,10 +5,6 @@ use serde::{Deserialize, Serialize};
pub struct Message {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@ -1,824 +0,0 @@
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
pub struct CalculatorTool;
impl CalculatorTool {
pub fn new() -> Self {
Self
}
}
impl Default for CalculatorTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for CalculatorTool {
fn name(&self) -> &str {
"calculator"
}
fn description(&self) -> &str {
"Perform arithmetic and statistical calculations. Supports 25 functions: \
add, subtract, divide, multiply, pow, sqrt, abs, modulo, round, \
log, ln, exp, factorial, sum, average, median, mode, min, max, \
range, variance, stdev, percentile, count, percentage_change, clamp. \
Use this tool whenever you need to compute a numeric result instead of guessing."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"function": {
"type": "string",
"description": "Calculation to perform. \
Arithmetic: add(values), subtract(values), divide(values), multiply(values), pow(a,b), sqrt(x), abs(x), modulo(a,b), round(x,decimals). \
Logarithmic/exponential: log(x,base?), ln(x), exp(x), factorial(x). \
Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \
Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \
Utility: percentage_change(a,b), clamp(x,min_val,max_val).",
"enum": [
"add", "subtract", "divide", "multiply", "pow", "sqrt",
"abs", "modulo", "round", "log", "ln", "exp", "factorial",
"sum", "average", "median", "mode", "min", "max", "range",
"variance", "stdev", "percentile", "count",
"percentage_change", "clamp"
]
},
"values": {
"type": "array",
"items": { "type": "number" },
"description": "Array of numeric values. Required for: add, subtract, divide, multiply, sum, average, median, mode, min, max, range, variance, stdev, percentile, count."
},
"a": {
"type": "number",
"description": "First operand. Required for: pow, modulo, percentage_change."
},
"b": {
"type": "number",
"description": "Second operand. Required for: pow, modulo, percentage_change."
},
"x": {
"type": "number",
"description": "Input number. Required for: sqrt, abs, exp, ln, log, factorial."
},
"base": {
"type": "number",
"description": "Logarithm base (default: 10). Optional for: log."
},
"decimals": {
"type": "integer",
"description": "Number of decimal places for rounding. Required for: round."
},
"p": {
"type": "integer",
"description": "Percentile rank (0-100). Required for: percentile."
},
"min_val": {
"type": "number",
"description": "Minimum bound. Required for: clamp."
},
"max_val": {
"type": "number",
"description": "Maximum bound. Required for: clamp."
}
},
"required": ["function"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let function = match args.get("function").and_then(|v| v.as_str()) {
Some(f) => f,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: function".to_string()),
});
}
};
let result = match function {
"add" => calc_add(&args),
"subtract" => calc_subtract(&args),
"divide" => calc_divide(&args),
"multiply" => calc_multiply(&args),
"pow" => calc_pow(&args),
"sqrt" => calc_sqrt(&args),
"abs" => calc_abs(&args),
"modulo" => calc_modulo(&args),
"round" => calc_round(&args),
"log" => calc_log(&args),
"ln" => calc_ln(&args),
"exp" => calc_exp(&args),
"factorial" => calc_factorial(&args),
"sum" => calc_sum(&args),
"average" => calc_average(&args),
"median" => calc_median(&args),
"mode" => calc_mode(&args),
"min" => calc_min(&args),
"max" => calc_max(&args),
"range" => calc_range(&args),
"variance" => calc_variance(&args),
"stdev" => calc_stdev(&args),
"percentile" => calc_percentile(&args),
"count" => calc_count(&args),
"percentage_change" => calc_percentage_change(&args),
"clamp" => calc_clamp(&args),
other => Err(format!("Unknown function: {other}")),
};
match result {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(err) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(err),
}),
}
}
}
fn extract_f64(args: &serde_json::Value, key: &str, name: &str) -> Result<f64, String> {
args.get(key)
.and_then(|v| v.as_f64())
.ok_or_else(|| format!("Missing required parameter: {name}"))
}
fn extract_i64(args: &serde_json::Value, key: &str, name: &str) -> Result<i64, String> {
args.get(key)
.and_then(|v| v.as_i64())
.ok_or_else(|| format!("Missing required parameter: {name}"))
}
fn extract_values(args: &serde_json::Value, min_len: usize) -> Result<Vec<f64>, String> {
let values = args
.get("values")
.and_then(|v| v.as_array())
.ok_or_else(|| "Missing required parameter: values (array of numbers)".to_string())?;
if values.len() < min_len {
return Err(format!(
"Expected at least {min_len} value(s), got {}",
values.len()
));
}
let mut nums = Vec::with_capacity(values.len());
for (i, v) in values.iter().enumerate() {
match v.as_f64() {
Some(n) => nums.push(n),
None => return Err(format!("values[{i}] is not a valid number")),
}
}
Ok(nums)
}
fn format_num(n: f64) -> String {
if n == n.floor() && n.abs() < 1e15 {
#[allow(clippy::cast_possible_truncation)]
let rounded = n.round() as i128;
format!("{rounded}")
} else {
format!("{n}")
}
}
fn calc_add(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
Ok(format_num(values.iter().sum()))
}
fn calc_subtract(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
let mut iter = values.iter();
let mut result = *iter.next().unwrap();
for v in iter {
result -= v;
}
Ok(format_num(result))
}
fn calc_divide(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
let mut iter = values.iter();
let mut result = *iter.next().unwrap();
for v in iter {
if *v == 0.0 {
return Err("Division by zero".to_string());
}
result /= v;
}
Ok(format_num(result))
}
fn calc_multiply(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 2)?;
let mut result = 1.0;
for v in &values {
result *= v;
}
Ok(format_num(result))
}
fn calc_pow(args: &serde_json::Value) -> Result<String, String> {
let base = extract_f64(args, "a", "a (base)")?;
let exp = extract_f64(args, "b", "b (exponent)")?;
Ok(format_num(base.powf(exp)))
}
fn calc_sqrt(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
if x < 0.0 {
return Err("Cannot compute square root of a negative number".to_string());
}
Ok(format_num(x.sqrt()))
}
fn calc_abs(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
Ok(format_num(x.abs()))
}
fn calc_modulo(args: &serde_json::Value) -> Result<String, String> {
let a = extract_f64(args, "a", "a")?;
let b = extract_f64(args, "b", "b")?;
if b == 0.0 {
return Err("Modulo by zero".to_string());
}
Ok(format_num(a % b))
}
fn calc_round(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
let decimals = extract_i64(args, "decimals", "decimals")?;
if decimals < 0 {
return Err("decimals must be non-negative".to_string());
}
let multiplier = 10_f64.powi(i32::try_from(decimals).unwrap_or(i32::MAX));
Ok(format_num((x * multiplier).round() / multiplier))
}
fn calc_log(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
if x <= 0.0 {
return Err("Logarithm requires a positive number".to_string());
}
let base = args.get("base").and_then(|v| v.as_f64()).unwrap_or(10.0);
if base <= 0.0 || base == 1.0 {
return Err("Logarithm base must be positive and not equal to 1".to_string());
}
Ok(format_num(x.log(base)))
}
fn calc_ln(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
if x <= 0.0 {
return Err("Natural logarithm requires a positive number".to_string());
}
Ok(format_num(x.ln()))
}
fn calc_exp(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
Ok(format_num(x.exp()))
}
fn calc_factorial(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
if x < 0.0 || x != x.floor() {
return Err("Factorial requires a non-negative integer".to_string());
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let n = x.round() as u128;
if n > 170 {
return Err("Factorial result exceeds f64 range (max input: 170)".to_string());
}
let mut result: u128 = 1;
for i in 2..=n {
result *= i;
}
Ok(result.to_string())
}
fn calc_sum(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
Ok(format_num(values.iter().sum()))
}
fn calc_average(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
if values.is_empty() {
return Err("Cannot compute average of an empty array".to_string());
}
Ok(format_num(values.iter().sum::<f64>() / values.len() as f64))
}
fn calc_median(args: &serde_json::Value) -> Result<String, String> {
let mut values = extract_values(args, 1)?;
if values.is_empty() {
return Err("Cannot compute median of an empty array".to_string());
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = values.len();
if len % 2 == 0 {
Ok(format_num(f64::midpoint(
values[len / 2 - 1],
values[len / 2],
)))
} else {
Ok(format_num(values[len / 2]))
}
}
fn calc_mode(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
if values.is_empty() {
return Err("Cannot compute mode of an empty array".to_string());
}
let mut freq: std::collections::HashMap<u64, usize> = std::collections::HashMap::new();
for &v in &values {
let key = v.to_bits();
*freq.entry(key).or_insert(0) += 1;
}
let max_freq = *freq.values().max().unwrap();
let mut seen = std::collections::HashSet::new();
let mut modes = Vec::new();
for &v in &values {
let key = v.to_bits();
if freq[&key] == max_freq && seen.insert(key) {
modes.push(v);
}
}
if modes.len() == 1 {
Ok(format_num(modes[0]))
} else {
let formatted: Vec<String> = modes.iter().map(|v| format_num(*v)).collect();
Ok(format!("Modes: {}", formatted.join(", ")))
}
}
fn calc_min(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
let Some(min_val) = values.iter().copied().reduce(f64::min) else {
return Err("Cannot compute min of an empty array".to_string());
};
Ok(format_num(min_val))
}
fn calc_max(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
let Some(max_val) = values.iter().copied().reduce(f64::max) else {
return Err("Cannot compute max of an empty array".to_string());
};
Ok(format_num(max_val))
}
fn calc_range(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
if values.is_empty() {
return Err("Cannot compute range of an empty array".to_string());
}
let min_val = values.iter().copied().fold(f64::INFINITY, f64::min);
let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
Ok(format_num(max_val - min_val))
}
fn calc_variance(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
if values.len() < 2 {
return Err("Variance requires at least 2 values".to_string());
}
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
Ok(format_num(variance))
}
fn calc_stdev(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
if values.len() < 2 {
return Err("Standard deviation requires at least 2 values".to_string());
}
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
Ok(format_num(variance.sqrt()))
}
fn calc_percentile(args: &serde_json::Value) -> Result<String, String> {
let mut values = extract_values(args, 1)?;
if values.is_empty() {
return Err("Cannot compute percentile of an empty array".to_string());
}
let p = extract_i64(args, "p", "p (percentile rank 0-100)")?;
if !(0..=100).contains(&p) {
return Err("Percentile rank must be between 0 and 100".to_string());
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let idx_f = p as f64 / 100.0 * (values.len() - 1) as f64;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let index = idx_f.round().clamp(0.0, (values.len() - 1) as f64) as usize;
Ok(format_num(values[index]))
}
fn calc_count(args: &serde_json::Value) -> Result<String, String> {
let values = extract_values(args, 1)?;
Ok(values.len().to_string())
}
fn calc_percentage_change(args: &serde_json::Value) -> Result<String, String> {
let old = extract_f64(args, "a", "a (old value)")?;
let new = extract_f64(args, "b", "b (new value)")?;
if old == 0.0 {
return Err("Cannot compute percentage change from zero".to_string());
}
Ok(format_num((new - old) / old.abs() * 100.0))
}
fn calc_clamp(args: &serde_json::Value) -> Result<String, String> {
let x = extract_f64(args, "x", "x")?;
let min_val = extract_f64(args, "min_val", "min_val")?;
let max_val = extract_f64(args, "max_val", "max_val")?;
if min_val > max_val {
return Err("min_val must be less than or equal to max_val".to_string());
}
Ok(format_num(x.clamp(min_val, max_val)))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_add() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "add", "values": [1.0, 2.0, 3.5]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "6.5");
}
#[tokio::test]
async fn test_subtract() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "subtract", "values": [10.0, 3.0, 1.5]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "5.5");
}
#[tokio::test]
async fn test_divide() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "divide", "values": [100.0, 4.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "25");
}
#[tokio::test]
async fn test_divide_by_zero() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "divide", "values": [10.0, 0.0]}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("zero"));
}
#[tokio::test]
async fn test_multiply() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "multiply", "values": [3.0, 4.0, 5.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "60");
}
#[tokio::test]
async fn test_pow() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "pow", "a": 2.0, "b": 10.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "1024");
}
#[tokio::test]
async fn test_sqrt() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "sqrt", "x": 144.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "12");
}
#[tokio::test]
async fn test_sqrt_negative() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "sqrt", "x": -4.0}))
.await
.unwrap();
assert!(!result.success);
}
#[tokio::test]
async fn test_abs() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "abs", "x": -42.5}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "42.5");
}
#[tokio::test]
async fn test_modulo() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "modulo", "a": 17.0, "b": 5.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2");
}
#[tokio::test]
async fn test_round() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "round", "x": 2.715, "decimals": 2}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2.72");
}
#[tokio::test]
async fn test_log_base10() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "log", "x": 100.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2");
}
#[tokio::test]
async fn test_log_custom_base() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "log", "x": 8.0, "base": 2.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "3");
}
#[tokio::test]
async fn test_ln() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "ln", "x": 1.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "0");
}
#[tokio::test]
async fn test_exp() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "exp", "x": 0.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "1");
}
#[tokio::test]
async fn test_factorial() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "factorial", "x": 5.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "120");
}
#[tokio::test]
async fn test_average() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "average", "values": [10.0, 20.0, 30.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "20");
}
#[tokio::test]
async fn test_median_odd() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "median", "values": [3.0, 1.0, 2.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2");
}
#[tokio::test]
async fn test_median_even() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "median", "values": [4.0, 1.0, 3.0, 2.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2.5");
}
#[tokio::test]
async fn test_mode() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "mode", "values": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "3");
}
#[tokio::test]
async fn test_min() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "min", "values": [5.0, 2.0, 8.0, 1.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "1");
}
#[tokio::test]
async fn test_max() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "max", "values": [5.0, 2.0, 8.0, 1.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "8");
}
#[tokio::test]
async fn test_range() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "range", "values": [1.0, 5.0, 10.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "9");
}
#[tokio::test]
async fn test_variance() {
let tool = CalculatorTool::new();
let result = tool
.execute(
json!({"function": "variance", "values": [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]}),
)
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "4");
}
#[tokio::test]
async fn test_stdev() {
let tool = CalculatorTool::new();
let result = tool
.execute(
json!({"function": "stdev", "values": [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]}),
)
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "2");
}
#[tokio::test]
async fn test_percentile_50() {
let tool = CalculatorTool::new();
let result = tool
.execute(
json!({"function": "percentile", "values": [1.0, 2.0, 3.0, 4.0, 5.0], "p": 50}),
)
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "3");
}
#[tokio::test]
async fn test_count() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "count", "values": [1.0, 2.0, 3.0, 4.0, 5.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "5");
}
#[tokio::test]
async fn test_percentage_change() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "percentage_change", "a": 50.0, "b": 75.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "50");
}
#[tokio::test]
async fn test_clamp_within_range() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "clamp", "x": 5.0, "min_val": 1.0, "max_val": 10.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "5");
}
#[tokio::test]
async fn test_clamp_below_min() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "clamp", "x": -5.0, "min_val": 0.0, "max_val": 10.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "0");
}
#[tokio::test]
async fn test_clamp_above_max() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "clamp", "x": 15.0, "min_val": 0.0, "max_val": 10.0}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "10");
}
#[tokio::test]
async fn test_unknown_function() {
let tool = CalculatorTool::new();
let result = tool.execute(json!({"function": "unknown"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("Unknown function"));
}
#[tokio::test]
async fn test_sum() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "sum", "values": [1.0, 2.0, 3.0, 4.0, 5.0]}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "15");
}
}

View File

@ -1,7 +0,0 @@
pub mod calculator;
pub mod registry;
pub mod traits;
pub use calculator::CalculatorTool;
pub use registry::ToolRegistry;
pub use traits::{Tool, ToolResult};

View File

@ -1,53 +0,0 @@
use std::collections::HashMap;
use crate::providers::{Tool, ToolFunction};
use super::traits::Tool as ToolTrait;
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn ToolTrait>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register<T: ToolTrait + 'static>(&mut self, tool: T) {
self.tools.insert(tool.name().to_string(), Box::new(tool));
}
pub fn get(&self, name: &str) -> Option<&Box<dyn ToolTrait>> {
self.tools.get(name)
}
pub fn get_definitions(&self) -> Vec<Tool> {
self.tools
.values()
.map(|tool| Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters_schema(),
},
})
.collect()
}
pub fn has_tools(&self) -> bool {
!self.tools.is_empty()
}
pub fn tool_names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}

View File

@ -1,16 +0,0 @@
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct ToolResult {
pub success: bool,
pub output: String,
pub error: Option<String>,
}
#[async_trait]
pub trait Tool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
}