636 lines
19 KiB
Rust
636 lines
19 KiB
Rust
use std::path::Path;
|
|
use std::process::Stdio;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::json;
|
|
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
|
|
use tokio::process::Command;
|
|
use tokio::sync::{Mutex, mpsc};
|
|
use tokio::time::{Instant, sleep_until};
|
|
|
|
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
|
use crate::tools::traits::{Tool, ToolResult};
|
|
use crate::tools::{extract_u64, extract_bool, check_null_args};
|
|
|
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
|
const MAX_OUTPUT_CHARS: usize = 50_000;
|
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
|
const USER_ACTION_HINT: &str =
|
|
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
|
|
|
/// Shell 类型枚举,支持跨平台
|
|
///
|
|
/// 这是 ShellInfo 的兼容包装,提供更方便的 API。
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum ShellKind {
|
|
Bash,
|
|
PowerShell,
|
|
Cmd,
|
|
}
|
|
|
|
impl ShellKind {
|
|
/// 根据平台检测默认 shell
|
|
pub fn detect() -> Self {
|
|
let info = ShellInfo::default();
|
|
match info.executable {
|
|
"bash" => ShellKind::Bash,
|
|
"powershell" => ShellKind::PowerShell,
|
|
"cmd" => ShellKind::Cmd,
|
|
_ => ShellKind::Bash, // fallback
|
|
}
|
|
}
|
|
|
|
/// 从 ShellInfo 获取 ShellKind
|
|
pub fn from_info(info: &ShellInfo) -> Self {
|
|
match info.executable {
|
|
"bash" => ShellKind::Bash,
|
|
"powershell" => ShellKind::PowerShell,
|
|
"cmd" => ShellKind::Cmd,
|
|
_ => ShellKind::Bash,
|
|
}
|
|
}
|
|
|
|
/// 获取对应的 ShellInfo
|
|
pub fn to_info(&self) -> ShellInfo {
|
|
match self {
|
|
ShellKind::Bash => ShellInfo {
|
|
name: "bash",
|
|
executable: "bash",
|
|
args: &["-c"],
|
|
},
|
|
ShellKind::PowerShell => ShellInfo {
|
|
name: "shell",
|
|
executable: "powershell",
|
|
args: &["-Command"],
|
|
},
|
|
ShellKind::Cmd => ShellInfo {
|
|
name: "shell",
|
|
executable: "cmd",
|
|
args: &["/C"],
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Shell 可执行文件名
|
|
pub fn executable(&self) -> &'static str {
|
|
self.to_info().executable
|
|
}
|
|
|
|
/// 执行命令所需的参数
|
|
pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> {
|
|
let info = self.to_info();
|
|
info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect()
|
|
}
|
|
|
|
/// 工具名称
|
|
pub fn tool_name(&self) -> &'static str {
|
|
self.to_info().name
|
|
}
|
|
|
|
/// 工具描述
|
|
pub fn tool_description(&self) -> &'static str {
|
|
match self {
|
|
ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.",
|
|
ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.",
|
|
ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.",
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct BashTool {
|
|
timeout_secs: u64,
|
|
working_dir: Option<String>,
|
|
deny_patterns: Vec<String>,
|
|
shell: ShellKind,
|
|
}
|
|
|
|
impl BashTool {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
timeout_secs: 60,
|
|
working_dir: None,
|
|
deny_patterns: dangerous_command_patterns(),
|
|
shell: ShellKind::detect(),
|
|
}
|
|
}
|
|
|
|
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
|
|
self.timeout_secs = timeout_secs;
|
|
self
|
|
}
|
|
|
|
pub fn with_working_dir(mut self, dir: String) -> Self {
|
|
self.working_dir = Some(dir);
|
|
self
|
|
}
|
|
|
|
pub fn with_shell(mut self, shell: ShellKind) -> Self {
|
|
self.shell = shell;
|
|
self
|
|
}
|
|
|
|
fn guard_command(&self, command: &str) -> Option<String> {
|
|
let lower = command.to_lowercase();
|
|
for pattern in &self.deny_patterns {
|
|
if regex::Regex::new(pattern)
|
|
.ok()
|
|
.map(|re| re.is_match(&lower))
|
|
.unwrap_or(false)
|
|
{
|
|
return Some(format!(
|
|
"Command blocked by safety guard (dangerous pattern: {})",
|
|
pattern
|
|
));
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
fn truncate_output(&self, output: &str) -> String {
|
|
let char_count = output.chars().count();
|
|
if char_count <= MAX_OUTPUT_CHARS {
|
|
return output.to_string();
|
|
}
|
|
|
|
let half = MAX_OUTPUT_CHARS / 2;
|
|
let head: String = output.chars().take(half).collect();
|
|
let tail: String = output
|
|
.chars()
|
|
.skip(char_count.saturating_sub(half))
|
|
.collect();
|
|
format!(
|
|
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
|
head,
|
|
char_count - MAX_OUTPUT_CHARS,
|
|
tail
|
|
)
|
|
}
|
|
|
|
fn pending_output(&self, output: &str) -> String {
|
|
format!(
|
|
"{}\n{}\n\n{}",
|
|
PENDING_USER_ACTION_MARKER,
|
|
USER_ACTION_HINT,
|
|
self.truncate_output(output.trim())
|
|
)
|
|
}
|
|
|
|
fn should_return_pending(&self, interactive: bool, output: &str) -> bool {
|
|
let normalized = output.to_lowercase();
|
|
let has_auth_phrase = [
|
|
"等待用户授权",
|
|
"等待授权",
|
|
"等待你授权",
|
|
"在浏览器中打开以下链接进行认证",
|
|
"open the following link",
|
|
"waiting for authorization",
|
|
"waiting for user authorization",
|
|
"waiting for approval",
|
|
"device/verify",
|
|
"user_code=",
|
|
]
|
|
.iter()
|
|
.any(|pattern| normalized.contains(pattern));
|
|
|
|
has_auth_phrase || (interactive && !output.trim().is_empty())
|
|
}
|
|
}
|
|
|
|
async fn drain_available_chunks(
|
|
rx: &mut mpsc::UnboundedReceiver<(bool, String)>,
|
|
stdout_buf: &Arc<Mutex<String>>,
|
|
stderr_buf: &Arc<Mutex<String>>,
|
|
) {
|
|
while let Ok((is_stderr, chunk)) = rx.try_recv() {
|
|
if is_stderr {
|
|
stderr_buf.lock().await.push_str(&chunk);
|
|
} else {
|
|
stdout_buf.lock().await.push_str(&chunk);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for BashTool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for BashTool {
|
|
fn name(&self) -> &str {
|
|
self.shell.tool_name()
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
self.shell.tool_description()
|
|
}
|
|
|
|
fn parameters_schema(&self) -> serde_json::Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"command": {
|
|
"type": "string",
|
|
"description": "The shell command to execute"
|
|
},
|
|
"timeout": {
|
|
"type": "integer",
|
|
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
|
|
"minimum": 1,
|
|
"maximum": MAX_TIMEOUT_SECS
|
|
},
|
|
"interactive": {
|
|
"type": "boolean",
|
|
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
|
|
}
|
|
},
|
|
"required": ["command"]
|
|
})
|
|
}
|
|
|
|
fn exclusive(&self) -> bool {
|
|
true // Shell commands should not run concurrently
|
|
}
|
|
|
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
if let Some(result) = check_null_args(&args, "bash") {
|
|
return Ok(result);
|
|
}
|
|
|
|
let command = match args.get("command").and_then(|v| v.as_str()) {
|
|
Some(c) => c,
|
|
None => {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some("Missing required parameter: command".to_string()),
|
|
});
|
|
}
|
|
};
|
|
|
|
// Safety check
|
|
if let Some(error) = self.guard_command(command) {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(error),
|
|
});
|
|
}
|
|
|
|
let timeout_secs = extract_u64(&args, "timeout")
|
|
.unwrap_or(self.timeout_secs)
|
|
.min(MAX_TIMEOUT_SECS);
|
|
let interactive = extract_bool(&args, "interactive").unwrap_or(false);
|
|
|
|
let cwd = self
|
|
.working_dir
|
|
.as_ref()
|
|
.map(|d| Path::new(d))
|
|
.unwrap_or_else(|| Path::new("."));
|
|
|
|
match self
|
|
.run_command(command, cwd, timeout_secs, interactive)
|
|
.await
|
|
{
|
|
Ok(output) => Ok(ToolResult {
|
|
success: true,
|
|
output,
|
|
error: None,
|
|
}),
|
|
Err(e) => Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(e),
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl BashTool {
|
|
async fn run_command(
|
|
&self,
|
|
command: &str,
|
|
cwd: &Path,
|
|
timeout_secs: u64,
|
|
interactive: bool,
|
|
) -> Result<String, String> {
|
|
let mut cmd = Command::new(self.shell.executable());
|
|
cmd.args(self.shell.command_args(command))
|
|
.stdout(Stdio::piped())
|
|
.stderr(Stdio::piped())
|
|
.current_dir(cwd);
|
|
|
|
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
|
|
|
let stdout = child.stdout.take();
|
|
let stderr = child.stderr.take();
|
|
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>();
|
|
|
|
if let Some(stdout) = stdout {
|
|
tokio::spawn(read_stream(stdout, false, tx.clone()));
|
|
}
|
|
if let Some(stderr) = stderr {
|
|
tokio::spawn(read_stream(stderr, true, tx.clone()));
|
|
}
|
|
drop(tx);
|
|
|
|
let stdout_buf = Arc::new(Mutex::new(String::new()));
|
|
let stderr_buf = Arc::new(Mutex::new(String::new()));
|
|
let deadline = Instant::now() + Duration::from_secs(timeout_secs);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
status = child.wait() => {
|
|
let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
|
|
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
|
while let Some((is_stderr, chunk)) = rx.recv().await {
|
|
if is_stderr {
|
|
stderr_buf.lock().await.push_str(&chunk);
|
|
} else {
|
|
stdout_buf.lock().await.push_str(&chunk);
|
|
}
|
|
}
|
|
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
|
|
return Ok(self.truncate_output(&output));
|
|
}
|
|
Some((is_stderr, chunk)) = rx.recv() => {
|
|
if is_stderr {
|
|
stderr_buf.lock().await.push_str(&chunk);
|
|
} else {
|
|
stdout_buf.lock().await.push_str(&chunk);
|
|
}
|
|
|
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
|
if self.should_return_pending(interactive, &combined) {
|
|
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
|
let _ = child.start_kill();
|
|
let _ = child.wait().await;
|
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
|
return Ok(self.pending_output(&combined));
|
|
}
|
|
}
|
|
_ = sleep_until(deadline) => {
|
|
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
|
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
|
let _ = child.start_kill();
|
|
let _ = child.wait().await;
|
|
if self.should_return_pending(interactive, &combined) {
|
|
return Ok(self.pending_output(&combined));
|
|
}
|
|
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn read_stream<R>(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>)
|
|
where
|
|
R: AsyncRead + Unpin + Send + 'static,
|
|
{
|
|
let mut reader = BufReader::new(stream);
|
|
let mut buffer = Vec::new();
|
|
|
|
loop {
|
|
let mut chunk = [0u8; 4096];
|
|
match reader.read(&mut chunk).await {
|
|
Ok(0) => break,
|
|
Ok(n) => {
|
|
buffer.extend_from_slice(&chunk[..n]);
|
|
|
|
// 处理完整的行
|
|
while let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
|
|
let line_bytes = &buffer[..pos + 1];
|
|
let line = decode_bytes(line_bytes);
|
|
let _ = tx.send((is_stderr, line));
|
|
buffer.drain(..pos + 1);
|
|
}
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
|
|
// 处理剩余的字节
|
|
if !buffer.is_empty() {
|
|
let remainder = decode_bytes(&buffer);
|
|
let _ = tx.send((is_stderr, remainder));
|
|
}
|
|
}
|
|
|
|
/// 尝试 UTF-8 解码,失败则尝试 GBK 解码
|
|
fn decode_bytes(bytes: &[u8]) -> String {
|
|
// 首先尝试 UTF-8
|
|
if let Ok(s) = std::str::from_utf8(bytes) {
|
|
return s.to_string();
|
|
}
|
|
|
|
// 尝试 GBK 解码
|
|
let (cow, _, had_errors) = encoding_rs::GBK.decode(bytes);
|
|
if !had_errors {
|
|
return cow.to_string();
|
|
}
|
|
|
|
// 如果 GBK 也失败,使用 lossy 转换
|
|
String::from_utf8_lossy(bytes).to_string()
|
|
}
|
|
|
|
fn format_command_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
|
|
let mut output = String::new();
|
|
|
|
if !stdout.is_empty() {
|
|
output.push_str(stdout);
|
|
}
|
|
|
|
if !stderr.trim().is_empty() {
|
|
if !output.is_empty() {
|
|
output.push_str("\n");
|
|
}
|
|
output.push_str("STDERR:\n");
|
|
output.push_str(stderr);
|
|
}
|
|
|
|
if let Some(code) = exit_code {
|
|
output.push_str(&format!("\nExit code: {}", code));
|
|
}
|
|
|
|
output
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_simple_command() {
|
|
let tool = BashTool::new();
|
|
let command = if cfg!(target_os = "windows") {
|
|
"Write-Output 'Hello World'"
|
|
} else {
|
|
"echo 'Hello World'"
|
|
};
|
|
let result = tool
|
|
.execute(json!({ "command": command }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(result.success);
|
|
assert!(result.output.contains("Hello World"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_pwd_command() {
|
|
let tool = BashTool::new();
|
|
let command = if cfg!(target_os = "windows") {
|
|
"Get-Location"
|
|
} else {
|
|
"pwd"
|
|
};
|
|
let result = tool.execute(json!({ "command": command })).await.unwrap();
|
|
|
|
assert!(result.success);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_ls_command() {
|
|
let tool = BashTool::new();
|
|
let temp_dir = std::env::temp_dir();
|
|
let command = if cfg!(target_os = "windows") {
|
|
format!("Get-ChildItem {}", temp_dir.display())
|
|
} else {
|
|
format!("ls -la {}", temp_dir.display())
|
|
};
|
|
let result = tool
|
|
.execute(json!({ "command": command }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(result.success);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dangerous_rm() {
|
|
let tool = BashTool::new();
|
|
// 测试 Unix 危险命令模式
|
|
let result = tool
|
|
.execute(json!({ "command": "rm -rf /some/path" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("blocked"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dangerous_windows_commands() {
|
|
let tool = BashTool::new();
|
|
// 测试 Windows del 命令模式(正则应该匹配)
|
|
let result = tool
|
|
.execute(json!({ "command": "del /f /q file.txt" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("blocked"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dangerous_fork_bomb() {
|
|
let tool = BashTool::new();
|
|
let result = tool
|
|
.execute(json!({ "command": ":(){ :|:& };:" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("blocked"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_missing_command() {
|
|
let tool = BashTool::new();
|
|
let result = tool.execute(json!({})).await.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("command"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_timeout() {
|
|
let tool = BashTool::new();
|
|
let command = if cfg!(target_os = "windows") {
|
|
"Start-Sleep -Seconds 10"
|
|
} else {
|
|
"sleep 10"
|
|
};
|
|
let result = tool
|
|
.execute(json!({
|
|
"command": command,
|
|
"timeout": 1
|
|
}))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("timed out"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_pending_user_action_detection() {
|
|
let tool = BashTool::new();
|
|
let command = if cfg!(target_os = "windows") {
|
|
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
|
|
} else {
|
|
"printf 'waiting for authorization'; sleep 10"
|
|
};
|
|
let result = tool
|
|
.execute(json!({
|
|
"command": command,
|
|
"timeout": 1,
|
|
"interactive": true
|
|
}))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(result.success);
|
|
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
|
|
}
|
|
|
|
#[test]
|
|
fn test_truncate_output_handles_utf8_char_boundaries() {
|
|
let tool = BashTool::new();
|
|
let input = "全".repeat(MAX_OUTPUT_CHARS + 100);
|
|
|
|
let output = tool.truncate_output(&input);
|
|
|
|
assert!(output.contains("chars truncated"));
|
|
assert!(output.is_char_boundary(output.len()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_shell_kind_detect() {
|
|
let shell = ShellKind::detect();
|
|
if cfg!(target_os = "windows") {
|
|
assert_eq!(shell, ShellKind::PowerShell);
|
|
} else {
|
|
assert_eq!(shell, ShellKind::Bash);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_shell_kind_executable() {
|
|
assert_eq!(ShellKind::Bash.executable(), "bash");
|
|
assert_eq!(ShellKind::PowerShell.executable(), "powershell");
|
|
assert_eq!(ShellKind::Cmd.executable(), "cmd");
|
|
}
|
|
|
|
#[test]
|
|
fn test_shell_kind_command_args() {
|
|
assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]);
|
|
assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]);
|
|
assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]);
|
|
}
|
|
}
|