Compare commits
No commits in common. "9834bd75cf2a1c3c06c65c08fcb1f83fdb484f3e" and "04736f9f46217e1ece03d5097e70a09cecc55448" have entirely different histories.
9834bd75cf
...
04736f9f46
@ -19,7 +19,3 @@ futures-util = "0.3"
|
|||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
dirs = "6.0.0"
|
dirs = "6.0.0"
|
||||||
prost = "0.14"
|
prost = "0.14"
|
||||||
tracing = "0.1"
|
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
|
||||||
tracing-appender = "0.2"
|
|
||||||
anyhow = "1.0"
|
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
||||||
use crate::tools::ToolRegistry;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub struct AgentLoop {
|
pub struct AgentLoop {
|
||||||
provider: Box<dyn LLMProvider>,
|
provider: Box<dyn LLMProvider>,
|
||||||
history: Vec<ChatMessage>,
|
history: Vec<ChatMessage>,
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
@ -18,25 +15,9 @@ impl AgentLoop {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
provider,
|
provider,
|
||||||
history: Vec::new(),
|
history: Vec::new(),
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
|
||||||
let provider = create_provider(provider_config)
|
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
provider,
|
|
||||||
history: Vec::new(),
|
|
||||||
tools,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
|
||||||
&self.tools
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
|
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
|
||||||
self.history.push(user_message.clone());
|
self.history.push(user_message.clone());
|
||||||
|
|
||||||
@ -45,52 +26,18 @@ impl AgentLoop {
|
|||||||
.map(|m| Message {
|
.map(|m| Message {
|
||||||
role: m.role.clone(),
|
role: m.role.clone(),
|
||||||
content: m.content.clone(),
|
content: m.content.clone(),
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
|
||||||
name: m.tool_name.clone(),
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
tracing::debug!(history_len = self.history.len(), "Sending request to LLM");
|
|
||||||
|
|
||||||
let tools = if self.tools.has_tools() {
|
|
||||||
Some(self.tools.get_definitions())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
tools,
|
tools: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = (*self.provider).chat(request).await
|
let response = (*self.provider).chat(request).await
|
||||||
.map_err(|e| {
|
.map_err(|e| AgentError::LlmError(e.to_string()))?;
|
||||||
tracing::error!(error = %e, "LLM request failed");
|
|
||||||
AgentError::LlmError(e.to_string())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
tracing::debug!(response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received");
|
|
||||||
|
|
||||||
if !response.tool_calls.is_empty() {
|
|
||||||
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content.clone());
|
|
||||||
self.history.push(assistant_message.clone());
|
|
||||||
|
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
|
||||||
|
|
||||||
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
|
||||||
let tool_message = ChatMessage::tool(
|
|
||||||
tool_call.id.clone(),
|
|
||||||
tool_call.name.clone(),
|
|
||||||
result.clone(),
|
|
||||||
);
|
|
||||||
self.history.push(tool_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.continue_with_tool_results(response.content).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
self.history.push(assistant_message.clone());
|
self.history.push(assistant_message.clone());
|
||||||
@ -98,81 +45,8 @@ impl AgentLoop {
|
|||||||
Ok(assistant_message)
|
Ok(assistant_message)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn continue_with_tool_results(&mut self, _original_content: String) -> Result<ChatMessage, AgentError> {
|
|
||||||
let messages: Vec<Message> = self.history
|
|
||||||
.iter()
|
|
||||||
.map(|m| Message {
|
|
||||||
role: m.role.clone(),
|
|
||||||
content: m.content.clone(),
|
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
|
||||||
name: m.tool_name.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = if self.tools.has_tools() {
|
|
||||||
Some(self.tools.get_definitions())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages,
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
tools,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = (*self.provider).chat(request).await
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::error!(error = %e, "LLM continuation request failed");
|
|
||||||
AgentError::LlmError(e.to_string())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
|
||||||
self.history.push(assistant_message.clone());
|
|
||||||
|
|
||||||
Ok(assistant_message)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
|
|
||||||
let mut results = Vec::with_capacity(tool_calls.len());
|
|
||||||
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
let result = self.execute_tool(tool_call).await;
|
|
||||||
results.push(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
results
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
|
|
||||||
let tool = match self.tools.get(&tool_call.name) {
|
|
||||||
Some(t) => t,
|
|
||||||
None => {
|
|
||||||
tracing::warn!(tool = %tool_call.name, "Tool not found");
|
|
||||||
return format!("Error: Tool '{}' not found", tool_call.name);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match tool.execute(tool_call.arguments.clone()).await {
|
|
||||||
Ok(result) => {
|
|
||||||
if result.success {
|
|
||||||
result.output
|
|
||||||
} else {
|
|
||||||
format!("Error: {}", result.error.unwrap_or_default())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
|
||||||
format!("Error: {}", e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clear_history(&mut self) {
|
pub fn clear_history(&mut self) {
|
||||||
let len = self.history.len();
|
|
||||||
self.history.clear();
|
self.history.clear();
|
||||||
tracing::debug!(previous_len = len, "Chat history cleared");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn history(&self) -> &[ChatMessage] {
|
pub fn history(&self) -> &[ChatMessage] {
|
||||||
@ -184,7 +58,6 @@ impl AgentLoop {
|
|||||||
pub enum AgentError {
|
pub enum AgentError {
|
||||||
ProviderCreation(String),
|
ProviderCreation(String),
|
||||||
LlmError(String),
|
LlmError(String),
|
||||||
Other(String),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for AgentError {
|
impl std::fmt::Display for AgentError {
|
||||||
@ -192,7 +65,6 @@ impl std::fmt::Display for AgentError {
|
|||||||
match self {
|
match self {
|
||||||
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
AgentError::ProviderCreation(e) => write!(f, "Provider creation error: {}", e),
|
||||||
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
|
||||||
AgentError::Other(e) => write!(f, "{}", e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,10 +6,6 @@ pub struct ChatMessage {
|
|||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub timestamp: i64,
|
pub timestamp: i64,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_call_id: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_name: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatMessage {
|
impl ChatMessage {
|
||||||
@ -19,8 +15,6 @@ impl ChatMessage {
|
|||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
|
||||||
tool_name: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,8 +24,6 @@ impl ChatMessage {
|
|||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
|
||||||
tool_name: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,19 +33,6 @@ impl ChatMessage {
|
|||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
|
||||||
tool_name: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
role: "tool".to_string(),
|
|
||||||
content: content.into(),
|
|
||||||
timestamp: current_timestamp(),
|
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
|
||||||
tool_name: Some(tool_name.into()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::sync::{broadcast, RwLock};
|
use tokio::sync::{broadcast, RwLock, Mutex};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use prost::{Message as ProstMessage, bytes::Bytes};
|
use prost::{Message as ProstMessage, bytes::Bytes};
|
||||||
|
|
||||||
|
use crate::agent::AgentLoop;
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::channels::manager::MessageHandler;
|
|
||||||
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
||||||
|
|
||||||
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
||||||
@ -131,31 +134,25 @@ pub struct FeishuChannel {
|
|||||||
running: Arc<RwLock<bool>>,
|
running: Arc<RwLock<bool>>,
|
||||||
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
||||||
connected: Arc<RwLock<bool>>,
|
connected: Arc<RwLock<bool>>,
|
||||||
/// Message handler for routing messages to Gateway
|
/// Dedup: message_id -> timestamp (cleaned after 30 min)
|
||||||
message_handler: Arc<dyn MessageHandler>,
|
seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||||
}
|
/// Agent for processing messages
|
||||||
|
agent: Arc<Mutex<AgentLoop>>,
|
||||||
/// Parsed message data from a Feishu frame
|
|
||||||
struct ParsedMessage {
|
|
||||||
message_id: String,
|
|
||||||
open_id: String,
|
|
||||||
chat_id: String,
|
|
||||||
content: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeishuChannel {
|
impl FeishuChannel {
|
||||||
pub fn new(
|
pub fn new(config: FeishuChannelConfig, provider_config: LLMProviderConfig) -> Result<Self, ChannelError> {
|
||||||
config: FeishuChannelConfig,
|
let agent = AgentLoop::new(provider_config)
|
||||||
message_handler: Arc<dyn MessageHandler>,
|
.map_err(|e| ChannelError::Other(format!("Failed to create agent: {}", e)))?;
|
||||||
_provider_config: LLMProviderConfig,
|
|
||||||
) -> Result<Self, ChannelError> {
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config,
|
config,
|
||||||
http_client: reqwest::Client::new(),
|
http_client: reqwest::Client::new(),
|
||||||
running: Arc::new(RwLock::new(false)),
|
running: Arc::new(RwLock::new(false)),
|
||||||
shutdown_tx: Arc::new(RwLock::new(None)),
|
shutdown_tx: Arc::new(RwLock::new(None)),
|
||||||
connected: Arc::new(RwLock::new(false)),
|
connected: Arc::new(RwLock::new(false)),
|
||||||
message_handler,
|
seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
agent: Arc::new(Mutex::new(agent)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,22 +259,23 @@ impl FeishuChannel {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle incoming message - delegate to message handler and send response
|
/// Handle incoming message - process through agent and send response
|
||||||
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
|
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
|
||||||
tracing::info!(open_id, chat_id, "Processing message from Feishu");
|
println!("Feishu: processing message from {} in chat {}: {}", open_id, chat_id, content);
|
||||||
|
|
||||||
// Delegate to message handler (Gateway)
|
// Process through agent
|
||||||
let response = self.message_handler
|
let user_msg = ChatMessage::user(content);
|
||||||
.handle_message("feishu", open_id, chat_id, content)
|
let mut agent = self.agent.lock().await;
|
||||||
.await?;
|
let response = agent.process(user_msg).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Agent error: {}", e)))?;
|
||||||
|
|
||||||
// Send response to the chat
|
// Send response to the chat
|
||||||
// Use open_id for p2p chats, chat_id for group chats
|
// Use open_id for p2p chats, chat_id for group chats
|
||||||
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
|
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
|
||||||
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||||
|
|
||||||
self.send_message(receive_id, receive_id_type, &response).await?;
|
self.send_message(receive_id, receive_id_type, &response.content).await?;
|
||||||
tracing::info!(receive_id, "Sent response to Feishu");
|
println!("Feishu: sent response to {}", receive_id);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -295,8 +293,8 @@ impl FeishuChannel {
|
|||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle incoming binary PbFrame - returns Some(ParsedMessage) if we need to ack
|
/// Handle incoming binary PbFrame - returns Some(message_id) if we need to ack
|
||||||
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<ParsedMessage>, ChannelError> {
|
async fn handle_frame(&self, frame: &PbFrame) -> Result<Option<String>, ChannelError> {
|
||||||
// method 0 = CONTROL (ping/pong)
|
// method 0 = CONTROL (ping/pong)
|
||||||
if frame.method == 0 {
|
if frame.method == 0 {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@ -327,7 +325,20 @@ impl FeishuChannel {
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deduplication check with TTL cleanup
|
||||||
let message_id = payload_data.message.message_id.clone();
|
let message_id = payload_data.message.message_id.clone();
|
||||||
|
{
|
||||||
|
let mut seen = self.seen_ids.write().await;
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// Clean expired entries (older than 30 min)
|
||||||
|
seen.retain(|_, ts| now.duration_since(*ts).as_secs() < 1800);
|
||||||
|
|
||||||
|
if seen.contains_key(&message_id) {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
seen.insert(message_id.clone(), now);
|
||||||
|
}
|
||||||
|
|
||||||
let open_id = payload_data.sender.sender_id.open_id
|
let open_id = payload_data.sender.sender_id.open_id
|
||||||
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
||||||
@ -337,12 +348,13 @@ impl FeishuChannel {
|
|||||||
let msg_type = msg.message_type.as_str();
|
let msg_type = msg.message_type.as_str();
|
||||||
let content = parse_message_content(msg_type, &msg.content);
|
let content = parse_message_content(msg_type, &msg.content);
|
||||||
|
|
||||||
Ok(Some(ParsedMessage {
|
// Handle the message - process and send response
|
||||||
message_id,
|
if let Err(e) = self.handle_message(&open_id, &chat_id, &content).await {
|
||||||
open_id,
|
eprintln!("Error handling message: {}", e);
|
||||||
chat_id,
|
}
|
||||||
content,
|
|
||||||
}))
|
// Return message_id for ack
|
||||||
|
Ok(Some(message_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send acknowledgment for a message
|
/// Send acknowledgment for a message
|
||||||
@ -363,14 +375,14 @@ impl FeishuChannel {
|
|||||||
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
|
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
|
||||||
|
|
||||||
let service_id = Self::extract_service_id(&wss_url);
|
let service_id = Self::extract_service_id(&wss_url);
|
||||||
tracing::info!(url = %wss_url, "Connecting to Feishu WebSocket");
|
println!("Feishu: connecting to {}", wss_url);
|
||||||
|
|
||||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url)
|
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
|
.map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
|
||||||
|
|
||||||
*self.connected.write().await = true;
|
*self.connected.write().await = true;
|
||||||
tracing::info!("Feishu WebSocket connected");
|
println!("Feishu channel connected");
|
||||||
|
|
||||||
let (mut write, mut read) = ws_stream.split();
|
let (mut write, mut read) = ws_stream.split();
|
||||||
|
|
||||||
@ -404,25 +416,17 @@ impl FeishuChannel {
|
|||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
||||||
let bytes: Bytes = data;
|
let bytes: Bytes = data;
|
||||||
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
||||||
// Parse the frame first
|
// Handle the frame and get message_id for ack if needed
|
||||||
match self.handle_frame(&frame).await {
|
match self.handle_frame(&frame).await {
|
||||||
Ok(Some(parsed)) => {
|
Ok(Some(_message_id)) => {
|
||||||
// Send ACK immediately (Feishu requires within 3 s)
|
// Send ACK immediately (Feishu requires within 3 s)
|
||||||
if let Err(e) = Self::send_ack(&frame, &mut write).await {
|
if let Err(e) = Self::send_ack(&frame, &mut write).await {
|
||||||
tracing::error!(error = %e, "Failed to send ACK to Feishu");
|
eprintln!("Error sending ack: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then process message asynchronously (don't await)
|
|
||||||
let channel = self.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await {
|
|
||||||
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to handle Feishu message");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
Ok(None) => {}
|
Ok(None) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to parse Feishu frame");
|
eprintln!("Error handling frame: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -442,11 +446,10 @@ impl FeishuChannel {
|
|||||||
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
||||||
}
|
}
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
||||||
tracing::debug!("Feishu WebSocket closed");
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Some(Err(e)) => {
|
Some(Err(e)) => {
|
||||||
tracing::warn!(error = %e, "Feishu WebSocket error");
|
eprintln!("WS error: {}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
@ -466,12 +469,12 @@ impl FeishuChannel {
|
|||||||
payload: None,
|
payload: None,
|
||||||
};
|
};
|
||||||
if write.send(tokio_tungstenite::tungstenite::Message::Binary(ping.encode_to_vec().into())).await.is_err() {
|
if write.send(tokio_tungstenite::tungstenite::Message::Binary(ping.encode_to_vec().into())).await.is_err() {
|
||||||
tracing::warn!("Feishu ping failed, reconnecting");
|
eprintln!("Feishu: ping failed, reconnecting");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = shutdown_rx.recv() => {
|
_ = shutdown_rx.recv() => {
|
||||||
tracing::info!("Feishu channel shutdown signal received");
|
println!("Feishu channel shutdown signal received");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -553,13 +556,13 @@ impl Channel for FeishuChannel {
|
|||||||
let shutdown_rx = shutdown_tx.subscribe();
|
let shutdown_rx = shutdown_tx.subscribe();
|
||||||
match channel.run_ws_loop(shutdown_rx).await {
|
match channel.run_ws_loop(shutdown_rx).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
tracing::info!("Feishu WebSocket disconnected");
|
println!("Feishu WebSocket disconnected");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
consecutive_failures += 1;
|
consecutive_failures += 1;
|
||||||
tracing::error!(attempt = consecutive_failures, error = %e, "Feishu WebSocket error");
|
eprintln!("Feishu WebSocket error (attempt {}): {}", consecutive_failures, e);
|
||||||
if consecutive_failures >= max_failures {
|
if consecutive_failures >= max_failures {
|
||||||
tracing::error!("Feishu channel: max failures reached, stopping");
|
eprintln!("Feishu channel: max failures reached, stopping");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -569,15 +572,15 @@ impl Channel for FeishuChannel {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("Feishu channel retrying in 5s...");
|
println!("Feishu channel retrying in 5s...");
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
*channel.running.write().await = false;
|
*channel.running.write().await = false;
|
||||||
tracing::info!("Feishu channel stopped");
|
println!("Feishu channel stopped");
|
||||||
});
|
});
|
||||||
|
|
||||||
tracing::info!("Feishu channel started");
|
println!("Feishu channel started");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -592,7 +595,6 @@ impl Channel for FeishuChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_running(&self) -> bool {
|
fn is_running(&self) -> bool {
|
||||||
// Note: blocking read, acceptable for this use case
|
false
|
||||||
self.running.try_read().map(|r| *r).unwrap_or(false)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,60 +1,41 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use async_trait::async_trait;
|
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::channels::feishu::FeishuChannel;
|
use crate::channels::feishu::FeishuChannel;
|
||||||
use crate::config::{Config, FeishuChannelConfig};
|
use crate::config::Config;
|
||||||
|
|
||||||
/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦
|
|
||||||
#[async_trait]
|
|
||||||
pub trait MessageHandler: Send + Sync {
|
|
||||||
async fn handle_message(
|
|
||||||
&self,
|
|
||||||
channel_name: &str,
|
|
||||||
sender_id: &str,
|
|
||||||
chat_id: &str,
|
|
||||||
content: &str,
|
|
||||||
) -> Result<String, ChannelError>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ChannelManager 管理所有 Channel
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ChannelManager {
|
pub struct ChannelManager {
|
||||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>,
|
||||||
message_handler: Arc<dyn MessageHandler>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChannelManager {
|
impl ChannelManager {
|
||||||
pub fn new(message_handler: Arc<dyn MessageHandler>) -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||||
message_handler,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取 MessageHandler 用于让 Channel 调用
|
pub async fn init(&self, config: &Config) -> Result<(), ChannelError> {
|
||||||
pub fn get_handler(&self) -> Arc<dyn MessageHandler> {
|
|
||||||
self.message_handler.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 初始化所有 Channel
|
|
||||||
pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
|
|
||||||
// Initialize Feishu channel if enabled
|
// Initialize Feishu channel if enabled
|
||||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||||
if feishu_config.enabled {
|
if feishu_config.enabled {
|
||||||
let handler = self.get_handler();
|
let agent_name = &feishu_config.agent;
|
||||||
let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config)
|
let provider_config = config.get_provider_config(agent_name)
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to get provider config: {}", e)))?;
|
||||||
|
|
||||||
|
let channel = FeishuChannel::new(feishu_config.clone(), provider_config)
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||||
|
|
||||||
self.channels
|
self.channels
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert("feishu".to_string(), Arc::new(channel));
|
.insert("feishu".to_string(), Arc::new(channel));
|
||||||
tracing::info!("Feishu channel registered");
|
println!("Feishu channel registered");
|
||||||
} else {
|
} else {
|
||||||
tracing::info!("Feishu channel disabled in config");
|
println!("Feishu channel disabled in config");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -63,9 +44,11 @@ impl ChannelManager {
|
|||||||
pub async fn start_all(&self) -> Result<(), ChannelError> {
|
pub async fn start_all(&self) -> Result<(), ChannelError> {
|
||||||
let channels = self.channels.read().await;
|
let channels = self.channels.read().await;
|
||||||
for (name, channel) in channels.iter() {
|
for (name, channel) in channels.iter() {
|
||||||
tracing::info!(channel = %name, "Starting channel");
|
println!("Starting channel: {}", name);
|
||||||
if let Err(e) = channel.start().await {
|
if let Err(e) = channel.start().await {
|
||||||
tracing::error!(channel = %name, error = %e, "Failed to start channel");
|
eprintln!("Warning: Failed to start channel {}: {}", name, e);
|
||||||
|
// Channel failed to start - it should have logged why
|
||||||
|
// Continue starting other channels
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -74,44 +57,22 @@ impl ChannelManager {
|
|||||||
pub async fn stop_all(&self) -> Result<(), ChannelError> {
|
pub async fn stop_all(&self) -> Result<(), ChannelError> {
|
||||||
let mut channels = self.channels.write().await;
|
let mut channels = self.channels.write().await;
|
||||||
for (name, channel) in channels.iter() {
|
for (name, channel) in channels.iter() {
|
||||||
tracing::info!(channel = %name, "Stopping channel");
|
println!("Stopping channel: {}", name);
|
||||||
if let Err(e) = channel.stop().await {
|
if let Err(e) = channel.stop().await {
|
||||||
tracing::error!(channel = %name, error = %e, "Error stopping channel");
|
eprintln!("Error stopping channel {}: {}", name, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
channels.clear();
|
channels.clear();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel + Send + Sync>> {
|
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel>> {
|
||||||
self.channels.read().await.get(name).cloned()
|
self.channels.read().await.get(name).cloned()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gateway 实现 MessageHandler trait
|
impl Default for ChannelManager {
|
||||||
#[derive(Clone)]
|
fn default() -> Self {
|
||||||
pub struct GatewayMessageHandler {
|
Self::new()
|
||||||
session_manager: crate::gateway::session::SessionManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GatewayMessageHandler {
|
|
||||||
pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self {
|
|
||||||
Self { session_manager }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl MessageHandler for GatewayMessageHandler {
|
|
||||||
async fn handle_message(
|
|
||||||
&self,
|
|
||||||
channel_name: &str,
|
|
||||||
sender_id: &str,
|
|
||||||
chat_id: &str,
|
|
||||||
content: &str,
|
|
||||||
) -> Result<String, ChannelError> {
|
|
||||||
self.session_manager
|
|
||||||
.handle_message(channel_name, sender_id, chat_id, content)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,7 @@ fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
|||||||
|
|
||||||
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (ws_stream, _) = connect_async(gateway_url).await?;
|
let (ws_stream, _) = connect_async(gateway_url).await?;
|
||||||
tracing::info!(url = %gateway_url, "Connected to gateway");
|
println!("Connected to gateway");
|
||||||
|
|
||||||
let (mut sender, mut receiver) = ws_stream.split();
|
let (mut sender, mut receiver) = ws_stream.split();
|
||||||
|
|
||||||
@ -35,7 +35,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
input.write_output(&format!("Error: {}", message)).await?;
|
input.write_output(&format!("Error: {}", message)).await?;
|
||||||
}
|
}
|
||||||
WsOutbound::SessionEstablished { session_id } => {
|
WsOutbound::SessionEstablished { session_id } => {
|
||||||
tracing::debug!(session_id = %session_id, "Session established");
|
|
||||||
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
@ -43,7 +42,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(Ok(Message::Close(_))) | None => {
|
Some(Ok(Message::Close(_))) | None => {
|
||||||
tracing::info!("Gateway disconnected");
|
|
||||||
input.write_output("Gateway disconnected").await?;
|
input.write_output("Gateway disconnected").await?;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -60,7 +58,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
"__CLEAR__" => {
|
"__CLEAR__" => {
|
||||||
let inbound = WsInbound::ClearHistory { chat_id: None };
|
let inbound = WsInbound::ClearHistory;
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
}
|
}
|
||||||
@ -69,22 +67,17 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
let inbound = WsInbound::UserInput {
|
let inbound = WsInbound::UserInput { content: msg.content };
|
||||||
content: msg.content,
|
|
||||||
channel: None,
|
|
||||||
chat_id: None,
|
|
||||||
sender_id: None,
|
|
||||||
};
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
if sender.send(Message::Text(text.into())).await.is_err() {
|
if sender.send(Message::Text(text.into())).await.is_err() {
|
||||||
tracing::error!("Failed to send message to gateway");
|
eprintln!("Failed to send message");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(None) => break,
|
Ok(None) => break,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "Input error");
|
eprintln!("Input error: {}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -67,8 +67,6 @@ pub struct GatewayConfig {
|
|||||||
pub host: String,
|
pub host: String,
|
||||||
#[serde(default = "default_gateway_port")]
|
#[serde(default = "default_gateway_port")]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
#[serde(default, rename = "session_ttl_hours")]
|
|
||||||
pub session_ttl_hours: Option<u64>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -94,7 +92,6 @@ impl Default for GatewayConfig {
|
|||||||
Self {
|
Self {
|
||||||
host: default_gateway_host(),
|
host: default_gateway_host(),
|
||||||
port: default_gateway_port(),
|
port: default_gateway_port(),
|
||||||
session_ttl_hours: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,7 +119,7 @@ pub struct LLMProviderConfig {
|
|||||||
|
|
||||||
fn get_default_config_path() -> PathBuf {
|
fn get_default_config_path() -> PathBuf {
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||||
home.join(".picobot").join("config.json")
|
home.join(".config").join("picobot").join("config.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -138,13 +135,13 @@ impl Config {
|
|||||||
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
fn load_from(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
load_env_file()?;
|
load_env_file()?;
|
||||||
let content = if path.exists() {
|
let content = if path.exists() {
|
||||||
tracing::info!(path = %path.display(), "Config loaded");
|
println!("Config loaded from: {}", path.display());
|
||||||
fs::read_to_string(path)?
|
fs::read_to_string(path)?
|
||||||
} else {
|
} else {
|
||||||
// Fallback to current directory
|
// Fallback to current directory
|
||||||
let fallback = Path::new("config.json");
|
let fallback = Path::new("config.json");
|
||||||
if fallback.exists() {
|
if fallback.exists() {
|
||||||
tracing::info!(path = %fallback.display(), "Config loaded from fallback path");
|
println!("Config loaded from: {}", fallback.display());
|
||||||
fs::read_to_string(fallback)?
|
fs::read_to_string(fallback)?
|
||||||
} else {
|
} else {
|
||||||
return Err(Box::new(ConfigError::ConfigNotFound(
|
return Err(Box::new(ConfigError::ConfigNotFound(
|
||||||
@ -192,7 +189,7 @@ pub enum ConfigError {
|
|||||||
impl std::fmt::Display for ConfigError {
|
impl std::fmt::Display for ConfigError {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.config/picobot/config.json", path),
|
||||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||||
|
|||||||
@ -6,9 +6,8 @@ use std::sync::Arc;
|
|||||||
use axum::{routing, Router};
|
use axum::{routing, Router};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
use crate::channels::{ChannelManager, manager::GatewayMessageHandler};
|
use crate::channels::ChannelManager;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::logging;
|
|
||||||
use session::SessionManager;
|
use session::SessionManager;
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
@ -20,37 +19,20 @@ pub struct GatewayState {
|
|||||||
impl GatewayState {
|
impl GatewayState {
|
||||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
let config = Config::load_default()?;
|
let config = Config::load_default()?;
|
||||||
|
let channel_manager = ChannelManager::new();
|
||||||
// Get provider config for SessionManager
|
|
||||||
let provider_config = config.get_provider_config("default")?;
|
|
||||||
|
|
||||||
// Session TTL from config (default 4 hours)
|
|
||||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
|
||||||
|
|
||||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config);
|
|
||||||
let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone()));
|
|
||||||
let channel_manager = ChannelManager::new(message_handler);
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config,
|
config,
|
||||||
session_manager,
|
session_manager: SessionManager::new(),
|
||||||
channel_manager,
|
channel_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Initialize logging
|
|
||||||
logging::init_logging();
|
|
||||||
tracing::info!("Starting PicoBot Gateway");
|
|
||||||
|
|
||||||
let state = Arc::new(GatewayState::new()?);
|
let state = Arc::new(GatewayState::new()?);
|
||||||
|
|
||||||
// Get provider config for channels
|
|
||||||
let provider_config = state.config.get_provider_config("default")?;
|
|
||||||
|
|
||||||
// Initialize and start channels
|
// Initialize and start channels
|
||||||
state.channel_manager.init(&state.config, provider_config).await?;
|
state.channel_manager.init(&state.config).await?;
|
||||||
state.channel_manager.start_all().await?;
|
state.channel_manager.start_all().await?;
|
||||||
|
|
||||||
// CLI args override config file values
|
// CLI args override config file values
|
||||||
@ -64,7 +46,7 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
|
|||||||
|
|
||||||
let addr = format!("{}:{}", bind_host, bind_port);
|
let addr = format!("{}:{}", bind_host, bind_port);
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
let listener = TcpListener::bind(&addr).await?;
|
||||||
tracing::info!(address = %addr, "Gateway listening");
|
println!("Gateway listening on {}", addr);
|
||||||
|
|
||||||
// Graceful shutdown using oneshot channel
|
// Graceful shutdown using oneshot channel
|
||||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
||||||
@ -73,7 +55,7 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
|
|||||||
// Spawn ctrl_c handler
|
// Spawn ctrl_c handler
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
tokio::signal::ctrl_c().await.ok();
|
tokio::signal::ctrl_c().await.ok();
|
||||||
tracing::info!("Shutdown signal received");
|
println!("Shutting down...");
|
||||||
let _ = channel_manager.stop_all().await;
|
let _ = channel_manager.stop_all().await;
|
||||||
let _ = shutdown_tx.send(());
|
let _ = shutdown_tx.send(());
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,204 +1,67 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::agent::{AgentLoop, AgentError};
|
use crate::agent::AgentLoop;
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::tools::{CalculatorTool, ToolRegistry};
|
|
||||||
|
|
||||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
pub id: Uuid,
|
pub id: Uuid,
|
||||||
pub channel_name: String,
|
pub agent_loop: Arc<Mutex<AgentLoop>>,
|
||||||
/// 按 chat_id 路由到不同 AgentLoop,支持多用户多会话
|
|
||||||
chat_agents: HashMap<String, Arc<Mutex<AgentLoop>>>,
|
|
||||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
channel_name: String,
|
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
user_tx: mpsc::Sender<WsOutbound>,
|
user_tx: mpsc::Sender<WsOutbound>,
|
||||||
tools: Arc<ToolRegistry>,
|
) -> Result<Self, crate::agent::AgentError> {
|
||||||
) -> Result<Self, AgentError> {
|
let agent_loop = AgentLoop::new(provider_config)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
channel_name,
|
agent_loop: Arc::new(Mutex::new(agent_loop)),
|
||||||
chat_agents: HashMap::new(),
|
|
||||||
user_tx,
|
user_tx,
|
||||||
provider_config,
|
|
||||||
tools,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取或创建指定 chat_id 的 AgentLoop
|
|
||||||
pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result<Arc<Mutex<AgentLoop>>, AgentError> {
|
|
||||||
if let Some(agent) = self.chat_agents.get(chat_id) {
|
|
||||||
tracing::trace!(chat_id = %chat_id, "Reusing existing agent");
|
|
||||||
return Ok(agent.clone());
|
|
||||||
}
|
|
||||||
tracing::debug!(chat_id = %chat_id, "Creating new agent for chat");
|
|
||||||
let agent = AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())?;
|
|
||||||
let arc = Arc::new(Mutex::new(agent));
|
|
||||||
self.chat_agents.insert(chat_id.to_string(), arc.clone());
|
|
||||||
Ok(arc)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 获取指定 chat_id 的 AgentLoop(不创建)
|
|
||||||
pub fn get_agent(&self, chat_id: &str) -> Option<Arc<Mutex<AgentLoop>>> {
|
|
||||||
self.chat_agents.get(chat_id).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 清除指定 chat_id 的历史
|
|
||||||
pub async fn clear_chat_history(&mut self, chat_id: &str) {
|
|
||||||
if let Some(agent) = self.chat_agents.get(chat_id) {
|
|
||||||
agent.lock().await.clear_history();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 清除所有历史
|
|
||||||
pub async fn clear_all_history(&mut self) {
|
|
||||||
for agent in self.chat_agents.values() {
|
|
||||||
agent.lock().await.clear_history();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send(&self, msg: WsOutbound) {
|
pub async fn send(&self, msg: WsOutbound) {
|
||||||
let _ = self.user_tx.send(msg).await;
|
let _ = self.user_tx.send(msg).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
use std::collections::HashMap;
|
||||||
/// 使用 Arc<Mutex<SessionManager>> 以从 Arc 获取可变访问
|
use std::sync::RwLock;
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct SessionManager {
|
pub struct SessionManager {
|
||||||
inner: Arc<Mutex<SessionManagerInner>>,
|
sessions: RwLock<HashMap<Uuid, Arc<Session>>>,
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SessionManagerInner {
|
|
||||||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
|
||||||
session_timestamps: HashMap<String, Instant>,
|
|
||||||
session_ttl: Duration,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_tools() -> ToolRegistry {
|
|
||||||
let mut registry = ToolRegistry::new();
|
|
||||||
registry.register(CalculatorTool::new());
|
|
||||||
registry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
sessions: RwLock::new(HashMap::new()),
|
||||||
sessions: HashMap::new(),
|
|
||||||
session_timestamps: HashMap::new(),
|
|
||||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
|
||||||
})),
|
|
||||||
provider_config,
|
|
||||||
tools: Arc::new(default_tools()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
pub fn add(&self, session: Arc<Session>) {
|
||||||
self.tools.clone()
|
self.sessions.write().unwrap().insert(session.id, session);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 确保 session 存在且未超时,超时则重建
|
pub fn remove(&self, id: &Uuid) {
|
||||||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
self.sessions.write().unwrap().remove(id);
|
||||||
let mut inner = self.inner.lock().await;
|
|
||||||
|
|
||||||
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) {
|
|
||||||
let elapsed = last_active.elapsed();
|
|
||||||
if elapsed > inner.session_ttl {
|
|
||||||
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::debug!(channel = %channel_name, "Creating new session");
|
|
||||||
true
|
|
||||||
};
|
|
||||||
|
|
||||||
if should_recreate {
|
|
||||||
// 移除旧 session
|
|
||||||
inner.sessions.remove(channel_name);
|
|
||||||
|
|
||||||
// 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS)
|
|
||||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
|
||||||
let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone()).await?;
|
|
||||||
let arc = Arc::new(Mutex::new(session));
|
|
||||||
|
|
||||||
inner.sessions.insert(channel_name.to_string(), arc.clone());
|
|
||||||
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
pub fn get(&self, id: &Uuid) -> Option<Arc<Session>> {
|
||||||
|
self.sessions.read().unwrap().get(id).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取 session(不检查超时)
|
pub fn len(&self) -> usize {
|
||||||
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
self.sessions.read().unwrap().len()
|
||||||
let inner = self.inner.lock().await;
|
}
|
||||||
inner.sessions.get(channel_name).cloned()
|
}
|
||||||
}
|
|
||||||
|
impl Default for SessionManager {
|
||||||
/// 更新最后活跃时间
|
fn default() -> Self {
|
||||||
pub async fn touch(&self, channel_name: &str) {
|
Self::new()
|
||||||
let mut inner = self.inner.lock().await;
|
|
||||||
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 处理消息:路由到对应 session 的 agent
|
|
||||||
pub async fn handle_message(
|
|
||||||
&self,
|
|
||||||
channel_name: &str,
|
|
||||||
_sender_id: &str,
|
|
||||||
chat_id: &str,
|
|
||||||
content: &str,
|
|
||||||
) -> Result<String, AgentError> {
|
|
||||||
tracing::debug!(channel = %channel_name, chat_id = %chat_id, content_len = content.len(), "Routing message to agent");
|
|
||||||
|
|
||||||
// 确保 session 存在(可能需要重建)
|
|
||||||
self.ensure_session(channel_name).await?;
|
|
||||||
|
|
||||||
// 更新活跃时间
|
|
||||||
self.touch(channel_name).await;
|
|
||||||
|
|
||||||
// 获取 session
|
|
||||||
let session = self.get(channel_name).await
|
|
||||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
|
||||||
|
|
||||||
// 获取或创建 chat_id 对应的 agent
|
|
||||||
let mut session_guard = session.lock().await;
|
|
||||||
let agent = session_guard.get_or_create_agent(chat_id).await?;
|
|
||||||
drop(session_guard);
|
|
||||||
|
|
||||||
let mut agent = agent.lock().await;
|
|
||||||
|
|
||||||
// 处理消息
|
|
||||||
let user_msg = ChatMessage::user(content);
|
|
||||||
let response = agent.process(user_msg).await?;
|
|
||||||
|
|
||||||
tracing::debug!(channel = %channel_name, chat_id = %chat_id, response_len = response.content.len(), "Agent response received");
|
|
||||||
|
|
||||||
Ok(response.content)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 清除指定 session 的所有历史
|
|
||||||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
|
||||||
if let Some(session) = self.get(channel_name).await {
|
|
||||||
let mut session_guard = session.lock().await;
|
|
||||||
session_guard.clear_all_history().await;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,7 +3,7 @@ use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
|||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::mpsc;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
||||||
use super::{GatewayState, session::Session};
|
use super::{GatewayState, session::Session};
|
||||||
@ -20,39 +20,33 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
let provider_config = match state.config.get_provider_config("default") {
|
let provider_config = match state.config.get_provider_config("default") {
|
||||||
Ok(cfg) => cfg,
|
Ok(cfg) => cfg,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "Failed to get provider config");
|
eprintln!("Failed to get provider config: {}", e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// CLI 使用独立的 session,channel_name = "cli-{uuid}"
|
let session = match Session::new(provider_config, sender).await {
|
||||||
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
|
Ok(s) => Arc::new(s),
|
||||||
|
|
||||||
// 创建 CLI session
|
|
||||||
let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await {
|
|
||||||
Ok(s) => Arc::new(Mutex::new(s)),
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "Failed to create session");
|
eprintln!("Failed to create session: {}", e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let session_id = session.lock().await.id;
|
let session_id = session.id;
|
||||||
tracing::info!(session_id = %session_id, "CLI session established");
|
state.session_manager.add(session.clone());
|
||||||
|
|
||||||
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
|
let _ = session.send(WsOutbound::SessionEstablished {
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
}).await;
|
}).await;
|
||||||
|
|
||||||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||||||
|
|
||||||
let mut receiver = receiver;
|
let mut receiver = receiver;
|
||||||
let session_id_for_sender = session_id;
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
if let Ok(text) = serialize_outbound(&msg) {
|
if let Ok(text) = serialize_outbound(&msg) {
|
||||||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||||
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -68,8 +62,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
handle_inbound(&session, inbound).await;
|
handle_inbound(&session, inbound).await;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
let _ = session.send(WsOutbound::Error {
|
||||||
let _ = session.lock().await.send(WsOutbound::Error {
|
|
||||||
code: "PARSE_ERROR".to_string(),
|
code: "PARSE_ERROR".to_string(),
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
}).await;
|
}).await;
|
||||||
@ -77,62 +70,47 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(WsMessage::Close(_)) | Err(_) => {
|
Ok(WsMessage::Close(_)) | Err(_) => {
|
||||||
tracing::debug!(session_id = %session_id, "WebSocket closed");
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!(session_id = %session_id, "CLI session ended");
|
state.session_manager.remove(&session_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
async fn handle_inbound(session: &Arc<Session>, inbound: WsInbound) {
|
||||||
let inbound_clone = inbound.clone();
|
match inbound {
|
||||||
|
WsInbound::UserInput { content } => {
|
||||||
// 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id)
|
|
||||||
let (content, chat_id) = match inbound_clone {
|
|
||||||
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
|
|
||||||
// CLI 使用 session 中的 channel_name 作为标识
|
|
||||||
// chat_id 使用传入的或使用默认
|
|
||||||
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
|
|
||||||
(content, chat_id)
|
|
||||||
}
|
|
||||||
_ => return,
|
|
||||||
};
|
|
||||||
|
|
||||||
let user_msg = ChatMessage::user(content);
|
let user_msg = ChatMessage::user(content);
|
||||||
|
let mut agent = session.agent_loop.lock().await;
|
||||||
let mut session_guard = session.lock().await;
|
|
||||||
let agent = match session_guard.get_or_create_agent(&chat_id).await {
|
|
||||||
Ok(a) => a,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent");
|
|
||||||
let _ = session_guard.send(WsOutbound::Error {
|
|
||||||
code: "AGENT_ERROR".to_string(),
|
|
||||||
message: e.to_string(),
|
|
||||||
}).await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
drop(session_guard);
|
|
||||||
|
|
||||||
let mut agent = agent.lock().await;
|
|
||||||
match agent.process(user_msg).await {
|
match agent.process(user_msg).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
let _ = session.send(WsOutbound::AssistantResponse {
|
||||||
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
|
|
||||||
id: response.id,
|
id: response.id,
|
||||||
content: response.content,
|
content: response.content,
|
||||||
role: response.role,
|
role: response.role,
|
||||||
}).await;
|
}).await;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
|
let _ = session.send(WsOutbound::Error {
|
||||||
let _ = session.lock().await.send(WsOutbound::Error {
|
|
||||||
code: "LLM_ERROR".to_string(),
|
code: "LLM_ERROR".to_string(),
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
}).await;
|
}).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
WsInbound::ClearHistory => {
|
||||||
|
let mut agent = session.agent_loop.lock().await;
|
||||||
|
agent.clear_history();
|
||||||
|
let _ = session.send(WsOutbound::AssistantResponse {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
content: "History cleared.".to_string(),
|
||||||
|
role: "system".to_string(),
|
||||||
|
}).await;
|
||||||
|
}
|
||||||
|
WsInbound::Ping => {
|
||||||
|
let _ = session.send(WsOutbound::Pong).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -7,5 +7,3 @@ pub mod gateway;
|
|||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
pub mod channels;
|
pub mod channels;
|
||||||
pub mod logging;
|
|
||||||
pub mod tools;
|
|
||||||
|
|||||||
@ -1,80 +0,0 @@
|
|||||||
use std::path::PathBuf;
|
|
||||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
|
||||||
use tracing_subscriber::{
|
|
||||||
fmt,
|
|
||||||
layer::SubscriberExt,
|
|
||||||
util::SubscriberInitExt,
|
|
||||||
EnvFilter,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Get the default log directory path: ~/.picobot/logs
|
|
||||||
pub fn get_default_log_dir() -> PathBuf {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
home.join(".picobot").join("logs")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the default config file path: ~/.picobot/config.json
|
|
||||||
pub fn get_default_config_path() -> PathBuf {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
home.join(".picobot").join("config.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Initialize logging with file appender
|
|
||||||
/// Logs are written to ~/.picobot/logs/ with daily rotation
|
|
||||||
pub fn init_logging() {
|
|
||||||
let log_dir = get_default_log_dir();
|
|
||||||
|
|
||||||
// Create log directory if it doesn't exist
|
|
||||||
if !log_dir.exists() {
|
|
||||||
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
|
||||||
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create file appender with daily rotation
|
|
||||||
let file_appender = RollingFileAppender::new(
|
|
||||||
Rotation::DAILY,
|
|
||||||
&log_dir,
|
|
||||||
"picobot.log",
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build subscriber with both console and file output
|
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let file_layer = fmt::layer()
|
|
||||||
.with_writer(file_appender)
|
|
||||||
.with_ansi(false)
|
|
||||||
.with_target(true)
|
|
||||||
.with_level(true)
|
|
||||||
.with_thread_ids(true);
|
|
||||||
|
|
||||||
let console_layer = fmt::layer()
|
|
||||||
.with_target(true)
|
|
||||||
.with_level(true);
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(console_layer)
|
|
||||||
.with(file_layer)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
tracing::info!("Logging initialized. Log directory: {}", log_dir.display());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Initialize logging without file output (console only)
|
|
||||||
pub fn init_logging_console_only() {
|
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let console_layer = fmt::layer()
|
|
||||||
.with_target(true)
|
|
||||||
.with_level(true);
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(console_layer)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
tracing::info!("Logging initialized (console only)");
|
|
||||||
}
|
|
||||||
@ -4,20 +4,9 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum WsInbound {
|
pub enum WsInbound {
|
||||||
#[serde(rename = "user_input")]
|
#[serde(rename = "user_input")]
|
||||||
UserInput {
|
UserInput { content: String },
|
||||||
content: String,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
channel: Option<String>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
chat_id: Option<String>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
sender_id: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "clear_history")]
|
#[serde(rename = "clear_history")]
|
||||||
ClearHistory {
|
ClearHistory,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
chat_id: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "ping")]
|
#[serde(rename = "ping")]
|
||||||
Ping,
|
Ping,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -104,19 +104,10 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
"messages": request.messages.iter().map(|m| {
|
"messages": request.messages.iter().map(|m| {
|
||||||
if m.role == "tool" {
|
|
||||||
json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content,
|
|
||||||
"tool_call_id": m.tool_call_id,
|
|
||||||
"name": m.name,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
json!({
|
json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": m.content
|
"content": m.content
|
||||||
})
|
})
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>(),
|
}).collect::<Vec<_>>(),
|
||||||
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
||||||
"max_tokens": request.max_tokens.or(self.max_tokens),
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
||||||
|
|||||||
@ -5,10 +5,6 @@ use serde::{Deserialize, Serialize};
|
|||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_call_id: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub name: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
@ -1,824 +0,0 @@
|
|||||||
use super::traits::{Tool, ToolResult};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
pub struct CalculatorTool;
|
|
||||||
|
|
||||||
impl CalculatorTool {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for CalculatorTool {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for CalculatorTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"calculator"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Perform arithmetic and statistical calculations. Supports 25 functions: \
|
|
||||||
add, subtract, divide, multiply, pow, sqrt, abs, modulo, round, \
|
|
||||||
log, ln, exp, factorial, sum, average, median, mode, min, max, \
|
|
||||||
range, variance, stdev, percentile, count, percentage_change, clamp. \
|
|
||||||
Use this tool whenever you need to compute a numeric result instead of guessing."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"function": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Calculation to perform. \
|
|
||||||
Arithmetic: add(values), subtract(values), divide(values), multiply(values), pow(a,b), sqrt(x), abs(x), modulo(a,b), round(x,decimals). \
|
|
||||||
Logarithmic/exponential: log(x,base?), ln(x), exp(x), factorial(x). \
|
|
||||||
Aggregation: sum(values), average(values), count(values), min(values), max(values), range(values). \
|
|
||||||
Statistics: median(values), mode(values), variance(values), stdev(values), percentile(values,p). \
|
|
||||||
Utility: percentage_change(a,b), clamp(x,min_val,max_val).",
|
|
||||||
"enum": [
|
|
||||||
"add", "subtract", "divide", "multiply", "pow", "sqrt",
|
|
||||||
"abs", "modulo", "round", "log", "ln", "exp", "factorial",
|
|
||||||
"sum", "average", "median", "mode", "min", "max", "range",
|
|
||||||
"variance", "stdev", "percentile", "count",
|
|
||||||
"percentage_change", "clamp"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"values": {
|
|
||||||
"type": "array",
|
|
||||||
"items": { "type": "number" },
|
|
||||||
"description": "Array of numeric values. Required for: add, subtract, divide, multiply, sum, average, median, mode, min, max, range, variance, stdev, percentile, count."
|
|
||||||
},
|
|
||||||
"a": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "First operand. Required for: pow, modulo, percentage_change."
|
|
||||||
},
|
|
||||||
"b": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Second operand. Required for: pow, modulo, percentage_change."
|
|
||||||
},
|
|
||||||
"x": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Input number. Required for: sqrt, abs, exp, ln, log, factorial."
|
|
||||||
},
|
|
||||||
"base": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Logarithm base (default: 10). Optional for: log."
|
|
||||||
},
|
|
||||||
"decimals": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Number of decimal places for rounding. Required for: round."
|
|
||||||
},
|
|
||||||
"p": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Percentile rank (0-100). Required for: percentile."
|
|
||||||
},
|
|
||||||
"min_val": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Minimum bound. Required for: clamp."
|
|
||||||
},
|
|
||||||
"max_val": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Maximum bound. Required for: clamp."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["function"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let function = match args.get("function").and_then(|v| v.as_str()) {
|
|
||||||
Some(f) => f,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: function".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = match function {
|
|
||||||
"add" => calc_add(&args),
|
|
||||||
"subtract" => calc_subtract(&args),
|
|
||||||
"divide" => calc_divide(&args),
|
|
||||||
"multiply" => calc_multiply(&args),
|
|
||||||
"pow" => calc_pow(&args),
|
|
||||||
"sqrt" => calc_sqrt(&args),
|
|
||||||
"abs" => calc_abs(&args),
|
|
||||||
"modulo" => calc_modulo(&args),
|
|
||||||
"round" => calc_round(&args),
|
|
||||||
"log" => calc_log(&args),
|
|
||||||
"ln" => calc_ln(&args),
|
|
||||||
"exp" => calc_exp(&args),
|
|
||||||
"factorial" => calc_factorial(&args),
|
|
||||||
"sum" => calc_sum(&args),
|
|
||||||
"average" => calc_average(&args),
|
|
||||||
"median" => calc_median(&args),
|
|
||||||
"mode" => calc_mode(&args),
|
|
||||||
"min" => calc_min(&args),
|
|
||||||
"max" => calc_max(&args),
|
|
||||||
"range" => calc_range(&args),
|
|
||||||
"variance" => calc_variance(&args),
|
|
||||||
"stdev" => calc_stdev(&args),
|
|
||||||
"percentile" => calc_percentile(&args),
|
|
||||||
"count" => calc_count(&args),
|
|
||||||
"percentage_change" => calc_percentage_change(&args),
|
|
||||||
"clamp" => calc_clamp(&args),
|
|
||||||
other => Err(format!("Unknown function: {other}")),
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(output) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Err(err) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(err),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_f64(args: &serde_json::Value, key: &str, name: &str) -> Result<f64, String> {
|
|
||||||
args.get(key)
|
|
||||||
.and_then(|v| v.as_f64())
|
|
||||||
.ok_or_else(|| format!("Missing required parameter: {name}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_i64(args: &serde_json::Value, key: &str, name: &str) -> Result<i64, String> {
|
|
||||||
args.get(key)
|
|
||||||
.and_then(|v| v.as_i64())
|
|
||||||
.ok_or_else(|| format!("Missing required parameter: {name}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_values(args: &serde_json::Value, min_len: usize) -> Result<Vec<f64>, String> {
|
|
||||||
let values = args
|
|
||||||
.get("values")
|
|
||||||
.and_then(|v| v.as_array())
|
|
||||||
.ok_or_else(|| "Missing required parameter: values (array of numbers)".to_string())?;
|
|
||||||
if values.len() < min_len {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected at least {min_len} value(s), got {}",
|
|
||||||
values.len()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let mut nums = Vec::with_capacity(values.len());
|
|
||||||
for (i, v) in values.iter().enumerate() {
|
|
||||||
match v.as_f64() {
|
|
||||||
Some(n) => nums.push(n),
|
|
||||||
None => return Err(format!("values[{i}] is not a valid number")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(nums)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_num(n: f64) -> String {
|
|
||||||
if n == n.floor() && n.abs() < 1e15 {
|
|
||||||
#[allow(clippy::cast_possible_truncation)]
|
|
||||||
let rounded = n.round() as i128;
|
|
||||||
format!("{rounded}")
|
|
||||||
} else {
|
|
||||||
format!("{n}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_add(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 2)?;
|
|
||||||
Ok(format_num(values.iter().sum()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_subtract(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 2)?;
|
|
||||||
let mut iter = values.iter();
|
|
||||||
let mut result = *iter.next().unwrap();
|
|
||||||
for v in iter {
|
|
||||||
result -= v;
|
|
||||||
}
|
|
||||||
Ok(format_num(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_divide(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 2)?;
|
|
||||||
let mut iter = values.iter();
|
|
||||||
let mut result = *iter.next().unwrap();
|
|
||||||
for v in iter {
|
|
||||||
if *v == 0.0 {
|
|
||||||
return Err("Division by zero".to_string());
|
|
||||||
}
|
|
||||||
result /= v;
|
|
||||||
}
|
|
||||||
Ok(format_num(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_multiply(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 2)?;
|
|
||||||
let mut result = 1.0;
|
|
||||||
for v in &values {
|
|
||||||
result *= v;
|
|
||||||
}
|
|
||||||
Ok(format_num(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_pow(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let base = extract_f64(args, "a", "a (base)")?;
|
|
||||||
let exp = extract_f64(args, "b", "b (exponent)")?;
|
|
||||||
Ok(format_num(base.powf(exp)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_sqrt(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
if x < 0.0 {
|
|
||||||
return Err("Cannot compute square root of a negative number".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num(x.sqrt()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_abs(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
Ok(format_num(x.abs()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_modulo(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let a = extract_f64(args, "a", "a")?;
|
|
||||||
let b = extract_f64(args, "b", "b")?;
|
|
||||||
if b == 0.0 {
|
|
||||||
return Err("Modulo by zero".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num(a % b))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_round(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
let decimals = extract_i64(args, "decimals", "decimals")?;
|
|
||||||
if decimals < 0 {
|
|
||||||
return Err("decimals must be non-negative".to_string());
|
|
||||||
}
|
|
||||||
let multiplier = 10_f64.powi(i32::try_from(decimals).unwrap_or(i32::MAX));
|
|
||||||
Ok(format_num((x * multiplier).round() / multiplier))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_log(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
if x <= 0.0 {
|
|
||||||
return Err("Logarithm requires a positive number".to_string());
|
|
||||||
}
|
|
||||||
let base = args.get("base").and_then(|v| v.as_f64()).unwrap_or(10.0);
|
|
||||||
if base <= 0.0 || base == 1.0 {
|
|
||||||
return Err("Logarithm base must be positive and not equal to 1".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num(x.log(base)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_ln(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
if x <= 0.0 {
|
|
||||||
return Err("Natural logarithm requires a positive number".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num(x.ln()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_exp(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
Ok(format_num(x.exp()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_factorial(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
if x < 0.0 || x != x.floor() {
|
|
||||||
return Err("Factorial requires a non-negative integer".to_string());
|
|
||||||
}
|
|
||||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
|
||||||
let n = x.round() as u128;
|
|
||||||
if n > 170 {
|
|
||||||
return Err("Factorial result exceeds f64 range (max input: 170)".to_string());
|
|
||||||
}
|
|
||||||
let mut result: u128 = 1;
|
|
||||||
for i in 2..=n {
|
|
||||||
result *= i;
|
|
||||||
}
|
|
||||||
Ok(result.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_sum(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
Ok(format_num(values.iter().sum()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_average(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
if values.is_empty() {
|
|
||||||
return Err("Cannot compute average of an empty array".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num(values.iter().sum::<f64>() / values.len() as f64))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_median(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let mut values = extract_values(args, 1)?;
|
|
||||||
if values.is_empty() {
|
|
||||||
return Err("Cannot compute median of an empty array".to_string());
|
|
||||||
}
|
|
||||||
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
||||||
let len = values.len();
|
|
||||||
if len % 2 == 0 {
|
|
||||||
Ok(format_num(f64::midpoint(
|
|
||||||
values[len / 2 - 1],
|
|
||||||
values[len / 2],
|
|
||||||
)))
|
|
||||||
} else {
|
|
||||||
Ok(format_num(values[len / 2]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_mode(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
if values.is_empty() {
|
|
||||||
return Err("Cannot compute mode of an empty array".to_string());
|
|
||||||
}
|
|
||||||
let mut freq: std::collections::HashMap<u64, usize> = std::collections::HashMap::new();
|
|
||||||
for &v in &values {
|
|
||||||
let key = v.to_bits();
|
|
||||||
*freq.entry(key).or_insert(0) += 1;
|
|
||||||
}
|
|
||||||
let max_freq = *freq.values().max().unwrap();
|
|
||||||
let mut seen = std::collections::HashSet::new();
|
|
||||||
let mut modes = Vec::new();
|
|
||||||
for &v in &values {
|
|
||||||
let key = v.to_bits();
|
|
||||||
if freq[&key] == max_freq && seen.insert(key) {
|
|
||||||
modes.push(v);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if modes.len() == 1 {
|
|
||||||
Ok(format_num(modes[0]))
|
|
||||||
} else {
|
|
||||||
let formatted: Vec<String> = modes.iter().map(|v| format_num(*v)).collect();
|
|
||||||
Ok(format!("Modes: {}", formatted.join(", ")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_min(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
let Some(min_val) = values.iter().copied().reduce(f64::min) else {
|
|
||||||
return Err("Cannot compute min of an empty array".to_string());
|
|
||||||
};
|
|
||||||
Ok(format_num(min_val))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_max(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
let Some(max_val) = values.iter().copied().reduce(f64::max) else {
|
|
||||||
return Err("Cannot compute max of an empty array".to_string());
|
|
||||||
};
|
|
||||||
Ok(format_num(max_val))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_range(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
if values.is_empty() {
|
|
||||||
return Err("Cannot compute range of an empty array".to_string());
|
|
||||||
}
|
|
||||||
let min_val = values.iter().copied().fold(f64::INFINITY, f64::min);
|
|
||||||
let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
|
|
||||||
Ok(format_num(max_val - min_val))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_variance(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
if values.len() < 2 {
|
|
||||||
return Err("Variance requires at least 2 values".to_string());
|
|
||||||
}
|
|
||||||
let mean = values.iter().sum::<f64>() / values.len() as f64;
|
|
||||||
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
|
|
||||||
Ok(format_num(variance))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_stdev(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
if values.len() < 2 {
|
|
||||||
return Err("Standard deviation requires at least 2 values".to_string());
|
|
||||||
}
|
|
||||||
let mean = values.iter().sum::<f64>() / values.len() as f64;
|
|
||||||
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
|
|
||||||
Ok(format_num(variance.sqrt()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_percentile(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let mut values = extract_values(args, 1)?;
|
|
||||||
if values.is_empty() {
|
|
||||||
return Err("Cannot compute percentile of an empty array".to_string());
|
|
||||||
}
|
|
||||||
let p = extract_i64(args, "p", "p (percentile rank 0-100)")?;
|
|
||||||
if !(0..=100).contains(&p) {
|
|
||||||
return Err("Percentile rank must be between 0 and 100".to_string());
|
|
||||||
}
|
|
||||||
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
||||||
|
|
||||||
let idx_f = p as f64 / 100.0 * (values.len() - 1) as f64;
|
|
||||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
|
||||||
let index = idx_f.round().clamp(0.0, (values.len() - 1) as f64) as usize;
|
|
||||||
Ok(format_num(values[index]))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_count(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let values = extract_values(args, 1)?;
|
|
||||||
Ok(values.len().to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_percentage_change(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let old = extract_f64(args, "a", "a (old value)")?;
|
|
||||||
let new = extract_f64(args, "b", "b (new value)")?;
|
|
||||||
if old == 0.0 {
|
|
||||||
return Err("Cannot compute percentage change from zero".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num((new - old) / old.abs() * 100.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_clamp(args: &serde_json::Value) -> Result<String, String> {
|
|
||||||
let x = extract_f64(args, "x", "x")?;
|
|
||||||
let min_val = extract_f64(args, "min_val", "min_val")?;
|
|
||||||
let max_val = extract_f64(args, "max_val", "max_val")?;
|
|
||||||
if min_val > max_val {
|
|
||||||
return Err("min_val must be less than or equal to max_val".to_string());
|
|
||||||
}
|
|
||||||
Ok(format_num(x.clamp(min_val, max_val)))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_add() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "add", "values": [1.0, 2.0, 3.5]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "6.5");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_subtract() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "subtract", "values": [10.0, 3.0, 1.5]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "5.5");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_divide() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "divide", "values": [100.0, 4.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "25");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_divide_by_zero() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "divide", "values": [10.0, 0.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
|
||||||
assert!(result.error.as_ref().unwrap().contains("zero"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_multiply() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "multiply", "values": [3.0, 4.0, 5.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "60");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_pow() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "pow", "a": 2.0, "b": 10.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "1024");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_sqrt() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "sqrt", "x": 144.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "12");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_sqrt_negative() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "sqrt", "x": -4.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_abs() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "abs", "x": -42.5}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "42.5");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_modulo() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "modulo", "a": 17.0, "b": 5.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "2");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_round() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "round", "x": 2.715, "decimals": 2}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "2.72");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_log_base10() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "log", "x": 100.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "2");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_log_custom_base() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "log", "x": 8.0, "base": 2.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "3");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_ln() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "ln", "x": 1.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "0");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_exp() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "exp", "x": 0.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "1");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_factorial() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "factorial", "x": 5.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "120");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_average() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "average", "values": [10.0, 20.0, 30.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "20");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_median_odd() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "median", "values": [3.0, 1.0, 2.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "2");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_median_even() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "median", "values": [4.0, 1.0, 3.0, 2.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "2.5");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_mode() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "mode", "values": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "3");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_min() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "min", "values": [5.0, 2.0, 8.0, 1.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "1");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_max() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "max", "values": [5.0, 2.0, 8.0, 1.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "8");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_range() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "range", "values": [1.0, 5.0, 10.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "9");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_variance() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(
|
|
||||||
json!({"function": "variance", "values": [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "4");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_stdev() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(
|
|
||||||
json!({"function": "stdev", "values": [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "2");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_percentile_50() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(
|
|
||||||
json!({"function": "percentile", "values": [1.0, 2.0, 3.0, 4.0, 5.0], "p": 50}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "3");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_count() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "count", "values": [1.0, 2.0, 3.0, 4.0, 5.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "5");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_percentage_change() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "percentage_change", "a": 50.0, "b": 75.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "50");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_clamp_within_range() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "clamp", "x": 5.0, "min_val": 1.0, "max_val": 10.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "5");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_clamp_below_min() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "clamp", "x": -5.0, "min_val": 0.0, "max_val": 10.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "0");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_clamp_above_max() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "clamp", "x": 15.0, "min_val": 0.0, "max_val": 10.0}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "10");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_unknown_function() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool.execute(json!({"function": "unknown"})).await.unwrap();
|
|
||||||
assert!(!result.success);
|
|
||||||
assert!(result.error.as_ref().unwrap().contains("Unknown function"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_sum() {
|
|
||||||
let tool = CalculatorTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({"function": "sum", "values": [1.0, 2.0, 3.0, 4.0, 5.0]}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result.success);
|
|
||||||
assert_eq!(result.output, "15");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
pub mod calculator;
|
|
||||||
pub mod registry;
|
|
||||||
pub mod traits;
|
|
||||||
|
|
||||||
pub use calculator::CalculatorTool;
|
|
||||||
pub use registry::ToolRegistry;
|
|
||||||
pub use traits::{Tool, ToolResult};
|
|
||||||
@ -1,53 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use crate::providers::{Tool, ToolFunction};
|
|
||||||
|
|
||||||
use super::traits::Tool as ToolTrait;
|
|
||||||
|
|
||||||
pub struct ToolRegistry {
|
|
||||||
tools: HashMap<String, Box<dyn ToolTrait>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ToolRegistry {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
tools: HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn register<T: ToolTrait + 'static>(&mut self, tool: T) {
|
|
||||||
self.tools.insert(tool.name().to_string(), Box::new(tool));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Option<&Box<dyn ToolTrait>> {
|
|
||||||
self.tools.get(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_definitions(&self) -> Vec<Tool> {
|
|
||||||
self.tools
|
|
||||||
.values()
|
|
||||||
.map(|tool| Tool {
|
|
||||||
tool_type: "function".to_string(),
|
|
||||||
function: ToolFunction {
|
|
||||||
name: tool.name().to_string(),
|
|
||||||
description: tool.description().to_string(),
|
|
||||||
parameters: tool.parameters_schema(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn has_tools(&self) -> bool {
|
|
||||||
!self.tools.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tool_names(&self) -> Vec<String> {
|
|
||||||
self.tools.keys().cloned().collect()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ToolRegistry {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ToolResult {
|
|
||||||
pub success: bool,
|
|
||||||
pub output: String,
|
|
||||||
pub error: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Tool: Send + Sync + 'static {
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
fn description(&self) -> &str;
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value;
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user