343 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use super::GatewayState;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::InboundMessage;
use crate::command::adapter::{InputAdapter, OutputAdapter};
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
use crate::command::context::CommandContext;
use crate::command::handler::CommandRouter;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::session_query::SessionQueryCommandHandler;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use crate::skills::SkillPromptProvider;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
const CLI_CHANNEL_NAME: &str = "cli";
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let cli_sessions = state.session_manager.cli_sessions();
let initial_record = match cli_sessions.create(None) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
return;
}
};
let runtime_session_id = uuid::Uuid::new_v4().to_string();
let mut current_session_id = initial_record.id.clone();
let mut current_topic_id: Option<String> = None;
state
.channel_manager
.cli_channel()
.register_connection(
current_session_id.clone(),
runtime_session_id.clone(),
sender.clone(),
)
.await;
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = sender
.send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(),
})
.await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = runtime_session_id.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
if let Err(e) = handle_inbound(
&state,
&sender,
&runtime_session_id,
&mut current_session_id,
&mut current_topic_id,
inbound,
)
.await
{
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = sender
.send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = sender
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
break;
}
_ => {}
}
}
state
.channel_manager
.cli_channel()
.unregister_connection(&runtime_session_id)
.await;
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
async fn handle_inbound(
state: &Arc<GatewayState>,
sender: &mpsc::Sender<WsOutbound>,
runtime_session_id: &str,
current_session_id: &mut String,
current_topic_id: &mut Option<String>,
inbound: WsInbound,
) -> Result<(), AgentError> {
match inbound {
WsInbound::Message {
content,
chat_id,
sender_id,
..
} => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
state
.channel_manager
.cli_channel()
.register_connection(
chat_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
state
.bus
.publish_inbound(InboundMessage {
channel: CLI_CHANNEL_NAME.to_string(),
sender_id,
chat_id,
content,
timestamp: current_timestamp(),
media: Vec::new(),
metadata: HashMap::new(),
forwarded_metadata: HashMap::new(),
})
.await
.map_err(|error| AgentError::Other(error.to_string()))?;
Ok(())
}
WsInbound::Command { payload } => {
// 使用 Command 系统处理命令
let input_adapter = WebSocketInputAdapter::new();
let output_adapter = WebSocketOutputAdapter::new();
// 解析命令
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
.with_session_id(current_session_id.as_str());
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
Ok(Some(cmd)) => cmd,
Ok(None) => {
// 不是命令,返回错误
let _ = sender
.send(WsOutbound::Error {
code: "INVALID_COMMAND".to_string(),
message: "Invalid command payload".to_string(),
})
.await;
return Ok(());
}
Err(e) => {
let _ = sender
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
return Ok(());
}
};
// 创建命令路由器
let _cli_sessions = state.session_manager.cli_sessions();
let store = state.session_manager.store();
let skills = state.session_manager.skills();
let provider_config = state.config.get_provider_config("default")
.map_err(|e| AgentError::Other(e.to_string()))?;
let prompt_repository = state.session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0,
provider_config.clone(),
prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(skills)),
]));
let mut router = CommandRouter::new();
// 注册 Session 处理器,添加 SessionManager
let session_handler = SessionCommandHandler::new(store.clone())
.with_session_manager(state.session_manager.clone());
router.register(Box::new(session_handler));
// 注册 SessionQuery 处理器
let session_query_handler = SessionQueryCommandHandler::new(store.clone())
.with_session_manager(state.session_manager.clone());
router.register(Box::new(session_query_handler));
router.register(Box::new(SaveSessionCommandHandler::new(
store,
system_prompt_provider,
)));
// 构建命令上下文
tracing::debug!(
current_session_id = %current_session_id,
current_topic_id = ?current_topic_id,
"Building CommandContext for WebSocket command"
);
let cmd_ctx = CommandContext::new("websocket", "cli")
.with_session_id(current_session_id.as_str())
.with_chat_id(current_session_id.as_str())
.with_topic_id(current_topic_id.as_deref().unwrap_or(""));
// 执行命令
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
// 处理响应
if response.success {
// 更新当前会话 ID如果是创建会话
if let Some(session_id) = response.metadata.get("session_id") {
tracing::info!(
old_session_id = %current_session_id,
new_session_id = %session_id,
"Updating current_session_id"
);
*current_session_id = session_id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
session_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
}
// 更新当前话题 ID如果是创建话题或切换话题
if let Some(topic_id) = response.metadata.get("topic_id") {
tracing::info!(
old_topic_id = ?current_topic_id,
new_topic_id = %topic_id,
"Updating current_topic_id"
);
*current_topic_id = Some(topic_id.clone());
}
} else if let Some(ref error) = response.error {
tracing::warn!(
error_code = %error.code,
error_message = %error.message,
"Command failed"
);
}
// 适配并发送响应
let outbounds = output_adapter.adapt(response);
for msg in outbounds {
let _ = sender.send(msg).await;
}
Ok(())
}
WsInbound::Ping => {
let _ = sender.send(WsOutbound::Pong).await;
Ok(())
}
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
sender_id
.map(str::trim)
.filter(|sender_id| !sender_id.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| runtime_session_id.to_string())
}
#[cfg(test)]
mod tests {
use super::resolve_ws_sender_id;
#[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
assert_eq!(
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
"user-42"
);
assert_eq!(
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
"user-42"
);
}
#[test]
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
}
}