PicoBot/src/tools/shell_session.rs
ooodc 02172b6065 feat(shell): 实现交互式Shell会话管理
- 新增ShellSessionManager管理交互式shell会话,支持进程保持和交互输入
- BashTool集成会话管理,支持session_id和stdin_input参数实现输入回复
- 修改BashTool执行逻辑,检测进程等待输入状态并保存会话状态
- Windows平台新增底层进程等待输入检测实现,辅助判断Shell交互状态
- 工具注册工厂注入ShellSessionManager,保证安全复用会话管理实例
- 增加默认agent prompt中Shell交互终端说明,提示交互流程及输入格式
- 交互式命令输出增加标识和提示,区分正常与等待输入状态
- 实现会话超时自动清理和优雅关闭接口,避免资源泄露
- 单元测试中统一使用BashTool默认构造,适配会话管理新增功能
2026-06-13 09:06:45 +08:00

284 lines
9.6 KiB
Rust

//! Interactive shell session management.
//!
//! Provides `ShellSessionManager`, an independent service that keeps child
//! processes alive between tool calls so the Agent can interact with
//! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`).
//!
//! The manager is created at the gateway/bootstrap layer and injected into
//! `BashTool` via `Arc`. It does NOT start background tasks on its own;
//! cleanup is driven externally via `cleanup_expired()` or `shutdown()`.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::process::{Child, ChildStdin};
use tokio::sync::{Mutex, mpsc};
use tokio::time::Instant;
use uuid::Uuid;
const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes
const OUTPUT_WAIT_MS: u64 = 2000;
/// A single interactive shell session backed by a live child process.
struct ShellSession {
child: Child,
stdin_writer: Option<ChildStdin>,
stdout_buf: Arc<Mutex<String>>,
stderr_buf: Arc<Mutex<String>>,
/// Background task handle that drains the output channel into buffers.
_drain_task: tokio::task::JoinHandle<()>,
created_at: Instant,
}
/// Independent service for managing interactive shell sessions.
///
/// Thread-safe — designed to be shared via `Arc<ShellSessionManager>`.
pub struct ShellSessionManager {
sessions: Mutex<HashMap<String, ShellSession>>,
}
impl ShellSessionManager {
/// Create a new, empty session manager.
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
}
}
/// Save a child process as an interactive session.
///
/// The caller provides:
/// - `child`: the spawned child process (with piped stdin/stdout/stderr)
/// - `stdin_writer`: the write-half of the piped stdin
/// - `rx`: the output channel receiver (produced by `read_stream` tasks)
/// - `initial_stdout` / `initial_stderr`: output already collected before
/// the session was created
///
/// Returns a unique `session_id` that can be used for subsequent
/// `send_input` / `get_output` / `close_session` calls.
pub async fn save_session(
&self,
mut child: Child,
stdin_writer: ChildStdin,
mut rx: mpsc::UnboundedReceiver<(bool, String)>,
initial_stdout: String,
initial_stderr: String,
) -> String {
let session_id = Uuid::new_v4().to_string();
let stdout_buf = Arc::new(Mutex::new(initial_stdout));
let stderr_buf = Arc::new(Mutex::new(initial_stderr));
// Spawn a background task that drains the channel into buffers.
let stdout_clone = stdout_buf.clone();
let stderr_clone = stderr_buf.clone();
let drain_task = tokio::spawn(async move {
while let Some((is_stderr, chunk)) = rx.recv().await {
if is_stderr {
stderr_clone.lock().await.push_str(&chunk);
} else {
stdout_clone.lock().await.push_str(&chunk);
}
}
});
// Kill the child's inherited stdin to prevent blocking on close.
// The actual stdin writing is done via stdin_writer.
let _ = child.stdin.take();
let session = ShellSession {
child,
stdin_writer: Some(stdin_writer),
stdout_buf,
stderr_buf,
_drain_task: drain_task,
created_at: Instant::now(),
};
self.sessions
.lock()
.await
.insert(session_id.clone(), session);
session_id
}
/// Send input to a session's stdin and return new output.
///
/// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive.
/// If the child process exits during the wait, returns the final output.
pub async fn send_input(&self, session_id: &str, input: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().await;
let session = sessions
.get_mut(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
// Write input to stdin
if let Some(writer) = &mut session.stdin_writer {
let data = if input.ends_with('\n') {
input.to_string()
} else {
format!("{}\n", input)
};
writer
.write_all(data.as_bytes())
.await
.map_err(|e| format!("Failed to write stdin: {}", e))?;
writer
.flush()
.await
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
} else {
return Err("Session stdin is closed".to_string());
}
// Record output length before wait
let prev_stdout_len = session.stdout_buf.lock().await.len();
let prev_stderr_len = session.stderr_buf.lock().await.len();
// Wait for new output or process exit
let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS);
loop {
tokio::select! {
status = session.child.wait() => {
// Process exited — collect final output
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
let code = status.ok().and_then(|s| s.code());
drop(sessions);
return Ok(Self::format_output(&stdout, &stderr, code));
}
_ = tokio::time::sleep_until(deadline) => {
// Timeout — return current output
break;
}
}
}
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect();
let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect();
let mut result = String::new();
if !new_stdout.is_empty() {
result.push_str(&new_stdout);
}
if !new_stderr.trim().is_empty() {
if !result.is_empty() {
result.push('\n');
}
result.push_str("STDERR:\n");
result.push_str(&new_stderr);
}
if result.is_empty() {
result.push_str("(No new output after input. Session still active.)");
}
Ok(result)
}
/// Get the full accumulated output of a session.
pub async fn get_output(&self, session_id: &str) -> Result<String, String> {
let sessions = self.sessions.lock().await;
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
Ok(Self::format_output(&stdout, &stderr, None))
}
/// Close a session: kill the child process and return final output.
pub async fn close_session(&self, session_id: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().await;
let mut session = sessions
.remove(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
drop(sessions); // Release lock before awaiting
// Close stdin to signal EOF
if let Some(mut writer) = session.stdin_writer.take() {
let _ = writer.shutdown().await;
}
// Kill and wait
let _ = session.child.start_kill();
let status = session.child.wait().await.ok();
let code = status.and_then(|s| s.code());
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
// Abort drain task
session._drain_task.abort();
Ok(Self::format_output(&stdout, &stderr, code))
}
/// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`).
///
/// Designed to be called periodically from the owner (e.g. gateway).
pub async fn cleanup_expired(&self) {
let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
let mut expired_ids = Vec::new();
{
let sessions = self.sessions.lock().await;
for (id, session) in sessions.iter() {
if session.created_at.elapsed() > timeout {
expired_ids.push(id.clone());
}
}
}
for id in expired_ids {
let _ = self.close_session(&id).await;
}
}
/// Gracefully shut down all active sessions.
pub async fn shutdown(&self) {
let ids: Vec<String> = {
let sessions = self.sessions.lock().await;
sessions.keys().cloned().collect()
};
for id in ids {
let _ = self.close_session(&id).await;
}
}
/// Number of currently active sessions.
pub async fn active_count(&self) -> usize {
self.sessions.lock().await.len()
}
fn format_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
let mut output = String::new();
if !stdout.is_empty() {
output.push_str(stdout);
}
if !stderr.trim().is_empty() {
if !output.is_empty() {
output.push('\n');
}
output.push_str("STDERR:\n");
output.push_str(stderr);
}
if let Some(code) = exit_code {
output.push_str(&format!("\nExit code: {}", code));
}
output
}
}
impl Default for ShellSessionManager {
fn default() -> Self {
Self::new()
}
}