use std::path::Path; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use serde_json::json; use tokio::io::{AsyncRead, AsyncReadExt, BufReader}; use tokio::process::Command; use tokio::sync::{Mutex, mpsc}; use tokio::time::{Instant, sleep_until}; use crate::platform::{ShellInfo, dangerous_command_patterns}; use crate::tools::traits::{Tool, ToolResult}; use crate::tools::{extract_u64, extract_bool, check_null_args}; const MAX_TIMEOUT_SECS: u64 = 600; const MAX_OUTPUT_CHARS: usize = 50_000; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; /// Shell 类型枚举,支持跨平台 /// /// 这是 ShellInfo 的兼容包装,提供更方便的 API。 #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ShellKind { Bash, PowerShell, Cmd, } impl ShellKind { /// 根据平台检测默认 shell pub fn detect() -> Self { let info = ShellInfo::default(); match info.executable { "bash" => ShellKind::Bash, "powershell" => ShellKind::PowerShell, "cmd" => ShellKind::Cmd, _ => ShellKind::Bash, // fallback } } /// 从 ShellInfo 获取 ShellKind pub fn from_info(info: &ShellInfo) -> Self { match info.executable { "bash" => ShellKind::Bash, "powershell" => ShellKind::PowerShell, "cmd" => ShellKind::Cmd, _ => ShellKind::Bash, } } /// 获取对应的 ShellInfo pub fn to_info(&self) -> ShellInfo { match self { ShellKind::Bash => ShellInfo { name: "bash", executable: "bash", args: &["-c"], }, ShellKind::PowerShell => ShellInfo { name: "shell", executable: "powershell", args: &["-Command"], }, ShellKind::Cmd => ShellInfo { name: "shell", executable: "cmd", args: &["/C"], }, } } /// Shell 可执行文件名 pub fn executable(&self) -> &'static str { self.to_info().executable } /// 执行命令所需的参数 pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> { let info = self.to_info(); info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect() } /// 工具名称 pub fn tool_name(&self) -> &'static str { self.to_info().name } /// 工具描述 pub fn tool_description(&self) -> &'static str { match self { ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.", ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.", ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.", } } } pub struct BashTool { timeout_secs: u64, working_dir: Option, deny_patterns: Vec, shell: ShellKind, } impl BashTool { pub fn new() -> Self { Self { timeout_secs: 60, working_dir: None, deny_patterns: dangerous_command_patterns(), shell: ShellKind::detect(), } } pub fn with_timeout(mut self, timeout_secs: u64) -> Self { self.timeout_secs = timeout_secs; self } pub fn with_working_dir(mut self, dir: String) -> Self { self.working_dir = Some(dir); self } pub fn with_shell(mut self, shell: ShellKind) -> Self { self.shell = shell; self } fn guard_command(&self, command: &str) -> Option { let lower = command.to_lowercase(); for pattern in &self.deny_patterns { if regex::Regex::new(pattern) .ok() .map(|re| re.is_match(&lower)) .unwrap_or(false) { return Some(format!( "Command blocked by safety guard (dangerous pattern: {})", pattern )); } } None } fn truncate_output(&self, output: &str) -> String { let char_count = output.chars().count(); if char_count <= MAX_OUTPUT_CHARS { return output.to_string(); } let half = MAX_OUTPUT_CHARS / 2; let head: String = output.chars().take(half).collect(); let tail: String = output .chars() .skip(char_count.saturating_sub(half)) .collect(); format!( "{}...\n\n(... {} chars truncated ...)\n\n{}", head, char_count - MAX_OUTPUT_CHARS, tail ) } fn pending_output(&self, output: &str) -> String { format!( "{}\n{}\n\n{}", PENDING_USER_ACTION_MARKER, USER_ACTION_HINT, self.truncate_output(output.trim()) ) } fn should_return_pending(&self, interactive: bool, output: &str) -> bool { let normalized = output.to_lowercase(); let has_auth_phrase = [ // 中文 — 原有 "等待用户授权", "等待授权", "等待你授权", "在浏览器中打开以下链接进行认证", // 中文 — 新增(lark-cli 等工具的常见提示) "请在浏览器中", "请打开以下链接", "打开以下链接", "打开链接", "访问以下", "访问此链接", "复制链接", "输入验证码", "输入授权码", "完成认证", "完成授权", "请登录", "正在等待", "等待用户", "手动授权", // 英文 — 原有 "open the following link", "waiting for authorization", "waiting for user authorization", "waiting for approval", "device/verify", "user_code=", // 英文 — 新增 "visit the following url", "visit this url", "open the following url", "browser to authenticate", "browser to complete", "enter the code", "enter code", "verification code", "authorization code", "one-time code", "device code", "oauth", "go to the following", "navigate to the following", "paste the code", ] .iter() .any(|pattern| normalized.contains(pattern)); has_auth_phrase || (interactive && !output.trim().is_empty()) } } async fn drain_available_chunks( rx: &mut mpsc::UnboundedReceiver<(bool, String)>, stdout_buf: &Arc>, stderr_buf: &Arc>, ) { while let Ok((is_stderr, chunk)) = rx.try_recv() { if is_stderr { stderr_buf.lock().await.push_str(&chunk); } else { stdout_buf.lock().await.push_str(&chunk); } } } impl Default for BashTool { fn default() -> Self { Self::new() } } #[async_trait] impl Tool for BashTool { fn name(&self) -> &str { self.shell.tool_name() } fn description(&self) -> &str { self.shell.tool_description() } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "command": { "type": "string", "description": "The shell command to execute" }, "timeout": { "type": "integer", "description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS), "minimum": 1, "maximum": MAX_TIMEOUT_SECS }, "interactive": { "type": "boolean", "description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication" } }, "required": ["command"] }) } fn exclusive(&self) -> bool { true // Shell commands should not run concurrently } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { if let Some(result) = check_null_args(&args, "bash") { return Ok(result); } let command = match args.get("command").and_then(|v| v.as_str()) { Some(c) => c, None => { return Ok(ToolResult { success: false, output: String::new(), error: Some("Missing required parameter: command".to_string()), }); } }; // Safety check if let Some(error) = self.guard_command(command) { return Ok(ToolResult { success: false, output: String::new(), error: Some(error), }); } let timeout_secs = extract_u64(&args, "timeout") .unwrap_or(self.timeout_secs) .min(MAX_TIMEOUT_SECS); let interactive = extract_bool(&args, "interactive").unwrap_or(false); let cwd = self .working_dir .as_ref() .map(|d| Path::new(d)) .unwrap_or_else(|| Path::new(".")); match self .run_command(command, cwd, timeout_secs, interactive) .await { Ok(output) => Ok(ToolResult { success: true, output, error: None, }), Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(e), }), } } } impl BashTool { async fn run_command( &self, command: &str, cwd: &Path, timeout_secs: u64, interactive: bool, ) -> Result { let mut cmd = Command::new(self.shell.executable()); cmd.args(self.shell.command_args(command)) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .current_dir(cwd); let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; let stdout = child.stdout.take(); let stderr = child.stderr.take(); let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>(); if let Some(stdout) = stdout { tokio::spawn(read_stream(stdout, false, tx.clone())); } if let Some(stderr) = stderr { tokio::spawn(read_stream(stderr, true, tx.clone())); } drop(tx); let stdout_buf = Arc::new(Mutex::new(String::new())); let stderr_buf = Arc::new(Mutex::new(String::new())); let deadline = Instant::now() + Duration::from_secs(timeout_secs); loop { tokio::select! { status = child.wait() => { let status = status.map_err(|e| format!("Failed to wait: {}", e))?; drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; while let Some((is_stderr, chunk)) = rx.recv().await { if is_stderr { stderr_buf.lock().await.push_str(&chunk); } else { stdout_buf.lock().await.push_str(&chunk); } } let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1))); return Ok(self.truncate_output(&output)); } Some((is_stderr, chunk)) = rx.recv() => { if is_stderr { stderr_buf.lock().await.push_str(&chunk); } else { stdout_buf.lock().await.push_str(&chunk); } let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); if self.should_return_pending(interactive, &combined) { drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; let _ = child.start_kill(); let _ = child.wait().await; let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); return Ok(self.pending_output(&combined)); } } _ = tokio::time::sleep(Duration::from_secs(2)) => { // Periodic safety net: when output has been silent for 2s, // check OS-level process state to see if the child is // genuinely blocked on stdin. Also re-run keyword detection // in case read_stream flushed a partial line since the last // rx.recv() iteration. if let Some(pid) = child.id() { if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) { drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); if !combined.trim().is_empty() { let _ = child.start_kill(); let _ = child.wait().await; return Ok(self.pending_output(&combined)); } } } let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); if self.should_return_pending(interactive, &combined) { drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let _ = child.start_kill(); let _ = child.wait().await; return Ok(self.pending_output(&combined)); } } _ = sleep_until(deadline) => { drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); let _ = child.start_kill(); let _ = child.wait().await; // OS-level process state check: if the child was blocked on // stdin, treat it as pending rather than a hard timeout error. if let Some(pid) = child.id() { if crate::platform::is_process_waiting_on_stdin(pid) == Some(true) && !combined.trim().is_empty() { return Ok(self.pending_output(&combined)); } } if self.should_return_pending(interactive, &combined) { return Ok(self.pending_output(&combined)); } return Err(format!("Command timed out after {} seconds", timeout_secs)); } } } } } /// Flush delay: if no new data arrives within this window, send any buffered /// partial line immediately. This ensures that prompts and URLs printed /// without a trailing newline are still visible to the detection logic. const STREAM_FLUSH_MS: u64 = 500; async fn read_stream(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>) where R: AsyncRead + Unpin + Send + 'static, { let mut reader = BufReader::new(stream); let mut buffer = Vec::new(); loop { let mut chunk = [0u8; 4096]; tokio::select! { result = reader.read(&mut chunk) => { match result { Ok(0) => break, Ok(n) => { buffer.extend_from_slice(&chunk[..n]); // 发送完整行(逻辑不变) while let Some(pos) = buffer.iter().position(|&b| b == b'\n') { let line_bytes = &buffer[..pos + 1]; let line = decode_bytes(line_bytes); let _ = tx.send((is_stderr, line)); buffer.drain(..pos + 1); } } Err(_) => break, } } _ = tokio::time::sleep(Duration::from_millis(STREAM_FLUSH_MS)) => { // 超时未收到新数据:flush 不完整的行 // 这确保像 lark-cli auth login 打印的 URL(不带换行符) // 也能被 run_command 的检测逻辑看到 if !buffer.is_empty() { let remainder = decode_bytes(&buffer); let _ = tx.send((is_stderr, remainder)); buffer.clear(); } } } } // 处理剩余的字节 if !buffer.is_empty() { let remainder = decode_bytes(&buffer); let _ = tx.send((is_stderr, remainder)); } } /// 尝试 UTF-8 解码,失败则尝试 GBK 解码 fn decode_bytes(bytes: &[u8]) -> String { // 首先尝试 UTF-8 if let Ok(s) = std::str::from_utf8(bytes) { return s.to_string(); } // 尝试 GBK 解码 let (cow, _, had_errors) = encoding_rs::GBK.decode(bytes); if !had_errors { return cow.to_string(); } // 如果 GBK 也失败,使用 lossy 转换 String::from_utf8_lossy(bytes).to_string() } fn format_command_output(stdout: &str, stderr: &str, exit_code: Option) -> String { let mut output = String::new(); if !stdout.is_empty() { output.push_str(stdout); } if !stderr.trim().is_empty() { if !output.is_empty() { output.push_str("\n"); } output.push_str("STDERR:\n"); output.push_str(stderr); } if let Some(code) = exit_code { output.push_str(&format!("\nExit code: {}", code)); } output } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_simple_command() { let tool = BashTool::new(); let command = if cfg!(target_os = "windows") { "Write-Output 'Hello World'" } else { "echo 'Hello World'" }; let result = tool .execute(json!({ "command": command })) .await .unwrap(); assert!(result.success); assert!(result.output.contains("Hello World")); } #[tokio::test] async fn test_pwd_command() { let tool = BashTool::new(); let command = if cfg!(target_os = "windows") { "Get-Location" } else { "pwd" }; let result = tool.execute(json!({ "command": command })).await.unwrap(); assert!(result.success); } #[tokio::test] async fn test_ls_command() { let tool = BashTool::new(); let temp_dir = std::env::temp_dir(); let command = if cfg!(target_os = "windows") { format!("Get-ChildItem {}", temp_dir.display()) } else { format!("ls -la {}", temp_dir.display()) }; let result = tool .execute(json!({ "command": command })) .await .unwrap(); assert!(result.success); } #[tokio::test] async fn test_dangerous_rm() { let tool = BashTool::new(); // 测试 Unix 危险命令模式 let result = tool .execute(json!({ "command": "rm -rf /some/path" })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("blocked")); } #[tokio::test] async fn test_dangerous_windows_commands() { let tool = BashTool::new(); // 测试 Windows del 命令模式(正则应该匹配) let result = tool .execute(json!({ "command": "del /f /q file.txt" })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("blocked")); } #[tokio::test] async fn test_dangerous_fork_bomb() { let tool = BashTool::new(); let result = tool .execute(json!({ "command": ":(){ :|:& };:" })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("blocked")); } #[tokio::test] async fn test_missing_command() { let tool = BashTool::new(); let result = tool.execute(json!({})).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("command")); } #[tokio::test] async fn test_timeout() { let tool = BashTool::new(); let command = if cfg!(target_os = "windows") { "Start-Sleep -Seconds 10" } else { "sleep 10" }; let result = tool .execute(json!({ "command": command, "timeout": 1 })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("timed out")); } #[tokio::test] async fn test_pending_user_action_detection() { let tool = BashTool::new(); let command = if cfg!(target_os = "windows") { "Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10" } else { "printf 'waiting for authorization'; sleep 10" }; let result = tool .execute(json!({ "command": command, "timeout": 5, "interactive": true })) .await .unwrap(); assert!(result.success); assert!(result.output.contains(PENDING_USER_ACTION_MARKER)); } #[test] fn test_truncate_output_handles_utf8_char_boundaries() { let tool = BashTool::new(); let input = "全".repeat(MAX_OUTPUT_CHARS + 100); let output = tool.truncate_output(&input); assert!(output.contains("chars truncated")); assert!(output.is_char_boundary(output.len())); } #[test] fn test_shell_kind_detect() { let shell = ShellKind::detect(); if cfg!(target_os = "windows") { assert_eq!(shell, ShellKind::PowerShell); } else { assert_eq!(shell, ShellKind::Bash); } } #[test] fn test_shell_kind_executable() { assert_eq!(ShellKind::Bash.executable(), "bash"); assert_eq!(ShellKind::PowerShell.executable(), "powershell"); assert_eq!(ShellKind::Cmd.executable(), "cmd"); } #[test] fn test_shell_kind_command_args() { assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]); assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]); assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]); } }