oudecheng 2e13f6932c feat: Enhance session management with topic support
- Added topic management capabilities, allowing users to create, switch, and query topics within sessions.
- Updated command structure to include new commands: SwitchSession and GetCurrentSession.
- Introduced TopicRecord for managing topic data in the storage layer.
- Modified session handlers to accommodate topic operations, including listing and loading topics.
- Enhanced database schema to support topics, including new tables and relationships.
- Updated input adapters to recognize new commands and handle topic-related actions.
- Improved logging for session and topic operations to aid in debugging and monitoring.
2026-05-15 15:01:58 +08:00

326 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use super::GatewayState;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::InboundMessage;
use crate::command::adapter::{InputAdapter, OutputAdapter};
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
use crate::command::context::CommandContext;
use crate::command::handler::CommandRouter;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::session_query::SessionQueryCommandHandler;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use crate::skills::SkillPromptProvider;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
const CLI_CHANNEL_NAME: &str = "cli";
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
})
}
async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (sender, receiver) = mpsc::channel::<WsOutbound>(100);
let cli_sessions = state.session_manager.cli_sessions();
let initial_record = match cli_sessions.create(None) {
Ok(record) => record,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial CLI session");
return;
}
};
let runtime_session_id = uuid::Uuid::new_v4().to_string();
let mut current_session_id = initial_record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
current_session_id.clone(),
runtime_session_id.clone(),
sender.clone(),
)
.await;
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
let _ = sender
.send(WsOutbound::SessionEstablished {
session_id: current_session_id.clone(),
})
.await;
let (mut ws_sender, mut ws_receiver) = ws.split();
let mut receiver = receiver;
let session_id_for_sender = runtime_session_id.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
}
}
});
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let text = text.to_string();
match parse_inbound(&text) {
Ok(inbound) => {
if let Err(e) = handle_inbound(
&state,
&sender,
&runtime_session_id,
&mut current_session_id,
inbound,
)
.await
{
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = sender
.send(WsOutbound::Error {
code: "SESSION_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = sender
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
break;
}
_ => {}
}
}
state
.channel_manager
.cli_channel()
.unregister_connection(&runtime_session_id)
.await;
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
async fn handle_inbound(
state: &Arc<GatewayState>,
sender: &mpsc::Sender<WsOutbound>,
runtime_session_id: &str,
current_session_id: &mut String,
inbound: WsInbound,
) -> Result<(), AgentError> {
match inbound {
WsInbound::Message {
content,
chat_id,
sender_id,
..
} => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
state
.channel_manager
.cli_channel()
.register_connection(
chat_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
state
.bus
.publish_inbound(InboundMessage {
channel: CLI_CHANNEL_NAME.to_string(),
sender_id,
chat_id,
content,
timestamp: current_timestamp(),
media: Vec::new(),
metadata: HashMap::new(),
forwarded_metadata: HashMap::new(),
})
.await
.map_err(|error| AgentError::Other(error.to_string()))?;
Ok(())
}
WsInbound::Command { payload } => {
// 使用 Command 系统处理命令
let input_adapter = WebSocketInputAdapter::new();
let output_adapter = WebSocketOutputAdapter::new();
// 解析命令
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
.with_session_id(current_session_id.as_str());
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
Ok(Some(cmd)) => cmd,
Ok(None) => {
// 不是命令,返回错误
let _ = sender
.send(WsOutbound::Error {
code: "INVALID_COMMAND".to_string(),
message: "Invalid command payload".to_string(),
})
.await;
return Ok(());
}
Err(e) => {
let _ = sender
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
return Ok(());
}
};
// 创建命令路由器
let cli_sessions = state.session_manager.cli_sessions();
let store = state.session_manager.store();
let skills = state.session_manager.skills();
let provider_config = state.config.get_provider_config("default")
.map_err(|e| AgentError::Other(e.to_string()))?;
let prompt_repository = state.session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0,
provider_config.clone(),
prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(skills)),
]));
let mut router = CommandRouter::new();
router.register(Box::new(SessionCommandHandler::new(store.clone())));
// 修复:添加 SessionManager 到 SessionQueryCommandHandler
let session_query_handler = SessionQueryCommandHandler::new(store.clone())
.with_session_manager(state.session_manager.clone());
router.register(Box::new(session_query_handler));
router.register(Box::new(SaveSessionCommandHandler::new(
store,
system_prompt_provider,
)));
// 构建命令上下文
tracing::debug!(
current_session_id = %current_session_id,
"Building CommandContext for WebSocket command"
);
let cmd_ctx = CommandContext::new("websocket", "cli")
.with_session_id(current_session_id.as_str())
.with_chat_id(current_session_id.as_str());
// 执行命令
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
// 处理响应
if response.success {
// 更新当前会话 ID如果是创建会话
if let Some(session_id) = response.metadata.get("session_id") {
tracing::info!(
old_session_id = %current_session_id,
new_session_id = %session_id,
"Updating current_session_id"
);
*current_session_id = session_id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
session_id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
}
} else if let Some(ref error) = response.error {
tracing::warn!(
error_code = %error.code,
error_message = %error.message,
"Command failed"
);
}
// 适配并发送响应
let outbounds = output_adapter.adapt(response);
for msg in outbounds {
let _ = sender.send(msg).await;
}
Ok(())
}
WsInbound::Ping => {
let _ = sender.send(WsOutbound::Pong).await;
Ok(())
}
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> String {
sender_id
.map(str::trim)
.filter(|sender_id| !sender_id.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| runtime_session_id.to_string())
}
#[cfg(test)]
mod tests {
use super::resolve_ws_sender_id;
#[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
assert_eq!(
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
"user-42"
);
assert_eq!(
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
"user-42"
);
}
#[test]
fn test_resolve_ws_sender_id_falls_back_to_runtime_session_id() {
assert_eq!(resolve_ws_sender_id(None, "runtime-1"), "runtime-1");
assert_eq!(resolve_ws_sender_id(Some(" "), "runtime-1"), "runtime-1");
}
}