feat(shell): 实现交互式Shell会话管理

- 新增ShellSessionManager管理交互式shell会话,支持进程保持和交互输入
- BashTool集成会话管理,支持session_id和stdin_input参数实现输入回复
- 修改BashTool执行逻辑,检测进程等待输入状态并保存会话状态
- Windows平台新增底层进程等待输入检测实现,辅助判断Shell交互状态
- 工具注册工厂注入ShellSessionManager,保证安全复用会话管理实例
- 增加默认agent prompt中Shell交互终端说明,提示交互流程及输入格式
- 交互式命令输出增加标识和提示,区分正常与等待输入状态
- 实现会话超时自动清理和优雅关闭接口,避免资源泄露
- 单元测试中统一使用BashTool默认构造,适配会话管理新增功能
This commit is contained in:
ooodc 2026-06-13 09:06:45 +08:00
parent 43cea50df8
commit 02172b6065
7 changed files with 602 additions and 55 deletions

View File

@ -40,6 +40,14 @@ rustls = { version = "0.23", features = ["ring"] }
wechatbot = { path = "vendor/wechatbot" } wechatbot = { path = "vendor/wechatbot" }
encoding_rs = "0.8" encoding_rs = "0.8"
libc = "0.2" libc = "0.2"
[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.59", features = [
"Win32_System_Threading",
"Win32_System_Diagnostics_Debug",
"Win32_Foundation",
"Win32_System_Kernel",
] }
# MCP (Model Context Protocol) support # MCP (Model Context Protocol) support
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
"client", "client",

View File

