feat: 添加平台特定的进程输入等待检查功能,增强 Bash 工具的用户授权检测

This commit is contained in:
oudecheng 2026-06-02 10:59:13 +08:00
parent 9b6cae0803
commit 1541dd7c10
3 changed files with 170 additions and 12 deletions

View File

@ -37,6 +37,7 @@ rusqlite = { version = "0.32", features = ["bundled"] }
rustls = { version = "0.23", features = ["ring"] }
wechatbot = { path = "vendor/wechatbot" }
encoding_rs = "0.8"
libc = "0.2"
# MCP (Model Context Protocol) support
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
"client",

View File

@ -111,6 +111,71 @@ pub fn dangerous_command_patterns() -> Vec<String> {
]
}
/// Check whether a child process is blocked waiting for stdin input.
///
/// Uses platform-specific mechanisms to determine if the process is genuinely
/// waiting for user input (as opposed to computing, sleeping, or doing I/O).
///
/// Returns `None` when the platform does not support this check or the
/// information cannot be read.
pub fn is_process_waiting_on_stdin(pid: u32) -> Option<bool> {
#[cfg(target_os = "linux")]
{
let wchan = std::fs::read_to_string(format!("/proc/{}/wchan", pid)).ok()?;
let wchan = wchan.trim();
if wchan.is_empty() {
return None;
}
Some(
wchan.contains("tty_read")
|| wchan.contains("n_tty_read")
|| wchan == "pipe_wait",
)
}
#[cfg(target_os = "macos")]
{
use std::mem;
let mut task_info: libc::proc_taskinfo = unsafe { mem::zeroed() };
let size = mem::size_of::<libc::proc_taskinfo>() as i32;
let ret = unsafe {
libc::proc_pidinfo(
pid as i32,
libc::PROC_PIDTASKINFO,
0,
&mut task_info as *mut _ as *mut libc::c_void,
size,
)
};
if ret <= 0 {
return None;
}
// pti_numrunning == 0 means no thread is actively on CPU.
// Combined with output silence this strongly suggests the process
// is blocked on I/O (likely a stdin read).
Some(task_info.pti_numrunning == 0)
}
#[cfg(target_os = "windows")]
{
// Windows: a full implementation would use:
// 1. OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid)
// 2. NtQuerySystemInformation(SystemProcessInformation) to enumerate
// threads and check if the main thread's WaitReason is Executive
// (indicating a wait on a kernel handle, e.g. console input).
// 3. CloseHandle(hProcess)
//
// This requires the undocumented NtQuerySystemInformation API from
// ntdll.dll. Until that is in place, the keyword-matching and
// periodic output-staleness checks in bash.rs handle Windows detection.
let _ = pid;
None
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{
let _ = pid;
None
}
}
/// Get the user's home directory.
///
/// Supports environment variable overrides for testing:

View File

@ -180,16 +180,50 @@ impl BashTool {
fn should_return_pending(&self, interactive: bool, output: &str) -> bool {
let normalized = output.to_lowercase();
let has_auth_phrase = [
// 中文 — 原有
"等待用户授权",
"等待授权",
"等待你授权",
"在浏览器中打开以下链接进行认证",
// 中文 — 新增lark-cli 等工具的常见提示)
"请在浏览器中",
"请打开以下链接",
"打开以下链接",
"打开链接",
"访问以下",
"访问此链接",
"复制链接",
"输入验证码",
"输入授权码",
"完成认证",
"完成授权",
"请登录",
"正在等待",
"等待用户",
"手动授权",
// 英文 — 原有
"open the following link",
"waiting for authorization",
"waiting for user authorization",
"waiting for approval",
"device/verify",
"user_code=",
// 英文 — 新增
"visit the following url",
"visit this url",
"open the following url",
"browser to authenticate",
"browser to complete",
"enter the code",
"enter code",
"verification code",
"authorization code",
"one-time code",
"device code",
"oauth",
"go to the following",
"navigate to the following",
"paste the code",
]
.iter()
.any(|pattern| normalized.contains(pattern));
@ -372,11 +406,49 @@ impl BashTool {
return Ok(self.pending_output(&combined));
}
}
_ = tokio::time::sleep(Duration::from_secs(2)) => {
// Periodic safety net: when output has been silent for 2s,
// check OS-level process state to see if the child is
// genuinely blocked on stdin. Also re-run keyword detection
// in case read_stream flushed a partial line since the last
// rx.recv() iteration.
if let Some(pid) = child.id() {
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if !combined.trim().is_empty() {
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(self.pending_output(&combined));
}
}
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if self.should_return_pending(interactive, &combined) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(self.pending_output(&combined));
}
}
_ = sleep_until(deadline) => {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
let _ = child.start_kill();
let _ = child.wait().await;
// OS-level process state check: if the child was blocked on
// stdin, treat it as pending rather than a hard timeout error.
if let Some(pid) = child.id() {
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true)
&& !combined.trim().is_empty()
{
return Ok(self.pending_output(&combined));
}
}
if self.should_return_pending(interactive, &combined) {
return Ok(self.pending_output(&combined));
}
@ -387,6 +459,11 @@ impl BashTool {
}
}
/// Flush delay: if no new data arrives within this window, send any buffered
/// partial line immediately. This ensures that prompts and URLs printed
/// without a trailing newline are still visible to the detection logic.
const STREAM_FLUSH_MS: u64 = 500;
async fn read_stream<R>(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>)
where
R: AsyncRead + Unpin + Send + 'static,
@ -396,12 +473,15 @@ where
loop {
let mut chunk = [0u8; 4096];
match reader.read(&mut chunk).await {
tokio::select! {
result = reader.read(&mut chunk) => {
match result {
Ok(0) => break,
Ok(n) => {
buffer.extend_from_slice(&chunk[..n]);
// 处理完整的行
// 发送完整行(逻辑不变)
while let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line_bytes = &buffer[..pos + 1];
let line = decode_bytes(line_bytes);
@ -412,6 +492,18 @@ where
Err(_) => break,
}
}
_ = tokio::time::sleep(Duration::from_millis(STREAM_FLUSH_MS)) => {
// 超时未收到新数据flush 不完整的行
// 这确保像 lark-cli auth login 打印的 URL不带换行符
// 也能被 run_command 的检测逻辑看到
if !buffer.is_empty() {
let remainder = decode_bytes(&buffer);
let _ = tx.send((is_stderr, remainder));
buffer.clear();
}
}
}
}
// 处理剩余的字节
if !buffer.is_empty() {
@ -588,7 +680,7 @@ mod tests {
let result = tool
.execute(json!({
"command": command,
"timeout": 1,
"timeout": 5,
"interactive": true
}))
.await