- Added topic management capabilities, allowing users to create, switch, and query topics within sessions. - Updated command structure to include new commands: SwitchSession and GetCurrentSession. - Introduced TopicRecord for managing topic data in the storage layer. - Modified session handlers to accommodate topic operations, including listing and loading topics. - Enhanced database schema to support topics, including new tables and relationships. - Updated input adapters to recognize new commands and handle topic-related actions. - Improved logging for session and topic operations to aid in debugging and monitoring.
326 lines
12 KiB
Rust
326 lines
12 KiB
Rust
use super::GatewayState;
|
||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||
use crate::bus::InboundMessage;
|
||
use crate::command::adapter::{InputAdapter, OutputAdapter};
|
||
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
||
use crate::command::context::CommandContext;
|
||
use crate::command::handler::CommandRouter;
|
||
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||
use crate::command::handlers::session::SessionCommandHandler;
|
||
use crate::command::handlers::session_query::SessionQueryCommandHandler;
|
||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||
use crate::skills::SkillPromptProvider;
|
||
use axum::extract::State;
|
||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||
use axum::response::Response;
|
||
use futures_util::{SinkExt, StreamExt};
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use tokio::sync::mpsc;
|
||
|
||
const CLI_CHANNEL_NAME: &str = "cli";
|
||
|
||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||
ws.on_upgrade(|socket| async {
|
||
handle_socket(socket, state).await;
|
||
})
|
||
}
|
||
|
||
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
|
||
|
||
let cli_sessions = state.session_manager.cli_sessions();
|
||
let initial_record = match cli_sessions.create(None) {
|
||
Ok(record) => record,
|
||
Err(e) => {
|
||
tracing::error!(error = %e, "Failed to create initial CLI session");
|
||
return;
|
||
}
|
||
};
|
||
|
||
let runtime_session_id = uuid::Uuid::new_v4().to_string();
|
||
let mut current_session_id = initial_record.id.clone();
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.register_connection(
|
||
current_session_id.clone(),
|
||
runtime_session_id.clone(),
|
||
sender.clone(),
|
||
)
|
||
.await;
|
||
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
||
|
||
let _ = sender
|
||
.send(WsOutbound::SessionEstablished {
|
||
session_id: current_session_id.clone(),
|
||
})
|
||
.await;
|
||
|
||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||
|
||
let mut receiver = receiver;
|
||
let session_id_for_sender = runtime_session_id.clone();
|
||
tokio::spawn(async move {
|
||
while let Some(msg) = receiver.recv().await {
|
||
if let Ok(text) = serialize_outbound(&msg) {
|
||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
while let Some(msg) = ws_receiver.next().await {
|
||
match msg {
|
||
Ok(WsMessage::Text(text)) => {
|
||
let text = text.to_string();
|
||
match parse_inbound(&text) {
|
||
Ok(inbound) => {
|
||
if let Err(e) = handle_inbound(
|
||
&state,
|
||
&sender,
|
||
&runtime_session_id,
|
||
&mut current_session_id,
|
||
inbound,
|
||
)
|
||
.await
|
||
{
|
||
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
code: "SESSION_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
})
|
||
.await;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
code: "PARSE_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
})
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
Ok(WsMessage::Close(_)) | Err(_) => {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
|
||
break;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.unregister_connection(&runtime_session_id)
|
||
.await;
|
||
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||
}
|
||
|
||
|
||
async fn handle_inbound(
|
||
state: &Arc<GatewayState>,
|
||
sender: &mpsc::Sender<WsOutbound>,
|
||
runtime_session_id: &str,
|
||
current_session_id: &mut String,
|
||
inbound: WsInbound,
|
||
) -> Result<(), AgentError> {
|
||
match inbound {
|
||
WsInbound::Message {
|
||
content,
|
||
chat_id,
|
||
sender_id,
|
||
..
|
||
} => {
|
||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.register_connection(
|
||
chat_id.clone(),
|
||
runtime_session_id.to_string(),
|
||
sender.clone(),
|
||
)
|
||
.await;
|
||
|
||
state
|
||
.bus
|
||
.publish_inbound(InboundMessage {
|
||
channel: CLI_CHANNEL_NAME.to_string(),
|
||
sender_id,
|
||
chat_id,
|
||
content,
|
||
timestamp: current_timestamp(),
|
||
media: Vec::new(),
|
||
metadata: HashMap::new(),
|
||
forwarded_metadata: HashMap::new(),
|
||
})
|
||
.await
|
||
.map_err(|error| AgentError::Other(error.to_string()))?;
|
||
|
||
Ok(())
|
||
}
|
||
WsInbound::Command { payload } => {
|
||
// 使用 Command 系统处理命令
|
||
let input_adapter = WebSocketInputAdapter::new();
|
||
let output_adapter = WebSocketOutputAdapter::new();
|
||
|
||
// 解析命令
|
||
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
|
||
.with_session_id(current_session_id.as_str());
|
||
|
||
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
|
||
Ok(Some(cmd)) => cmd,
|
||
Ok(None) => {
|
||
// 不是命令,返回错误
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
code: "INVALID_COMMAND".to_string(),
|
||
message: "Invalid command payload".to_string(),
|
||
})
|
||
.await;
|
||
return Ok(());
|
||
}
|
||
Err(e) => {
|
||
let _ = sender
|
||
.send(WsOutbound::Error {
|
||
code: "PARSE_ERROR".to_string(),
|
||
message: e.to_string(),
|
||
})
|
||
.await;
|
||
return Ok(());
|
||
}
|
||
};
|
||
|
||
// 创建命令路由器
|
||
let cli_sessions = state.session_manager.cli_sessions();
|
||
let store = state.session_manager.store();
|
||
let skills = state.session_manager.skills();
|
||
let provider_config = state.config.get_provider_config("default")
|
||
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||
let prompt_repository = state.session_manager.store().clone();
|
||
|
||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||
Box::new(AgentPromptProvider::new(
|
||
0,
|
||
provider_config.clone(),
|
||
prompt_repository.clone(),
|
||
)),
|
||
Box::new(SkillPromptProvider::new(skills)),
|
||
]));
|
||
|
||
let mut router = CommandRouter::new();
|
||
router.register(Box::new(SessionCommandHandler::new(store.clone())));
|
||
// 修复:添加 SessionManager 到 SessionQueryCommandHandler
|
||
let session_query_handler = SessionQueryCommandHandler::new(store.clone())
|
||
.with_session_manager(state.session_manager.clone());
|
||
router.register(Box::new(session_query_handler));
|
||
router.register(Box::new(SaveSessionCommandHandler::new(
|
||
store,
|
||
system_prompt_provider,
|
||
)));
|
||
|
||
// 构建命令上下文
|
||
tracing::debug!(
|
||
current_session_id = %current_session_id,
|
||
"Building CommandContext for WebSocket command"
|
||
);
|
||
let cmd_ctx = CommandContext::new("websocket", "cli")
|
||
.with_session_id(current_session_id.as_str())
|
||
.with_chat_id(current_session_id.as_str());
|
||
|
||
// 执行命令
|
||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||
|
||
// 处理响应
|
||
if response.success {
|
||
// 更新当前会话 ID(如果是创建会话)
|
||
if let Some(session_id) = response.metadata.get("session_id") {
|
||
tracing::info!(
|
||
old_session_id = %current_session_id,
|
||
new_session_id = %session_id,
|
||
"Updating current_session_id"
|
||
);
|
||
*current_session_id = session_id.clone();
|
||
state
|
||
.channel_manager
|
||
.cli_channel()
|
||
.register_connection(
|
||
session_id.clone(),
|
||
runtime_session_id.to_string(),
|
||
sender.clone(),
|
||
)
|
||
.await;
|
||
}
|
||
} else if let Some(ref error) = response.error {
|
||
tracing::warn!(
|
||
error_code = %error.code,
|
||
error_message = %error.message,
|
||
"Command failed"
|
||
);
|
||
}
|
||
|
||
// 适配并发送响应
|
||
let outbounds = output_adapter.adapt(response);
|
||
for msg in outbounds {
|
||
let _ = sender.send(msg).await;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
WsInbound::Ping => {
|
||
let _ = sender.send(WsOutbound::Pong).await;
|
||
Ok(())
|
||
}
|
||
}
|
||
}
|
||
|
||
fn current_timestamp() -> i64 {
|
||
std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap()
|
||
.as_millis() as i64
|
||
}
|
||
|
||
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
|
||
sender_id
|
||
.map(str::trim)
|
||
.filter(|sender_id| !sender_id.is_empty())
|
||
.map(ToOwned::to_owned)
|
||
.unwrap_or_else(|| runtime_session_id.to_string())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::resolve_ws_sender_id;
|
||
|
||
#[test]
|
||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||
assert_eq!(
|
||
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
|
||
"user-42"
|
||
);
|
||
assert_eq!(
|
||
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
|
||
"user-42"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
|
||
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
|
||
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
|
||
}
|
||
}
|