@ -112,6 +112,13 @@
- 默认创建静默任务silent_agent_task在独立后台会话中执行不干扰主对话 - 默认创建静默任务silent_agent_task在独立后台会话中执行不干扰主对话
- 静默模式下如需发送消息给用户prompt中需显式使用 send_session_message 工具 - 静默模式下如需发送消息给用户prompt中需显式使用 send_session_message 工具
## Shell 交互终端
- 当 shell 工具返回包含 `__PICOBOT_PENDING_USER_ACTION__``[session_id: xxx]` 的结果时,表示进程正在等待输入
- 阅读已输出的内容,理解提示含义(如确认提示 Y/N、输入密码、选择选项等
- 使用 `session_id``stdin_input` 参数回复交互内容,例如:`{"command": "echo test", "session_id": "xxx", "stdin_input": "Y"}`
- 常见场景:确认提示输入 Y/N、输入密码/验证码、选择选项、Read-Host 等
## todo工具使用规范 ## todo工具使用规范
- 严格按照既定的未完成的todo工作项执行任务如果工作项不在适用就更新不得随意遗漏工作项 - 严格按照既定的未完成的todo工作项执行任务如果工作项不在适用就更新不得随意遗漏工作项

View File

@ -11,8 +11,8 @@ use crate::tools::todo_write::TodoItem;
use crate::tools::{ use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, MemorySearchTool, HttpRequestTool, MemoryManageTool, MemorySearchTool,
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool, SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
TodoWriteTool, ToolRegistry, WebFetchTool, TodoWriteTool, ToolRegistry, WebFetchTool,
}; };
@ -29,6 +29,7 @@ pub(crate) struct ToolRegistryFactory {
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>, subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>, mcp_manager: Option<Arc<McpClientManager>>,
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>, todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
shell_session_manager: Arc<ShellSessionManager>,
} }
impl ToolRegistryFactory { impl ToolRegistryFactory {
@ -56,6 +57,7 @@ impl ToolRegistryFactory {
subagent_runtime: None, subagent_runtime: None,
mcp_manager: None, mcp_manager: None,
todo_state: None, todo_state: None,
shell_session_manager: Arc::new(ShellSessionManager::new()),
} }
} }
@ -87,6 +89,11 @@ impl ToolRegistryFactory {
!self.disabled_tools.contains(tool_name) !self.disabled_tools.contains(tool_name)
} }
/// Get a reference to the shell session manager for lifecycle control.
pub(crate) fn shell_session_manager(&self) -> Arc<ShellSessionManager> {
self.shell_session_manager.clone()
}
pub(crate) fn build(&self) -> ToolRegistry { pub(crate) fn build(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new(); let mut registry = ToolRegistry::new();
@ -135,7 +142,7 @@ impl ToolRegistryFactory {
registry.register(SkillManageTool::new(self.skills.clone())); registry.register(SkillManageTool::new(self.skills.clone()));
} }
if self.is_enabled("bash") { if self.is_enabled("bash") {
registry.register(BashTool::new()); registry.register(BashTool::new(self.shell_session_manager.clone()));
} }
if self.is_enabled("http_request") { if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new( registry.register(HttpRequestTool::new(
@ -184,7 +191,7 @@ impl ToolRegistryFactory {
registry.register(FileEditTool::new()); registry.register(FileEditTool::new());
} }
if self.is_enabled("bash") { if self.is_enabled("bash") {
registry.register(BashTool::new()); registry.register(BashTool::new(self.shell_session_manager.clone()));
} }
if self.is_enabled("http_request") { if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new( registry.register(HttpRequestTool::new(

View File

@ -156,18 +156,7 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
{ {
// Windows: a full implementation would use: windows_is_process_waiting_on_stdin(pid)
// 1. OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid)
// 2. NtQuerySystemInformation(SystemProcessInformation) to enumerate
// threads and check if the main thread's WaitReason is Executive
// (indicating a wait on a kernel handle, e.g. console input).
// 3. CloseHandle(hProcess)
//
// This requires the undocumented NtQuerySystemInformation API from
// ntdll.dll. Until that is in place, the keyword-matching and
// periodic output-staleness checks in bash.rs handle Windows detection.
let _ = pid;
None
} }
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))] #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{ {
@ -176,6 +165,149 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
} }
} }
/// Windows implementation: check if a process is waiting for stdin input.
///
/// Uses NtQuerySystemInformation to enumerate process threads and check if
/// all threads are in Wait state with Executive wait reason, which indicates
/// the process is blocked on I/O (likely console input).
#[cfg(target_os = "windows")]
fn windows_is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
// SystemProcessInformation = 5
const SYSTEM_PROCESS_INFORMATION: u32 = 5;
const STATUS_INFO_LENGTH_MISMATCH: i32 = -1073741820; // 0xC0000004
#[repr(C)]
#[allow(non_snake_case)]
struct SystemProcessInfo {
next_entry_offset: u32,
number_of_threads: u32,
working_set_private_size: i64,
hard_fault_count: u32,
number_of_threads_high_watermark: u32,
cycle_time: u64,
create_time: i64,
user_time: i64,
kernel_time: i64,
image_name_length: u16,
image_name_max_length: u16,
image_name: *const u16,
base_priority: i32,
unique_process_id: *mut std::ffi::c_void,
inherited_from_unique_process_id: *mut std::ffi::c_void,
handle_count: u32,
session_id: u32,
unique_process_key: usize,
peak_virtual_size: usize,
virtual_size: usize,
page_fault_count: u32,
peak_working_set_size: usize,
working_set_size: usize,
quota_peak_paged_pool_usage: usize,
quota_paged_pool_usage: usize,
quota_peak_non_paged_pool_usage: usize,
quota_non_paged_pool_usage: usize,
pagefile_usage: usize,
peak_pagefile_usage: usize,
private_page_count: usize,
read_operation_count: i64,
write_operation_count: i64,
other_operation_count: i64,
read_transfer_count: i64,
write_transfer_count: i64,
other_transfer_count: i64,
// SYSTEM_THREAD_INFORMATION[1] follows in memory
threads: [SystemThreadInfo; 1],
}
#[repr(C)]
#[derive(Clone, Copy)]
#[allow(non_snake_case)]
struct SystemThreadInfo {
kernel_time: i64,
user_time: i64,
create_time: i64,
wait_time: u32,
start_address: *mut std::ffi::c_void,
client_id_unique_process: *mut std::ffi::c_void,
client_id_unique_thread: *mut std::ffi::c_void,
priority: i32,
base_priority: i32,
context_switches: u32,
thread_state: u32,
wait_reason: u32,
}
#[allow(non_snake_case)]
unsafe extern "system" {
fn NtQuerySystemInformation(
system_information_class: u32,
system_information: *mut u8,
system_information_length: u32,
return_length: *mut u32,
) -> i32;
}
unsafe {
// Query buffer size first
let mut buf_len: u32 = 0;
let status = NtQuerySystemInformation(
SYSTEM_PROCESS_INFORMATION,
std::ptr::null_mut(),
0,
&mut buf_len,
);
if status != STATUS_INFO_LENGTH_MISMATCH || buf_len == 0 {
return None;
}
// Allocate buffer with extra space (processes may be created between calls)
buf_len = buf_len.saturating_mul(2).max(65536);
let mut buffer: Vec<u8> = vec![0u8; buf_len as usize];
let status = NtQuerySystemInformation(
SYSTEM_PROCESS_INFORMATION,
buffer.as_mut_ptr(),
buf_len,
&mut buf_len,
);
if status < 0 {
return None;
}
// Walk the linked list of SYSTEM_PROCESS_INFORMATION
let mut offset: usize = 0;
loop {
let info = &*(buffer.as_ptr().add(offset) as *const SystemProcessInfo);
let proc_id = info.unique_process_id as u32;
if proc_id == pid {
let thread_count = info.number_of_threads as usize;
if thread_count == 0 {
return Some(false);
}
// Thread states: Running=2, Waiting=5
// Wait reasons: Executive=0, FreePage=1, PageIn=2, PoolAllocation=3,
// DelayExecution=4, Suspended=5, UserRequest=6, ...
// Executive wait + all threads waiting = likely blocked on I/O
let all_waiting = (0..thread_count).all(|i| {
let thread = &*info.threads.as_ptr().add(i);
thread.thread_state == 5 && thread.wait_reason == 0
});
return Some(all_waiting);
}
if info.next_entry_offset == 0 {
break;
}
offset += info.next_entry_offset as usize;
}
None
}
}
/// Get the user's home directory. /// Get the user's home directory.
/// ///
/// Supports environment variable overrides for testing: /// Supports environment variable overrides for testing:

