ooodc 73dab09bfe Refactor code for improved readability and consistency
- Adjusted formatting and indentation in various files for better clarity.
- Consolidated multi-line statements into single lines where appropriate.
- Enhanced error handling messages for better debugging.
- Added a new InboundProcessor struct to handle inbound messages more effectively.
- Updated test cases to ensure they align with the new code structure.
2026-04-28 10:33:31 +08:00

619 lines
22 KiB
Rust

use super::{
GatewayState,
session::{Session, handle_in_chat_command, schedule_background_history_compaction},
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use async_trait::async_trait;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>,
show_tool_results: bool,
}
#[async_trait]
impl EmittedMessageHandler for WsToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
if !should_display_message_to_user(self.show_tool_results, &message) {
return;
}
for outbound in ws_outbound_from_chat_message(&message) {
let _ = self.sender.send(outbound).await;
}
}
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let provider_config = match state.config.get_provider_config("default") {
Ok(cfg) => cfg,
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
return;
}
};
let initial_record = match state.session_manager.create_cli_session(None) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
return;
}
};
let channel_name = "cli".to_string();
// 创建 CLI session
let session = match Session::new(
channel_name.clone(),
provider_config,
sender,
state.session_manager.tools(),
state.session_manager.skills(),
state.session_manager.store(),
state.config.gateway.agent_prompt_reinject_every,
)
.await
{
Ok(s) => Arc::new(Mutex::new(s)),
Err(e) => {
tracing::error!(error = %e, "Failed to create session");
return;
}
};
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
return;
}
let runtime_session_id = session.lock().await.id.to_string();
let mut current_session_id = initial_record.id.clone();
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = session
.lock()
.await
.send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(),
})
.await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = runtime_session_id.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
if let Err(e) = handle_inbound(
&state,
&session,
&runtime_session_id,
&mut current_session_id,
inbound,
)
.await
{
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = session
.lock()
.await
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
break;
}
_ => {}
}
}
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
SessionSummary {
session_id: record.id,
title: record.title,
channel_name: record.channel_name,
chat_id: record.chat_id,
message_count: record.message_count,
last_active_at: record.last_active_at,
archived_at: record.archived_at,
}
}
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
let mut outbound = Vec::new();
if !message.content.trim().is_empty() {
outbound.push(WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
});
}
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
}],
},
_ => Vec::new(),
}
}
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
async fn handle_inbound(
state: &Arc<GatewayState>,
session: &Arc<Mutex<Session>>,
runtime_session_id: &str,
current_session_id: &mut String,
inbound: WsInbound,
) -> Result<(), crate::agent::AgentError> {
match inbound {
WsInbound::UserInput {
content,
chat_id,
sender_id,
..
} => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
let (history, agent, user_tx) = {
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
if let Some(command_response) =
handle_in_chat_command(&mut session_guard, &chat_id, &content)?
{
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
content: command_response,
role: "assistant".to_string(),
})
.await;
return Ok(());
}
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;
let history = session_guard.get_or_create_history(&chat_id).clone();
session_guard.record_skill_offer(&chat_id)?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
show_tool_results: state.config.gateway.show_tool_results,
});
let agent = session_guard
.create_agent(&chat_id, Some(&sender_id), Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
(history, agent, session_guard.user_tx.clone())
};
match agent.process(history).await {
Ok(result) => {
let mut session_guard = session.lock().await;
session_guard
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result
.emitted_messages
.iter()
.filter(|message| {
!message.is_assistant_tool_call_message()
&& should_display_message_to_user(
state.config.gateway.show_tool_results,
message,
)
})
.flat_map(ws_outbound_from_chat_message)
{
let _ = session_guard.send(outbound).await;
}
drop(session_guard);
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone())
.await
{
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
}
}
Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
let _ = user_tx
.send(WsOutbound::Error {
code: "LLM_ERROR".to_string(),
message: error.to_string(),
})
.await;
}
}
Ok(())
}
WsInbound::ClearHistory {
session_id,
chat_id,
} => {
let target = session_id
.or(chat_id)
.unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?;
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::HistoryCleared { session_id: target })
.await;
Ok(())
}
WsInbound::CreateSession { title } => {
let record = state.session_manager.create_cli_session(title.as_deref())?;
*current_session_id = record.id.clone();
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
})
.await;
Ok(())
}
WsInbound::ListSessions { include_archived } => {
let records = state.session_manager.list_cli_sessions(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect();
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionList {
sessions: summaries,
current_session_id: Some(current_session_id.clone()),
})
.await;
Ok(())
}
WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::Error {
code: "SESSION_NOT_FOUND".to_string(),
message: format!("Session not found: {}", session_id),
})
.await;
return Ok(());
};
*current_session_id = record.id.clone();
let mut session_guard = session.lock().await;
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionLoaded {
session_id: record.id,
title: record.title,
message_count: record.message_count,
})
.await;
Ok(())
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.rename_session(&target, &title)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionRenamed {
session_id: target,
title,
})
.await;
Ok(())
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.archive_session(&target)?;
let session_guard = session.lock().await;
let _ = session_guard
.send(WsOutbound::SessionArchived { session_id: target })
.await;
Ok(())
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.delete_session(&target)?;
let replacement = if target == *current_session_id {
Some(state.session_manager.create_cli_session(None)?)
} else {
None
};
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::SessionDeleted {
session_id: target.clone(),
})
.await;
if let Some(record) = replacement {
*current_session_id = record.id.clone();
session_guard.ensure_chat_loaded(&record.id)?;
let _ = session_guard
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
})
.await;
}
Ok(())
}
WsInbound::Ping => {
let session_guard = session.lock().await;
let _ = session_guard.send(WsOutbound::Pong).await;
Ok(())
}
}
}
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
sender_id
.map(str::trim)
.filter(|sender_id| !sender_id.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| runtime_session_id.to_string())
}
#[cfg(test)]
mod tests {
use super::{
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
ws_outbound_from_chat_message,
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::protocol::WsOutbound;
use crate::providers::ToolCall;
use serde_json::json;
use tokio::sync::mpsc;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "### calculator\n- expression: 1 + 1");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_keeps_assistant_content_when_tool_calls_exist() {
let message = ChatMessage::assistant_with_tool_calls(
"日报已整理完成。",
vec![ToolCall {
id: "call-1".to_string(),
name: "memory_manage".to_string(),
arguments: json!({"action": "put"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 2);
assert!(matches!(outbound[0], WsOutbound::AssistantResponse { .. }));
assert!(matches!(outbound[1], WsOutbound::ToolCall { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
let completed = ChatMessage::tool("call-1", "calculator", "2");
let pending = ChatMessage::tool_with_state(
"call-2",
"bash",
"waiting",
ToolMessageState::PendingUserAction,
);
assert!(!should_display_message_to_user(false, &completed));
assert!(should_display_message_to_user(false, &pending));
assert!(should_display_message_to_user(true, &completed));
}
#[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
assert_eq!(
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
"user-42"
);
assert_eq!(
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
"user-42"
);
}
#[test]
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
}
#[tokio::test]
async fn test_ws_tool_call_emitter_hides_completed_tool_results_when_disabled() {
let (sender, mut receiver) = mpsc::channel(4);
let emitter = WsToolCallEmitter {
sender,
show_tool_results: false,
};
emitter
.handle(ChatMessage::tool("call-1", "calculator", "2"))
.await;
assert!(
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
.await
.is_err()
);
}
}