- 新增ShellSessionManager管理交互式shell会话,支持进程保持和交互输入 - BashTool集成会话管理,支持session_id和stdin_input参数实现输入回复 - 修改BashTool执行逻辑,检测进程等待输入状态并保存会话状态 - Windows平台新增底层进程等待输入检测实现,辅助判断Shell交互状态 - 工具注册工厂注入ShellSessionManager,保证安全复用会话管理实例 - 增加默认agent prompt中Shell交互终端说明,提示交互流程及输入格式 - 交互式命令输出增加标识和提示,区分正常与等待输入状态 - 实现会话超时自动清理和优雅关闭接口,避免资源泄露 - 单元测试中统一使用BashTool默认构造,适配会话管理新增功能
284 lines
9.6 KiB
Rust
284 lines
9.6 KiB
Rust
//! Interactive shell session management.
|
|
//!
|
|
//! Provides `ShellSessionManager`, an independent service that keeps child
|
|
//! processes alive between tool calls so the Agent can interact with
|
|
//! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`).
|
|
//!
|
|
//! The manager is created at the gateway/bootstrap layer and injected into
|
|
//! `BashTool` via `Arc`. It does NOT start background tasks on its own;
|
|
//! cleanup is driven externally via `cleanup_expired()` or `shutdown()`.
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::process::{Child, ChildStdin};
|
|
use tokio::sync::{Mutex, mpsc};
|
|
use tokio::time::Instant;
|
|
|
|
use uuid::Uuid;
|
|
|
|
const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes
|
|
const OUTPUT_WAIT_MS: u64 = 2000;
|
|
|
|
/// A single interactive shell session backed by a live child process.
|
|
struct ShellSession {
|
|
child: Child,
|
|
stdin_writer: Option<ChildStdin>,
|
|
stdout_buf: Arc<Mutex<String>>,
|
|
stderr_buf: Arc<Mutex<String>>,
|
|
/// Background task handle that drains the output channel into buffers.
|
|
_drain_task: tokio::task::JoinHandle<()>,
|
|
created_at: Instant,
|
|
}
|
|
|
|
/// Independent service for managing interactive shell sessions.
|
|
///
|
|
/// Thread-safe — designed to be shared via `Arc<ShellSessionManager>`.
|
|
pub struct ShellSessionManager {
|
|
sessions: Mutex<HashMap<String, ShellSession>>,
|
|
}
|
|
|
|
impl ShellSessionManager {
|
|
/// Create a new, empty session manager.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
sessions: Mutex::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
/// Save a child process as an interactive session.
|
|
///
|
|
/// The caller provides:
|
|
/// - `child`: the spawned child process (with piped stdin/stdout/stderr)
|
|
/// - `stdin_writer`: the write-half of the piped stdin
|
|
/// - `rx`: the output channel receiver (produced by `read_stream` tasks)
|
|
/// - `initial_stdout` / `initial_stderr`: output already collected before
|
|
/// the session was created
|
|
///
|
|
/// Returns a unique `session_id` that can be used for subsequent
|
|
/// `send_input` / `get_output` / `close_session` calls.
|
|
pub async fn save_session(
|
|
&self,
|
|
mut child: Child,
|
|
stdin_writer: ChildStdin,
|
|
mut rx: mpsc::UnboundedReceiver<(bool, String)>,
|
|
initial_stdout: String,
|
|
initial_stderr: String,
|
|
) -> String {
|
|
let session_id = Uuid::new_v4().to_string();
|
|
let stdout_buf = Arc::new(Mutex::new(initial_stdout));
|
|
let stderr_buf = Arc::new(Mutex::new(initial_stderr));
|
|
|
|
// Spawn a background task that drains the channel into buffers.
|
|
let stdout_clone = stdout_buf.clone();
|
|
let stderr_clone = stderr_buf.clone();
|
|
let drain_task = tokio::spawn(async move {
|
|
while let Some((is_stderr, chunk)) = rx.recv().await {
|
|
if is_stderr {
|
|
stderr_clone.lock().await.push_str(&chunk);
|
|
} else {
|
|
stdout_clone.lock().await.push_str(&chunk);
|
|
}
|
|
}
|
|
});
|
|
|
|
// Kill the child's inherited stdin to prevent blocking on close.
|
|
// The actual stdin writing is done via stdin_writer.
|
|
let _ = child.stdin.take();
|
|
|
|
let session = ShellSession {
|
|
child,
|
|
stdin_writer: Some(stdin_writer),
|
|
stdout_buf,
|
|
stderr_buf,
|
|
_drain_task: drain_task,
|
|
created_at: Instant::now(),
|
|
};
|
|
|
|
self.sessions
|
|
.lock()
|
|
.await
|
|
.insert(session_id.clone(), session);
|
|
session_id
|
|
}
|
|
|
|
/// Send input to a session's stdin and return new output.
|
|
///
|
|
/// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive.
|
|
/// If the child process exits during the wait, returns the final output.
|
|
pub async fn send_input(&self, session_id: &str, input: &str) -> Result<String, String> {
|
|
let mut sessions = self.sessions.lock().await;
|
|
|
|
let session = sessions
|
|
.get_mut(session_id)
|
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
|
|
|
// Write input to stdin
|
|
if let Some(writer) = &mut session.stdin_writer {
|
|
let data = if input.ends_with('\n') {
|
|
input.to_string()
|
|
} else {
|
|
format!("{}\n", input)
|
|
};
|
|
writer
|
|
.write_all(data.as_bytes())
|
|
.await
|
|
.map_err(|e| format!("Failed to write stdin: {}", e))?;
|
|
writer
|
|
.flush()
|
|
.await
|
|
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
|
|
} else {
|
|
return Err("Session stdin is closed".to_string());
|
|
}
|
|
|
|
// Record output length before wait
|
|
let prev_stdout_len = session.stdout_buf.lock().await.len();
|
|
let prev_stderr_len = session.stderr_buf.lock().await.len();
|
|
|
|
// Wait for new output or process exit
|
|
let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS);
|
|
loop {
|
|
tokio::select! {
|
|
status = session.child.wait() => {
|
|
// Process exited — collect final output
|
|
let stdout = session.stdout_buf.lock().await.clone();
|
|
let stderr = session.stderr_buf.lock().await.clone();
|
|
let code = status.ok().and_then(|s| s.code());
|
|
drop(sessions);
|
|
return Ok(Self::format_output(&stdout, &stderr, code));
|
|
}
|
|
_ = tokio::time::sleep_until(deadline) => {
|
|
// Timeout — return current output
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
let stdout = session.stdout_buf.lock().await.clone();
|
|
let stderr = session.stderr_buf.lock().await.clone();
|
|
|
|
let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect();
|
|
let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect();
|
|
|
|
let mut result = String::new();
|
|
if !new_stdout.is_empty() {
|
|
result.push_str(&new_stdout);
|
|
}
|
|
if !new_stderr.trim().is_empty() {
|
|
if !result.is_empty() {
|
|
result.push('\n');
|
|
}
|
|
result.push_str("STDERR:\n");
|
|
result.push_str(&new_stderr);
|
|
}
|
|
if result.is_empty() {
|
|
result.push_str("(No new output after input. Session still active.)");
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Get the full accumulated output of a session.
|
|
pub async fn get_output(&self, session_id: &str) -> Result<String, String> {
|
|
let sessions = self.sessions.lock().await;
|
|
let session = sessions
|
|
.get(session_id)
|
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
|
|
|
let stdout = session.stdout_buf.lock().await.clone();
|
|
let stderr = session.stderr_buf.lock().await.clone();
|
|
Ok(Self::format_output(&stdout, &stderr, None))
|
|
}
|
|
|
|
/// Close a session: kill the child process and return final output.
|
|
pub async fn close_session(&self, session_id: &str) -> Result<String, String> {
|
|
let mut sessions = self.sessions.lock().await;
|
|
let mut session = sessions
|
|
.remove(session_id)
|
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
|
drop(sessions); // Release lock before awaiting
|
|
|
|
// Close stdin to signal EOF
|
|
if let Some(mut writer) = session.stdin_writer.take() {
|
|
let _ = writer.shutdown().await;
|
|
}
|
|
|
|
// Kill and wait
|
|
let _ = session.child.start_kill();
|
|
let status = session.child.wait().await.ok();
|
|
let code = status.and_then(|s| s.code());
|
|
|
|
let stdout = session.stdout_buf.lock().await.clone();
|
|
let stderr = session.stderr_buf.lock().await.clone();
|
|
|
|
// Abort drain task
|
|
session._drain_task.abort();
|
|
|
|
Ok(Self::format_output(&stdout, &stderr, code))
|
|
}
|
|
|
|
/// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`).
|
|
///
|
|
/// Designed to be called periodically from the owner (e.g. gateway).
|
|
pub async fn cleanup_expired(&self) {
|
|
let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
|
|
let mut expired_ids = Vec::new();
|
|
|
|
{
|
|
let sessions = self.sessions.lock().await;
|
|
for (id, session) in sessions.iter() {
|
|
if session.created_at.elapsed() > timeout {
|
|
expired_ids.push(id.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
for id in expired_ids {
|
|
let _ = self.close_session(&id).await;
|
|
}
|
|
}
|
|
|
|
/// Gracefully shut down all active sessions.
|
|
pub async fn shutdown(&self) {
|
|
let ids: Vec<String> = {
|
|
let sessions = self.sessions.lock().await;
|
|
sessions.keys().cloned().collect()
|
|
};
|
|
for id in ids {
|
|
let _ = self.close_session(&id).await;
|
|
}
|
|
}
|
|
|
|
/// Number of currently active sessions.
|
|
pub async fn active_count(&self) -> usize {
|
|
self.sessions.lock().await.len()
|
|
}
|
|
|
|
fn format_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
|
|
let mut output = String::new();
|
|
if !stdout.is_empty() {
|
|
output.push_str(stdout);
|
|
}
|
|
if !stderr.trim().is_empty() {
|
|
if !output.is_empty() {
|
|
output.push('\n');
|
|
}
|
|
output.push_str("STDERR:\n");
|
|
output.push_str(stderr);
|
|
}
|
|
if let Some(code) = exit_code {
|
|
output.push_str(&format!("\nExit code: {}", code));
|
|
}
|
|
output
|
|
}
|
|
}
|
|
|
|
impl Default for ShellSessionManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|