diff --git a/src/bus/message.rs b/src/bus/message.rs index 9da269e..d319900 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -164,6 +164,7 @@ pub struct InboundMessage { pub channel: String, pub sender_id: String, pub chat_id: String, + pub dialog_id: Option, pub content: String, pub timestamp: i64, pub media: Vec, @@ -199,52 +200,20 @@ impl OutboundMessage { } } -// ============================================================================ -// ControlInbound - Session management operations (CLI channel only) -// ============================================================================ - -/// Session management operations that flow through the control channel -#[derive(Debug, Clone)] -pub enum ControlInbound { - CreateSession { title: Option }, - ListSessions { include_archived: bool }, - LoadSession { session_id: String }, - RenameSession { session_id: String, title: String }, - ArchiveSession { session_id: String }, - DeleteSession { session_id: String }, - ClearHistory { session_id: String }, -} - -// ============================================================================ -// ControlOutbound - Responses for control operations -// ============================================================================ - -/// Responses for session management operations -#[derive(Debug, Clone)] -pub enum ControlOutbound { - SessionCreated { session_id: String, title: String }, - SessionList { sessions: Vec }, - SessionLoaded { session_id: String, title: String, message_count: i64 }, - SessionRenamed { session_id: String, title: String }, - SessionArchived { session_id: String }, - SessionDeleted { session_id: String }, - HistoryCleared { session_id: String }, - Pong, - Error { code: String, message: String }, -} - // ============================================================================ // ControlMessage - Message for control channel (session management) +// Uses SessionCommand from session module // ============================================================================ use crate::channels::base::ChannelError; +use crate::session::{SessionCommand, SessionEvent}; use tokio::sync::mpsc; /// Control message containing a session operation and reply channel #[derive(Debug, Clone)] pub struct ControlMessage { - pub op: ControlInbound, - pub reply_tx: mpsc::Sender>, + pub op: SessionCommand, + pub reply_tx: mpsc::Sender>, } // ============================================================================ diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 3236b4f..7c2de5e 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -2,7 +2,7 @@ pub mod dispatcher; pub mod message; pub use dispatcher::OutboundDispatcher; -pub use message::{ChatMessage, ContentBlock, ControlInbound, ControlMessage, ControlOutbound, InboundMessage, MediaItem, OutboundMessage}; +pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, OutboundMessage}; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; diff --git a/src/channels/cli_chat.rs b/src/channels/cli_chat.rs index 24d127b..83a104d 100644 --- a/src/channels/cli_chat.rs +++ b/src/channels/cli_chat.rs @@ -3,11 +3,17 @@ use async_trait::async_trait; use tokio::sync::{mpsc, Mutex}; use uuid::Uuid; -use crate::bus::{ControlInbound, ControlMessage, ControlOutbound, InboundMessage, MessageBus, OutboundMessage}; +use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage}; +use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId}; use crate::protocol::{parse_inbound, WsInbound, WsOutbound}; use super::base::{Channel, ChannelError}; +/// Generate a short ID (8 characters) from a UUID +fn short_id() -> String { + Uuid::new_v4().to_string()[..8].to_string() +} + // ============================================================================ // Client - Connected CLI client // ============================================================================ @@ -36,6 +42,9 @@ impl CliChatChannel { /// Register a new client connection, returns (session_id, client) pub(crate) async fn register_client(&self, sender: mpsc::Sender) -> (String, Arc) { + // Generate connection ID (used as chat_id) - use short ID + let connection_id = short_id(); + let client = Arc::new(Client { sender, current_session_id: Mutex::new(None), @@ -43,11 +52,12 @@ impl CliChatChannel { self.clients.lock().await.push(client.clone()); // Create initial session via control message - let session_id = match self.create_session_via_control(None).await { + let session_id = match self.create_session_via_control(&connection_id, None).await { Ok(id) => id, Err(e) => { tracing::error!(error = %e, "Failed to create initial session"); - Uuid::new_v4().to_string() + // Fall back to old format for backward compatibility + connection_id.clone() } }; @@ -101,11 +111,11 @@ impl CliChatChannel { match inbound { WsInbound::UserInput { content, chat_id, .. } => { - let chat_id = chat_id.or(current_session_guard.clone()).unwrap_or_else(|| Uuid::new_v4().to_string()); + let chat_id = chat_id.or(current_session_guard.clone()).unwrap_or_else(short_id); // If no session, create one first if current_session_guard.is_none() { - let new_id = self.create_session_via_control(None).await?; + let new_id = self.create_session_via_control(&chat_id, None).await?; *current_session_guard = Some(new_id); } @@ -116,6 +126,7 @@ impl CliChatChannel { channel: self.name().to_string(), sender_id: "cli".to_string(), chat_id: session_id.clone(), + dialog_id: None, // Use default/current dialog content, timestamp: crate::bus::message::current_timestamp(), media: Vec::new(), @@ -131,13 +142,15 @@ impl CliChatChannel { .ok_or_else(|| ChannelError::Other("No active session".to_string()))?; let (reply_tx, mut reply_rx) = mpsc::channel(1); + let session_id = UnifiedSessionId::parse(&target) + .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: ControlInbound::ClearHistory { session_id: target.clone() }, + op: SessionCommand::ClearHistory { session_id }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::HistoryCleared { .. })) => { + Some(Ok(SessionEvent::HistoryCleared { .. })) => { let _ = client .sender .send(WsOutbound::HistoryCleared { session_id: target }) @@ -155,7 +168,10 @@ impl CliChatChannel { } } WsInbound::CreateSession { title } => { - let new_id = self.create_session_via_control(title.as_deref()).await?; + // Use current session's chat_id if available, otherwise generate new one + let chat_id = current_session_guard.clone() + .unwrap_or_else(short_id); + let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?; *current_session_guard = Some(new_id.clone()); let _ = client .sender @@ -166,19 +182,42 @@ impl CliChatChannel { .await; } WsInbound::ListSessions { include_archived } => { + // List dialogs for the current chat + let chat_id = current_session_guard.clone() + .unwrap_or_else(|| "".to_string()); + let chat_id_for_response = chat_id.clone(); let (reply_tx, mut reply_rx) = mpsc::channel(1); bus.publish_control(ControlMessage { - op: ControlInbound::ListSessions { include_archived }, + op: SessionCommand::ListDialogs { + channel: "cli_chat".to_string(), + chat_id, + include_archived, + }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::SessionList { sessions })) => { + Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => { + // Convert DialogInfo to SessionSummary for backward compatibility + let sessions: Vec = dialogs.into_iter().map(|d| { + crate::protocol::SessionSummary { + session_id: d.session_id.to_string(), + title: d.title, + channel_name: d.session_id.channel.clone(), + chat_id: d.session_id.chat_id, + message_count: d.message_count, + last_active_at: d.last_active_at, + archived_at: d.archived_at, + } + }).collect(); + let current_session_id = current_dialog_id.map(|did| { + UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string() + }); let _ = client .sender .send(WsOutbound::SessionList { sessions, - current_session_id: current_session_guard.clone(), + current_session_id, }) .await; } @@ -194,28 +233,44 @@ impl CliChatChannel { } } WsInbound::LoadSession { session_id } => { + // LoadSession: parse the session_id and get current dialog info let (reply_tx, mut reply_rx) = mpsc::channel(1); + let unified_id = UnifiedSessionId::parse(&session_id) + .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: ControlInbound::LoadSession { session_id: session_id.clone() }, + op: SessionCommand::GetCurrentDialog { + channel: unified_id.channel.clone(), + chat_id: unified_id.chat_id.clone(), + }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::SessionLoaded { session_id, title, message_count })) => { - *current_session_guard = Some(session_id.clone()); - let _ = client - .sender - .send(WsOutbound::SessionLoaded { - session_id, - title, - message_count, - }) - .await; + Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => { + if let Some(current_session_id) = current_session_id_opt { + *current_session_guard = Some(current_session_id.to_string()); + let _ = client + .sender + .send(WsOutbound::SessionLoaded { + session_id: current_session_id.to_string(), + title: "Session".to_string(), // TODO: get actual title + message_count: 0, // TODO: get actual count + }) + .await; + } else { + let _ = client + .sender + .send(WsOutbound::Error { + code: "NO_CURRENT_DIALOG".to_string(), + message: "No current dialog".to_string(), + }) + .await; + } } Some(Ok(_)) => { // Unexpected response type } - Some(Err(e)) => { + Some(Err(_e)) => { let _ = client .sender .send(WsOutbound::Error { @@ -235,16 +290,18 @@ impl CliChatChannel { })?; let (reply_tx, mut reply_rx) = mpsc::channel(1); + let unified_id = UnifiedSessionId::parse(&target) + .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: ControlInbound::RenameSession { session_id: target.clone(), title: title.clone() }, + op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::SessionRenamed { session_id, title })) => { + Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => { let _ = client .sender - .send(WsOutbound::SessionRenamed { session_id, title }) + .send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title }) .await; } Some(Ok(_)) => { @@ -264,16 +321,18 @@ impl CliChatChannel { })?; let (reply_tx, mut reply_rx) = mpsc::channel(1); + let unified_id = UnifiedSessionId::parse(&target) + .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: ControlInbound::ArchiveSession { session_id: target.clone() }, + op: SessionCommand::ArchiveDialog { session_id: unified_id }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::SessionArchived { session_id })) => { + Some(Ok(SessionEvent::DialogArchived { session_id })) => { let _ = client .sender - .send(WsOutbound::SessionArchived { session_id }) + .send(WsOutbound::SessionArchived { session_id: session_id.to_string() }) .await; } Some(Ok(_)) => { @@ -293,22 +352,24 @@ impl CliChatChannel { })?; let (reply_tx, mut reply_rx) = mpsc::channel(1); + let unified_id = UnifiedSessionId::parse(&target) + .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; bus.publish_control(ControlMessage { - op: ControlInbound::DeleteSession { session_id: target.clone() }, + op: SessionCommand::DeleteDialog { session_id: unified_id }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::SessionDeleted { session_id })) => { + Some(Ok(SessionEvent::DialogDeleted { session_id })) => { let _ = client .sender - .send(WsOutbound::SessionDeleted { session_id: session_id.clone() }) + .send(WsOutbound::SessionDeleted { session_id: session_id.to_string() }) .await; // If deleting current session, create a new one if current_session_guard.as_deref() == Some(&target) { drop(reply_rx); - if let Ok(new_id) = self.create_session_via_control(None).await { + if let Ok(new_id) = self.create_session_via_control(&target, None).await { *current_session_guard = Some(new_id.clone()); let _ = client .sender @@ -339,7 +400,7 @@ impl CliChatChannel { } /// Create a session via control message and return the session_id - async fn create_session_via_control(&self, title: Option<&str>) -> Result { + async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result { let bus = { let guard = self.bus.lock().unwrap(); guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? @@ -347,13 +408,17 @@ impl CliChatChannel { let (reply_tx, mut reply_rx) = mpsc::channel(1); bus.publish_control(ControlMessage { - op: ControlInbound::CreateSession { title: title.map(String::from) }, + op: SessionCommand::CreateDialog { + channel: "cli_chat".to_string(), + chat_id: connection_id.to_string(), + title: title.map(String::from), + }, reply_tx, }).await?; match reply_rx.recv().await { - Some(Ok(ControlOutbound::SessionCreated { session_id, .. })) => { - Ok(session_id) + Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => { + Ok(session_id.to_string()) } Some(Ok(_)) => { Err(ChannelError::Other("Unexpected response type".to_string())) @@ -388,7 +453,7 @@ impl Channel for CliChatChannel { let clients = self.clients.lock().await.clone(); for client in clients { let outbound = WsOutbound::AssistantResponse { - id: Uuid::new_v4().to_string(), + id: short_id(), content: msg.content.clone(), role: "assistant".to_string(), }; diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index a1fd8e5..b33fb25 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -1106,6 +1106,7 @@ impl FeishuChannel { channel: "feishu".to_string(), sender_id: parsed.open_id.clone(), chat_id: parsed.chat_id.clone(), + dialog_id: None, // Use default/current dialog content: parsed.content.clone(), timestamp: crate::bus::message::current_timestamp(), media: parsed.media.map(|m| vec![m]).unwrap_or_default(), diff --git a/src/client/channel.rs b/src/client/channel.rs deleted file mode 100644 index 56030e1..0000000 --- a/src/client/channel.rs +++ /dev/null @@ -1,50 +0,0 @@ -use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt}; - -pub struct CliChannel { - read: BufReader, - write: tokio::io::Stdout, -} - -impl CliChannel { - pub fn new() -> Self { - Self { - read: BufReader::new(tokio::io::stdin()), - write: tokio::io::stdout(), - } - } - - pub async fn read_line(&mut self, prompt: &str) -> Result, std::io::Error> { - print!("{}", prompt); - self.write.flush().await?; - - let mut line = String::new(); - let bytes_read = self.read.read_line(&mut line).await?; - - if bytes_read == 0 { - return Ok(None); - } - - Ok(Some(line.trim_end().to_string())) - } - - pub async fn write_line(&mut self, content: &str) -> Result<(), std::io::Error> { - self.write.write_all(content.as_bytes()).await?; - self.write.write_all(b"\n").await?; - self.write.flush().await - } - - pub async fn write_response(&mut self, content: &str) -> Result<(), std::io::Error> { - for line in content.lines() { - self.write.write_all(b" ").await?; - self.write.write_all(line.as_bytes()).await?; - self.write.write_all(b"\n").await?; - } - self.write.flush().await - } -} - -impl Default for CliChannel { - fn default() -> Self { - Self::new() - } -} diff --git a/src/client/input.rs b/src/client/input.rs deleted file mode 100644 index 6dd4aba..0000000 --- a/src/client/input.rs +++ /dev/null @@ -1,127 +0,0 @@ -use super::channel::CliChannel; - -pub enum InputEvent { - Message(String), - Command(InputCommand), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum InputCommand { - Exit, - Clear, - New(Option), - Sessions, - Use(String), - Rename(String), - Archive, - Delete, -} - -pub struct InputHandler { - channel: CliChannel, -} - -impl InputHandler { - pub fn new() -> Self { - Self { - channel: CliChannel::new(), - } - } - - pub async fn read_input(&mut self, prompt: &str) -> Result, InputError> { - match self.channel.read_line(prompt).await { - Ok(Some(line)) => { - if line.trim().is_empty() { - return Ok(None); - } - - if let Some(cmd) = self.handle_special_commands(&line) { - return Ok(Some(InputEvent::Command(cmd))); - } - - Ok(Some(InputEvent::Message(line))) - } - Ok(None) => Ok(None), - Err(e) => Err(InputError::IoError(e)), - } - } - - pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> { - self.channel.write_line(content).await.map_err(InputError::IoError) - } - - pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> { - self.channel.write_response(content).await.map_err(InputError::IoError) - } - - fn handle_special_commands(&self, line: &str) -> Option { - let trimmed = line.trim(); - let mut parts = trimmed.splitn(2, char::is_whitespace); - let command = parts.next()?; - let arg = parts.next().map(str::trim).filter(|value| !value.is_empty()); - - match command { - "/quit" | "/exit" | "/q" => Some(InputCommand::Exit), - "/clear" => Some(InputCommand::Clear), - "/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))), - "/sessions" => Some(InputCommand::Sessions), - "/use" => arg.map(|value| InputCommand::Use(value.to_string())), - "/rename" => arg.map(|value| InputCommand::Rename(value.to_string())), - "/archive" => Some(InputCommand::Archive), - "/delete" => Some(InputCommand::Delete), - _ => None, - } - } -} - -impl Default for InputHandler { - fn default() -> Self { - Self::new() - } -} - -#[derive(Debug)] -pub enum InputError { - IoError(std::io::Error), -} - -impl std::fmt::Display for InputError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - InputError::IoError(e) => write!(f, "IO error: {}", e), - } - } -} - -impl std::error::Error for InputError {} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_special_command_parsing() { - let handler = InputHandler::new(); - - assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit)); - assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear)); - assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None))); - assert_eq!( - handler.handle_special_commands("/new planning"), - Some(InputCommand::New(Some("planning".to_string()))) - ); - assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions)); - assert_eq!( - handler.handle_special_commands("/use abc123"), - Some(InputCommand::Use("abc123".to_string())) - ); - assert_eq!( - handler.handle_special_commands("/rename project alpha"), - Some(InputCommand::Rename("project alpha".to_string())) - ); - assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive)); - assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete)); - assert_eq!(handler.handle_special_commands("/unknown"), None); - assert_eq!(handler.handle_special_commands("/use"), None); - } -} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index d090602..f429482 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -5,12 +5,11 @@ use std::sync::Arc; use axum::{routing, Router}; use tokio::net::TcpListener; -use crate::bus::{ControlInbound, ControlMessage, ControlOutbound, OutboundDispatcher}; +use crate::bus::{ControlMessage, OutboundDispatcher}; use crate::channels::{ChannelManager, CliChatChannel}; use crate::channels::base::{Channel, ChannelError}; use crate::config::Config; use crate::logging; -use crate::protocol::SessionSummary; use crate::session::SessionManager; pub struct GatewayState { @@ -94,6 +93,7 @@ impl GatewayState { &inbound.channel, &inbound.sender_id, &inbound.chat_id, + inbound.dialog_id.as_deref(), &inbound.content, inbound.media, ).await { @@ -138,59 +138,52 @@ impl GatewayState { session_manager: &SessionManager, msg: ControlMessage, ) { + use crate::session::{SessionCommand::*, SessionEvent}; + let reply_tx = msg.reply_tx; - let result = match msg.op { - ControlInbound::CreateSession { title } => { - session_manager.create_cli_session(title.as_deref()) - .map(|record| ControlOutbound::SessionCreated { - session_id: record.id, - title: record.title, - }) + let result: Result = match msg.op { + CreateDialog { channel, chat_id, title } => { + session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await + .map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title }) + .map_err(|e| ChannelError::Other(e.to_string())) } - ControlInbound::ListSessions { include_archived } => { - session_manager.list_cli_sessions(include_archived) - .map(|records| ControlOutbound::SessionList { - sessions: records.into_iter().map(|r| SessionSummary { - session_id: r.id, - title: r.title, - channel_name: r.channel_name, - chat_id: r.chat_id, - message_count: r.message_count, - last_active_at: r.last_active_at, - archived_at: r.archived_at, - }).collect() - }) + ListDialogs { channel, chat_id, include_archived } => { + session_manager.list_dialogs(&channel, &chat_id, include_archived).await + .map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id }) + .map_err(|e| ChannelError::Other(e.to_string())) } - ControlInbound::LoadSession { session_id } => { - session_manager.get_session_record(&session_id) - .map(|opt| opt.map(|r| ControlOutbound::SessionLoaded { - session_id: r.id, - title: r.title, - message_count: r.message_count, - }).unwrap_or_else(|| ControlOutbound::Error { - code: "SESSION_NOT_FOUND".to_string(), - message: format!("Session not found: {}", session_id), - })) + GetCurrentDialog { channel, chat_id } => { + session_manager.get_current_dialog(&channel, &chat_id).await + .map(|session_id| SessionEvent::CurrentDialog { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())) } - ControlInbound::RenameSession { session_id, title } => { - session_manager.rename_session(&session_id, &title) - .map(|()| ControlOutbound::SessionRenamed { session_id, title }) + SwitchDialog { channel, chat_id, dialog_id } => { + session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await + .map(|session_id| SessionEvent::DialogSwitched { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())) } - ControlInbound::ArchiveSession { session_id } => { - session_manager.archive_session(&session_id) - .map(|()| ControlOutbound::SessionArchived { session_id }) + RenameDialog { session_id, title } => { + session_manager.rename_dialog(&session_id, &title) + .map(|()| SessionEvent::DialogRenamed { session_id, title }) + .map_err(|e| ChannelError::Other(e.to_string())) } - ControlInbound::DeleteSession { session_id } => { - session_manager.delete_session(&session_id) - .map(|()| ControlOutbound::SessionDeleted { session_id }) + ArchiveDialog { session_id } => { + session_manager.archive_dialog(&session_id) + .map(|()| SessionEvent::DialogArchived { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())) } - ControlInbound::ClearHistory { session_id } => { - session_manager.clear_session_messages(&session_id) - .map(|()| ControlOutbound::HistoryCleared { session_id }) + DeleteDialog { session_id } => { + session_manager.delete_dialog(&session_id) + .map(|()| SessionEvent::DialogDeleted { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())) + } + ClearHistory { session_id } => { + session_manager.clear_dialog_history(&session_id) + .map(|()| SessionEvent::HistoryCleared { session_id }) + .map_err(|e| ChannelError::Other(e.to_string())) } }; - let result = result.map_err(|e| ChannelError::Other(e.to_string())); let _ = reply_tx.send(result).await; } } diff --git a/src/session/commands.rs b/src/session/commands.rs index 29ca0e4..d7d9e81 100644 --- a/src/session/commands.rs +++ b/src/session/commands.rs @@ -1,11 +1,66 @@ +use super::session_id::UnifiedSessionId; + /// Session management commands issued by Channel to SessionManager #[derive(Debug, Clone)] pub enum SessionCommand { - CreateSession { title: Option }, - ListSessions { include_archived: bool }, - LoadSession { session_id: String }, - RenameSession { session_id: String, title: String }, - ArchiveSession { session_id: String }, - DeleteSession { session_id: String }, - ClearHistory { session_id: String }, + /// Create a new dialog in the given chat + CreateDialog { + channel: String, + chat_id: String, + title: Option, + }, + /// List all dialogs in a chat + ListDialogs { + channel: String, + chat_id: String, + include_archived: bool, + }, + /// Switch to a specific dialog (set as current) + SwitchDialog { + channel: String, + chat_id: String, + dialog_id: String, + }, + /// Get the current dialog for a chat + GetCurrentDialog { + channel: String, + chat_id: String, + }, + /// Rename a dialog + RenameDialog { + session_id: UnifiedSessionId, + title: String, + }, + /// Archive a dialog + ArchiveDialog { + session_id: UnifiedSessionId, + }, + /// Delete a dialog + DeleteDialog { + session_id: UnifiedSessionId, + }, + /// Clear dialog history + ClearHistory { + session_id: UnifiedSessionId, + }, +} + +impl SessionCommand { + /// Create a CreateDialog command + pub fn create_dialog(channel: impl Into, chat_id: impl Into, title: Option) -> Self { + Self::CreateDialog { + channel: channel.into(), + chat_id: chat_id.into(), + title, + } + } + + /// Create a ListDialogs command + pub fn list_dialogs(channel: impl Into, chat_id: impl Into, include_archived: bool) -> Self { + Self::ListDialogs { + channel: channel.into(), + chat_id: chat_id.into(), + include_archived, + } + } } diff --git a/src/session/events.rs b/src/session/events.rs index dff29e6..a086f52 100644 --- a/src/session/events.rs +++ b/src/session/events.rs @@ -1,14 +1,57 @@ -use crate::protocol::SessionSummary; +use super::session_id::UnifiedSessionId; + +/// Dialog information returned by SessionManager +#[derive(Debug, Clone)] +pub struct DialogInfo { + pub session_id: UnifiedSessionId, + pub title: String, + pub created_at: i64, + pub last_active_at: i64, + pub message_count: i64, + pub archived_at: Option, +} /// Session events emitted by SessionManager to Channel #[derive(Debug, Clone)] pub enum SessionEvent { - SessionCreated { session_id: String, title: String }, - SessionList { sessions: Vec }, - SessionLoaded { session_id: String, title: String, message_count: i64 }, - SessionRenamed { session_id: String, title: String }, - SessionArchived { session_id: String }, - SessionDeleted { session_id: String }, - HistoryCleared { session_id: String }, - Error { code: String, message: String }, + /// A new dialog was created + DialogCreated { + session_id: UnifiedSessionId, + title: String, + }, + /// List of dialogs returned + DialogList { + dialogs: Vec, + current_dialog_id: Option, + }, + /// Current dialog info returned + CurrentDialog { + session_id: Option, + }, + /// Dialog switched successfully + DialogSwitched { + session_id: UnifiedSessionId, + }, + /// Dialog renamed + DialogRenamed { + session_id: UnifiedSessionId, + title: String, + }, + /// Dialog archived + DialogArchived { + session_id: UnifiedSessionId, + }, + /// Dialog deleted + DialogDeleted { + session_id: UnifiedSessionId, + }, + /// Dialog history cleared + HistoryCleared { + session_id: UnifiedSessionId, + }, + /// Error occurred + Error { + code: String, + message: String, + }, } diff --git a/src/session/mod.rs b/src/session/mod.rs index b02a4ef..1871703 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,9 +2,10 @@ pub mod error; pub mod commands; pub mod events; pub mod session; +pub mod session_id; pub use error::SessionError; pub use commands::SessionCommand; -pub use events::SessionEvent; - +pub use events::{SessionEvent, DialogInfo}; pub use session::{Session, SessionManager}; +pub use session_id::UnifiedSessionId; diff --git a/src/session/session.rs b/src/session/session.rs index d76b81c..e0ed89c 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -8,19 +8,24 @@ use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::protocol::WsOutbound; use crate::providers::{create_provider, LLMProvider}; -use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; +use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; +use crate::session::events::DialogInfo; +use crate::storage::{SessionRecord, SessionStore}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, ToolRegistry, WebFetchTool, }; -/// Session 按 channel 隔离,每个 channel 一个 Session -/// History 按 chat_id 隔离,由 Session 统一管理 +/// Generate a short ID (8 characters) from a UUID +fn short_id() -> String { + Uuid::new_v4().to_string()[..8].to_string() +} + +/// Session = 一个 dialog +/// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history pub struct Session { - pub id: Uuid, - pub channel_name: String, - /// 按 chat_id 路由到不同会话历史,支持多用户多会话 - chat_histories: HashMap>, + pub id: UnifiedSessionId, + messages: Vec, pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, provider: Arc, @@ -31,7 +36,7 @@ pub struct Session { impl Session { pub async fn new( - channel_name: String, + id: UnifiedSessionId, provider_config: LLMProviderConfig, user_tx: mpsc::Sender, tools: Arc, @@ -42,9 +47,8 @@ impl Session { let provider: Arc = Arc::from(provider_box); Ok(Self { - id: Uuid::new_v4(), - channel_name, - chat_histories: HashMap::new(), + id, + messages: Vec::new(), user_tx, provider_config: provider_config.clone(), provider: provider.clone(), @@ -54,93 +58,83 @@ impl Session { }) } - pub fn persistent_session_id(&self, chat_id: &str) -> String { - persistent_session_id(&self.channel_name, chat_id) + /// 获取持久化 session ID + pub fn persistent_session_id(&self) -> String { + self.id.to_string() } - pub fn ensure_persistent_session(&self, chat_id: &str) -> Result { + /// 确保存储中有此 session + pub fn ensure_persistent_session(&self) -> Result { self.store - .ensure_channel_session(&self.channel_name, chat_id) + .ensure_channel_session(&self.id.channel, &self.id.chat_id, &self.id.dialog_id) .map_err(|err| AgentError::Other(format!("session persistence error: {}", err))) } - pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { - if self.chat_histories.contains_key(chat_id) { + /// 加载历史消息到内存 + pub fn load_history(&mut self) -> Result<(), AgentError> { + if !self.messages.is_empty() { return Ok(()); } - - let history = self - .store - .load_messages(&self.persistent_session_id(chat_id)) + let history = self.store + .load_messages(&self.persistent_session_id()) .map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?; - self.chat_histories.insert(chat_id.to_string(), history); + self.messages = history; Ok(()) } - /// 获取或创建指定 chat_id 的会话历史 - pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec { - self.chat_histories - .entry(chat_id.to_string()) - .or_insert_with(Vec::new) + /// 添加消息到历史 + pub fn add_message(&mut self, message: ChatMessage) { + self.messages.push(message); } - /// 获取指定 chat_id 的会话历史(不创建) - pub fn get_history(&self, chat_id: &str) -> Option<&Vec> { - self.chat_histories.get(chat_id) + /// 获取消息历史 + pub fn get_history(&self) -> &[ChatMessage] { + &self.messages } - /// 使用完整消息追加到历史 - pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) { - let history = self.get_or_create_history(chat_id); - history.push(message); - } - - pub fn remove_history(&mut self, chat_id: &str) { - self.chat_histories.remove(chat_id); - } - - pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> { - if let Some(history) = self.chat_histories.get_mut(chat_id) { - let len = history.len(); - history.clear(); - #[cfg(debug_assertions)] - tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared"); - } - + /// 清除历史消息 + pub fn clear_history(&mut self) -> Result<(), AgentError> { + let len = self.messages.len(); + self.messages.clear(); + #[cfg(debug_assertions)] + tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared"); self.store - .clear_messages(&self.persistent_session_id(chat_id)) - .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err))) + .clear_messages(&self.persistent_session_id()) + .map_err(|err| AgentError::Other(format!("clear history error: {}", err))) } - pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> { - if let Some(history) = self.chat_histories.get_mut(chat_id) { - let len = history.len(); - history.clear(); - #[cfg(debug_assertions)] - tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory"); - } - + /// 重置对话上下文 + pub fn reset_context(&mut self) -> Result<(), AgentError> { + let len = self.messages.len(); + self.messages.clear(); + #[cfg(debug_assertions)] + tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory"); self.store - .reset_session(&self.persistent_session_id(chat_id)) - .map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err))) + .reset_session(&self.persistent_session_id()) + .map_err(|err| AgentError::Other(format!("reset context error: {}", err))) } - /// 将消息写入内存与持久化层 - pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> { - let session_id = self.persistent_session_id(chat_id); + /// Archive 此 session + pub fn archive(&self) -> Result<(), AgentError> { self.store - .append_message(&session_id, &message) - .map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?; - self.add_message(chat_id, message); - Ok(()) + .archive_session(&self.persistent_session_id()) + .map_err(|err| AgentError::Other(format!("archive session error: {}", err))) } - pub fn append_persisted_messages(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError> + /// 持久化消息 + pub fn append_message(&self, message: &ChatMessage) -> Result<(), AgentError> { + self.store + .append_message(&self.persistent_session_id(), message) + .map_err(|err| AgentError::Other(format!("append message error: {}", err))) + } + + /// 持久化多条消息 + pub fn append_messages(&self, messages: I) -> Result<(), AgentError> where I: IntoIterator, { for message in messages { - self.append_persisted_message(chat_id, message)?; + self.append_message(&message)?; } Ok(()) } @@ -153,23 +147,6 @@ impl Session { } } - /// 清除所有历史 - pub fn clear_all_history(&mut self) -> Result<(), AgentError> { - let chat_ids: Vec = self.chat_histories.keys().cloned().collect(); - let total: usize = self.chat_histories.values().map(|h| h.len()).sum(); - self.chat_histories.clear(); - #[cfg(debug_assertions)] - tracing::debug!(previous_total = total, "All chat histories cleared"); - - for chat_id in chat_ids { - self.store - .clear_messages(&self.persistent_session_id(&chat_id)) - .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?; - } - - Ok(()) - } - pub async fn send(&self, msg: WsOutbound) { let _ = self.user_tx.send(msg).await; } @@ -204,6 +181,7 @@ pub struct SessionManager { } struct SessionManagerInner { + /// Sessions keyed by UnifiedSessionId.to_string() sessions: HashMap>>, session_timestamps: HashMap, session_ttl: Duration, @@ -238,15 +216,19 @@ fn parse_in_chat_command(content: &str) -> Option { } } +/// Handle in-chat commands like /reset +/// Returns Some(new_dialog_id) if FreshConversation was triggered pub(crate) fn handle_in_chat_command( session: &mut Session, - chat_id: &str, content: &str, ) -> Result, AgentError> { match parse_in_chat_command(content) { Some(InChatCommand::FreshConversation) => { - session.reset_chat_context(chat_id)?; - Ok(Some("Started a fresh conversation.".to_string())) + // Archive the current session + session.archive()?; + + // Return new dialog_id to be created + Ok(Some(short_id())) } None => Ok(None), } @@ -327,73 +309,246 @@ impl SessionManager { .map_err(|err| AgentError::Other(format!("load messages error: {}", err))) } - /// 确保 session 存在且未超时,超时则重建 - pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { - let mut inner = self.inner.lock().await; + // ========================================================================= + // Dialog management methods (UnifiedSessionId based) + // ========================================================================= - let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) { - let elapsed = last_active.elapsed(); - if elapsed > inner.session_ttl { - tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating"); - true - } else { - false - } - } else { - #[cfg(debug_assertions)] - tracing::debug!(channel = %channel_name, "Creating new session"); - true - }; + /// Create a new session (dialog) and return (session_id, title) + pub async fn create_session( + &self, + channel: &str, + chat_id: &str, + title: Option<&str>, + ) -> Result<(UnifiedSessionId, String), AgentError> { + let dialog_id = short_id(); + let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id); + let session_id_str = unified_id.to_string(); - if should_recreate { - // 移除旧 session - inner.sessions.remove(channel_name); + let title = title + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) + .unwrap_or_else(|| format!("Dialog {}", &dialog_id)); - // 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS) + // Ensure storage record exists + self.store + .ensure_channel_session(channel, chat_id, &dialog_id) + .map_err(|err| AgentError::Other(format!("create session error: {}", err)))?; + + // Create session instance + let (user_tx, _rx) = mpsc::channel::(100); + let session = Session::new( + unified_id.clone(), + self.provider_config.clone(), + user_tx, + self.tools.clone(), + self.store.clone(), + ) + .await?; + + let arc = Arc::new(Mutex::new(session)); + let inner = &mut *self.inner.lock().await; + inner.sessions.insert(session_id_str.clone(), arc.clone()); + inner.session_timestamps.insert(session_id_str, Instant::now()); + + Ok((unified_id, title)) + } + + /// Get or create a session by UnifiedSessionId + pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result>, AgentError> { + let session_id_str = unified_id.to_string(); + let inner = &mut *self.inner.lock().await; + + // Check if session exists + if let Some(session) = inner.sessions.get(&session_id_str) { + // Update timestamp + inner.session_timestamps.insert(session_id_str, Instant::now()); + return Ok(session.clone()); + } + + // Check if session exists in storage + if let Ok(Some(_)) = self.store.get_session(&session_id_str) { + // Create session instance from storage let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( - channel_name.to_string(), + unified_id.clone(), self.provider_config.clone(), user_tx, self.tools.clone(), self.store.clone(), ) .await?; - let arc = Arc::new(Mutex::new(session)); - inner.sessions.insert(channel_name.to_string(), arc.clone()); - inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); + let arc = Arc::new(Mutex::new(session)); + inner.sessions.insert(session_id_str.clone(), arc.clone()); + inner.session_timestamps.insert(session_id_str, Instant::now()); + return Ok(arc); } - Ok(()) + // Session doesn't exist - create new directly + let (user_tx, _rx) = mpsc::channel::(100); + let session = Session::new( + unified_id.clone(), + self.provider_config.clone(), + user_tx, + self.tools.clone(), + self.store.clone(), + ) + .await?; + + let arc = Arc::new(Mutex::new(session)); + inner.sessions.insert(session_id_str.clone(), arc.clone()); + inner.session_timestamps.insert(session_id_str, Instant::now()); + Ok(arc) } - /// 获取 session(不检查超时) - pub async fn get(&self, channel_name: &str) -> Option>> { - let inner = self.inner.lock().await; - inner.sessions.get(channel_name).cloned() + /// List all dialogs for a chat scope (internal) + async fn list_dialogs_for_chat( + &self, + channel: &str, + chat_id: &str, + include_archived: bool, + ) -> Result, AgentError> { + let records = self.store + .list_sessions(channel, include_archived) + .map_err(|err| AgentError::Other(format!("list dialogs error: {}", err)))?; + + let dialogs: Vec = records + .into_iter() + .filter(|r| { + // Filter to only dialogs for this chat_id + if let Some(sid) = UnifiedSessionId::parse(&r.id) { + sid.chat_id == chat_id + } else { + false + } + }) + .map(|r| { + let sid = UnifiedSessionId::parse(&r.id).unwrap(); + DialogInfo { + session_id: sid, + title: r.title, + created_at: r.created_at, + last_active_at: r.last_active_at, + message_count: r.message_count, + archived_at: r.archived_at, + } + }) + .collect(); + + Ok(dialogs) } - /// 更新最后活跃时间 - pub async fn touch(&self, channel_name: &str) { - let mut inner = self.inner.lock().await; - inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); + /// Get the most recent dialog for a chat scope (from storage) + pub async fn get_most_recent_dialog( + &self, + channel: &str, + chat_id: &str, + ) -> Result, AgentError> { + let records = self.store + .list_sessions(channel, false) + .map_err(|err| AgentError::Other(format!("get recent dialog error: {}", err)))?; + + let most_recent = records + .into_iter() + .filter(|r| { + if let Some(sid) = UnifiedSessionId::parse(&r.id) { + sid.chat_id == chat_id + } else { + false + } + }) + .max_by_key(|r| r.last_active_at); + + Ok(most_recent.map(|r| UnifiedSessionId::parse(&r.id).unwrap())) + } + + /// Rename a dialog + pub fn rename_dialog(&self, session_id: &UnifiedSessionId, title: &str) -> Result<(), AgentError> { + self.store + .rename_session(&session_id.to_string(), title) + .map_err(|err| AgentError::Other(format!("rename dialog error: {}", err))) + } + + /// Create a new dialog (wrapper for create_session to match gateway interface) + pub async fn create_dialog( + &self, + channel: &str, + chat_id: &str, + title: Option<&str>, + ) -> Result<(UnifiedSessionId, String), AgentError> { + self.create_session(channel, chat_id, title).await + } + + /// Get current dialog for a chat (wrapper for get_most_recent_dialog) + pub async fn get_current_dialog( + &self, + channel: &str, + chat_id: &str, + ) -> Result, AgentError> { + self.get_most_recent_dialog(channel, chat_id).await + } + + /// Switch to a different dialog - not applicable in new architecture + /// Each Session IS a dialog, so switching is just loading that session + pub async fn switch_dialog( + &self, + _channel: &str, + _chat_id: &str, + _dialog_id: &str, + ) -> Result { + Err(AgentError::Other("switch_dialog not applicable in new architecture".to_string())) + } + + /// List all dialogs for a chat scope (returns tuple for gateway compatibility) + pub async fn list_dialogs( + &self, + channel: &str, + chat_id: &str, + include_archived: bool, + ) -> Result<(Vec, Option), AgentError> { + let dialogs = self.list_dialogs_for_chat(channel, chat_id, include_archived).await?; + let current = self.get_most_recent_dialog(channel, chat_id).await?; + Ok((dialogs, current.map(|id| id.to_string()))) + } + + /// Archive a dialog + pub fn archive_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { + self.store + .archive_session(&session_id.to_string()) + .map_err(|err| AgentError::Other(format!("archive dialog error: {}", err))) + } + + /// Delete a dialog + pub fn delete_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { + self.store + .delete_session(&session_id.to_string()) + .map_err(|err| AgentError::Other(format!("delete dialog error: {}", err))) + } + + /// Clear dialog history + pub fn clear_dialog_history(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { + self.store + .clear_messages(&session_id.to_string()) + .map_err(|err| AgentError::Other(format!("clear dialog history error: {}", err))) } /// 处理消息:路由到对应 session 的 agent pub async fn handle_message( &self, - channel_name: &str, + channel: &str, _sender_id: &str, chat_id: &str, + dialog_id: Option<&str>, content: &str, media: Vec, ) -> Result { #[cfg(debug_assertions)] { tracing::debug!( - channel = %channel_name, + channel = %channel, chat_id = %chat_id, + dialog_id = ?dialog_id, content_len = content.len(), media_count = %media.len(), "Routing message to agent" @@ -403,28 +558,41 @@ impl SessionManager { } } - // 确保 session 存在(可能需要重建) - self.ensure_session(channel_name).await?; + // 确定 dialog_id + let dialog_id = dialog_id.unwrap_or(DEFAULT_DIALOG_ID); - // 更新活跃时间 - self.touch(channel_name).await; - - // 获取 session - let session = self - .get(channel_name) - .await - .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; + // 获取或创建 session + let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id); + let session = self.get_or_create_session(&unified_id).await?; // 处理消息 - let response = { + let response: String = { let mut session_guard = session.lock().await; - session_guard.ensure_persistent_session(chat_id)?; - session_guard.ensure_chat_loaded(chat_id)?; + // 检查是否是 FreshConversation 命令 + let fresh_conversation_result = handle_in_chat_command(&mut session_guard, content)?; - if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? { - return Ok(command_response); - } + let (session_to_use, fresh_started) = match fresh_conversation_result { + Some(_new_dialog_id) => { + // Archive the old session + session_guard.archive()?; + drop(session_guard); + + // Create new session for the new dialog + // This creates and registers the session + let (new_unified_id, _title) = self.create_session(channel, chat_id, None).await?; + // Get the newly created session + let new_session = self.get_or_create_session(&new_unified_id).await?; + (new_session, true) + } + None => (Arc::clone(&session), false), + }; + + // 使用选定的 session 进行处理 + let mut session_guard = session_to_use.lock().await; + + // 确保 session 持久化记录存在 + session_guard.ensure_persistent_session()?; // 添加用户消息到历史 let media_refs: Vec = media.iter().map(|m| m.path.clone()).collect(); @@ -432,13 +600,16 @@ impl SessionManager { if !media_refs.is_empty() { tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); } - let user_message = session_guard.create_user_message(content, media_refs); - session_guard.append_persisted_message(chat_id, user_message)?; - // 获取完整历史 - let history = session_guard.get_or_create_history(chat_id).clone(); + let user_message = session_guard.create_user_message(content, media_refs); + session_guard.add_message(user_message.clone()); + session_guard.append_message(&user_message)?; + + // 加载历史 + session_guard.load_history()?; // 压缩历史(如果需要) + let history = session_guard.get_history().to_vec(); let history = session_guard.compressor .compress_if_needed(history) .await?; @@ -447,29 +618,35 @@ impl SessionManager { let agent = session_guard.create_agent()?; let result = agent.process(history).await?; - // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 - session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + // 持久化 assistant 消息 + for msg in &result.emitted_messages { + session_guard.append_message(msg)?; + } - result.final_response + // 如果是 FreshConversation 命令,返回命令消息 + if fresh_started { + "Starting a fresh conversation...".to_string() + } else { + result.final_response.content + } }; #[cfg(debug_assertions)] tracing::debug!( - channel = %channel_name, + channel = %channel, chat_id = %chat_id, - response_len = response.content.len(), + response_len = %response.len(), "Agent response received" ); - Ok(response.content) + Ok(response) } /// 清除指定 session 的所有历史 - pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> { - if let Some(session) = self.get(channel_name).await { - let mut session_guard = session.lock().await; - session_guard.clear_all_history()?; - } + pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { + let session = self.get_or_create_session(unified_id).await?; + let mut session_guard = session.lock().await; + session_guard.clear_history()?; Ok(()) } } @@ -503,44 +680,4 @@ mod tests { assert_eq!(parse_in_chat_command("/new planning"), None); assert_eq!(parse_in_chat_command("please /reset"), None); } - - #[tokio::test] - async fn test_handle_in_chat_command_resets_active_history_only() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let (user_tx, _user_rx) = mpsc::channel(4); - let tools = Arc::new(default_tools()); - let mut session = Session::new( - "feishu".to_string(), - test_provider_config(), - user_tx, - tools, - store.clone(), - ) - .await - .unwrap(); - - session.ensure_persistent_session("chat-1").unwrap(); - session.ensure_chat_loaded("chat-1").unwrap(); - session - .append_persisted_message("chat-1", ChatMessage::user("hello")) - .unwrap(); - - let response = handle_in_chat_command(&mut session, "chat-1", "/reset") - .unwrap() - .unwrap(); - - assert_eq!(response, "Started a fresh conversation."); - assert!(session.get_history("chat-1").unwrap().is_empty()); - assert!(store - .load_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .is_empty()); - assert_eq!( - store - .load_all_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .len(), - 1, - ); - } } diff --git a/src/session/session_id.rs b/src/session/session_id.rs new file mode 100644 index 0000000..f88d22d --- /dev/null +++ b/src/session/session_id.rs @@ -0,0 +1,120 @@ +/// Unified session identifier composed of channel, chat_id, and dialog_id +/// +/// Format: `channel:chat_id:dialog_id` +/// +/// Examples: +/// - CLI: `"cli_chat:sid_abc123:dialog_xyz"` +/// - Feishu: `"feishu:oc_123456:dialog_xyz"` +/// +/// For simple cases where only one dialog exists per chat: +/// - `dialog_id` defaults to `"default"` + +use serde::{Deserialize, Serialize}; + +pub const DEFAULT_DIALOG_ID: &str = "default"; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct UnifiedSessionId { + pub channel: String, + pub chat_id: String, + pub dialog_id: String, +} + +impl UnifiedSessionId { + /// Create a new UnifiedSessionId + pub fn new(channel: impl Into, chat_id: impl Into, dialog_id: impl Into) -> Self { + Self { + channel: channel.into(), + chat_id: chat_id.into(), + dialog_id: dialog_id.into(), + } + } + + /// Create with default dialog_id ("default") + pub fn with_default_dialog(channel: impl Into, chat_id: impl Into) -> Self { + Self { + channel: channel.into(), + chat_id: chat_id.into(), + dialog_id: DEFAULT_DIALOG_ID.to_string(), + } + } + + /// Parse from string format "channel:chat_id:dialog_id" + pub fn parse(s: &str) -> Option { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != 3 { + return None; + } + Some(Self { + channel: parts[0].to_string(), + chat_id: parts[1].to_string(), + dialog_id: parts[2].to_string(), + }) + } + + /// Convert to string format "channel:chat_id:dialog_id" + pub fn to_string(&self) -> String { + format!("{}:{}:{}", self.channel, self.chat_id, self.dialog_id) + } + + /// Get the session key without dialog_id (channel:chat_id) + /// This is used to group all dialogs within a chat + pub fn chat_scope(&self) -> String { + format!("{}:{}", self.channel, self.chat_id) + } +} + +impl std::fmt::Display for UnifiedSessionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_string()) + } +} + +// Note: No Deref implementation to avoid confusion between String and UnifiedSessionId + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + let id = UnifiedSessionId::new("cli_chat", "sid123", "dialog456"); + assert_eq!(id.channel, "cli_chat"); + assert_eq!(id.chat_id, "sid123"); + assert_eq!(id.dialog_id, "dialog456"); + } + + #[test] + fn test_with_default_dialog() { + let id = UnifiedSessionId::with_default_dialog("feishu", "oc123"); + assert_eq!(id.channel, "feishu"); + assert_eq!(id.chat_id, "oc123"); + assert_eq!(id.dialog_id, "default"); + } + + #[test] + fn test_parse() { + let id = UnifiedSessionId::parse("cli_chat:sid123:dialog456").unwrap(); + assert_eq!(id.channel, "cli_chat"); + assert_eq!(id.chat_id, "sid123"); + assert_eq!(id.dialog_id, "dialog456"); + } + + #[test] + fn test_parse_invalid() { + assert!(UnifiedSessionId::parse("invalid").is_none()); + assert!(UnifiedSessionId::parse("only:two").is_none()); + } + + #[test] + fn test_to_string() { + let id = UnifiedSessionId::new("feishu", "oc123", "dialog789"); + assert_eq!(id.to_string(), "feishu:oc123:dialog789"); + } + + #[test] + fn test_chat_scope() { + let id = UnifiedSessionId::new("feishu", "oc123", "dialog789"); + assert_eq!(id.chat_scope(), "feishu:oc123"); + } +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 2fc86d5..4002220 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -140,8 +140,9 @@ impl SessionStore { &self, channel_name: &str, chat_id: &str, + dialog_id: &str, ) -> Result { - let session_id = persistent_session_id(channel_name, chat_id); + let session_id = persistent_session_id(channel_name, chat_id, dialog_id); if let Some(record) = self.get_session(&session_id)? { return Ok(record); } @@ -343,12 +344,8 @@ impl SessionStore { } } -pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String { - if channel_name == "cli" || channel_name == "cli_chat" { - chat_id.to_string() - } else { - format!("{}:{}", channel_name, chat_id) - } +pub fn persistent_session_id(channel_name: &str, chat_id: &str, dialog_id: &str) -> String { + format!("{}:{}:{}", channel_name, chat_id, dialog_id) } fn default_session_db_path() -> Result { @@ -474,9 +471,9 @@ mod tests { #[test] fn test_persistent_session_id_for_cli_and_channel() { - assert_eq!(persistent_session_id("cli", "abc"), "abc"); - assert_eq!(persistent_session_id("cli_chat", "abc"), "abc"); - assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc"); + assert_eq!(persistent_session_id("cli", "abc", "default"), "cli:abc:default"); + assert_eq!(persistent_session_id("cli_chat", "abc", "default"), "cli_chat:abc:default"); + assert_eq!(persistent_session_id("feishu", "abc", "default"), "feishu:abc:default"); } #[test] @@ -535,8 +532,8 @@ mod tests { fn test_ensure_channel_session_is_stable() { let store = SessionStore::in_memory().unwrap(); - let first = store.ensure_channel_session("feishu", "chat-1").unwrap(); - let second = store.ensure_channel_session("feishu", "chat-1").unwrap(); + let first = store.ensure_channel_session("feishu", "chat-1", "default").unwrap(); + let second = store.ensure_channel_session("feishu", "chat-1", "default").unwrap(); assert_eq!(first.id, second.id); assert_eq!(first.chat_id, "chat-1");