- Introduced a new CalculatorTool for performing various arithmetic and statistical calculations. - Enhanced the AgentLoop to support tool execution, including handling tool calls in the process flow. - Updated ChatMessage structure to include optional fields for tool call identification and names. - Modified the Session and SessionManager to manage tool registrations and pass them to agents. - Updated the OpenAIProvider to serialize tool-related message fields. - Added a ToolRegistry for managing multiple tools and their definitions. - Implemented tests for the CalculatorTool to ensure functionality and correctness.
139 lines
4.9 KiB
Rust
139 lines
4.9 KiB
Rust
use std::sync::Arc;
|
||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||
use axum::extract::State;
|
||
use axum::response::Response;
|
||
use futures_util::{SinkExt, StreamExt};
|
||
use tokio::sync::{mpsc, Mutex};
|
||
use crate::bus::ChatMessage;
|
||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
||
use super::{GatewayState, session::Session};
|
||
|
||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||
ws.on_upgrade(|socket| async {
|
||
handle_socket(socket, state).await;
|
||
})
|
||
}
|
||
|
||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
||
|
||
let provider_config = match state.config.get_provider_config("default") {
|
||
Ok(cfg) => cfg,
|
||
Err(e) => {
|
||
tracing::error!(error = %e, "Failed to get provider config");
|
||
return;
|
||
}
|
||
};
|
||
|
||
// CLI 使用独立的 session,channel_name = "cli-{uuid}"
|
||
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
|
||
|
||
// 创建 CLI session
|
||
let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await {
|
||
Ok(s) => Arc::new(Mutex::new(s)),
|
||
Err(e) => {
|
||
tracing::error!(error = %e, "Failed to create session");
|
||
return;
|
||
}
|
||
};
|
||
|
||
let session_id = session.lock().await.id;
|
||
tracing::info!(session_id = %session_id, "CLI session established");
|
||
|
||
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
|
||
session_id: session_id.to_string(),
|
||
}).await;
|
||
|
||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||
|
||
let mut receiver = receiver;
|
||
let session_id_for_sender = session_id;
|
||
tokio::spawn(async move {
|
||
while let Some(msg) = receiver.recv().await {
|
||
if let Ok(text) = serialize_outbound(&msg) {
|
||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
while let Some(msg) = ws_receiver.next().await {
|
||
match msg {
|
||
Ok(WsMessage::Text(text)) => {
|
||
let text = text.to_string();
|
||
match parse_inbound(&text) {
|
||
Ok(inbound) => {
|
||
handle_inbound(&session, inbound).await;
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||
let _ = session.lock().await.send(WsOutbound::Error {
|
||
code: "PARSE_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
}).await;
|
||
}
|
||
}
|
||
}
|
||
Ok(WsMessage::Close(_)) | Err(_) => {
|
||
tracing::debug!(session_id = %session_id, "WebSocket closed");
|
||
break;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
tracing::info!(session_id = %session_id, "CLI session ended");
|
||
}
|
||
|
||
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
||
let inbound_clone = inbound.clone();
|
||
|
||
// 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id)
|
||
let (content, chat_id) = match inbound_clone {
|
||
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
|
||
// CLI 使用 session 中的 channel_name 作为标识
|
||
// chat_id 使用传入的或使用默认
|
||
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
|
||
(content, chat_id)
|
||
}
|
||
_ => return,
|
||
};
|
||
|
||
let user_msg = ChatMessage::user(content);
|
||
|
||
let mut session_guard = session.lock().await;
|
||
let agent = match session_guard.get_or_create_agent(&chat_id).await {
|
||
Ok(a) => a,
|
||
Err(e) => {
|
||
tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent");
|
||
let _ = session_guard.send(WsOutbound::Error {
|
||
code: "AGENT_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
}).await;
|
||
return;
|
||
}
|
||
};
|
||
drop(session_guard);
|
||
|
||
let mut agent = agent.lock().await;
|
||
match agent.process(user_msg).await {
|
||
Ok(response) => {
|
||
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
||
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
|
||
id: response.id,
|
||
content: response.content,
|
||
role: response.role,
|
||
}).await;
|
||
}
|
||
Err(e) => {
|
||
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
|
||
let _ = session.lock().await.send(WsOutbound::Error {
|
||
code: "LLM_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
}).await;
|
||
}
|
||
}
|
||
}
|