feat(shell): 实现交互式Shell会话管理
- 新增ShellSessionManager管理交互式shell会话,支持进程保持和交互输入 - BashTool集成会话管理,支持session_id和stdin_input参数实现输入回复 - 修改BashTool执行逻辑,检测进程等待输入状态并保存会话状态 - Windows平台新增底层进程等待输入检测实现,辅助判断Shell交互状态 - 工具注册工厂注入ShellSessionManager,保证安全复用会话管理实例 - 增加默认agent prompt中Shell交互终端说明,提示交互流程及输入格式 - 交互式命令输出增加标识和提示,区分正常与等待输入状态 - 实现会话超时自动清理和优雅关闭接口,避免资源泄露 - 单元测试中统一使用BashTool默认构造,适配会话管理新增功能
This commit is contained in:
parent
43cea50df8
commit
02172b6065
@ -40,6 +40,14 @@ rustls = { version = "0.23", features = ["ring"] }
|
||||
wechatbot = { path = "vendor/wechatbot" }
|
||||
encoding_rs = "0.8"
|
||||
libc = "0.2"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows-sys = { version = "0.59", features = [
|
||||
"Win32_System_Threading",
|
||||
"Win32_System_Diagnostics_Debug",
|
||||
"Win32_Foundation",
|
||||
"Win32_System_Kernel",
|
||||
] }
|
||||
# MCP (Model Context Protocol) support
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
||||
"client",
|
||||
|
||||
@ -112,6 +112,13 @@
|
||||
- 默认创建静默任务(silent_agent_task),在独立后台会话中执行,不干扰主对话
|
||||
- 静默模式下如需发送消息给用户,prompt中需显式使用 send_session_message 工具
|
||||
|
||||
## Shell 交互终端
|
||||
|
||||
- 当 shell 工具返回包含 `__PICOBOT_PENDING_USER_ACTION__` 和 `[session_id: xxx]` 的结果时,表示进程正在等待输入
|
||||
- 阅读已输出的内容,理解提示含义(如确认提示 Y/N、输入密码、选择选项等)
|
||||
- 使用 `session_id` 和 `stdin_input` 参数回复交互内容,例如:`{"command": "echo test", "session_id": "xxx", "stdin_input": "Y"}`
|
||||
- 常见场景:确认提示输入 Y/N、输入密码/验证码、选择选项、Read-Host 等
|
||||
|
||||
## todo工具使用规范
|
||||
|
||||
- 严格按照既定的未完成的todo工作项执行任务,如果工作项不在适用就更新,不得随意遗漏工作项
|
||||
|
||||
@ -11,8 +11,8 @@ use crate::tools::todo_write::TodoItem;
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
||||
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool,
|
||||
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||||
SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
|
||||
SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||||
TodoWriteTool, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
|
||||
@ -29,6 +29,7 @@ pub(crate) struct ToolRegistryFactory {
|
||||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||||
mcp_manager: Option<Arc<McpClientManager>>,
|
||||
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
|
||||
shell_session_manager: Arc<ShellSessionManager>,
|
||||
}
|
||||
|
||||
impl ToolRegistryFactory {
|
||||
@ -56,6 +57,7 @@ impl ToolRegistryFactory {
|
||||
subagent_runtime: None,
|
||||
mcp_manager: None,
|
||||
todo_state: None,
|
||||
shell_session_manager: Arc::new(ShellSessionManager::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,6 +89,11 @@ impl ToolRegistryFactory {
|
||||
!self.disabled_tools.contains(tool_name)
|
||||
}
|
||||
|
||||
/// Get a reference to the shell session manager for lifecycle control.
|
||||
pub(crate) fn shell_session_manager(&self) -> Arc<ShellSessionManager> {
|
||||
self.shell_session_manager.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn build(&self) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
|
||||
@ -135,7 +142,7 @@ impl ToolRegistryFactory {
|
||||
registry.register(SkillManageTool::new(self.skills.clone()));
|
||||
}
|
||||
if self.is_enabled("bash") {
|
||||
registry.register(BashTool::new());
|
||||
registry.register(BashTool::new(self.shell_session_manager.clone()));
|
||||
}
|
||||
if self.is_enabled("http_request") {
|
||||
registry.register(HttpRequestTool::new(
|
||||
@ -184,7 +191,7 @@ impl ToolRegistryFactory {
|
||||
registry.register(FileEditTool::new());
|
||||
}
|
||||
if self.is_enabled("bash") {
|
||||
registry.register(BashTool::new());
|
||||
registry.register(BashTool::new(self.shell_session_manager.clone()));
|
||||
}
|
||||
if self.is_enabled("http_request") {
|
||||
registry.register(HttpRequestTool::new(
|
||||
|
||||
@ -156,18 +156,7 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
// Windows: a full implementation would use:
|
||||
// 1. OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid)
|
||||
// 2. NtQuerySystemInformation(SystemProcessInformation) to enumerate
|
||||
// threads and check if the main thread's WaitReason is Executive
|
||||
// (indicating a wait on a kernel handle, e.g. console input).
|
||||
// 3. CloseHandle(hProcess)
|
||||
//
|
||||
// This requires the undocumented NtQuerySystemInformation API from
|
||||
// ntdll.dll. Until that is in place, the keyword-matching and
|
||||
// periodic output-staleness checks in bash.rs handle Windows detection.
|
||||
let _ = pid;
|
||||
None
|
||||
windows_is_process_waiting_on_stdin(pid)
|
||||
}
|
||||
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
|
||||
{
|
||||
@ -176,6 +165,149 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Windows implementation: check if a process is waiting for stdin input.
|
||||
///
|
||||
/// Uses NtQuerySystemInformation to enumerate process threads and check if
|
||||
/// all threads are in Wait state with Executive wait reason, which indicates
|
||||
/// the process is blocked on I/O (likely console input).
|
||||
#[cfg(target_os = "windows")]
|
||||
fn windows_is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
|
||||
// SystemProcessInformation = 5
|
||||
const SYSTEM_PROCESS_INFORMATION: u32 = 5;
|
||||
const STATUS_INFO_LENGTH_MISMATCH: i32 = -1073741820; // 0xC0000004
|
||||
|
||||
#[repr(C)]
|
||||
#[allow(non_snake_case)]
|
||||
struct SystemProcessInfo {
|
||||
next_entry_offset: u32,
|
||||
number_of_threads: u32,
|
||||
working_set_private_size: i64,
|
||||
hard_fault_count: u32,
|
||||
number_of_threads_high_watermark: u32,
|
||||
cycle_time: u64,
|
||||
create_time: i64,
|
||||
user_time: i64,
|
||||
kernel_time: i64,
|
||||
image_name_length: u16,
|
||||
image_name_max_length: u16,
|
||||
image_name: *const u16,
|
||||
base_priority: i32,
|
||||
unique_process_id: *mut std::ffi::c_void,
|
||||
inherited_from_unique_process_id: *mut std::ffi::c_void,
|
||||
handle_count: u32,
|
||||
session_id: u32,
|
||||
unique_process_key: usize,
|
||||
peak_virtual_size: usize,
|
||||
virtual_size: usize,
|
||||
page_fault_count: u32,
|
||||
peak_working_set_size: usize,
|
||||
working_set_size: usize,
|
||||
quota_peak_paged_pool_usage: usize,
|
||||
quota_paged_pool_usage: usize,
|
||||
quota_peak_non_paged_pool_usage: usize,
|
||||
quota_non_paged_pool_usage: usize,
|
||||
pagefile_usage: usize,
|
||||
peak_pagefile_usage: usize,
|
||||
private_page_count: usize,
|
||||
read_operation_count: i64,
|
||||
write_operation_count: i64,
|
||||
other_operation_count: i64,
|
||||
read_transfer_count: i64,
|
||||
write_transfer_count: i64,
|
||||
other_transfer_count: i64,
|
||||
// SYSTEM_THREAD_INFORMATION[1] follows in memory
|
||||
threads: [SystemThreadInfo; 1],
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
#[allow(non_snake_case)]
|
||||
struct SystemThreadInfo {
|
||||
kernel_time: i64,
|
||||
user_time: i64,
|
||||
create_time: i64,
|
||||
wait_time: u32,
|
||||
start_address: *mut std::ffi::c_void,
|
||||
client_id_unique_process: *mut std::ffi::c_void,
|
||||
client_id_unique_thread: *mut std::ffi::c_void,
|
||||
priority: i32,
|
||||
base_priority: i32,
|
||||
context_switches: u32,
|
||||
thread_state: u32,
|
||||
wait_reason: u32,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe extern "system" {
|
||||
fn NtQuerySystemInformation(
|
||||
system_information_class: u32,
|
||||
system_information: *mut u8,
|
||||
system_information_length: u32,
|
||||
return_length: *mut u32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// Query buffer size first
|
||||
let mut buf_len: u32 = 0;
|
||||
let status = NtQuerySystemInformation(
|
||||
SYSTEM_PROCESS_INFORMATION,
|
||||
std::ptr::null_mut(),
|
||||
0,
|
||||
&mut buf_len,
|
||||
);
|
||||
if status != STATUS_INFO_LENGTH_MISMATCH || buf_len == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Allocate buffer with extra space (processes may be created between calls)
|
||||
buf_len = buf_len.saturating_mul(2).max(65536);
|
||||
let mut buffer: Vec<u8> = vec![0u8; buf_len as usize];
|
||||
|
||||
let status = NtQuerySystemInformation(
|
||||
SYSTEM_PROCESS_INFORMATION,
|
||||
buffer.as_mut_ptr(),
|
||||
buf_len,
|
||||
&mut buf_len,
|
||||
);
|
||||
if status < 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Walk the linked list of SYSTEM_PROCESS_INFORMATION
|
||||
let mut offset: usize = 0;
|
||||
loop {
|
||||
let info = &*(buffer.as_ptr().add(offset) as *const SystemProcessInfo);
|
||||
let proc_id = info.unique_process_id as u32;
|
||||
|
||||
if proc_id == pid {
|
||||
let thread_count = info.number_of_threads as usize;
|
||||
if thread_count == 0 {
|
||||
return Some(false);
|
||||
}
|
||||
|
||||
// Thread states: Running=2, Waiting=5
|
||||
// Wait reasons: Executive=0, FreePage=1, PageIn=2, PoolAllocation=3,
|
||||
// DelayExecution=4, Suspended=5, UserRequest=6, ...
|
||||
// Executive wait + all threads waiting = likely blocked on I/O
|
||||
let all_waiting = (0..thread_count).all(|i| {
|
||||
let thread = &*info.threads.as_ptr().add(i);
|
||||
thread.thread_state == 5 && thread.wait_reason == 0
|
||||
});
|
||||
|
||||
return Some(all_waiting);
|
||||
}
|
||||
|
||||
if info.next_entry_offset == 0 {
|
||||
break;
|
||||
}
|
||||
offset += info.next_entry_offset as usize;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the user's home directory.
|
||||
///
|
||||
/// Supports environment variable overrides for testing:
|
||||
|
||||
@ -11,13 +11,16 @@ use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::time::{Instant, sleep_until};
|
||||
|
||||
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
||||
use crate::tools::shell_session::ShellSessionManager;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use crate::tools::{extract_u64, extract_bool, check_null_args};
|
||||
|
||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||
const USER_ACTION_HINT: &str =
|
||||
const INTERACTIVE_HINT: &str =
|
||||
"进程正在等待输入。请使用 session_id 和 stdin_input 参数回复交互内容。";
|
||||
const NON_INTERACTIVE_HINT: &str =
|
||||
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||
|
||||
/// Shell 类型枚举,支持跨平台
|
||||
@ -104,15 +107,17 @@ pub struct BashTool {
|
||||
working_dir: Option<String>,
|
||||
deny_patterns: Vec<String>,
|
||||
shell: ShellKind,
|
||||
session_manager: Arc<ShellSessionManager>,
|
||||
}
|
||||
|
||||
impl BashTool {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(session_manager: Arc<ShellSessionManager>) -> Self {
|
||||
Self {
|
||||
timeout_secs: 60,
|
||||
working_dir: None,
|
||||
deny_patterns: dangerous_command_patterns(),
|
||||
shell: ShellKind::detect(),
|
||||
session_manager,
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,11 +173,20 @@ impl BashTool {
|
||||
)
|
||||
}
|
||||
|
||||
fn pending_output(&self, output: &str) -> String {
|
||||
fn pending_output(&self, output: &str, session_id: Option<&str>) -> String {
|
||||
let hint = if session_id.is_some() {
|
||||
INTERACTIVE_HINT
|
||||
} else {
|
||||
NON_INTERACTIVE_HINT
|
||||
};
|
||||
let session_line = session_id
|
||||
.map(|id| format!("[session_id: {}]\n", id))
|
||||
.unwrap_or_default();
|
||||
format!(
|
||||
"{}\n{}\n\n{}",
|
||||
"{}\n{}{}\n\n{}",
|
||||
PENDING_USER_ACTION_MARKER,
|
||||
USER_ACTION_HINT,
|
||||
session_line,
|
||||
hint,
|
||||
self.truncate_output(output.trim())
|
||||
)
|
||||
}
|
||||
@ -248,7 +262,7 @@ async fn drain_available_chunks(
|
||||
|
||||
impl Default for BashTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
Self::new(Arc::new(ShellSessionManager::new()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,6 +293,14 @@ impl Tool for BashTool {
|
||||
"interactive": {
|
||||
"type": "boolean",
|
||||
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Continue a previous interactive session by providing its session_id"
|
||||
},
|
||||
"stdin_input": {
|
||||
"type": "string",
|
||||
"description": "Input text to send to the process stdin (used with session_id)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
@ -294,6 +316,26 @@ impl Tool for BashTool {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// Handle session continuation first
|
||||
if let Some(session_id) = args.get("session_id").and_then(|v| v.as_str()) {
|
||||
let input = args
|
||||
.get("stdin_input")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
return match self.session_manager.send_input(session_id, input).await {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
@ -353,15 +395,19 @@ impl BashTool {
|
||||
) -> Result<String, String> {
|
||||
let mut cmd = Command::new(self.shell.executable());
|
||||
cmd.args(self.shell.command_args(command))
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(cwd);
|
||||
|
||||
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||
|
||||
// Take stdin writer before stdout/stderr
|
||||
let child_stdin = child.stdin.take();
|
||||
let stdout = child.stdout.take();
|
||||
let stderr = child.stderr.take();
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>();
|
||||
let (tx, rx_inner) = mpsc::unbounded_channel::<(bool, String)>();
|
||||
let mut rx: Option<mpsc::UnboundedReceiver<(bool, String)>> = Some(rx_inner);
|
||||
|
||||
if let Some(stdout) = stdout {
|
||||
tokio::spawn(read_stream(stdout, false, tx.clone()));
|
||||
@ -379,8 +425,9 @@ impl BashTool {
|
||||
tokio::select! {
|
||||
status = child.wait() => {
|
||||
let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
|
||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
||||
while let Some((is_stderr, chunk)) = rx.recv().await {
|
||||
let mut rx_val = rx.take().unwrap();
|
||||
drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await;
|
||||
while let Some((is_stderr, chunk)) = rx_val.recv().await {
|
||||
if is_stderr {
|
||||
stderr_buf.lock().await.push_str(&chunk);
|
||||
} else {
|
||||
@ -390,7 +437,12 @@ impl BashTool {
|
||||
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
|
||||
return Ok(self.truncate_output(&output));
|
||||
}
|
||||
Some((is_stderr, chunk)) = rx.recv() => {
|
||||
Some((is_stderr, chunk)) = async {
|
||||
match rx.as_mut() {
|
||||
Some(r) => r.recv().await,
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
} => {
|
||||
if is_stderr {
|
||||
stderr_buf.lock().await.push_str(&chunk);
|
||||
} else {
|
||||
@ -399,59 +451,115 @@ impl BashTool {
|
||||
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
if self.should_return_pending(interactive, &combined) {
|
||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
||||
let mut rx_val = rx.take().unwrap();
|
||||
drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await;
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
// Try to save as interactive session
|
||||
if let Some(stdin) = child_stdin {
|
||||
let session_id = self.session_manager.save_session(
|
||||
child, stdin, rx_val,
|
||||
stdout_buf.lock().await.clone(),
|
||||
stderr_buf.lock().await.clone(),
|
||||
).await;
|
||||
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||
}
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
return Ok(self.pending_output(&combined));
|
||||
return Ok(self.pending_output(&combined, None));
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(2)) => {
|
||||
// Periodic safety net: when output has been silent for 2s,
|
||||
// check OS-level process state to see if the child is
|
||||
// genuinely blocked on stdin. Also re-run keyword detection
|
||||
// in case read_stream flushed a partial line since the last
|
||||
// rx.recv() iteration.
|
||||
// Periodic safety net: check OS-level process state
|
||||
if let Some(pid) = child.id() {
|
||||
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) {
|
||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
||||
if let Some(rx_ref) = rx.as_mut() {
|
||||
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
|
||||
}
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
if !combined.trim().is_empty() {
|
||||
if let Some(stdin) = child_stdin {
|
||||
if let Some(rx_val) = rx.take() {
|
||||
let session_id = self.session_manager.save_session(
|
||||
child, stdin, rx_val,
|
||||
stdout_buf.lock().await.clone(),
|
||||
stderr_buf.lock().await.clone(),
|
||||
).await;
|
||||
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||
}
|
||||
}
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
return Ok(self.pending_output(&combined));
|
||||
return Ok(self.pending_output(&combined, None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
if self.should_return_pending(interactive, &combined) {
|
||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
||||
if let Some(rx_ref) = rx.as_mut() {
|
||||
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
|
||||
}
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
if let Some(stdin) = child_stdin {
|
||||
if let Some(rx_val) = rx.take() {
|
||||
let session_id = self.session_manager.save_session(
|
||||
child, stdin, rx_val,
|
||||
stdout_buf.lock().await.clone(),
|
||||
stderr_buf.lock().await.clone(),
|
||||
).await;
|
||||
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||
}
|
||||
}
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
return Ok(self.pending_output(&combined));
|
||||
return Ok(self.pending_output(&combined, None));
|
||||
}
|
||||
}
|
||||
_ = sleep_until(deadline) => {
|
||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
||||
if let Some(rx_ref) = rx.as_mut() {
|
||||
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
|
||||
}
|
||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
|
||||
// OS-level process state check: if the child was blocked on
|
||||
// stdin, treat it as pending rather than a hard timeout error.
|
||||
// OS-level check: if blocked on stdin, save as session
|
||||
if let Some(pid) = child.id() {
|
||||
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true)
|
||||
&& !combined.trim().is_empty()
|
||||
{
|
||||
return Ok(self.pending_output(&combined));
|
||||
if let Some(stdin) = child_stdin {
|
||||
if let Some(rx_val) = rx.take() {
|
||||
let session_id = self.session_manager.save_session(
|
||||
child, stdin, rx_val,
|
||||
stdout_buf.lock().await.clone(),
|
||||
stderr_buf.lock().await.clone(),
|
||||
).await;
|
||||
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||
}
|
||||
}
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
return Ok(self.pending_output(&combined, None));
|
||||
}
|
||||
}
|
||||
|
||||
if self.should_return_pending(interactive, &combined) {
|
||||
return Ok(self.pending_output(&combined));
|
||||
if let Some(stdin) = child_stdin {
|
||||
if let Some(rx_val) = rx.take() {
|
||||
let session_id = self.session_manager.save_session(
|
||||
child, stdin, rx_val,
|
||||
stdout_buf.lock().await.clone(),
|
||||
stderr_buf.lock().await.clone(),
|
||||
).await;
|
||||
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||
}
|
||||
}
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
return Ok(self.pending_output(&combined, None));
|
||||
}
|
||||
|
||||
let _ = child.start_kill();
|
||||
let _ = child.wait().await;
|
||||
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
||||
}
|
||||
}
|
||||
@ -557,7 +665,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simple_command() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Write-Output 'Hello World'"
|
||||
} else {
|
||||
@ -574,7 +682,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pwd_command() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Get-Location"
|
||||
} else {
|
||||
@ -587,7 +695,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ls_command() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
format!("Get-ChildItem {}", temp_dir.display())
|
||||
@ -604,7 +712,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_rm() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
// 测试 Unix 危险命令模式
|
||||
let result = tool
|
||||
.execute(json!({ "command": "rm -rf /some/path" }))
|
||||
@ -617,7 +725,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_windows_commands() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
// 测试 Windows del 命令模式(正则应该匹配)
|
||||
let result = tool
|
||||
.execute(json!({ "command": "del /f /q file.txt" }))
|
||||
@ -630,7 +738,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_fork_bomb() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let result = tool
|
||||
.execute(json!({ "command": ":(){ :|:& };:" }))
|
||||
.await
|
||||
@ -642,7 +750,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_missing_command() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
@ -651,7 +759,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Start-Sleep -Seconds 10"
|
||||
} else {
|
||||
@ -671,7 +779,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pending_user_action_detection() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
|
||||
} else {
|
||||
@ -692,7 +800,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_truncate_output_handles_utf8_char_boundaries() {
|
||||
let tool = BashTool::new();
|
||||
let tool = BashTool::default();
|
||||
let input = "全".repeat(MAX_OUTPUT_CHARS + 100);
|
||||
|
||||
let output = tool.truncate_output(&input);
|
||||
|
||||
@ -10,6 +10,7 @@ pub mod registry;
|
||||
pub mod scheduler_manage;
|
||||
pub mod session_send;
|
||||
pub mod schema;
|
||||
pub mod shell_session;
|
||||
pub mod skill_activate;
|
||||
pub mod skill_manage;
|
||||
pub mod task;
|
||||
@ -33,6 +34,7 @@ pub use session_send::{
|
||||
SessionSendTool,
|
||||
};
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use shell_session::ShellSessionManager;
|
||||
pub use skill_activate::SkillActivateTool;
|
||||
pub use skill_manage::SkillManageTool;
|
||||
pub use task::{
|
||||
|
||||
283
src/tools/shell_session.rs
Normal file
283
src/tools/shell_session.rs
Normal file
@ -0,0 +1,283 @@
|
||||
//! Interactive shell session management.
|
||||
//!
|
||||
//! Provides `ShellSessionManager`, an independent service that keeps child
|
||||
//! processes alive between tool calls so the Agent can interact with
|
||||
//! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`).
|
||||
//!
|
||||
//! The manager is created at the gateway/bootstrap layer and injected into
|
||||
//! `BashTool` via `Arc`. It does NOT start background tasks on its own;
|
||||
//! cleanup is driven externally via `cleanup_expired()` or `shutdown()`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::{Child, ChildStdin};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::time::Instant;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes
|
||||
const OUTPUT_WAIT_MS: u64 = 2000;
|
||||
|
||||
/// A single interactive shell session backed by a live child process.
|
||||
struct ShellSession {
|
||||
child: Child,
|
||||
stdin_writer: Option<ChildStdin>,
|
||||
stdout_buf: Arc<Mutex<String>>,
|
||||
stderr_buf: Arc<Mutex<String>>,
|
||||
/// Background task handle that drains the output channel into buffers.
|
||||
_drain_task: tokio::task::JoinHandle<()>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
/// Independent service for managing interactive shell sessions.
|
||||
///
|
||||
/// Thread-safe — designed to be shared via `Arc<ShellSessionManager>`.
|
||||
pub struct ShellSessionManager {
|
||||
sessions: Mutex<HashMap<String, ShellSession>>,
|
||||
}
|
||||
|
||||
impl ShellSessionManager {
|
||||
/// Create a new, empty session manager.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save a child process as an interactive session.
|
||||
///
|
||||
/// The caller provides:
|
||||
/// - `child`: the spawned child process (with piped stdin/stdout/stderr)
|
||||
/// - `stdin_writer`: the write-half of the piped stdin
|
||||
/// - `rx`: the output channel receiver (produced by `read_stream` tasks)
|
||||
/// - `initial_stdout` / `initial_stderr`: output already collected before
|
||||
/// the session was created
|
||||
///
|
||||
/// Returns a unique `session_id` that can be used for subsequent
|
||||
/// `send_input` / `get_output` / `close_session` calls.
|
||||
pub async fn save_session(
|
||||
&self,
|
||||
mut child: Child,
|
||||
stdin_writer: ChildStdin,
|
||||
mut rx: mpsc::UnboundedReceiver<(bool, String)>,
|
||||
initial_stdout: String,
|
||||
initial_stderr: String,
|
||||
) -> String {
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let stdout_buf = Arc::new(Mutex::new(initial_stdout));
|
||||
let stderr_buf = Arc::new(Mutex::new(initial_stderr));
|
||||
|
||||
// Spawn a background task that drains the channel into buffers.
|
||||
let stdout_clone = stdout_buf.clone();
|
||||
let stderr_clone = stderr_buf.clone();
|
||||
let drain_task = tokio::spawn(async move {
|
||||
while let Some((is_stderr, chunk)) = rx.recv().await {
|
||||
if is_stderr {
|
||||
stderr_clone.lock().await.push_str(&chunk);
|
||||
} else {
|
||||
stdout_clone.lock().await.push_str(&chunk);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Kill the child's inherited stdin to prevent blocking on close.
|
||||
// The actual stdin writing is done via stdin_writer.
|
||||
let _ = child.stdin.take();
|
||||
|
||||
let session = ShellSession {
|
||||
child,
|
||||
stdin_writer: Some(stdin_writer),
|
||||
stdout_buf,
|
||||
stderr_buf,
|
||||
_drain_task: drain_task,
|
||||
created_at: Instant::now(),
|
||||
};
|
||||
|
||||
self.sessions
|
||||
.lock()
|
||||
.await
|
||||
.insert(session_id.clone(), session);
|
||||
session_id
|
||||
}
|
||||
|
||||
/// Send input to a session's stdin and return new output.
|
||||
///
|
||||
/// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive.
|
||||
/// If the child process exits during the wait, returns the final output.
|
||||
pub async fn send_input(&self, session_id: &str, input: &str) -> Result<String, String> {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
|
||||
let session = sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
|
||||
// Write input to stdin
|
||||
if let Some(writer) = &mut session.stdin_writer {
|
||||
let data = if input.ends_with('\n') {
|
||||
input.to_string()
|
||||
} else {
|
||||
format!("{}\n", input)
|
||||
};
|
||||
writer
|
||||
.write_all(data.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("Failed to write stdin: {}", e))?;
|
||||
writer
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
|
||||
} else {
|
||||
return Err("Session stdin is closed".to_string());
|
||||
}
|
||||
|
||||
// Record output length before wait
|
||||
let prev_stdout_len = session.stdout_buf.lock().await.len();
|
||||
let prev_stderr_len = session.stderr_buf.lock().await.len();
|
||||
|
||||
// Wait for new output or process exit
|
||||
let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS);
|
||||
loop {
|
||||
tokio::select! {
|
||||
status = session.child.wait() => {
|
||||
// Process exited — collect final output
|
||||
let stdout = session.stdout_buf.lock().await.clone();
|
||||
let stderr = session.stderr_buf.lock().await.clone();
|
||||
let code = status.ok().and_then(|s| s.code());
|
||||
drop(sessions);
|
||||
return Ok(Self::format_output(&stdout, &stderr, code));
|
||||
}
|
||||
_ = tokio::time::sleep_until(deadline) => {
|
||||
// Timeout — return current output
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let stdout = session.stdout_buf.lock().await.clone();
|
||||
let stderr = session.stderr_buf.lock().await.clone();
|
||||
|
||||
let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect();
|
||||
let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect();
|
||||
|
||||
let mut result = String::new();
|
||||
if !new_stdout.is_empty() {
|
||||
result.push_str(&new_stdout);
|
||||
}
|
||||
if !new_stderr.trim().is_empty() {
|
||||
if !result.is_empty() {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("STDERR:\n");
|
||||
result.push_str(&new_stderr);
|
||||
}
|
||||
if result.is_empty() {
|
||||
result.push_str("(No new output after input. Session still active.)");
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get the full accumulated output of a session.
|
||||
pub async fn get_output(&self, session_id: &str) -> Result<String, String> {
|
||||
let sessions = self.sessions.lock().await;
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
|
||||
let stdout = session.stdout_buf.lock().await.clone();
|
||||
let stderr = session.stderr_buf.lock().await.clone();
|
||||
Ok(Self::format_output(&stdout, &stderr, None))
|
||||
}
|
||||
|
||||
/// Close a session: kill the child process and return final output.
|
||||
pub async fn close_session(&self, session_id: &str) -> Result<String, String> {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
let mut session = sessions
|
||||
.remove(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
drop(sessions); // Release lock before awaiting
|
||||
|
||||
// Close stdin to signal EOF
|
||||
if let Some(mut writer) = session.stdin_writer.take() {
|
||||
let _ = writer.shutdown().await;
|
||||
}
|
||||
|
||||
// Kill and wait
|
||||
let _ = session.child.start_kill();
|
||||
let status = session.child.wait().await.ok();
|
||||
let code = status.and_then(|s| s.code());
|
||||
|
||||
let stdout = session.stdout_buf.lock().await.clone();
|
||||
let stderr = session.stderr_buf.lock().await.clone();
|
||||
|
||||
// Abort drain task
|
||||
session._drain_task.abort();
|
||||
|
||||
Ok(Self::format_output(&stdout, &stderr, code))
|
||||
}
|
||||
|
||||
/// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`).
|
||||
///
|
||||
/// Designed to be called periodically from the owner (e.g. gateway).
|
||||
pub async fn cleanup_expired(&self) {
|
||||
let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
|
||||
let mut expired_ids = Vec::new();
|
||||
|
||||
{
|
||||
let sessions = self.sessions.lock().await;
|
||||
for (id, session) in sessions.iter() {
|
||||
if session.created_at.elapsed() > timeout {
|
||||
expired_ids.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for id in expired_ids {
|
||||
let _ = self.close_session(&id).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Gracefully shut down all active sessions.
|
||||
pub async fn shutdown(&self) {
|
||||
let ids: Vec<String> = {
|
||||
let sessions = self.sessions.lock().await;
|
||||
sessions.keys().cloned().collect()
|
||||
};
|
||||
for id in ids {
|
||||
let _ = self.close_session(&id).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of currently active sessions.
|
||||
pub async fn active_count(&self) -> usize {
|
||||
self.sessions.lock().await.len()
|
||||
}
|
||||
|
||||
fn format_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
|
||||
let mut output = String::new();
|
||||
if !stdout.is_empty() {
|
||||
output.push_str(stdout);
|
||||
}
|
||||
if !stderr.trim().is_empty() {
|
||||
if !output.is_empty() {
|
||||
output.push('\n');
|
||||
}
|
||||
output.push_str("STDERR:\n");
|
||||
output.push_str(stderr);
|
||||
}
|
||||
if let Some(code) = exit_code {
|
||||
output.push_str(&format!("\nExit code: {}", code));
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ShellSessionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user