//! Interactive shell session management. //! //! Provides `ShellSessionManager`, an independent service that keeps child //! processes alive between tool calls so the Agent can interact with //! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`). //! //! The manager is created at the gateway/bootstrap layer and injected into //! `BashTool` via `Arc`. It does NOT start background tasks on its own; //! cleanup is driven externally via `cleanup_expired()` or `shutdown()`. use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::process::{Child, ChildStdin}; use tokio::sync::{Mutex, mpsc}; use tokio::time::Instant; use uuid::Uuid; const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes const OUTPUT_WAIT_MS: u64 = 2000; /// A single interactive shell session backed by a live child process. struct ShellSession { child: Child, stdin_writer: Option, stdout_buf: Arc>, stderr_buf: Arc>, /// Background task handle that drains the output channel into buffers. _drain_task: tokio::task::JoinHandle<()>, created_at: Instant, } /// Independent service for managing interactive shell sessions. /// /// Thread-safe — designed to be shared via `Arc`. pub struct ShellSessionManager { sessions: Mutex>, } impl ShellSessionManager { /// Create a new, empty session manager. pub fn new() -> Self { Self { sessions: Mutex::new(HashMap::new()), } } /// Save a child process as an interactive session. /// /// The caller provides: /// - `child`: the spawned child process (with piped stdin/stdout/stderr) /// - `stdin_writer`: the write-half of the piped stdin /// - `rx`: the output channel receiver (produced by `read_stream` tasks) /// - `initial_stdout` / `initial_stderr`: output already collected before /// the session was created /// /// Returns a unique `session_id` that can be used for subsequent /// `send_input` / `get_output` / `close_session` calls. pub async fn save_session( &self, mut child: Child, stdin_writer: ChildStdin, mut rx: mpsc::UnboundedReceiver<(bool, String)>, initial_stdout: String, initial_stderr: String, ) -> String { let session_id = Uuid::new_v4().to_string(); let stdout_buf = Arc::new(Mutex::new(initial_stdout)); let stderr_buf = Arc::new(Mutex::new(initial_stderr)); // Spawn a background task that drains the channel into buffers. let stdout_clone = stdout_buf.clone(); let stderr_clone = stderr_buf.clone(); let drain_task = tokio::spawn(async move { while let Some((is_stderr, chunk)) = rx.recv().await { if is_stderr { stderr_clone.lock().await.push_str(&chunk); } else { stdout_clone.lock().await.push_str(&chunk); } } }); // Kill the child's inherited stdin to prevent blocking on close. // The actual stdin writing is done via stdin_writer. let _ = child.stdin.take(); let session = ShellSession { child, stdin_writer: Some(stdin_writer), stdout_buf, stderr_buf, _drain_task: drain_task, created_at: Instant::now(), }; self.sessions .lock() .await .insert(session_id.clone(), session); session_id } /// Send input to a session's stdin and return new output. /// /// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive. /// If the child process exits during the wait, returns the final output. pub async fn send_input(&self, session_id: &str, input: &str) -> Result { let mut sessions = self.sessions.lock().await; let session = sessions .get_mut(session_id) .ok_or_else(|| format!("Session not found: {}", session_id))?; // Write input to stdin if let Some(writer) = &mut session.stdin_writer { let data = if input.ends_with('\n') { input.to_string() } else { format!("{}\n", input) }; writer .write_all(data.as_bytes()) .await .map_err(|e| format!("Failed to write stdin: {}", e))?; writer .flush() .await .map_err(|e| format!("Failed to flush stdin: {}", e))?; } else { return Err("Session stdin is closed".to_string()); } // Record output length before wait let prev_stdout_len = session.stdout_buf.lock().await.len(); let prev_stderr_len = session.stderr_buf.lock().await.len(); // Wait for new output or process exit let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS); loop { tokio::select! { status = session.child.wait() => { // Process exited — collect final output let stdout = session.stdout_buf.lock().await.clone(); let stderr = session.stderr_buf.lock().await.clone(); let code = status.ok().and_then(|s| s.code()); drop(sessions); return Ok(Self::format_output(&stdout, &stderr, code)); } _ = tokio::time::sleep_until(deadline) => { // Timeout — return current output break; } } } let stdout = session.stdout_buf.lock().await.clone(); let stderr = session.stderr_buf.lock().await.clone(); let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect(); let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect(); let mut result = String::new(); if !new_stdout.is_empty() { result.push_str(&new_stdout); } if !new_stderr.trim().is_empty() { if !result.is_empty() { result.push('\n'); } result.push_str("STDERR:\n"); result.push_str(&new_stderr); } if result.is_empty() { result.push_str("(No new output after input. Session still active.)"); } Ok(result) } /// Get the full accumulated output of a session. pub async fn get_output(&self, session_id: &str) -> Result { let sessions = self.sessions.lock().await; let session = sessions .get(session_id) .ok_or_else(|| format!("Session not found: {}", session_id))?; let stdout = session.stdout_buf.lock().await.clone(); let stderr = session.stderr_buf.lock().await.clone(); Ok(Self::format_output(&stdout, &stderr, None)) } /// Close a session: kill the child process and return final output. pub async fn close_session(&self, session_id: &str) -> Result { let mut sessions = self.sessions.lock().await; let mut session = sessions .remove(session_id) .ok_or_else(|| format!("Session not found: {}", session_id))?; drop(sessions); // Release lock before awaiting // Close stdin to signal EOF if let Some(mut writer) = session.stdin_writer.take() { let _ = writer.shutdown().await; } // Kill and wait let _ = session.child.start_kill(); let status = session.child.wait().await.ok(); let code = status.and_then(|s| s.code()); let stdout = session.stdout_buf.lock().await.clone(); let stderr = session.stderr_buf.lock().await.clone(); // Abort drain task session._drain_task.abort(); Ok(Self::format_output(&stdout, &stderr, code)) } /// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`). /// /// Designed to be called periodically from the owner (e.g. gateway). pub async fn cleanup_expired(&self) { let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS); let mut expired_ids = Vec::new(); { let sessions = self.sessions.lock().await; for (id, session) in sessions.iter() { if session.created_at.elapsed() > timeout { expired_ids.push(id.clone()); } } } for id in expired_ids { let _ = self.close_session(&id).await; } } /// Gracefully shut down all active sessions. pub async fn shutdown(&self) { let ids: Vec = { let sessions = self.sessions.lock().await; sessions.keys().cloned().collect() }; for id in ids { let _ = self.close_session(&id).await; } } /// Number of currently active sessions. pub async fn active_count(&self) -> usize { self.sessions.lock().await.len() } fn format_output(stdout: &str, stderr: &str, exit_code: Option) -> String { let mut output = String::new(); if !stdout.is_empty() { output.push_str(stdout); } if !stderr.trim().is_empty() { if !output.is_empty() { output.push('\n'); } output.push_str("STDERR:\n"); output.push_str(stderr); } if let Some(code) = exit_code { output.push_str(&format!("\nExit code: {}", code)); } output } } impl Default for ShellSessionManager { fn default() -> Self { Self::new() } }