From 02172b606568186bce6df8e12f432e75a9d866dc Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 13 Jun 2026 09:06:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(shell):=20=E5=AE=9E=E7=8E=B0=E4=BA=A4?= =?UTF-8?q?=E4=BA=92=E5=BC=8FShell=E4=BC=9A=E8=AF=9D=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增ShellSessionManager管理交互式shell会话,支持进程保持和交互输入 - BashTool集成会话管理,支持session_id和stdin_input参数实现输入回复 - 修改BashTool执行逻辑,检测进程等待输入状态并保存会话状态 - Windows平台新增底层进程等待输入检测实现,辅助判断Shell交互状态 - 工具注册工厂注入ShellSessionManager,保证安全复用会话管理实例 - 增加默认agent prompt中Shell交互终端说明,提示交互流程及输入格式 - 交互式命令输出增加标识和提示,区分正常与等待输入状态 - 实现会话超时自动清理和优雅关闭接口,避免资源泄露 - 单元测试中统一使用BashTool默认构造,适配会话管理新增功能 --- Cargo.toml | 8 + src/gateway/default_agent_prompt.md | 7 + src/gateway/tool_registry_factory.rs | 15 +- src/platform/mod.rs | 156 +++++++++++++-- src/tools/bash.rs | 186 ++++++++++++++---- src/tools/mod.rs | 2 + src/tools/shell_session.rs | 283 +++++++++++++++++++++++++++ 7 files changed, 602 insertions(+), 55 deletions(-) create mode 100644 src/tools/shell_session.rs diff --git a/Cargo.toml b/Cargo.toml index c25fb28..30c1e23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,14 @@ rustls = { version = "0.23", features = ["ring"] } wechatbot = { path = "vendor/wechatbot" } encoding_rs = "0.8" libc = "0.2" + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.59", features = [ + "Win32_System_Threading", + "Win32_System_Diagnostics_Debug", + "Win32_Foundation", + "Win32_System_Kernel", +] } # MCP (Model Context Protocol) support rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [ "client", diff --git a/src/gateway/default_agent_prompt.md b/src/gateway/default_agent_prompt.md index 141c383..9856d9a 100644 --- a/src/gateway/default_agent_prompt.md +++ b/src/gateway/default_agent_prompt.md @@ -112,6 +112,13 @@ - 默认创建静默任务(silent_agent_task),在独立后台会话中执行,不干扰主对话 - 静默模式下如需发送消息给用户,prompt中需显式使用 send_session_message 工具 +## Shell 交互终端 + +- 当 shell 工具返回包含 `__PICOBOT_PENDING_USER_ACTION__` 和 `[session_id: xxx]` 的结果时,表示进程正在等待输入 +- 阅读已输出的内容,理解提示含义(如确认提示 Y/N、输入密码、选择选项等) +- 使用 `session_id` 和 `stdin_input` 参数回复交互内容,例如:`{"command": "echo test", "session_id": "xxx", "stdin_input": "Y"}` +- 常见场景:确认提示输入 Y/N、输入密码/验证码、选择选项、Read-Host 等 + ## todo工具使用规范 - 严格按照既定的未完成的todo工作项执行任务,如果工作项不在适用就更新,不得随意遗漏工作项 diff --git a/src/gateway/tool_registry_factory.rs b/src/gateway/tool_registry_factory.rs index 520928a..6c1e026 100644 --- a/src/gateway/tool_registry_factory.rs +++ b/src/gateway/tool_registry_factory.rs @@ -11,8 +11,8 @@ use crate::tools::todo_write::TodoItem; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, MemoryManageTool, MemorySearchTool, - SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, - SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, + SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager, + SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, TodoWriteTool, ToolRegistry, WebFetchTool, }; @@ -29,6 +29,7 @@ pub(crate) struct ToolRegistryFactory { subagent_runtime: Option>, mcp_manager: Option>, todo_state: Option>>>>, + shell_session_manager: Arc, } impl ToolRegistryFactory { @@ -56,6 +57,7 @@ impl ToolRegistryFactory { subagent_runtime: None, mcp_manager: None, todo_state: None, + shell_session_manager: Arc::new(ShellSessionManager::new()), } } @@ -87,6 +89,11 @@ impl ToolRegistryFactory { !self.disabled_tools.contains(tool_name) } + /// Get a reference to the shell session manager for lifecycle control. + pub(crate) fn shell_session_manager(&self) -> Arc { + self.shell_session_manager.clone() + } + pub(crate) fn build(&self) -> ToolRegistry { let mut registry = ToolRegistry::new(); @@ -135,7 +142,7 @@ impl ToolRegistryFactory { registry.register(SkillManageTool::new(self.skills.clone())); } if self.is_enabled("bash") { - registry.register(BashTool::new()); + registry.register(BashTool::new(self.shell_session_manager.clone())); } if self.is_enabled("http_request") { registry.register(HttpRequestTool::new( @@ -184,7 +191,7 @@ impl ToolRegistryFactory { registry.register(FileEditTool::new()); } if self.is_enabled("bash") { - registry.register(BashTool::new()); + registry.register(BashTool::new(self.shell_session_manager.clone())); } if self.is_enabled("http_request") { registry.register(HttpRequestTool::new( diff --git a/src/platform/mod.rs b/src/platform/mod.rs index 535316a..be72978 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -156,18 +156,7 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option { } #[cfg(target_os = "windows")] { - // Windows: a full implementation would use: - // 1. OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid) - // 2. NtQuerySystemInformation(SystemProcessInformation) to enumerate - // threads and check if the main thread's WaitReason is Executive - // (indicating a wait on a kernel handle, e.g. console input). - // 3. CloseHandle(hProcess) - // - // This requires the undocumented NtQuerySystemInformation API from - // ntdll.dll. Until that is in place, the keyword-matching and - // periodic output-staleness checks in bash.rs handle Windows detection. - let _ = pid; - None + windows_is_process_waiting_on_stdin(pid) } #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))] { @@ -176,6 +165,149 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option { } } +/// Windows implementation: check if a process is waiting for stdin input. +/// +/// Uses NtQuerySystemInformation to enumerate process threads and check if +/// all threads are in Wait state with Executive wait reason, which indicates +/// the process is blocked on I/O (likely console input). +#[cfg(target_os = "windows")] +fn windows_is_process_waiting_on_stdin(pid: u32) -> Option { + // SystemProcessInformation = 5 + const SYSTEM_PROCESS_INFORMATION: u32 = 5; + const STATUS_INFO_LENGTH_MISMATCH: i32 = -1073741820; // 0xC0000004 + + #[repr(C)] + #[allow(non_snake_case)] + struct SystemProcessInfo { + next_entry_offset: u32, + number_of_threads: u32, + working_set_private_size: i64, + hard_fault_count: u32, + number_of_threads_high_watermark: u32, + cycle_time: u64, + create_time: i64, + user_time: i64, + kernel_time: i64, + image_name_length: u16, + image_name_max_length: u16, + image_name: *const u16, + base_priority: i32, + unique_process_id: *mut std::ffi::c_void, + inherited_from_unique_process_id: *mut std::ffi::c_void, + handle_count: u32, + session_id: u32, + unique_process_key: usize, + peak_virtual_size: usize, + virtual_size: usize, + page_fault_count: u32, + peak_working_set_size: usize, + working_set_size: usize, + quota_peak_paged_pool_usage: usize, + quota_paged_pool_usage: usize, + quota_peak_non_paged_pool_usage: usize, + quota_non_paged_pool_usage: usize, + pagefile_usage: usize, + peak_pagefile_usage: usize, + private_page_count: usize, + read_operation_count: i64, + write_operation_count: i64, + other_operation_count: i64, + read_transfer_count: i64, + write_transfer_count: i64, + other_transfer_count: i64, + // SYSTEM_THREAD_INFORMATION[1] follows in memory + threads: [SystemThreadInfo; 1], + } + + #[repr(C)] + #[derive(Clone, Copy)] + #[allow(non_snake_case)] + struct SystemThreadInfo { + kernel_time: i64, + user_time: i64, + create_time: i64, + wait_time: u32, + start_address: *mut std::ffi::c_void, + client_id_unique_process: *mut std::ffi::c_void, + client_id_unique_thread: *mut std::ffi::c_void, + priority: i32, + base_priority: i32, + context_switches: u32, + thread_state: u32, + wait_reason: u32, + } + + #[allow(non_snake_case)] + unsafe extern "system" { + fn NtQuerySystemInformation( + system_information_class: u32, + system_information: *mut u8, + system_information_length: u32, + return_length: *mut u32, + ) -> i32; + } + + unsafe { + // Query buffer size first + let mut buf_len: u32 = 0; + let status = NtQuerySystemInformation( + SYSTEM_PROCESS_INFORMATION, + std::ptr::null_mut(), + 0, + &mut buf_len, + ); + if status != STATUS_INFO_LENGTH_MISMATCH || buf_len == 0 { + return None; + } + + // Allocate buffer with extra space (processes may be created between calls) + buf_len = buf_len.saturating_mul(2).max(65536); + let mut buffer: Vec = vec![0u8; buf_len as usize]; + + let status = NtQuerySystemInformation( + SYSTEM_PROCESS_INFORMATION, + buffer.as_mut_ptr(), + buf_len, + &mut buf_len, + ); + if status < 0 { + return None; + } + + // Walk the linked list of SYSTEM_PROCESS_INFORMATION + let mut offset: usize = 0; + loop { + let info = &*(buffer.as_ptr().add(offset) as *const SystemProcessInfo); + let proc_id = info.unique_process_id as u32; + + if proc_id == pid { + let thread_count = info.number_of_threads as usize; + if thread_count == 0 { + return Some(false); + } + + // Thread states: Running=2, Waiting=5 + // Wait reasons: Executive=0, FreePage=1, PageIn=2, PoolAllocation=3, + // DelayExecution=4, Suspended=5, UserRequest=6, ... + // Executive wait + all threads waiting = likely blocked on I/O + let all_waiting = (0..thread_count).all(|i| { + let thread = &*info.threads.as_ptr().add(i); + thread.thread_state == 5 && thread.wait_reason == 0 + }); + + return Some(all_waiting); + } + + if info.next_entry_offset == 0 { + break; + } + offset += info.next_entry_offset as usize; + } + + None + } +} + /// Get the user's home directory. /// /// Supports environment variable overrides for testing: diff --git a/src/tools/bash.rs b/src/tools/bash.rs index c307a55..a4f5f44 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -11,13 +11,16 @@ use tokio::sync::{Mutex, mpsc}; use tokio::time::{Instant, sleep_until}; use crate::platform::{ShellInfo, dangerous_command_patterns}; +use crate::tools::shell_session::ShellSessionManager; use crate::tools::traits::{Tool, ToolResult}; use crate::tools::{extract_u64, extract_bool, check_null_args}; const MAX_TIMEOUT_SECS: u64 = 600; const MAX_OUTPUT_CHARS: usize = 50_000; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; -const USER_ACTION_HINT: &str = +const INTERACTIVE_HINT: &str = + "进程正在等待输入。请使用 session_id 和 stdin_input 参数回复交互内容。"; +const NON_INTERACTIVE_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; /// Shell 类型枚举,支持跨平台 @@ -104,15 +107,17 @@ pub struct BashTool { working_dir: Option, deny_patterns: Vec, shell: ShellKind, + session_manager: Arc, } impl BashTool { - pub fn new() -> Self { + pub fn new(session_manager: Arc) -> Self { Self { timeout_secs: 60, working_dir: None, deny_patterns: dangerous_command_patterns(), shell: ShellKind::detect(), + session_manager, } } @@ -168,11 +173,20 @@ impl BashTool { ) } - fn pending_output(&self, output: &str) -> String { + fn pending_output(&self, output: &str, session_id: Option<&str>) -> String { + let hint = if session_id.is_some() { + INTERACTIVE_HINT + } else { + NON_INTERACTIVE_HINT + }; + let session_line = session_id + .map(|id| format!("[session_id: {}]\n", id)) + .unwrap_or_default(); format!( - "{}\n{}\n\n{}", + "{}\n{}{}\n\n{}", PENDING_USER_ACTION_MARKER, - USER_ACTION_HINT, + session_line, + hint, self.truncate_output(output.trim()) ) } @@ -248,7 +262,7 @@ async fn drain_available_chunks( impl Default for BashTool { fn default() -> Self { - Self::new() + Self::new(Arc::new(ShellSessionManager::new())) } } @@ -279,6 +293,14 @@ impl Tool for BashTool { "interactive": { "type": "boolean", "description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication" + }, + "session_id": { + "type": "string", + "description": "Continue a previous interactive session by providing its session_id" + }, + "stdin_input": { + "type": "string", + "description": "Input text to send to the process stdin (used with session_id)" } }, "required": ["command"] @@ -294,6 +316,26 @@ impl Tool for BashTool { return Ok(result); } + // Handle session continuation first + if let Some(session_id) = args.get("session_id").and_then(|v| v.as_str()) { + let input = args + .get("stdin_input") + .and_then(|v| v.as_str()) + .unwrap_or(""); + return match self.session_manager.send_input(session_id, input).await { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + }; + } + let command = match args.get("command").and_then(|v| v.as_str()) { Some(c) => c, None => { @@ -353,15 +395,19 @@ impl BashTool { ) -> Result { let mut cmd = Command::new(self.shell.executable()); cmd.args(self.shell.command_args(command)) + .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .current_dir(cwd); let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; + // Take stdin writer before stdout/stderr + let child_stdin = child.stdin.take(); let stdout = child.stdout.take(); let stderr = child.stderr.take(); - let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>(); + let (tx, rx_inner) = mpsc::unbounded_channel::<(bool, String)>(); + let mut rx: Option> = Some(rx_inner); if let Some(stdout) = stdout { tokio::spawn(read_stream(stdout, false, tx.clone())); @@ -379,8 +425,9 @@ impl BashTool { tokio::select! { status = child.wait() => { let status = status.map_err(|e| format!("Failed to wait: {}", e))?; - drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; - while let Some((is_stderr, chunk)) = rx.recv().await { + let mut rx_val = rx.take().unwrap(); + drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await; + while let Some((is_stderr, chunk)) = rx_val.recv().await { if is_stderr { stderr_buf.lock().await.push_str(&chunk); } else { @@ -390,7 +437,12 @@ impl BashTool { let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1))); return Ok(self.truncate_output(&output)); } - Some((is_stderr, chunk)) = rx.recv() => { + Some((is_stderr, chunk)) = async { + match rx.as_mut() { + Some(r) => r.recv().await, + None => std::future::pending().await, + } + } => { if is_stderr { stderr_buf.lock().await.push_str(&chunk); } else { @@ -399,59 +451,115 @@ impl BashTool { let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); if self.should_return_pending(interactive, &combined) { - drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + let mut rx_val = rx.take().unwrap(); + drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await; + let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); + // Try to save as interactive session + if let Some(stdin) = child_stdin { + let session_id = self.session_manager.save_session( + child, stdin, rx_val, + stdout_buf.lock().await.clone(), + stderr_buf.lock().await.clone(), + ).await; + return Ok(self.pending_output(&combined, Some(&session_id))); + } let _ = child.start_kill(); let _ = child.wait().await; - let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); - return Ok(self.pending_output(&combined)); + return Ok(self.pending_output(&combined, None)); } } _ = tokio::time::sleep(Duration::from_secs(2)) => { - // Periodic safety net: when output has been silent for 2s, - // check OS-level process state to see if the child is - // genuinely blocked on stdin. Also re-run keyword detection - // in case read_stream flushed a partial line since the last - // rx.recv() iteration. + // Periodic safety net: check OS-level process state if let Some(pid) = child.id() { if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) { - drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + if let Some(rx_ref) = rx.as_mut() { + drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await; + } let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); if !combined.trim().is_empty() { + if let Some(stdin) = child_stdin { + if let Some(rx_val) = rx.take() { + let session_id = self.session_manager.save_session( + child, stdin, rx_val, + stdout_buf.lock().await.clone(), + stderr_buf.lock().await.clone(), + ).await; + return Ok(self.pending_output(&combined, Some(&session_id))); + } + } let _ = child.start_kill(); let _ = child.wait().await; - return Ok(self.pending_output(&combined)); + return Ok(self.pending_output(&combined, None)); } } } let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); if self.should_return_pending(interactive, &combined) { - drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + if let Some(rx_ref) = rx.as_mut() { + drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await; + } let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); + if let Some(stdin) = child_stdin { + if let Some(rx_val) = rx.take() { + let session_id = self.session_manager.save_session( + child, stdin, rx_val, + stdout_buf.lock().await.clone(), + stderr_buf.lock().await.clone(), + ).await; + return Ok(self.pending_output(&combined, Some(&session_id))); + } + } let _ = child.start_kill(); let _ = child.wait().await; - return Ok(self.pending_output(&combined)); + return Ok(self.pending_output(&combined, None)); } } _ = sleep_until(deadline) => { - drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + if let Some(rx_ref) = rx.as_mut() { + drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await; + } let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); - let _ = child.start_kill(); - let _ = child.wait().await; - // OS-level process state check: if the child was blocked on - // stdin, treat it as pending rather than a hard timeout error. + // OS-level check: if blocked on stdin, save as session if let Some(pid) = child.id() { if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) && !combined.trim().is_empty() { - return Ok(self.pending_output(&combined)); + if let Some(stdin) = child_stdin { + if let Some(rx_val) = rx.take() { + let session_id = self.session_manager.save_session( + child, stdin, rx_val, + stdout_buf.lock().await.clone(), + stderr_buf.lock().await.clone(), + ).await; + return Ok(self.pending_output(&combined, Some(&session_id))); + } + } + let _ = child.start_kill(); + let _ = child.wait().await; + return Ok(self.pending_output(&combined, None)); } } if self.should_return_pending(interactive, &combined) { - return Ok(self.pending_output(&combined)); + if let Some(stdin) = child_stdin { + if let Some(rx_val) = rx.take() { + let session_id = self.session_manager.save_session( + child, stdin, rx_val, + stdout_buf.lock().await.clone(), + stderr_buf.lock().await.clone(), + ).await; + return Ok(self.pending_output(&combined, Some(&session_id))); + } + } + let _ = child.start_kill(); + let _ = child.wait().await; + return Ok(self.pending_output(&combined, None)); } + + let _ = child.start_kill(); + let _ = child.wait().await; return Err(format!("Command timed out after {} seconds", timeout_secs)); } } @@ -557,7 +665,7 @@ mod tests { #[tokio::test] async fn test_simple_command() { - let tool = BashTool::new(); + let tool = BashTool::default(); let command = if cfg!(target_os = "windows") { "Write-Output 'Hello World'" } else { @@ -574,7 +682,7 @@ mod tests { #[tokio::test] async fn test_pwd_command() { - let tool = BashTool::new(); + let tool = BashTool::default(); let command = if cfg!(target_os = "windows") { "Get-Location" } else { @@ -587,7 +695,7 @@ mod tests { #[tokio::test] async fn test_ls_command() { - let tool = BashTool::new(); + let tool = BashTool::default(); let temp_dir = std::env::temp_dir(); let command = if cfg!(target_os = "windows") { format!("Get-ChildItem {}", temp_dir.display()) @@ -604,7 +712,7 @@ mod tests { #[tokio::test] async fn test_dangerous_rm() { - let tool = BashTool::new(); + let tool = BashTool::default(); // 测试 Unix 危险命令模式 let result = tool .execute(json!({ "command": "rm -rf /some/path" })) @@ -617,7 +725,7 @@ mod tests { #[tokio::test] async fn test_dangerous_windows_commands() { - let tool = BashTool::new(); + let tool = BashTool::default(); // 测试 Windows del 命令模式(正则应该匹配) let result = tool .execute(json!({ "command": "del /f /q file.txt" })) @@ -630,7 +738,7 @@ mod tests { #[tokio::test] async fn test_dangerous_fork_bomb() { - let tool = BashTool::new(); + let tool = BashTool::default(); let result = tool .execute(json!({ "command": ":(){ :|:& };:" })) .await @@ -642,7 +750,7 @@ mod tests { #[tokio::test] async fn test_missing_command() { - let tool = BashTool::new(); + let tool = BashTool::default(); let result = tool.execute(json!({})).await.unwrap(); assert!(!result.success); @@ -651,7 +759,7 @@ mod tests { #[tokio::test] async fn test_timeout() { - let tool = BashTool::new(); + let tool = BashTool::default(); let command = if cfg!(target_os = "windows") { "Start-Sleep -Seconds 10" } else { @@ -671,7 +779,7 @@ mod tests { #[tokio::test] async fn test_pending_user_action_detection() { - let tool = BashTool::new(); + let tool = BashTool::default(); let command = if cfg!(target_os = "windows") { "Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10" } else { @@ -692,7 +800,7 @@ mod tests { #[test] fn test_truncate_output_handles_utf8_char_boundaries() { - let tool = BashTool::new(); + let tool = BashTool::default(); let input = "全".repeat(MAX_OUTPUT_CHARS + 100); let output = tool.truncate_output(&input); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index eeca684..5a421e6 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -10,6 +10,7 @@ pub mod registry; pub mod scheduler_manage; pub mod session_send; pub mod schema; +pub mod shell_session; pub mod skill_activate; pub mod skill_manage; pub mod task; @@ -33,6 +34,7 @@ pub use session_send::{ SessionSendTool, }; pub use schema::{CleaningStrategy, SchemaCleanr}; +pub use shell_session::ShellSessionManager; pub use skill_activate::SkillActivateTool; pub use skill_manage::SkillManageTool; pub use task::{ diff --git a/src/tools/shell_session.rs b/src/tools/shell_session.rs new file mode 100644 index 0000000..a495aee --- /dev/null +++ b/src/tools/shell_session.rs @@ -0,0 +1,283 @@ +//! Interactive shell session management. +//! +//! Provides `ShellSessionManager`, an independent service that keeps child +//! processes alive between tool calls so the Agent can interact with +//! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`). +//! +//! The manager is created at the gateway/bootstrap layer and injected into +//! `BashTool` via `Arc`. It does NOT start background tasks on its own; +//! cleanup is driven externally via `cleanup_expired()` or `shutdown()`. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use tokio::io::AsyncWriteExt; +use tokio::process::{Child, ChildStdin}; +use tokio::sync::{Mutex, mpsc}; +use tokio::time::Instant; + +use uuid::Uuid; + +const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes +const OUTPUT_WAIT_MS: u64 = 2000; + +/// A single interactive shell session backed by a live child process. +struct ShellSession { + child: Child, + stdin_writer: Option, + stdout_buf: Arc>, + stderr_buf: Arc>, + /// Background task handle that drains the output channel into buffers. + _drain_task: tokio::task::JoinHandle<()>, + created_at: Instant, +} + +/// Independent service for managing interactive shell sessions. +/// +/// Thread-safe — designed to be shared via `Arc`. +pub struct ShellSessionManager { + sessions: Mutex>, +} + +impl ShellSessionManager { + /// Create a new, empty session manager. + pub fn new() -> Self { + Self { + sessions: Mutex::new(HashMap::new()), + } + } + + /// Save a child process as an interactive session. + /// + /// The caller provides: + /// - `child`: the spawned child process (with piped stdin/stdout/stderr) + /// - `stdin_writer`: the write-half of the piped stdin + /// - `rx`: the output channel receiver (produced by `read_stream` tasks) + /// - `initial_stdout` / `initial_stderr`: output already collected before + /// the session was created + /// + /// Returns a unique `session_id` that can be used for subsequent + /// `send_input` / `get_output` / `close_session` calls. + pub async fn save_session( + &self, + mut child: Child, + stdin_writer: ChildStdin, + mut rx: mpsc::UnboundedReceiver<(bool, String)>, + initial_stdout: String, + initial_stderr: String, + ) -> String { + let session_id = Uuid::new_v4().to_string(); + let stdout_buf = Arc::new(Mutex::new(initial_stdout)); + let stderr_buf = Arc::new(Mutex::new(initial_stderr)); + + // Spawn a background task that drains the channel into buffers. + let stdout_clone = stdout_buf.clone(); + let stderr_clone = stderr_buf.clone(); + let drain_task = tokio::spawn(async move { + while let Some((is_stderr, chunk)) = rx.recv().await { + if is_stderr { + stderr_clone.lock().await.push_str(&chunk); + } else { + stdout_clone.lock().await.push_str(&chunk); + } + } + }); + + // Kill the child's inherited stdin to prevent blocking on close. + // The actual stdin writing is done via stdin_writer. + let _ = child.stdin.take(); + + let session = ShellSession { + child, + stdin_writer: Some(stdin_writer), + stdout_buf, + stderr_buf, + _drain_task: drain_task, + created_at: Instant::now(), + }; + + self.sessions + .lock() + .await + .insert(session_id.clone(), session); + session_id + } + + /// Send input to a session's stdin and return new output. + /// + /// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive. + /// If the child process exits during the wait, returns the final output. + pub async fn send_input(&self, session_id: &str, input: &str) -> Result { + let mut sessions = self.sessions.lock().await; + + let session = sessions + .get_mut(session_id) + .ok_or_else(|| format!("Session not found: {}", session_id))?; + + // Write input to stdin + if let Some(writer) = &mut session.stdin_writer { + let data = if input.ends_with('\n') { + input.to_string() + } else { + format!("{}\n", input) + }; + writer + .write_all(data.as_bytes()) + .await + .map_err(|e| format!("Failed to write stdin: {}", e))?; + writer + .flush() + .await + .map_err(|e| format!("Failed to flush stdin: {}", e))?; + } else { + return Err("Session stdin is closed".to_string()); + } + + // Record output length before wait + let prev_stdout_len = session.stdout_buf.lock().await.len(); + let prev_stderr_len = session.stderr_buf.lock().await.len(); + + // Wait for new output or process exit + let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS); + loop { + tokio::select! { + status = session.child.wait() => { + // Process exited — collect final output + let stdout = session.stdout_buf.lock().await.clone(); + let stderr = session.stderr_buf.lock().await.clone(); + let code = status.ok().and_then(|s| s.code()); + drop(sessions); + return Ok(Self::format_output(&stdout, &stderr, code)); + } + _ = tokio::time::sleep_until(deadline) => { + // Timeout — return current output + break; + } + } + } + + let stdout = session.stdout_buf.lock().await.clone(); + let stderr = session.stderr_buf.lock().await.clone(); + + let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect(); + let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect(); + + let mut result = String::new(); + if !new_stdout.is_empty() { + result.push_str(&new_stdout); + } + if !new_stderr.trim().is_empty() { + if !result.is_empty() { + result.push('\n'); + } + result.push_str("STDERR:\n"); + result.push_str(&new_stderr); + } + if result.is_empty() { + result.push_str("(No new output after input. Session still active.)"); + } + + Ok(result) + } + + /// Get the full accumulated output of a session. + pub async fn get_output(&self, session_id: &str) -> Result { + let sessions = self.sessions.lock().await; + let session = sessions + .get(session_id) + .ok_or_else(|| format!("Session not found: {}", session_id))?; + + let stdout = session.stdout_buf.lock().await.clone(); + let stderr = session.stderr_buf.lock().await.clone(); + Ok(Self::format_output(&stdout, &stderr, None)) + } + + /// Close a session: kill the child process and return final output. + pub async fn close_session(&self, session_id: &str) -> Result { + let mut sessions = self.sessions.lock().await; + let mut session = sessions + .remove(session_id) + .ok_or_else(|| format!("Session not found: {}", session_id))?; + drop(sessions); // Release lock before awaiting + + // Close stdin to signal EOF + if let Some(mut writer) = session.stdin_writer.take() { + let _ = writer.shutdown().await; + } + + // Kill and wait + let _ = session.child.start_kill(); + let status = session.child.wait().await.ok(); + let code = status.and_then(|s| s.code()); + + let stdout = session.stdout_buf.lock().await.clone(); + let stderr = session.stderr_buf.lock().await.clone(); + + // Abort drain task + session._drain_task.abort(); + + Ok(Self::format_output(&stdout, &stderr, code)) + } + + /// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`). + /// + /// Designed to be called periodically from the owner (e.g. gateway). + pub async fn cleanup_expired(&self) { + let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS); + let mut expired_ids = Vec::new(); + + { + let sessions = self.sessions.lock().await; + for (id, session) in sessions.iter() { + if session.created_at.elapsed() > timeout { + expired_ids.push(id.clone()); + } + } + } + + for id in expired_ids { + let _ = self.close_session(&id).await; + } + } + + /// Gracefully shut down all active sessions. + pub async fn shutdown(&self) { + let ids: Vec = { + let sessions = self.sessions.lock().await; + sessions.keys().cloned().collect() + }; + for id in ids { + let _ = self.close_session(&id).await; + } + } + + /// Number of currently active sessions. + pub async fn active_count(&self) -> usize { + self.sessions.lock().await.len() + } + + fn format_output(stdout: &str, stderr: &str, exit_code: Option) -> String { + let mut output = String::new(); + if !stdout.is_empty() { + output.push_str(stdout); + } + if !stderr.trim().is_empty() { + if !output.is_empty() { + output.push('\n'); + } + output.push_str("STDERR:\n"); + output.push_str(stderr); + } + if let Some(code) = exit_code { + output.push_str(&format!("\nExit code: {}", code)); + } + output + } +} + +impl Default for ShellSessionManager { + fn default() -> Self { + Self::new() + } +}