View File

@ -11,13 +11,16 @@ use tokio::sync::{Mutex, mpsc};
use tokio::time::{Instant, sleep_until}; use tokio::time::{Instant, sleep_until};
use crate::platform::{ShellInfo, dangerous_command_patterns}; use crate::platform::{ShellInfo, dangerous_command_patterns};
use crate::tools::shell_session::ShellSessionManager;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
use crate::tools::{extract_u64, extract_bool, check_null_args}; use crate::tools::{extract_u64, extract_bool, check_null_args};
const MAX_TIMEOUT_SECS: u64 = 600; const MAX_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_CHARS: usize = 50_000; const MAX_OUTPUT_CHARS: usize = 50_000;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const USER_ACTION_HINT: &str = const INTERACTIVE_HINT: &str =
"进程正在等待输入。请使用 session_id 和 stdin_input 参数回复交互内容。";
const NON_INTERACTIVE_HINT: &str =
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
/// Shell 类型枚举,支持跨平台 /// Shell 类型枚举,支持跨平台
@ -104,15 +107,17 @@ pub struct BashTool {
working_dir: Option<String>, working_dir: Option<String>,
deny_patterns: Vec<String>, deny_patterns: Vec<String>,
shell: ShellKind, shell: ShellKind,
session_manager: Arc<ShellSessionManager>,
} }
impl BashTool { impl BashTool {
pub fn new() -> Self { pub fn new(session_manager: Arc<ShellSessionManager>) -> Self {
Self { Self {
timeout_secs: 60, timeout_secs: 60,
working_dir: None, working_dir: None,
deny_patterns: dangerous_command_patterns(), deny_patterns: dangerous_command_patterns(),
shell: ShellKind::detect(), shell: ShellKind::detect(),
session_manager,
} }
} }
@ -168,11 +173,20 @@ impl BashTool {
) )
} }
fn pending_output(&self, output: &str) -> String { fn pending_output(&self, output: &str, session_id: Option<&str>) -> String {
let hint = if session_id.is_some() {
INTERACTIVE_HINT
} else {
NON_INTERACTIVE_HINT
};
let session_line = session_id
.map(|id| format!("[session_id: {}]\n", id))
.unwrap_or_default();
format!( format!(
"{}\n{}\n\n{}", "{}\n{}{}\n\n{}",
PENDING_USER_ACTION_MARKER, PENDING_USER_ACTION_MARKER,
USER_ACTION_HINT, session_line,
hint,
self.truncate_output(output.trim()) self.truncate_output(output.trim())
) )
} }
@ -248,7 +262,7 @@ async fn drain_available_chunks(
impl Default for BashTool { impl Default for BashTool {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new(Arc::new(ShellSessionManager::new()))
} }
} }
@ -279,6 +293,14 @@ impl Tool for BashTool {
"interactive": { "interactive": {
"type": "boolean", "type": "boolean",
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication" "description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
},
"session_id": {
"type": "string",
"description": "Continue a previous interactive session by providing its session_id"
},
"stdin_input": {
"type": "string",
"description": "Input text to send to the process stdin (used with session_id)"
} }
}, },
"required": ["command"] "required": ["command"]
@ -294,6 +316,26 @@ impl Tool for BashTool {
return Ok(result); return Ok(result);
} }
// Handle session continuation first
if let Some(session_id) = args.get("session_id").and_then(|v| v.as_str()) {
let input = args
.get("stdin_input")
.and_then(|v| v.as_str())
.unwrap_or("");
return match self.session_manager.send_input(session_id, input).await {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
};
}
let command = match args.get("command").and_then(|v| v.as_str()) { let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c, Some(c) => c,
None => { None => {
@ -353,15 +395,19 @@ impl BashTool {
) -> Result<String, String> { ) -> Result<String, String> {
let mut cmd = Command::new(self.shell.executable()); let mut cmd = Command::new(self.shell.executable());
cmd.args(self.shell.command_args(command)) cmd.args(self.shell.command_args(command))
.stdin(Stdio::piped())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.current_dir(cwd); .current_dir(cwd);
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
// Take stdin writer before stdout/stderr
let child_stdin = child.stdin.take();
let stdout = child.stdout.take(); let stdout = child.stdout.take();
let stderr = child.stderr.take(); let stderr = child.stderr.take();
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>(); let (tx, rx_inner) = mpsc::unbounded_channel::<(bool, String)>();
let mut rx: Option<mpsc::UnboundedReceiver<(bool, String)>> = Some(rx_inner);
if let Some(stdout) = stdout { if let Some(stdout) = stdout {
tokio::spawn(read_stream(stdout, false, tx.clone())); tokio::spawn(read_stream(stdout, false, tx.clone()));
@ -379,8 +425,9 @@ impl BashTool {
tokio::select! { tokio::select! {
status = child.wait() => { status = child.wait() => {
let status = status.map_err(|e| format!("Failed to wait: {}", e))?; let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; let mut rx_val = rx.take().unwrap();
while let Some((is_stderr, chunk)) = rx.recv().await { drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await;
while let Some((is_stderr, chunk)) = rx_val.recv().await {
if is_stderr { if is_stderr {
stderr_buf.lock().await.push_str(&chunk); stderr_buf.lock().await.push_str(&chunk);
} else { } else {
@ -390,7 +437,12 @@ impl BashTool {
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1))); let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
return Ok(self.truncate_output(&output)); return Ok(self.truncate_output(&output));
} }
Some((is_stderr, chunk)) = rx.recv() => { Some((is_stderr, chunk)) = async {
match rx.as_mut() {
Some(r) => r.recv().await,
None => std::future::pending().await,
}
} => {
if is_stderr { if is_stderr {
stderr_buf.lock().await.push_str(&chunk); stderr_buf.lock().await.push_str(&chunk);
} else { } else {
@ -399,59 +451,115 @@ impl BashTool {
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if self.should_return_pending(interactive, &combined) { if self.should_return_pending(interactive, &combined) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; let mut rx_val = rx.take().unwrap();
drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
// Try to save as interactive session
if let Some(stdin) = child_stdin {
let session_id = self.session_manager.save_session(
child, stdin, rx_val,
stdout_buf.lock().await.clone(),
stderr_buf.lock().await.clone(),
).await;
return Ok(self.pending_output(&combined, Some(&session_id)));
}
let _ = child.start_kill(); let _ = child.start_kill();
let _ = child.wait().await; let _ = child.wait().await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); return Ok(self.pending_output(&combined, None));
return Ok(self.pending_output(&combined));
} }
} }
_ = tokio::time::sleep(Duration::from_secs(2)) => { _ = tokio::time::sleep(Duration::from_secs(2)) => {
// Periodic safety net: when output has been silent for 2s, // Periodic safety net: check OS-level process state
// check OS-level process state to see if the child is
// genuinely blocked on stdin. Also re-run keyword detection
// in case read_stream flushed a partial line since the last
// rx.recv() iteration.
if let Some(pid) = child.id() { if let Some(pid) = child.id() {
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) { if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; if let Some(rx_ref) = rx.as_mut() {
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if !combined.trim().is_empty() { if !combined.trim().is_empty() {
if let Some(stdin) = child_stdin {
if let Some(rx_val) = rx.take() {
let session_id = self.session_manager.save_session(
child, stdin, rx_val,
stdout_buf.lock().await.clone(),
stderr_buf.lock().await.clone(),
).await;
return Ok(self.pending_output(&combined, Some(&session_id)));
}
}
let _ = child.start_kill(); let _ = child.start_kill();
let _ = child.wait().await; let _ = child.wait().await;
return Ok(self.pending_output(&combined)); return Ok(self.pending_output(&combined, None));
} }
} }
} }
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if self.should_return_pending(interactive, &combined) { if self.should_return_pending(interactive, &combined) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; if let Some(rx_ref) = rx.as_mut() {
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if let Some(stdin) = child_stdin {
if let Some(rx_val) = rx.take() {
let session_id = self.session_manager.save_session(
child, stdin, rx_val,
stdout_buf.lock().await.clone(),
stderr_buf.lock().await.clone(),
).await;
return Ok(self.pending_output(&combined, Some(&session_id)));
}
}
let _ = child.start_kill(); let _ = child.start_kill();
let _ = child.wait().await; let _ = child.wait().await;
return Ok(self.pending_output(&combined)); return Ok(self.pending_output(&combined, None));
} }
} }
_ = sleep_until(deadline) => { _ = sleep_until(deadline) => {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; if let Some(rx_ref) = rx.as_mut() {
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
let _ = child.start_kill();
let _ = child.wait().await;
// OS-level process state check: if the child was blocked on // OS-level check: if blocked on stdin, save as session
// stdin, treat it as pending rather than a hard timeout error.
if let Some(pid) = child.id() { if let Some(pid) = child.id() {
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) if crate::platform::is_process_waiting_on_stdin(pid) == Some(true)
&& !combined.trim().is_empty() && !combined.trim().is_empty()
{ {
return Ok(self.pending_output(&combined)); if let Some(stdin) = child_stdin {
if let Some(rx_val) = rx.take() {
let session_id = self.session_manager.save_session(
child, stdin, rx_val,
stdout_buf.lock().await.clone(),
stderr_buf.lock().await.clone(),
).await;
return Ok(self.pending_output(&combined, Some(&session_id)));
}
}
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(self.pending_output(&combined, None));
} }
} }
if self.should_return_pending(interactive, &combined) { if self.should_return_pending(interactive, &combined) {
return Ok(self.pending_output(&combined)); if let Some(stdin) = child_stdin {
if let Some(rx_val) = rx.take() {
let session_id = self.session_manager.save_session(
child, stdin, rx_val,
stdout_buf.lock().await.clone(),
stderr_buf.lock().await.clone(),
).await;
return Ok(self.pending_output(&combined, Some(&session_id)));
} }
}
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(self.pending_output(&combined, None));
}
let _ = child.start_kill();
let _ = child.wait().await;
return Err(format!("Command timed out after {} seconds", timeout_secs)); return Err(format!("Command timed out after {} seconds", timeout_secs));
} }
} }
@ -557,7 +665,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_simple_command() { async fn test_simple_command() {
let tool = BashTool::new(); let tool = BashTool::default();
let command = if cfg!(target_os = "windows") { let command = if cfg!(target_os = "windows") {
"Write-Output 'Hello World'" "Write-Output 'Hello World'"
} else { } else {
@ -574,7 +682,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_pwd_command() { async fn test_pwd_command() {
let tool = BashTool::new(); let tool = BashTool::default();
let command = if cfg!(target_os = "windows") { let command = if cfg!(target_os = "windows") {
"Get-Location" "Get-Location"
} else { } else {
@ -587,7 +695,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_ls_command() { async fn test_ls_command() {
let tool = BashTool::new(); let tool = BashTool::default();
let temp_dir = std::env::temp_dir(); let temp_dir = std::env::temp_dir();
let command = if cfg!(target_os = "windows") { let command = if cfg!(target_os = "windows") {
format!("Get-ChildItem {}", temp_dir.display()) format!("Get-ChildItem {}", temp_dir.display())
@ -604,7 +712,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_dangerous_rm() { async fn test_dangerous_rm() {
let tool = BashTool::new(); let tool = BashTool::default();
// 测试 Unix 危险命令模式 // 测试 Unix 危险命令模式
let result = tool let result = tool
.execute(json!({ "command": "rm -rf /some/path" })) .execute(json!({ "command": "rm -rf /some/path" }))
@ -617,7 +725,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_dangerous_windows_commands() { async fn test_dangerous_windows_commands() {
let tool = BashTool::new(); let tool = BashTool::default();
// 测试 Windows del 命令模式(正则应该匹配) // 测试 Windows del 命令模式(正则应该匹配)
let result = tool let result = tool
.execute(json!({ "command": "del /f /q file.txt" })) .execute(json!({ "command": "del /f /q file.txt" }))
@ -630,7 +738,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_dangerous_fork_bomb() { async fn test_dangerous_fork_bomb() {
let tool = BashTool::new(); let tool = BashTool::default();
let result = tool let result = tool
.execute(json!({ "command": ":(){ :|:& };:" })) .execute(json!({ "command": ":(){ :|:& };:" }))
.await .await
@ -642,7 +750,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_missing_command() { async fn test_missing_command() {
let tool = BashTool::new(); let tool = BashTool::default();
let result = tool.execute(json!({})).await.unwrap(); let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success); assert!(!result.success);
@ -651,7 +759,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_timeout() { async fn test_timeout() {
let tool = BashTool::new(); let tool = BashTool::default();
let command = if cfg!(target_os = "windows") { let command = if cfg!(target_os = "windows") {
"Start-Sleep -Seconds 10" "Start-Sleep -Seconds 10"
} else { } else {
@ -671,7 +779,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_pending_user_action_detection() { async fn test_pending_user_action_detection() {
let tool = BashTool::new(); let tool = BashTool::default();
let command = if cfg!(target_os = "windows") { let command = if cfg!(target_os = "windows") {
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10" "Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
} else { } else {
@ -692,7 +800,7 @@ mod tests {
#[test] #[test]
fn test_truncate_output_handles_utf8_char_boundaries() { fn test_truncate_output_handles_utf8_char_boundaries() {
let tool = BashTool::new(); let tool = BashTool::default();
let input = "".repeat(MAX_OUTPUT_CHARS + 100); let input = "".repeat(MAX_OUTPUT_CHARS + 100);
let output = tool.truncate_output(&input); let output = tool.truncate_output(&input);

View File

@ -10,6 +10,7 @@ pub mod registry;
pub mod scheduler_manage; pub mod scheduler_manage;
pub mod session_send; pub mod session_send;
pub mod schema; pub mod schema;
pub mod shell_session;
pub mod skill_activate; pub mod skill_activate;
pub mod skill_manage; pub mod skill_manage;
pub mod task; pub mod task;
@ -33,6 +34,7 @@ pub use session_send::{
SessionSendTool, SessionSendTool,
}; };
pub use schema::{CleaningStrategy, SchemaCleanr}; pub use schema::{CleaningStrategy, SchemaCleanr};
pub use shell_session::ShellSessionManager;
pub use skill_activate::SkillActivateTool; pub use skill_activate::SkillActivateTool;
pub use skill_manage::SkillManageTool; pub use skill_manage::SkillManageTool;
pub use task::{ pub use task::{

283
src/tools/shell_session.rs Normal file
View File

@ -0,0 +1,283 @@
//! Interactive shell session management.
//!
//! Provides `ShellSessionManager`, an independent service that keeps child
//! processes alive between tool calls so the Agent can interact with
//! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`).
//!
//! The manager is created at the gateway/bootstrap layer and injected into
//! `BashTool` via `Arc`. It does NOT start background tasks on its own;
//! cleanup is driven externally via `cleanup_expired()` or `shutdown()`.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::process::{Child, ChildStdin};
use tokio::sync::{Mutex, mpsc};
use tokio::time::Instant;
use uuid::Uuid;
const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes
const OUTPUT_WAIT_MS: u64 = 2000;
/// A single interactive shell session backed by a live child process.
struct ShellSession {
child: Child,
stdin_writer: Option<ChildStdin>,
stdout_buf: Arc<Mutex<String>>,
stderr_buf: Arc<Mutex<String>>,
/// Background task handle that drains the output channel into buffers.
_drain_task: tokio::task::JoinHandle<()>,
created_at: Instant,
}
/// Independent service for managing interactive shell sessions.
///
/// Thread-safe — designed to be shared via `Arc<ShellSessionManager>`.
pub struct ShellSessionManager {
sessions: Mutex<HashMap<String, ShellSession>>,
}
impl ShellSessionManager {
/// Create a new, empty session manager.
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
}
}
/// Save a child process as an interactive session.
///
/// The caller provides:
/// - `child`: the spawned child process (with piped stdin/stdout/stderr)
/// - `stdin_writer`: the write-half of the piped stdin
/// - `rx`: the output channel receiver (produced by `read_stream` tasks)
/// - `initial_stdout` / `initial_stderr`: output already collected before
/// the session was created
///
/// Returns a unique `session_id` that can be used for subsequent
/// `send_input` / `get_output` / `close_session` calls.
pub async fn save_session(
&self,
mut child: Child,
stdin_writer: ChildStdin,
mut rx: mpsc::UnboundedReceiver<(bool, String)>,
initial_stdout: String,
initial_stderr: String,
) -> String {
let session_id = Uuid::new_v4().to_string();
let stdout_buf = Arc::new(Mutex::new(initial_stdout));
let stderr_buf = Arc::new(Mutex::new(initial_stderr));
// Spawn a background task that drains the channel into buffers.
let stdout_clone = stdout_buf.clone();
let stderr_clone = stderr_buf.clone();
let drain_task = tokio::spawn(async move {
while let Some((is_stderr, chunk)) = rx.recv().await {
if is_stderr {
stderr_clone.lock().await.push_str(&chunk);
} else {
stdout_clone.lock().await.push_str(&chunk);
}
}
});
// Kill the child's inherited stdin to prevent blocking on close.
// The actual stdin writing is done via stdin_writer.
let _ = child.stdin.take();
let session = ShellSession {
child,
stdin_writer: Some(stdin_writer),
stdout_buf,
stderr_buf,
_drain_task: drain_task,
created_at: Instant::now(),
};
self.sessions
.lock()
.await
.insert(session_id.clone(), session);
session_id
}
/// Send input to a session's stdin and return new output.
///
/// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive.
/// If the child process exits during the wait, returns the final output.
pub async fn send_input(&self, session_id: &str, input: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().await;
let session = sessions
.get_mut(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
// Write input to stdin
if let Some(writer) = &mut session.stdin_writer {
let data = if input.ends_with('\n') {
input.to_string()
} else {
format!("{}\n", input)
};
writer
.write_all(data.as_bytes())
.await
.map_err(|e| format!("Failed to write stdin: {}", e))?;
writer
.flush()
.await
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
} else {
return Err("Session stdin is closed".to_string());
}
// Record output length before wait
let prev_stdout_len = session.stdout_buf.lock().await.len();
let prev_stderr_len = session.stderr_buf.lock().await.len();
// Wait for new output or process exit
let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS);
loop {
tokio::select! {
status = session.child.wait() => {
// Process exited — collect final output
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
let code = status.ok().and_then(|s| s.code());
drop(sessions);
return Ok(Self::format_output(&stdout, &stderr, code));
}
_ = tokio::time::sleep_until(deadline) => {
// Timeout — return current output
break;
}
}
}
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect();
let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect();
let mut result = String::new();
if !new_stdout.is_empty() {
result.push_str(&new_stdout);
}
if !new_stderr.trim().is_empty() {
if !result.is_empty() {
result.push('\n');
}
result.push_str("STDERR:\n");
result.push_str(&new_stderr);
}
if result.is_empty() {
result.push_str("(No new output after input. Session still active.)");
}
Ok(result)
}
/// Get the full accumulated output of a session.
pub async fn get_output(&self, session_id: &str) -> Result<String, String> {
let sessions = self.sessions.lock().await;
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
Ok(Self::format_output(&stdout, &stderr, None))
}
/// Close a session: kill the child process and return final output.
pub async fn close_session(&self, session_id: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().await;
let mut session = sessions
.remove(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
drop(sessions); // Release lock before awaiting
// Close stdin to signal EOF
if let Some(mut writer) = session.stdin_writer.take() {
let _ = writer.shutdown().await;
}
// Kill and wait
let _ = session.child.start_kill();
let status = session.child.wait().await.ok();
let code = status.and_then(|s| s.code());
let stdout = session.stdout_buf.lock().await.clone();
let stderr = session.stderr_buf.lock().await.clone();
// Abort drain task
session._drain_task.abort();
Ok(Self::format_output(&stdout, &stderr, code))
}
/// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`).
///
/// Designed to be called periodically from the owner (e.g. gateway).
pub async fn cleanup_expired(&self) {
let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
let mut expired_ids = Vec::new();
{
let sessions = self.sessions.lock().await;
for (id, session) in sessions.iter() {
if session.created_at.elapsed() > timeout {
expired_ids.push(id.clone());
}
}
}
for id in expired_ids {
let _ = self.close_session(&id).await;
}
}
/// Gracefully shut down all active sessions.
pub async fn shutdown(&self) {
let ids: Vec<String> = {
let sessions = self.sessions.lock().await;
sessions.keys().cloned().collect()
};
for id in ids {
let _ = self.close_session(&id).await;
}
}
/// Number of currently active sessions.
pub async fn active_count(&self) -> usize {
self.sessions.lock().await.len()
}
fn format_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
let mut output = String::new();
if !stdout.is_empty() {
output.push_str(stdout);
}
if !stderr.trim().is_empty() {
if !output.is_empty() {
output.push('\n');
}
output.push_str("STDERR:\n");
output.push_str(stderr);
}
if let Some(code) = exit_code {
output.push_str(&format!("\nExit code: {}", code));
}
output
}
}
impl Default for ShellSessionManager {
fn default() -> Self {
Self::new()
}
}