PicoBot/src/tools/bash.rs

728 lines
23 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::json;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::{Mutex, mpsc};
use tokio::time::{Instant, sleep_until};
use crate::platform::{ShellInfo, dangerous_command_patterns};
use crate::tools::traits::{Tool, ToolResult};
use crate::tools::{extract_u64, extract_bool, check_null_args};
const MAX_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_CHARS: usize = 50_000;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const USER_ACTION_HINT: &str =
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
/// Shell 类型枚举,支持跨平台
///
/// 这是 ShellInfo 的兼容包装,提供更方便的 API。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShellKind {
Bash,
PowerShell,
Cmd,
}
impl ShellKind {
/// 根据平台检测默认 shell
pub fn detect() -> Self {
let info = ShellInfo::default();
match info.executable {
"bash" => ShellKind::Bash,
"powershell" => ShellKind::PowerShell,
"cmd" => ShellKind::Cmd,
_ => ShellKind::Bash, // fallback
}
}
/// 从 ShellInfo 获取 ShellKind
pub fn from_info(info: &ShellInfo) -> Self {
match info.executable {
"bash" => ShellKind::Bash,
"powershell" => ShellKind::PowerShell,
"cmd" => ShellKind::Cmd,
_ => ShellKind::Bash,
}
}
/// 获取对应的 ShellInfo
pub fn to_info(&self) -> ShellInfo {
match self {
ShellKind::Bash => ShellInfo {
name: "bash",
executable: "bash",
args: &["-c"],
},
ShellKind::PowerShell => ShellInfo {
name: "shell",
executable: "powershell",
args: &["-Command"],
},
ShellKind::Cmd => ShellInfo {
name: "shell",
executable: "cmd",
args: &["/C"],
},
}
}
/// Shell 可执行文件名
pub fn executable(&self) -> &'static str {
self.to_info().executable
}
/// 执行命令所需的参数
pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> {
let info = self.to_info();
info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect()
}
/// 工具名称
pub fn tool_name(&self) -> &'static str {
self.to_info().name
}
/// 工具描述
pub fn tool_description(&self) -> &'static str {
match self {
ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.",
ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.",
ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.",
}
}
}
pub struct BashTool {
timeout_secs: u64,
working_dir: Option<String>,
deny_patterns: Vec<String>,
shell: ShellKind,
}
impl BashTool {
pub fn new() -> Self {
Self {
timeout_secs: 60,
working_dir: None,
deny_patterns: dangerous_command_patterns(),
shell: ShellKind::detect(),
}
}
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = timeout_secs;
self
}
pub fn with_working_dir(mut self, dir: String) -> Self {
self.working_dir = Some(dir);
self
}
pub fn with_shell(mut self, shell: ShellKind) -> Self {
self.shell = shell;
self
}
fn guard_command(&self, command: &str) -> Option<String> {
let lower = command.to_lowercase();
for pattern in &self.deny_patterns {
if regex::Regex::new(pattern)
.ok()
.map(|re| re.is_match(&lower))
.unwrap_or(false)
{
return Some(format!(
"Command blocked by safety guard (dangerous pattern: {})",
pattern
));
}
}
None
}
fn truncate_output(&self, output: &str) -> String {
let char_count = output.chars().count();
if char_count <= MAX_OUTPUT_CHARS {
return output.to_string();
}
let half = MAX_OUTPUT_CHARS / 2;
let head: String = output.chars().take(half).collect();
let tail: String = output
.chars()
.skip(char_count.saturating_sub(half))
.collect();
format!(
"{}...\n\n(... {} chars truncated ...)\n\n{}",
head,
char_count - MAX_OUTPUT_CHARS,
tail
)
}
fn pending_output(&self, output: &str) -> String {
format!(
"{}\n{}\n\n{}",
PENDING_USER_ACTION_MARKER,
USER_ACTION_HINT,
self.truncate_output(output.trim())
)
}
fn should_return_pending(&self, interactive: bool, output: &str) -> bool {
let normalized = output.to_lowercase();
let has_auth_phrase = [
// 中文 — 原有
"等待用户授权",
"等待授权",
"等待你授权",
"在浏览器中打开以下链接进行认证",
// 中文 — 新增lark-cli 等工具的常见提示)
"请在浏览器中",
"请打开以下链接",
"打开以下链接",
"打开链接",
"访问以下",
"访问此链接",
"复制链接",
"输入验证码",
"输入授权码",
"完成认证",
"完成授权",
"请登录",
"正在等待",
"等待用户",
"手动授权",
// 英文 — 原有
"open the following link",
"waiting for authorization",
"waiting for user authorization",
"waiting for approval",
"device/verify",
"user_code=",
// 英文 — 新增
"visit the following url",
"visit this url",
"open the following url",
"browser to authenticate",
"browser to complete",
"enter the code",
"enter code",
"verification code",
"authorization code",
"one-time code",
"device code",
"oauth",
"go to the following",
"navigate to the following",
"paste the code",
]
.iter()
.any(|pattern| normalized.contains(pattern));
has_auth_phrase || (interactive && !output.trim().is_empty())
}
}
async fn drain_available_chunks(
rx: &mut mpsc::UnboundedReceiver<(bool, String)>,
stdout_buf: &Arc<Mutex<String>>,
stderr_buf: &Arc<Mutex<String>>,
) {
while let Ok((is_stderr, chunk)) = rx.try_recv() {
if is_stderr {
stderr_buf.lock().await.push_str(&chunk);
} else {
stdout_buf.lock().await.push_str(&chunk);
}
}
}
impl Default for BashTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for BashTool {
fn name(&self) -> &str {
self.shell.tool_name()
}
fn description(&self) -> &str {
self.shell.tool_description()
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
},
"timeout": {
"type": "integer",
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
"minimum": 1,
"maximum": MAX_TIMEOUT_SECS
},
"interactive": {
"type": "boolean",
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
}
},
"required": ["command"]
})
}
fn exclusive(&self) -> bool {
true // Shell commands should not run concurrently
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if let Some(result) = check_null_args(&args, "bash") {
return Ok(result);
}
let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: command".to_string()),
});
}
};
// Safety check
if let Some(error) = self.guard_command(command) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let timeout_secs = extract_u64(&args, "timeout")
.unwrap_or(self.timeout_secs)
.min(MAX_TIMEOUT_SECS);
let interactive = extract_bool(&args, "interactive").unwrap_or(false);
let cwd = self
.working_dir
.as_ref()
.map(|d| Path::new(d))
.unwrap_or_else(|| Path::new("."));
match self
.run_command(command, cwd, timeout_secs, interactive)
.await
{
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
}
impl BashTool {
async fn run_command(
&self,
command: &str,
cwd: &Path,
timeout_secs: u64,
interactive: bool,
) -> Result<String, String> {
let mut cmd = Command::new(self.shell.executable());
cmd.args(self.shell.command_args(command))
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(cwd);
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let stdout = child.stdout.take();
let stderr = child.stderr.take();
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>();
if let Some(stdout) = stdout {
tokio::spawn(read_stream(stdout, false, tx.clone()));
}
if let Some(stderr) = stderr {
tokio::spawn(read_stream(stderr, true, tx.clone()));
}
drop(tx);
let stdout_buf = Arc::new(Mutex::new(String::new()));
let stderr_buf = Arc::new(Mutex::new(String::new()));
let deadline = Instant::now() + Duration::from_secs(timeout_secs);
loop {
tokio::select! {
status = child.wait() => {
let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
while let Some((is_stderr, chunk)) = rx.recv().await {
if is_stderr {
stderr_buf.lock().await.push_str(&chunk);
} else {
stdout_buf.lock().await.push_str(&chunk);
}
}
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
return Ok(self.truncate_output(&output));
}
Some((is_stderr, chunk)) = rx.recv() => {
if is_stderr {
stderr_buf.lock().await.push_str(&chunk);
} else {
stdout_buf.lock().await.push_str(&chunk);
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if self.should_return_pending(interactive, &combined) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let _ = child.start_kill();
let _ = child.wait().await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
return Ok(self.pending_output(&combined));
}
}
_ = tokio::time::sleep(Duration::from_secs(2)) => {
// Periodic safety net: when output has been silent for 2s,
// check OS-level process state to see if the child is
// genuinely blocked on stdin. Also re-run keyword detection
// in case read_stream flushed a partial line since the last
// rx.recv() iteration.
if let Some(pid) = child.id() {
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if !combined.trim().is_empty() {
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(self.pending_output(&combined));
}
}
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if self.should_return_pending(interactive, &combined) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(self.pending_output(&combined));
}
}
_ = sleep_until(deadline) => {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
let _ = child.start_kill();
let _ = child.wait().await;
// OS-level process state check: if the child was blocked on
// stdin, treat it as pending rather than a hard timeout error.
if let Some(pid) = child.id() {
if crate::platform::is_process_waiting_on_stdin(pid) == Some(true)
&& !combined.trim().is_empty()
{
return Ok(self.pending_output(&combined));
}
}
if self.should_return_pending(interactive, &combined) {
return Ok(self.pending_output(&combined));
}
return Err(format!("Command timed out after {} seconds", timeout_secs));
}
}
}
}
}
/// Flush delay: if no new data arrives within this window, send any buffered
/// partial line immediately. This ensures that prompts and URLs printed
/// without a trailing newline are still visible to the detection logic.
const STREAM_FLUSH_MS: u64 = 500;
async fn read_stream<R>(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>)
where
R: AsyncRead + Unpin + Send + 'static,
{
let mut reader = BufReader::new(stream);
let mut buffer = Vec::new();
loop {
let mut chunk = [0u8; 4096];
tokio::select! {
result = reader.read(&mut chunk) => {
match result {
Ok(0) => break,
Ok(n) => {
buffer.extend_from_slice(&chunk[..n]);
// 发送完整行(逻辑不变)
while let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line_bytes = &buffer[..pos + 1];
let line = decode_bytes(line_bytes);
let _ = tx.send((is_stderr, line));
buffer.drain(..pos + 1);
}
}
Err(_) => break,
}
}
_ = tokio::time::sleep(Duration::from_millis(STREAM_FLUSH_MS)) => {
// 超时未收到新数据flush 不完整的行
// 这确保像 lark-cli auth login 打印的 URL不带换行符
// 也能被 run_command 的检测逻辑看到
if !buffer.is_empty() {
let remainder = decode_bytes(&buffer);
let _ = tx.send((is_stderr, remainder));
buffer.clear();
}
}
}
}
// 处理剩余的字节
if !buffer.is_empty() {
let remainder = decode_bytes(&buffer);
let _ = tx.send((is_stderr, remainder));
}
}
/// 尝试 UTF-8 解码,失败则尝试 GBK 解码
fn decode_bytes(bytes: &[u8]) -> String {
// 首先尝试 UTF-8
if let Ok(s) = std::str::from_utf8(bytes) {
return s.to_string();
}
// 尝试 GBK 解码
let (cow, _, had_errors) = encoding_rs::GBK.decode(bytes);
if !had_errors {
return cow.to_string();
}
// 如果 GBK 也失败,使用 lossy 转换
String::from_utf8_lossy(bytes).to_string()
}
fn format_command_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
let mut output = String::new();
if !stdout.is_empty() {
output.push_str(stdout);
}
if !stderr.trim().is_empty() {
if !output.is_empty() {
output.push_str("\n");
}
output.push_str("STDERR:\n");
output.push_str(stderr);
}
if let Some(code) = exit_code {
output.push_str(&format!("\nExit code: {}", code));
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_simple_command() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Write-Output 'Hello World'"
} else {
"echo 'Hello World'"
};
let result = tool
.execute(json!({ "command": command }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Hello World"));
}
#[tokio::test]
async fn test_pwd_command() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Get-Location"
} else {
"pwd"
};
let result = tool.execute(json!({ "command": command })).await.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_ls_command() {
let tool = BashTool::new();
let temp_dir = std::env::temp_dir();
let command = if cfg!(target_os = "windows") {
format!("Get-ChildItem {}", temp_dir.display())
} else {
format!("ls -la {}", temp_dir.display())
};
let result = tool
.execute(json!({ "command": command }))
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_dangerous_rm() {
let tool = BashTool::new();
// 测试 Unix 危险命令模式
let result = tool
.execute(json!({ "command": "rm -rf /some/path" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("blocked"));
}
#[tokio::test]
async fn test_dangerous_windows_commands() {
let tool = BashTool::new();
// 测试 Windows del 命令模式(正则应该匹配)
let result = tool
.execute(json!({ "command": "del /f /q file.txt" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("blocked"));
}
#[tokio::test]
async fn test_dangerous_fork_bomb() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": ":(){ :|:& };:" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("blocked"));
}
#[tokio::test]
async fn test_missing_command() {
let tool = BashTool::new();
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("command"));
}
#[tokio::test]
async fn test_timeout() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Start-Sleep -Seconds 10"
} else {
"sleep 10"
};
let result = tool
.execute(json!({
"command": command,
"timeout": 1
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("timed out"));
}
#[tokio::test]
async fn test_pending_user_action_detection() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
} else {
"printf 'waiting for authorization'; sleep 10"
};
let result = tool
.execute(json!({
"command": command,
"timeout": 5,
"interactive": true
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
}
#[test]
fn test_truncate_output_handles_utf8_char_boundaries() {
let tool = BashTool::new();
let input = "".repeat(MAX_OUTPUT_CHARS + 100);
let output = tool.truncate_output(&input);
assert!(output.contains("chars truncated"));
assert!(output.is_char_boundary(output.len()));
}
#[test]
fn test_shell_kind_detect() {
let shell = ShellKind::detect();
if cfg!(target_os = "windows") {
assert_eq!(shell, ShellKind::PowerShell);
} else {
assert_eq!(shell, ShellKind::Bash);
}
}
#[test]
fn test_shell_kind_executable() {
assert_eq!(ShellKind::Bash.executable(), "bash");
assert_eq!(ShellKind::PowerShell.executable(), "powershell");
assert_eq!(ShellKind::Cmd.executable(), "cmd");
}
#[test]
fn test_shell_kind_command_args() {
assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]);
assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]);
assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]);
}
}