feat(shell): 实现交互式Shell会话管理
- 新增ShellSessionManager管理交互式shell会话,支持进程保持和交互输入 - BashTool集成会话管理,支持session_id和stdin_input参数实现输入回复 - 修改BashTool执行逻辑,检测进程等待输入状态并保存会话状态 - Windows平台新增底层进程等待输入检测实现,辅助判断Shell交互状态 - 工具注册工厂注入ShellSessionManager,保证安全复用会话管理实例 - 增加默认agent prompt中Shell交互终端说明,提示交互流程及输入格式 - 交互式命令输出增加标识和提示,区分正常与等待输入状态 - 实现会话超时自动清理和优雅关闭接口,避免资源泄露 - 单元测试中统一使用BashTool默认构造,适配会话管理新增功能
This commit is contained in:
parent
43cea50df8
commit
02172b6065
@ -40,6 +40,14 @@ rustls = { version = "0.23", features = ["ring"] }
|
|||||||
wechatbot = { path = "vendor/wechatbot" }
|
wechatbot = { path = "vendor/wechatbot" }
|
||||||
encoding_rs = "0.8"
|
encoding_rs = "0.8"
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
|
|
||||||
|
[target.'cfg(windows)'.dependencies]
|
||||||
|
windows-sys = { version = "0.59", features = [
|
||||||
|
"Win32_System_Threading",
|
||||||
|
"Win32_System_Diagnostics_Debug",
|
||||||
|
"Win32_Foundation",
|
||||||
|
"Win32_System_Kernel",
|
||||||
|
] }
|
||||||
# MCP (Model Context Protocol) support
|
# MCP (Model Context Protocol) support
|
||||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
||||||
"client",
|
"client",
|
||||||
|
|||||||
@ -112,6 +112,13 @@
|
|||||||
- 默认创建静默任务(silent_agent_task),在独立后台会话中执行,不干扰主对话
|
- 默认创建静默任务(silent_agent_task),在独立后台会话中执行,不干扰主对话
|
||||||
- 静默模式下如需发送消息给用户,prompt中需显式使用 send_session_message 工具
|
- 静默模式下如需发送消息给用户,prompt中需显式使用 send_session_message 工具
|
||||||
|
|
||||||
|
## Shell 交互终端
|
||||||
|
|
||||||
|
- 当 shell 工具返回包含 `__PICOBOT_PENDING_USER_ACTION__` 和 `[session_id: xxx]` 的结果时,表示进程正在等待输入
|
||||||
|
- 阅读已输出的内容,理解提示含义(如确认提示 Y/N、输入密码、选择选项等)
|
||||||
|
- 使用 `session_id` 和 `stdin_input` 参数回复交互内容,例如:`{"command": "echo test", "session_id": "xxx", "stdin_input": "Y"}`
|
||||||
|
- 常见场景:确认提示输入 Y/N、输入密码/验证码、选择选项、Read-Host 等
|
||||||
|
|
||||||
## todo工具使用规范
|
## todo工具使用规范
|
||||||
|
|
||||||
- 严格按照既定的未完成的todo工作项执行任务,如果工作项不在适用就更新,不得随意遗漏工作项
|
- 严格按照既定的未完成的todo工作项执行任务,如果工作项不在适用就更新,不得随意遗漏工作项
|
||||||
|
|||||||
@ -11,8 +11,8 @@ use crate::tools::todo_write::TodoItem;
|
|||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
||||||
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool,
|
SchedulerManageTool, SessionMessageSender, SessionSendTool, ShellSessionManager,
|
||||||
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
SkillActivateTool, SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||||||
TodoWriteTool, ToolRegistry, WebFetchTool,
|
TodoWriteTool, ToolRegistry, WebFetchTool,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -29,6 +29,7 @@ pub(crate) struct ToolRegistryFactory {
|
|||||||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||||||
mcp_manager: Option<Arc<McpClientManager>>,
|
mcp_manager: Option<Arc<McpClientManager>>,
|
||||||
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
|
todo_state: Option<Arc<RwLock<HashMap<String, Vec<TodoItem>>>>>,
|
||||||
|
shell_session_manager: Arc<ShellSessionManager>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRegistryFactory {
|
impl ToolRegistryFactory {
|
||||||
@ -56,6 +57,7 @@ impl ToolRegistryFactory {
|
|||||||
subagent_runtime: None,
|
subagent_runtime: None,
|
||||||
mcp_manager: None,
|
mcp_manager: None,
|
||||||
todo_state: None,
|
todo_state: None,
|
||||||
|
shell_session_manager: Arc::new(ShellSessionManager::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,6 +89,11 @@ impl ToolRegistryFactory {
|
|||||||
!self.disabled_tools.contains(tool_name)
|
!self.disabled_tools.contains(tool_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get a reference to the shell session manager for lifecycle control.
|
||||||
|
pub(crate) fn shell_session_manager(&self) -> Arc<ShellSessionManager> {
|
||||||
|
self.shell_session_manager.clone()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn build(&self) -> ToolRegistry {
|
pub(crate) fn build(&self) -> ToolRegistry {
|
||||||
let mut registry = ToolRegistry::new();
|
let mut registry = ToolRegistry::new();
|
||||||
|
|
||||||
@ -135,7 +142,7 @@ impl ToolRegistryFactory {
|
|||||||
registry.register(SkillManageTool::new(self.skills.clone()));
|
registry.register(SkillManageTool::new(self.skills.clone()));
|
||||||
}
|
}
|
||||||
if self.is_enabled("bash") {
|
if self.is_enabled("bash") {
|
||||||
registry.register(BashTool::new());
|
registry.register(BashTool::new(self.shell_session_manager.clone()));
|
||||||
}
|
}
|
||||||
if self.is_enabled("http_request") {
|
if self.is_enabled("http_request") {
|
||||||
registry.register(HttpRequestTool::new(
|
registry.register(HttpRequestTool::new(
|
||||||
@ -184,7 +191,7 @@ impl ToolRegistryFactory {
|
|||||||
registry.register(FileEditTool::new());
|
registry.register(FileEditTool::new());
|
||||||
}
|
}
|
||||||
if self.is_enabled("bash") {
|
if self.is_enabled("bash") {
|
||||||
registry.register(BashTool::new());
|
registry.register(BashTool::new(self.shell_session_manager.clone()));
|
||||||
}
|
}
|
||||||
if self.is_enabled("http_request") {
|
if self.is_enabled("http_request") {
|
||||||
registry.register(HttpRequestTool::new(
|
registry.register(HttpRequestTool::new(
|
||||||
|
|||||||
@ -156,18 +156,7 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
|
|||||||
}
|
}
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
{
|
{
|
||||||
// Windows: a full implementation would use:
|
windows_is_process_waiting_on_stdin(pid)
|
||||||
// 1. OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid)
|
|
||||||
// 2. NtQuerySystemInformation(SystemProcessInformation) to enumerate
|
|
||||||
// threads and check if the main thread's WaitReason is Executive
|
|
||||||
// (indicating a wait on a kernel handle, e.g. console input).
|
|
||||||
// 3. CloseHandle(hProcess)
|
|
||||||
//
|
|
||||||
// This requires the undocumented NtQuerySystemInformation API from
|
|
||||||
// ntdll.dll. Until that is in place, the keyword-matching and
|
|
||||||
// periodic output-staleness checks in bash.rs handle Windows detection.
|
|
||||||
let _ = pid;
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
|
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
|
||||||
{
|
{
|
||||||
@ -176,6 +165,149 @@ pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Windows implementation: check if a process is waiting for stdin input.
|
||||||
|
///
|
||||||
|
/// Uses NtQuerySystemInformation to enumerate process threads and check if
|
||||||
|
/// all threads are in Wait state with Executive wait reason, which indicates
|
||||||
|
/// the process is blocked on I/O (likely console input).
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
fn windows_is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
|
||||||
|
// SystemProcessInformation = 5
|
||||||
|
const SYSTEM_PROCESS_INFORMATION: u32 = 5;
|
||||||
|
const STATUS_INFO_LENGTH_MISMATCH: i32 = -1073741820; // 0xC0000004
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
struct SystemProcessInfo {
|
||||||
|
next_entry_offset: u32,
|
||||||
|
number_of_threads: u32,
|
||||||
|
working_set_private_size: i64,
|
||||||
|
hard_fault_count: u32,
|
||||||
|
number_of_threads_high_watermark: u32,
|
||||||
|
cycle_time: u64,
|
||||||
|
create_time: i64,
|
||||||
|
user_time: i64,
|
||||||
|
kernel_time: i64,
|
||||||
|
image_name_length: u16,
|
||||||
|
image_name_max_length: u16,
|
||||||
|
image_name: *const u16,
|
||||||
|
base_priority: i32,
|
||||||
|
unique_process_id: *mut std::ffi::c_void,
|
||||||
|
inherited_from_unique_process_id: *mut std::ffi::c_void,
|
||||||
|
handle_count: u32,
|
||||||
|
session_id: u32,
|
||||||
|
unique_process_key: usize,
|
||||||
|
peak_virtual_size: usize,
|
||||||
|
virtual_size: usize,
|
||||||
|
page_fault_count: u32,
|
||||||
|
peak_working_set_size: usize,
|
||||||
|
working_set_size: usize,
|
||||||
|
quota_peak_paged_pool_usage: usize,
|
||||||
|
quota_paged_pool_usage: usize,
|
||||||
|
quota_peak_non_paged_pool_usage: usize,
|
||||||
|
quota_non_paged_pool_usage: usize,
|
||||||
|
pagefile_usage: usize,
|
||||||
|
peak_pagefile_usage: usize,
|
||||||
|
private_page_count: usize,
|
||||||
|
read_operation_count: i64,
|
||||||
|
write_operation_count: i64,
|
||||||
|
other_operation_count: i64,
|
||||||
|
read_transfer_count: i64,
|
||||||
|
write_transfer_count: i64,
|
||||||
|
other_transfer_count: i64,
|
||||||
|
// SYSTEM_THREAD_INFORMATION[1] follows in memory
|
||||||
|
threads: [SystemThreadInfo; 1],
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
struct SystemThreadInfo {
|
||||||
|
kernel_time: i64,
|
||||||
|
user_time: i64,
|
||||||
|
create_time: i64,
|
||||||
|
wait_time: u32,
|
||||||
|
start_address: *mut std::ffi::c_void,
|
||||||
|
client_id_unique_process: *mut std::ffi::c_void,
|
||||||
|
client_id_unique_thread: *mut std::ffi::c_void,
|
||||||
|
priority: i32,
|
||||||
|
base_priority: i32,
|
||||||
|
context_switches: u32,
|
||||||
|
thread_state: u32,
|
||||||
|
wait_reason: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
unsafe extern "system" {
|
||||||
|
fn NtQuerySystemInformation(
|
||||||
|
system_information_class: u32,
|
||||||
|
system_information: *mut u8,
|
||||||
|
system_information_length: u32,
|
||||||
|
return_length: *mut u32,
|
||||||
|
) -> i32;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// Query buffer size first
|
||||||
|
let mut buf_len: u32 = 0;
|
||||||
|
let status = NtQuerySystemInformation(
|
||||||
|
SYSTEM_PROCESS_INFORMATION,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
0,
|
||||||
|
&mut buf_len,
|
||||||
|
);
|
||||||
|
if status != STATUS_INFO_LENGTH_MISMATCH || buf_len == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate buffer with extra space (processes may be created between calls)
|
||||||
|
buf_len = buf_len.saturating_mul(2).max(65536);
|
||||||
|
let mut buffer: Vec<u8> = vec![0u8; buf_len as usize];
|
||||||
|
|
||||||
|
let status = NtQuerySystemInformation(
|
||||||
|
SYSTEM_PROCESS_INFORMATION,
|
||||||
|
buffer.as_mut_ptr(),
|
||||||
|
buf_len,
|
||||||
|
&mut buf_len,
|
||||||
|
);
|
||||||
|
if status < 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk the linked list of SYSTEM_PROCESS_INFORMATION
|
||||||
|
let mut offset: usize = 0;
|
||||||
|
loop {
|
||||||
|
let info = &*(buffer.as_ptr().add(offset) as *const SystemProcessInfo);
|
||||||
|
let proc_id = info.unique_process_id as u32;
|
||||||
|
|
||||||
|
if proc_id == pid {
|
||||||
|
let thread_count = info.number_of_threads as usize;
|
||||||
|
if thread_count == 0 {
|
||||||
|
return Some(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thread states: Running=2, Waiting=5
|
||||||
|
// Wait reasons: Executive=0, FreePage=1, PageIn=2, PoolAllocation=3,
|
||||||
|
// DelayExecution=4, Suspended=5, UserRequest=6, ...
|
||||||
|
// Executive wait + all threads waiting = likely blocked on I/O
|
||||||
|
let all_waiting = (0..thread_count).all(|i| {
|
||||||
|
let thread = &*info.threads.as_ptr().add(i);
|
||||||
|
thread.thread_state == 5 && thread.wait_reason == 0
|
||||||
|
});
|
||||||
|
|
||||||
|
return Some(all_waiting);
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.next_entry_offset == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
offset += info.next_entry_offset as usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the user's home directory.
|
/// Get the user's home directory.
|
||||||
///
|
///
|
||||||
/// Supports environment variable overrides for testing:
|
/// Supports environment variable overrides for testing:
|
||||||
|
|||||||
@ -11,13 +11,16 @@ use tokio::sync::{Mutex, mpsc};
|
|||||||
use tokio::time::{Instant, sleep_until};
|
use tokio::time::{Instant, sleep_until};
|
||||||
|
|
||||||
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
||||||
|
use crate::tools::shell_session::ShellSessionManager;
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
use crate::tools::{extract_u64, extract_bool, check_null_args};
|
use crate::tools::{extract_u64, extract_bool, check_null_args};
|
||||||
|
|
||||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||||
const USER_ACTION_HINT: &str =
|
const INTERACTIVE_HINT: &str =
|
||||||
|
"进程正在等待输入。请使用 session_id 和 stdin_input 参数回复交互内容。";
|
||||||
|
const NON_INTERACTIVE_HINT: &str =
|
||||||
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||||
|
|
||||||
/// Shell 类型枚举,支持跨平台
|
/// Shell 类型枚举,支持跨平台
|
||||||
@ -104,15 +107,17 @@ pub struct BashTool {
|
|||||||
working_dir: Option<String>,
|
working_dir: Option<String>,
|
||||||
deny_patterns: Vec<String>,
|
deny_patterns: Vec<String>,
|
||||||
shell: ShellKind,
|
shell: ShellKind,
|
||||||
|
session_manager: Arc<ShellSessionManager>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BashTool {
|
impl BashTool {
|
||||||
pub fn new() -> Self {
|
pub fn new(session_manager: Arc<ShellSessionManager>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
timeout_secs: 60,
|
timeout_secs: 60,
|
||||||
working_dir: None,
|
working_dir: None,
|
||||||
deny_patterns: dangerous_command_patterns(),
|
deny_patterns: dangerous_command_patterns(),
|
||||||
shell: ShellKind::detect(),
|
shell: ShellKind::detect(),
|
||||||
|
session_manager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,11 +173,20 @@ impl BashTool {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pending_output(&self, output: &str) -> String {
|
fn pending_output(&self, output: &str, session_id: Option<&str>) -> String {
|
||||||
|
let hint = if session_id.is_some() {
|
||||||
|
INTERACTIVE_HINT
|
||||||
|
} else {
|
||||||
|
NON_INTERACTIVE_HINT
|
||||||
|
};
|
||||||
|
let session_line = session_id
|
||||||
|
.map(|id| format!("[session_id: {}]\n", id))
|
||||||
|
.unwrap_or_default();
|
||||||
format!(
|
format!(
|
||||||
"{}\n{}\n\n{}",
|
"{}\n{}{}\n\n{}",
|
||||||
PENDING_USER_ACTION_MARKER,
|
PENDING_USER_ACTION_MARKER,
|
||||||
USER_ACTION_HINT,
|
session_line,
|
||||||
|
hint,
|
||||||
self.truncate_output(output.trim())
|
self.truncate_output(output.trim())
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -248,7 +262,7 @@ async fn drain_available_chunks(
|
|||||||
|
|
||||||
impl Default for BashTool {
|
impl Default for BashTool {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new(Arc::new(ShellSessionManager::new()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,6 +293,14 @@ impl Tool for BashTool {
|
|||||||
"interactive": {
|
"interactive": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
|
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
|
||||||
|
},
|
||||||
|
"session_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Continue a previous interactive session by providing its session_id"
|
||||||
|
},
|
||||||
|
"stdin_input": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Input text to send to the process stdin (used with session_id)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["command"]
|
"required": ["command"]
|
||||||
@ -294,6 +316,26 @@ impl Tool for BashTool {
|
|||||||
return Ok(result);
|
return Ok(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle session continuation first
|
||||||
|
if let Some(session_id) = args.get("session_id").and_then(|v| v.as_str()) {
|
||||||
|
let input = args
|
||||||
|
.get("stdin_input")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
return match self.session_manager.send_input(session_id, input).await {
|
||||||
|
Ok(output) => Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output,
|
||||||
|
error: None,
|
||||||
|
}),
|
||||||
|
Err(e) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||||
Some(c) => c,
|
Some(c) => c,
|
||||||
None => {
|
None => {
|
||||||
@ -353,15 +395,19 @@ impl BashTool {
|
|||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
let mut cmd = Command::new(self.shell.executable());
|
let mut cmd = Command::new(self.shell.executable());
|
||||||
cmd.args(self.shell.command_args(command))
|
cmd.args(self.shell.command_args(command))
|
||||||
|
.stdin(Stdio::piped())
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.current_dir(cwd);
|
.current_dir(cwd);
|
||||||
|
|
||||||
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||||
|
|
||||||
|
// Take stdin writer before stdout/stderr
|
||||||
|
let child_stdin = child.stdin.take();
|
||||||
let stdout = child.stdout.take();
|
let stdout = child.stdout.take();
|
||||||
let stderr = child.stderr.take();
|
let stderr = child.stderr.take();
|
||||||
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>();
|
let (tx, rx_inner) = mpsc::unbounded_channel::<(bool, String)>();
|
||||||
|
let mut rx: Option<mpsc::UnboundedReceiver<(bool, String)>> = Some(rx_inner);
|
||||||
|
|
||||||
if let Some(stdout) = stdout {
|
if let Some(stdout) = stdout {
|
||||||
tokio::spawn(read_stream(stdout, false, tx.clone()));
|
tokio::spawn(read_stream(stdout, false, tx.clone()));
|
||||||
@ -379,8 +425,9 @@ impl BashTool {
|
|||||||
tokio::select! {
|
tokio::select! {
|
||||||
status = child.wait() => {
|
status = child.wait() => {
|
||||||
let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
|
let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
let mut rx_val = rx.take().unwrap();
|
||||||
while let Some((is_stderr, chunk)) = rx.recv().await {
|
drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await;
|
||||||
|
while let Some((is_stderr, chunk)) = rx_val.recv().await {
|
||||||
if is_stderr {
|
if is_stderr {
|
||||||
stderr_buf.lock().await.push_str(&chunk);
|
stderr_buf.lock().await.push_str(&chunk);
|
||||||
} else {
|
} else {
|
||||||
@ -390,7 +437,12 @@ impl BashTool {
|
|||||||
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
|
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
|
||||||
return Ok(self.truncate_output(&output));
|
return Ok(self.truncate_output(&output));
|
||||||
}
|
}
|
||||||
Some((is_stderr, chunk)) = rx.recv() => {
|
Some((is_stderr, chunk)) = async {
|
||||||
|
match rx.as_mut() {
|
||||||
|
Some(r) => r.recv().await,
|
||||||
|
None => std::future::pending().await,
|
||||||
|
}
|
||||||
|
} => {
|
||||||
if is_stderr {
|
if is_stderr {
|
||||||
stderr_buf.lock().await.push_str(&chunk);
|
stderr_buf.lock().await.push_str(&chunk);
|
||||||
} else {
|
} else {
|
||||||
@ -399,59 +451,115 @@ impl BashTool {
|
|||||||
|
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||||
if self.should_return_pending(interactive, &combined) {
|
if self.should_return_pending(interactive, &combined) {
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
let mut rx_val = rx.take().unwrap();
|
||||||
|
drain_available_chunks(&mut rx_val, &stdout_buf, &stderr_buf).await;
|
||||||
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||||
|
// Try to save as interactive session
|
||||||
|
if let Some(stdin) = child_stdin {
|
||||||
|
let session_id = self.session_manager.save_session(
|
||||||
|
child, stdin, rx_val,
|
||||||
|
stdout_buf.lock().await.clone(),
|
||||||
|
stderr_buf.lock().await.clone(),
|
||||||
|
).await;
|
||||||
|
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||||
|
}
|
||||||
let _ = child.start_kill();
|
let _ = child.start_kill();
|
||||||
let _ = child.wait().await;
|
let _ = child.wait().await;
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
return Ok(self.pending_output(&combined, None));
|
||||||
return Ok(self.pending_output(&combined));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = tokio::time::sleep(Duration::from_secs(2)) => {
|
_ = tokio::time::sleep(Duration::from_secs(2)) => {
|
||||||
// Periodic safety net: when output has been silent for 2s,
|
// Periodic safety net: check OS-level process state
|
||||||
// check OS-level process state to see if the child is
|
|
||||||
// genuinely blocked on stdin. Also re-run keyword detection
|
|
||||||
// in case read_stream flushed a partial line since the last
|
|
||||||
// rx.recv() iteration.
|
|
||||||
if let Some(pid) = child.id() {
|
if let Some(pid) = child.id() {
|
||||||
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) {
|
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) {
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
if let Some(rx_ref) = rx.as_mut() {
|
||||||
|
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
|
||||||
|
}
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||||
if !combined.trim().is_empty() {
|
if !combined.trim().is_empty() {
|
||||||
|
if let Some(stdin) = child_stdin {
|
||||||
|
if let Some(rx_val) = rx.take() {
|
||||||
|
let session_id = self.session_manager.save_session(
|
||||||
|
child, stdin, rx_val,
|
||||||
|
stdout_buf.lock().await.clone(),
|
||||||
|
stderr_buf.lock().await.clone(),
|
||||||
|
).await;
|
||||||
|
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||||
|
}
|
||||||
|
}
|
||||||
let _ = child.start_kill();
|
let _ = child.start_kill();
|
||||||
let _ = child.wait().await;
|
let _ = child.wait().await;
|
||||||
return Ok(self.pending_output(&combined));
|
return Ok(self.pending_output(&combined, None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||||
if self.should_return_pending(interactive, &combined) {
|
if self.should_return_pending(interactive, &combined) {
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
if let Some(rx_ref) = rx.as_mut() {
|
||||||
|
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
|
||||||
|
}
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||||
|
if let Some(stdin) = child_stdin {
|
||||||
|
if let Some(rx_val) = rx.take() {
|
||||||
|
let session_id = self.session_manager.save_session(
|
||||||
|
child, stdin, rx_val,
|
||||||
|
stdout_buf.lock().await.clone(),
|
||||||
|
stderr_buf.lock().await.clone(),
|
||||||
|
).await;
|
||||||
|
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||||
|
}
|
||||||
|
}
|
||||||
let _ = child.start_kill();
|
let _ = child.start_kill();
|
||||||
let _ = child.wait().await;
|
let _ = child.wait().await;
|
||||||
return Ok(self.pending_output(&combined));
|
return Ok(self.pending_output(&combined, None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = sleep_until(deadline) => {
|
_ = sleep_until(deadline) => {
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
if let Some(rx_ref) = rx.as_mut() {
|
||||||
|
drain_available_chunks(rx_ref, &stdout_buf, &stderr_buf).await;
|
||||||
|
}
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
||||||
let _ = child.start_kill();
|
|
||||||
let _ = child.wait().await;
|
|
||||||
|
|
||||||
// OS-level process state check: if the child was blocked on
|
// OS-level check: if blocked on stdin, save as session
|
||||||
// stdin, treat it as pending rather than a hard timeout error.
|
|
||||||
if let Some(pid) = child.id() {
|
if let Some(pid) = child.id() {
|
||||||
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true)
|
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true)
|
||||||
&& !combined.trim().is_empty()
|
&& !combined.trim().is_empty()
|
||||||
{
|
{
|
||||||
return Ok(self.pending_output(&combined));
|
if let Some(stdin) = child_stdin {
|
||||||
|
if let Some(rx_val) = rx.take() {
|
||||||
|
let session_id = self.session_manager.save_session(
|
||||||
|
child, stdin, rx_val,
|
||||||
|
stdout_buf.lock().await.clone(),
|
||||||
|
stderr_buf.lock().await.clone(),
|
||||||
|
).await;
|
||||||
|
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = child.start_kill();
|
||||||
|
let _ = child.wait().await;
|
||||||
|
return Ok(self.pending_output(&combined, None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.should_return_pending(interactive, &combined) {
|
if self.should_return_pending(interactive, &combined) {
|
||||||
return Ok(self.pending_output(&combined));
|
if let Some(stdin) = child_stdin {
|
||||||
|
if let Some(rx_val) = rx.take() {
|
||||||
|
let session_id = self.session_manager.save_session(
|
||||||
|
child, stdin, rx_val,
|
||||||
|
stdout_buf.lock().await.clone(),
|
||||||
|
stderr_buf.lock().await.clone(),
|
||||||
|
).await;
|
||||||
|
return Ok(self.pending_output(&combined, Some(&session_id)));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
let _ = child.start_kill();
|
||||||
|
let _ = child.wait().await;
|
||||||
|
return Ok(self.pending_output(&combined, None));
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = child.start_kill();
|
||||||
|
let _ = child.wait().await;
|
||||||
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -557,7 +665,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_simple_command() {
|
async fn test_simple_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let command = if cfg!(target_os = "windows") {
|
let command = if cfg!(target_os = "windows") {
|
||||||
"Write-Output 'Hello World'"
|
"Write-Output 'Hello World'"
|
||||||
} else {
|
} else {
|
||||||
@ -574,7 +682,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pwd_command() {
|
async fn test_pwd_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let command = if cfg!(target_os = "windows") {
|
let command = if cfg!(target_os = "windows") {
|
||||||
"Get-Location"
|
"Get-Location"
|
||||||
} else {
|
} else {
|
||||||
@ -587,7 +695,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_ls_command() {
|
async fn test_ls_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let temp_dir = std::env::temp_dir();
|
let temp_dir = std::env::temp_dir();
|
||||||
let command = if cfg!(target_os = "windows") {
|
let command = if cfg!(target_os = "windows") {
|
||||||
format!("Get-ChildItem {}", temp_dir.display())
|
format!("Get-ChildItem {}", temp_dir.display())
|
||||||
@ -604,7 +712,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_dangerous_rm() {
|
async fn test_dangerous_rm() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
// 测试 Unix 危险命令模式
|
// 测试 Unix 危险命令模式
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "command": "rm -rf /some/path" }))
|
.execute(json!({ "command": "rm -rf /some/path" }))
|
||||||
@ -617,7 +725,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_dangerous_windows_commands() {
|
async fn test_dangerous_windows_commands() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
// 测试 Windows del 命令模式(正则应该匹配)
|
// 测试 Windows del 命令模式(正则应该匹配)
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "command": "del /f /q file.txt" }))
|
.execute(json!({ "command": "del /f /q file.txt" }))
|
||||||
@ -630,7 +738,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_dangerous_fork_bomb() {
|
async fn test_dangerous_fork_bomb() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "command": ":(){ :|:& };:" }))
|
.execute(json!({ "command": ":(){ :|:& };:" }))
|
||||||
.await
|
.await
|
||||||
@ -642,7 +750,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_missing_command() {
|
async fn test_missing_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let result = tool.execute(json!({})).await.unwrap();
|
let result = tool.execute(json!({})).await.unwrap();
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
@ -651,7 +759,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timeout() {
|
async fn test_timeout() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let command = if cfg!(target_os = "windows") {
|
let command = if cfg!(target_os = "windows") {
|
||||||
"Start-Sleep -Seconds 10"
|
"Start-Sleep -Seconds 10"
|
||||||
} else {
|
} else {
|
||||||
@ -671,7 +779,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pending_user_action_detection() {
|
async fn test_pending_user_action_detection() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let command = if cfg!(target_os = "windows") {
|
let command = if cfg!(target_os = "windows") {
|
||||||
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
|
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
|
||||||
} else {
|
} else {
|
||||||
@ -692,7 +800,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_truncate_output_handles_utf8_char_boundaries() {
|
fn test_truncate_output_handles_utf8_char_boundaries() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::default();
|
||||||
let input = "全".repeat(MAX_OUTPUT_CHARS + 100);
|
let input = "全".repeat(MAX_OUTPUT_CHARS + 100);
|
||||||
|
|
||||||
let output = tool.truncate_output(&input);
|
let output = tool.truncate_output(&input);
|
||||||
|
|||||||
@ -10,6 +10,7 @@ pub mod registry;
|
|||||||
pub mod scheduler_manage;
|
pub mod scheduler_manage;
|
||||||
pub mod session_send;
|
pub mod session_send;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
|
pub mod shell_session;
|
||||||
pub mod skill_activate;
|
pub mod skill_activate;
|
||||||
pub mod skill_manage;
|
pub mod skill_manage;
|
||||||
pub mod task;
|
pub mod task;
|
||||||
@ -33,6 +34,7 @@ pub use session_send::{
|
|||||||
SessionSendTool,
|
SessionSendTool,
|
||||||
};
|
};
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
|
pub use shell_session::ShellSessionManager;
|
||||||
pub use skill_activate::SkillActivateTool;
|
pub use skill_activate::SkillActivateTool;
|
||||||
pub use skill_manage::SkillManageTool;
|
pub use skill_manage::SkillManageTool;
|
||||||
pub use task::{
|
pub use task::{
|
||||||
|
|||||||
283
src/tools/shell_session.rs
Normal file
283
src/tools/shell_session.rs
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
//! Interactive shell session management.
|
||||||
|
//!
|
||||||
|
//! Provides `ShellSessionManager`, an independent service that keeps child
|
||||||
|
//! processes alive between tool calls so the Agent can interact with
|
||||||
|
//! stdin-waiting prompts (e.g. `Read-Host`, `Confirm (Y/N)`).
|
||||||
|
//!
|
||||||
|
//! The manager is created at the gateway/bootstrap layer and injected into
|
||||||
|
//! `BashTool` via `Arc`. It does NOT start background tasks on its own;
|
||||||
|
//! cleanup is driven externally via `cleanup_expired()` or `shutdown()`.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio::process::{Child, ChildStdin};
|
||||||
|
use tokio::sync::{Mutex, mpsc};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
const SESSION_TIMEOUT_SECS: u64 = 300; // 5 minutes
|
||||||
|
const OUTPUT_WAIT_MS: u64 = 2000;
|
||||||
|
|
||||||
|
/// A single interactive shell session backed by a live child process.
|
||||||
|
struct ShellSession {
|
||||||
|
child: Child,
|
||||||
|
stdin_writer: Option<ChildStdin>,
|
||||||
|
stdout_buf: Arc<Mutex<String>>,
|
||||||
|
stderr_buf: Arc<Mutex<String>>,
|
||||||
|
/// Background task handle that drains the output channel into buffers.
|
||||||
|
_drain_task: tokio::task::JoinHandle<()>,
|
||||||
|
created_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Independent service for managing interactive shell sessions.
|
||||||
|
///
|
||||||
|
/// Thread-safe — designed to be shared via `Arc<ShellSessionManager>`.
|
||||||
|
pub struct ShellSessionManager {
|
||||||
|
sessions: Mutex<HashMap<String, ShellSession>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShellSessionManager {
|
||||||
|
/// Create a new, empty session manager.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save a child process as an interactive session.
|
||||||
|
///
|
||||||
|
/// The caller provides:
|
||||||
|
/// - `child`: the spawned child process (with piped stdin/stdout/stderr)
|
||||||
|
/// - `stdin_writer`: the write-half of the piped stdin
|
||||||
|
/// - `rx`: the output channel receiver (produced by `read_stream` tasks)
|
||||||
|
/// - `initial_stdout` / `initial_stderr`: output already collected before
|
||||||
|
/// the session was created
|
||||||
|
///
|
||||||
|
/// Returns a unique `session_id` that can be used for subsequent
|
||||||
|
/// `send_input` / `get_output` / `close_session` calls.
|
||||||
|
pub async fn save_session(
|
||||||
|
&self,
|
||||||
|
mut child: Child,
|
||||||
|
stdin_writer: ChildStdin,
|
||||||
|
mut rx: mpsc::UnboundedReceiver<(bool, String)>,
|
||||||
|
initial_stdout: String,
|
||||||
|
initial_stderr: String,
|
||||||
|
) -> String {
|
||||||
|
let session_id = Uuid::new_v4().to_string();
|
||||||
|
let stdout_buf = Arc::new(Mutex::new(initial_stdout));
|
||||||
|
let stderr_buf = Arc::new(Mutex::new(initial_stderr));
|
||||||
|
|
||||||
|
// Spawn a background task that drains the channel into buffers.
|
||||||
|
let stdout_clone = stdout_buf.clone();
|
||||||
|
let stderr_clone = stderr_buf.clone();
|
||||||
|
let drain_task = tokio::spawn(async move {
|
||||||
|
while let Some((is_stderr, chunk)) = rx.recv().await {
|
||||||
|
if is_stderr {
|
||||||
|
stderr_clone.lock().await.push_str(&chunk);
|
||||||
|
} else {
|
||||||
|
stdout_clone.lock().await.push_str(&chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Kill the child's inherited stdin to prevent blocking on close.
|
||||||
|
// The actual stdin writing is done via stdin_writer.
|
||||||
|
let _ = child.stdin.take();
|
||||||
|
|
||||||
|
let session = ShellSession {
|
||||||
|
child,
|
||||||
|
stdin_writer: Some(stdin_writer),
|
||||||
|
stdout_buf,
|
||||||
|
stderr_buf,
|
||||||
|
_drain_task: drain_task,
|
||||||
|
created_at: Instant::now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.sessions
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(session_id.clone(), session);
|
||||||
|
session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send input to a session's stdin and return new output.
|
||||||
|
///
|
||||||
|
/// After writing, waits up to `OUTPUT_WAIT_MS` for new output to arrive.
|
||||||
|
/// If the child process exits during the wait, returns the final output.
|
||||||
|
pub async fn send_input(&self, session_id: &str, input: &str) -> Result<String, String> {
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
|
||||||
|
let session = sessions
|
||||||
|
.get_mut(session_id)
|
||||||
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
|
|
||||||
|
// Write input to stdin
|
||||||
|
if let Some(writer) = &mut session.stdin_writer {
|
||||||
|
let data = if input.ends_with('\n') {
|
||||||
|
input.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}\n", input)
|
||||||
|
};
|
||||||
|
writer
|
||||||
|
.write_all(data.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to write stdin: {}", e))?;
|
||||||
|
writer
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
|
||||||
|
} else {
|
||||||
|
return Err("Session stdin is closed".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record output length before wait
|
||||||
|
let prev_stdout_len = session.stdout_buf.lock().await.len();
|
||||||
|
let prev_stderr_len = session.stderr_buf.lock().await.len();
|
||||||
|
|
||||||
|
// Wait for new output or process exit
|
||||||
|
let deadline = Instant::now() + Duration::from_millis(OUTPUT_WAIT_MS);
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
status = session.child.wait() => {
|
||||||
|
// Process exited — collect final output
|
||||||
|
let stdout = session.stdout_buf.lock().await.clone();
|
||||||
|
let stderr = session.stderr_buf.lock().await.clone();
|
||||||
|
let code = status.ok().and_then(|s| s.code());
|
||||||
|
drop(sessions);
|
||||||
|
return Ok(Self::format_output(&stdout, &stderr, code));
|
||||||
|
}
|
||||||
|
_ = tokio::time::sleep_until(deadline) => {
|
||||||
|
// Timeout — return current output
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = session.stdout_buf.lock().await.clone();
|
||||||
|
let stderr = session.stderr_buf.lock().await.clone();
|
||||||
|
|
||||||
|
let new_stdout: String = stdout.chars().skip(prev_stdout_len).collect();
|
||||||
|
let new_stderr: String = stderr.chars().skip(prev_stderr_len).collect();
|
||||||
|
|
||||||
|
let mut result = String::new();
|
||||||
|
if !new_stdout.is_empty() {
|
||||||
|
result.push_str(&new_stdout);
|
||||||
|
}
|
||||||
|
if !new_stderr.trim().is_empty() {
|
||||||
|
if !result.is_empty() {
|
||||||
|
result.push('\n');
|
||||||
|
}
|
||||||
|
result.push_str("STDERR:\n");
|
||||||
|
result.push_str(&new_stderr);
|
||||||
|
}
|
||||||
|
if result.is_empty() {
|
||||||
|
result.push_str("(No new output after input. Session still active.)");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the full accumulated output of a session.
|
||||||
|
pub async fn get_output(&self, session_id: &str) -> Result<String, String> {
|
||||||
|
let sessions = self.sessions.lock().await;
|
||||||
|
let session = sessions
|
||||||
|
.get(session_id)
|
||||||
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
|
|
||||||
|
let stdout = session.stdout_buf.lock().await.clone();
|
||||||
|
let stderr = session.stderr_buf.lock().await.clone();
|
||||||
|
Ok(Self::format_output(&stdout, &stderr, None))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Close a session: kill the child process and return final output.
|
||||||
|
pub async fn close_session(&self, session_id: &str) -> Result<String, String> {
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
let mut session = sessions
|
||||||
|
.remove(session_id)
|
||||||
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
|
drop(sessions); // Release lock before awaiting
|
||||||
|
|
||||||
|
// Close stdin to signal EOF
|
||||||
|
if let Some(mut writer) = session.stdin_writer.take() {
|
||||||
|
let _ = writer.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kill and wait
|
||||||
|
let _ = session.child.start_kill();
|
||||||
|
let status = session.child.wait().await.ok();
|
||||||
|
let code = status.and_then(|s| s.code());
|
||||||
|
|
||||||
|
let stdout = session.stdout_buf.lock().await.clone();
|
||||||
|
let stderr = session.stderr_buf.lock().await.clone();
|
||||||
|
|
||||||
|
// Abort drain task
|
||||||
|
session._drain_task.abort();
|
||||||
|
|
||||||
|
Ok(Self::format_output(&stdout, &stderr, code))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove and kill all expired sessions (older than `SESSION_TIMEOUT_SECS`).
|
||||||
|
///
|
||||||
|
/// Designed to be called periodically from the owner (e.g. gateway).
|
||||||
|
pub async fn cleanup_expired(&self) {
|
||||||
|
let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
|
||||||
|
let mut expired_ids = Vec::new();
|
||||||
|
|
||||||
|
{
|
||||||
|
let sessions = self.sessions.lock().await;
|
||||||
|
for (id, session) in sessions.iter() {
|
||||||
|
if session.created_at.elapsed() > timeout {
|
||||||
|
expired_ids.push(id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id in expired_ids {
|
||||||
|
let _ = self.close_session(&id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gracefully shut down all active sessions.
|
||||||
|
pub async fn shutdown(&self) {
|
||||||
|
let ids: Vec<String> = {
|
||||||
|
let sessions = self.sessions.lock().await;
|
||||||
|
sessions.keys().cloned().collect()
|
||||||
|
};
|
||||||
|
for id in ids {
|
||||||
|
let _ = self.close_session(&id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of currently active sessions.
|
||||||
|
pub async fn active_count(&self) -> usize {
|
||||||
|
self.sessions.lock().await.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
|
||||||
|
let mut output = String::new();
|
||||||
|
if !stdout.is_empty() {
|
||||||
|
output.push_str(stdout);
|
||||||
|
}
|
||||||
|
if !stderr.trim().is_empty() {
|
||||||
|
if !output.is_empty() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
output.push_str("STDERR:\n");
|
||||||
|
output.push_str(stderr);
|
||||||
|
}
|
||||||
|
if let Some(code) = exit_code {
|
||||||
|
output.push_str(&format!("\nExit code: {}", code));
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ShellSessionManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user