xiaoxixi 9834bd75cf feat: add calculator tool and integrate with agent loop
- Introduced a new CalculatorTool for performing various arithmetic and statistical calculations.
- Enhanced the AgentLoop to support tool execution, including handling tool calls in the process flow.
- Updated ChatMessage structure to include optional fields for tool call identification and names.
- Modified the Session and SessionManager to manage tool registrations and pass them to agents.
- Updated the OpenAIProvider to serialize tool-related message fields.
- Added a ToolRegistry for managing multiple tools and their definitions.
- Implemented tests for the CalculatorTool to ensure functionality and correctness.
2026-04-06 23:43:45 +08:00

139 lines
4.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
use super::{GatewayState, session::Session};
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
return;
}
};
// CLI 使用独立的 sessionchannel_name = "cli-{uuid}"
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
// 创建 CLI session
let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await {
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
return;
}
};
let session_id = session.lock().await.id;
tracing::info!(session_id = %session_id, "CLI session established");
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
session_id: session_id.to_string(),
}).await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = session_id;
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
handle_inbound(&session, inbound).await;
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = session.lock().await.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
tracing::debug!(session_id = %session_id, "WebSocket closed");
break;
}
_ => {}
}
}
tracing::info!(session_id = %session_id, "CLI session ended");
}
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
let inbound_clone = inbound.clone();
// 提取 content 和 chat_idCLI 使用 session id 作为 chat_id
let (content, chat_id) = match inbound_clone {
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
// CLI 使用 session 中的 channel_name 作为标识
// chat_id 使用传入的或使用默认
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
(content, chat_id)
}
_ => return,
};
let user_msg = ChatMessage::user(content);
let mut session_guard = session.lock().await;
let agent = match session_guard.get_or_create_agent(&chat_id).await {
Ok(a) => a,
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent");
let _ = session_guard.send(WsOutbound::Error {
code: "AGENT_ERROR".to_string(),
message: e.to_string(),
}).await;
return;
}
};
drop(session_guard);
let mut agent = agent.lock().await;
match agent.process(user_msg).await {
Ok(response) => {
tracing::debug!(chat_id = %chat_id, "Agent response sent");
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
id: response.id,
content: response.content,
role: response.role,
}).await;
}
Err(e) => {
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
let _ = session.lock().await.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
}