- Adjusted formatting and indentation in various files for better clarity. - Consolidated multi-line statements into single lines where appropriate. - Enhanced error handling messages for better debugging. - Added a new InboundProcessor struct to handle inbound messages more effectively. - Updated test cases to ensure they align with the new code structure.
619 lines
22 KiB
Rust
619 lines
22 KiB
Rust
use super::{
|
|
GatewayState,
|
|
session::{Session, handle_in_chat_command, schedule_background_history_compaction},
|
|
};
|
|
use crate::agent::EmittedMessageHandler;
|
|
use crate::bus::ChatMessage;
|
|
use crate::bus::message::{ToolMessageState, format_tool_call_content};
|
|
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
|
use async_trait::async_trait;
|
|
use axum::extract::State;
|
|
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
|
use axum::response::Response;
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use std::sync::Arc;
|
|
use tokio::sync::{Mutex, mpsc};
|
|
|
|
struct WsToolCallEmitter {
|
|
sender: mpsc::Sender<WsOutbound>,
|
|
show_tool_results: bool,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl EmittedMessageHandler for WsToolCallEmitter {
|
|
async fn handle(&self, message: ChatMessage) {
|
|
if !should_display_message_to_user(self.show_tool_results, &message) {
|
|
return;
|
|
}
|
|
|
|
for outbound in ws_outbound_from_chat_message(&message) {
|
|
let _ = self.sender.send(outbound).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
|
ws.on_upgrade(|socket| async {
|
|
handle_socket(socket, state).await;
|
|
})
|
|
}
|
|
|
|
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
|
|
|
let provider_config = match state.config.get_provider_config("default") {
|
|
Ok(cfg) => cfg,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to get provider config");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let initial_record = match state.session_manager.create_cli_session(None) {
|
|
Ok(record) => record,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to create initial CLI session");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let channel_name = "cli".to_string();
|
|
|
|
// 创建 CLI session
|
|
let session = match Session::new(
|
|
channel_name.clone(),
|
|
provider_config,
|
|
sender,
|
|
state.session_manager.tools(),
|
|
state.session_manager.skills(),
|
|
state.session_manager.store(),
|
|
state.config.gateway.agent_prompt_reinject_every,
|
|
)
|
|
.await
|
|
{
|
|
Ok(s) => Arc::new(Mutex::new(s)),
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to create session");
|
|
return;
|
|
}
|
|
};
|
|
|
|
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
|
|
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
|
|
return;
|
|
}
|
|
|
|
let runtime_session_id = session.lock().await.id.to_string();
|
|
let mut current_session_id = initial_record.id.clone();
|
|
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
|
|
|
let _ = session
|
|
.lock()
|
|
.await
|
|
.send(WsOutbound::SessionEstablished {
|
|
session_id: current_session_id.clone(),
|
|
})
|
|
.await;
|
|
|
|
let (mut ws_sender, mut ws_receiver) = ws.split();
|
|
|
|
let mut receiver = receiver;
|
|
let session_id_for_sender = runtime_session_id.clone();
|
|
tokio::spawn(async move {
|
|
while let Some(msg) = receiver.recv().await {
|
|
if let Ok(text) = serialize_outbound(&msg) {
|
|
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
while let Some(msg) = ws_receiver.next().await {
|
|
match msg {
|
|
Ok(WsMessage::Text(text)) => {
|
|
let text = text.to_string();
|
|
match parse_inbound(&text) {
|
|
Ok(inbound) => {
|
|
if let Err(e) = handle_inbound(
|
|
&state,
|
|
&session,
|
|
&runtime_session_id,
|
|
&mut current_session_id,
|
|
inbound,
|
|
)
|
|
.await
|
|
{
|
|
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
|
let _ = session
|
|
.lock()
|
|
.await
|
|
.send(WsOutbound::Error {
|
|
code: "SESSION_ERROR".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
|
let _ = session
|
|
.lock()
|
|
.await
|
|
.send(WsOutbound::Error {
|
|
code: "PARSE_ERROR".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
Ok(WsMessage::Close(_)) | Err(_) => {
|
|
#[cfg(debug_assertions)]
|
|
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
|
}
|
|
|
|
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
|
SessionSummary {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
channel_name: record.channel_name,
|
|
chat_id: record.chat_id,
|
|
message_count: record.message_count,
|
|
last_active_at: record.last_active_at,
|
|
archived_at: record.archived_at,
|
|
}
|
|
}
|
|
|
|
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|
match message.role.as_str() {
|
|
"assistant" => {
|
|
if let Some(tool_calls) = &message.tool_calls {
|
|
let mut outbound = Vec::new();
|
|
if !message.content.trim().is_empty() {
|
|
outbound.push(WsOutbound::AssistantResponse {
|
|
id: message.id.clone(),
|
|
content: message.content.clone(),
|
|
role: message.role.clone(),
|
|
});
|
|
}
|
|
|
|
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
|
id: message.id.clone(),
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
|
role: message.role.clone(),
|
|
}));
|
|
outbound
|
|
} else {
|
|
vec![WsOutbound::AssistantResponse {
|
|
id: message.id.clone(),
|
|
content: message.content.clone(),
|
|
role: message.role.clone(),
|
|
}]
|
|
}
|
|
}
|
|
"tool" => match message
|
|
.tool_state
|
|
.as_ref()
|
|
.unwrap_or(&ToolMessageState::Completed)
|
|
{
|
|
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
|
id: message.id.clone(),
|
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
content: message.content.clone(),
|
|
role: message.role.clone(),
|
|
}],
|
|
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
|
id: message.id.clone(),
|
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
content: message.content.clone(),
|
|
role: message.role.clone(),
|
|
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
|
}],
|
|
},
|
|
_ => Vec::new(),
|
|
}
|
|
}
|
|
|
|
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
|
if message.role != "tool" {
|
|
return true;
|
|
}
|
|
|
|
show_tool_results
|
|
|| matches!(
|
|
message
|
|
.tool_state
|
|
.as_ref()
|
|
.unwrap_or(&ToolMessageState::Completed),
|
|
ToolMessageState::PendingUserAction
|
|
)
|
|
}
|
|
|
|
async fn handle_inbound(
|
|
state: &Arc<GatewayState>,
|
|
session: &Arc<Mutex<Session>>,
|
|
runtime_session_id: &str,
|
|
current_session_id: &mut String,
|
|
inbound: WsInbound,
|
|
) -> Result<(), crate::agent::AgentError> {
|
|
match inbound {
|
|
WsInbound::UserInput {
|
|
content,
|
|
chat_id,
|
|
sender_id,
|
|
..
|
|
} => {
|
|
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
|
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
|
let (history, agent, user_tx) = {
|
|
let mut session_guard = session.lock().await;
|
|
|
|
session_guard.ensure_persistent_session(&chat_id)?;
|
|
session_guard.ensure_chat_loaded(&chat_id)?;
|
|
|
|
if let Some(command_response) =
|
|
handle_in_chat_command(&mut session_guard, &chat_id, &content)?
|
|
{
|
|
let _ = session_guard
|
|
.send(WsOutbound::AssistantResponse {
|
|
id: uuid::Uuid::new_v4().to_string(),
|
|
content: command_response,
|
|
role: "assistant".to_string(),
|
|
})
|
|
.await;
|
|
return Ok(());
|
|
}
|
|
|
|
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
|
|
|
|
let user_message = session_guard.create_user_message(&content, Vec::new());
|
|
let user_message_id = user_message.id.clone();
|
|
session_guard.append_persisted_message(&chat_id, user_message)?;
|
|
|
|
let history = session_guard.get_or_create_history(&chat_id).clone();
|
|
session_guard.record_skill_offer(&chat_id)?;
|
|
|
|
let live_emitter = Arc::new(WsToolCallEmitter {
|
|
sender: session_guard.user_tx.clone(),
|
|
show_tool_results: state.config.gateway.show_tool_results,
|
|
});
|
|
let agent = session_guard
|
|
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
|
|
.with_emitted_message_handler(live_emitter);
|
|
|
|
(history, agent, session_guard.user_tx.clone())
|
|
};
|
|
|
|
match agent.process(history).await {
|
|
Ok(result) => {
|
|
let mut session_guard = session.lock().await;
|
|
session_guard
|
|
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
|
for outbound in result
|
|
.emitted_messages
|
|
.iter()
|
|
.filter(|message| {
|
|
!message.is_assistant_tool_call_message()
|
|
&& should_display_message_to_user(
|
|
state.config.gateway.show_tool_results,
|
|
message,
|
|
)
|
|
})
|
|
.flat_map(ws_outbound_from_chat_message)
|
|
{
|
|
let _ = session_guard.send(outbound).await;
|
|
}
|
|
|
|
drop(session_guard);
|
|
|
|
if let Err(error) =
|
|
schedule_background_history_compaction(session.clone(), chat_id.clone())
|
|
.await
|
|
{
|
|
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
|
|
}
|
|
}
|
|
Err(error) => {
|
|
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
|
let _ = user_tx
|
|
.send(WsOutbound::Error {
|
|
code: "LLM_ERROR".to_string(),
|
|
message: error.to_string(),
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
WsInbound::ClearHistory {
|
|
session_id,
|
|
chat_id,
|
|
} => {
|
|
let target = session_id
|
|
.or(chat_id)
|
|
.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.clear_session_messages(&target)?;
|
|
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.remove_history(&target);
|
|
let _ = session_guard
|
|
.send(WsOutbound::HistoryCleared { session_id: target })
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::CreateSession { title } => {
|
|
let record = state.session_manager.create_cli_session(title.as_deref())?;
|
|
*current_session_id = record.id.clone();
|
|
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.ensure_chat_loaded(&record.id)?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionCreated {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::ListSessions { include_archived } => {
|
|
let records = state.session_manager.list_cli_sessions(include_archived)?;
|
|
let summaries = records.into_iter().map(to_session_summary).collect();
|
|
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionList {
|
|
sessions: summaries,
|
|
current_session_id: Some(current_session_id.clone()),
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::LoadSession { session_id } => {
|
|
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::Error {
|
|
code: "SESSION_NOT_FOUND".to_string(),
|
|
message: format!("Session not found: {}", session_id),
|
|
})
|
|
.await;
|
|
return Ok(());
|
|
};
|
|
|
|
*current_session_id = record.id.clone();
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.ensure_chat_loaded(&record.id)?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionLoaded {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
message_count: record.message_count,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::RenameSession { session_id, title } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.rename_session(&target, &title)?;
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionRenamed {
|
|
session_id: target,
|
|
title,
|
|
})
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::ArchiveSession { session_id } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.archive_session(&target)?;
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionArchived { session_id: target })
|
|
.await;
|
|
Ok(())
|
|
}
|
|
WsInbound::DeleteSession { session_id } => {
|
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
state.session_manager.delete_session(&target)?;
|
|
|
|
let replacement = if target == *current_session_id {
|
|
Some(state.session_manager.create_cli_session(None)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut session_guard = session.lock().await;
|
|
session_guard.remove_history(&target);
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionDeleted {
|
|
session_id: target.clone(),
|
|
})
|
|
.await;
|
|
|
|
if let Some(record) = replacement {
|
|
*current_session_id = record.id.clone();
|
|
session_guard.ensure_chat_loaded(&record.id)?;
|
|
let _ = session_guard
|
|
.send(WsOutbound::SessionCreated {
|
|
session_id: record.id,
|
|
title: record.title,
|
|
})
|
|
.await;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
WsInbound::Ping => {
|
|
let session_guard = session.lock().await;
|
|
let _ = session_guard.send(WsOutbound::Pong).await;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
|
sender_id
|
|
.map(str::trim)
|
|
.filter(|sender_id| !sender_id.is_empty())
|
|
.map(ToOwned::to_owned)
|
|
.unwrap_or_else(|| runtime_session_id.to_string())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
|
|
ws_outbound_from_chat_message,
|
|
};
|
|
use crate::agent::EmittedMessageHandler;
|
|
use crate::bus::ChatMessage;
|
|
use crate::bus::message::ToolMessageState;
|
|
use crate::protocol::WsOutbound;
|
|
use crate::providers::ToolCall;
|
|
use serde_json::json;
|
|
use tokio::sync::mpsc;
|
|
|
|
#[test]
|
|
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
|
|
let message = ChatMessage::assistant_with_tool_calls(
|
|
"",
|
|
vec![ToolCall {
|
|
id: "call-1".to_string(),
|
|
name: "calculator".to_string(),
|
|
arguments: json!({"expression": "1 + 1"}),
|
|
}],
|
|
);
|
|
|
|
let outbound = ws_outbound_from_chat_message(&message);
|
|
|
|
assert_eq!(outbound.len(), 1);
|
|
match &outbound[0] {
|
|
WsOutbound::ToolCall {
|
|
tool_call_id,
|
|
tool_name,
|
|
arguments,
|
|
content,
|
|
..
|
|
} => {
|
|
assert_eq!(tool_call_id, "call-1");
|
|
assert_eq!(tool_name, "calculator");
|
|
assert_eq!(arguments["expression"], "1 + 1");
|
|
assert_eq!(content, "### calculator\n- expression: 1 + 1");
|
|
}
|
|
other => panic!("unexpected outbound variant: {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
|
|
let message = ChatMessage::assistant_with_tool_calls(
|
|
"日报已整理完成。",
|
|
vec![ToolCall {
|
|
id: "call-1".to_string(),
|
|
name: "memory_manage".to_string(),
|
|
arguments: json!({"action": "put"}),
|
|
}],
|
|
);
|
|
|
|
let outbound = ws_outbound_from_chat_message(&message);
|
|
|
|
assert_eq!(outbound.len(), 2);
|
|
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
|
|
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
|
|
}
|
|
|
|
#[test]
|
|
fn test_ws_outbound_from_chat_message_includes_tool_results() {
|
|
let message = ChatMessage::tool("call-1", "calculator", "2");
|
|
|
|
let outbound = ws_outbound_from_chat_message(&message);
|
|
|
|
assert_eq!(outbound.len(), 1);
|
|
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
|
|
}
|
|
|
|
#[test]
|
|
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
|
|
let message = ChatMessage::tool_with_state(
|
|
"call-1",
|
|
"bash",
|
|
"等待你完成授权后再继续。",
|
|
ToolMessageState::PendingUserAction,
|
|
);
|
|
|
|
let outbound = ws_outbound_from_chat_message(&message);
|
|
|
|
assert_eq!(outbound.len(), 1);
|
|
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
|
}
|
|
|
|
#[test]
|
|
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
|
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
|
let pending = ChatMessage::tool_with_state(
|
|
"call-2",
|
|
"bash",
|
|
"waiting",
|
|
ToolMessageState::PendingUserAction,
|
|
);
|
|
|
|
assert!(!should_display_message_to_user(false, &completed));
|
|
assert!(should_display_message_to_user(false, &pending));
|
|
assert!(should_display_message_to_user(true, &completed));
|
|
}
|
|
|
|
#[test]
|
|
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
|
assert_eq!(
|
|
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
|
|
"user-42"
|
|
);
|
|
assert_eq!(
|
|
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
|
|
"user-42"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
|
|
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
|
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() {
|
|
let (sender, mut receiver) = mpsc::channel(4);
|
|
let emitter = WsToolCallEmitter {
|
|
sender,
|
|
show_tool_results: false,
|
|
};
|
|
|
|
emitter
|
|
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
|
.await;
|
|
|
|
assert!(
|
|
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
|
.await
|
|
.is_err()
|
|
);
|
|
}
|
|
}
|