- Execute shell commands with timeout - Safety guards block dangerous commands (rm -rf, fork bombs) - Output truncation for large outputs - Working directory support - Includes 7 unit tests
316 lines
8.4 KiB
Rust
316 lines
8.4 KiB
Rust
use std::path::Path;
|
|
use std::process::Stdio;
|
|
use std::time::Duration;
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::json;
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::process::Command;
|
|
use tokio::time::timeout;
|
|
|
|
use crate::tools::traits::{Tool, ToolResult};
|
|
|
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
|
const MAX_OUTPUT_CHARS: usize = 50_000;
|
|
|
|
pub struct BashTool {
|
|
timeout_secs: u64,
|
|
working_dir: Option<String>,
|
|
deny_patterns: Vec<String>,
|
|
}
|
|
|
|
impl BashTool {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
timeout_secs: 60,
|
|
working_dir: None,
|
|
deny_patterns: vec![
|
|
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
|
r"\bdel\s+/[fq]\b".to_string(),
|
|
r"\brmdir\s+/s\b".to_string(),
|
|
r":\(\)\s*\{.*\};\s*:".to_string(),
|
|
],
|
|
}
|
|
}
|
|
|
|
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
|
|
self.timeout_secs = timeout_secs;
|
|
self
|
|
}
|
|
|
|
pub fn with_working_dir(mut self, dir: String) -> Self {
|
|
self.working_dir = Some(dir);
|
|
self
|
|
}
|
|
|
|
fn guard_command(&self, command: &str) -> Option<String> {
|
|
let lower = command.to_lowercase();
|
|
for pattern in &self.deny_patterns {
|
|
if regex::Regex::new(pattern)
|
|
.ok()
|
|
.map(|re| re.is_match(&lower))
|
|
.unwrap_or(false)
|
|
{
|
|
return Some(format!(
|
|
"Command blocked by safety guard (dangerous pattern: {})",
|
|
pattern
|
|
));
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
fn truncate_output(&self, output: &str) -> String {
|
|
if output.len() <= MAX_OUTPUT_CHARS {
|
|
return output.to_string();
|
|
}
|
|
|
|
let half = MAX_OUTPUT_CHARS / 2;
|
|
format!(
|
|
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
|
&output[..half],
|
|
output.len() - MAX_OUTPUT_CHARS,
|
|
&output[output.len() - half..]
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Default for BashTool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for BashTool {
|
|
fn name(&self) -> &str {
|
|
"bash"
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
"Execute a bash shell command and return its output. Use with caution."
|
|
}
|
|
|
|
fn parameters_schema(&self) -> serde_json::Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"command": {
|
|
"type": "string",
|
|
"description": "The shell command to execute"
|
|
},
|
|
"timeout": {
|
|
"type": "integer",
|
|
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
|
|
"minimum": 1,
|
|
"maximum": MAX_TIMEOUT_SECS
|
|
}
|
|
},
|
|
"required": ["command"]
|
|
})
|
|
}
|
|
|
|
fn exclusive(&self) -> bool {
|
|
true // Shell commands should not run concurrently
|
|
}
|
|
|
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
let command = match args.get("command").and_then(|v| v.as_str()) {
|
|
Some(c) => c,
|
|
None => {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some("Missing required parameter: command".to_string()),
|
|
});
|
|
}
|
|
};
|
|
|
|
// Safety check
|
|
if let Some(error) = self.guard_command(command) {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(error),
|
|
});
|
|
}
|
|
|
|
let timeout_secs = args
|
|
.get("timeout")
|
|
.and_then(|v| v.as_u64())
|
|
.unwrap_or(self.timeout_secs)
|
|
.min(MAX_TIMEOUT_SECS);
|
|
|
|
let cwd = self
|
|
.working_dir
|
|
.as_ref()
|
|
.map(|d| Path::new(d))
|
|
.unwrap_or_else(|| Path::new("."));
|
|
|
|
let result = timeout(
|
|
Duration::from_secs(timeout_secs),
|
|
self.run_command(command, cwd),
|
|
)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(Ok(output)) => Ok(ToolResult {
|
|
success: true,
|
|
output,
|
|
error: None,
|
|
}),
|
|
Ok(Err(e)) => Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(e),
|
|
}),
|
|
Err(_) => Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(format!(
|
|
"Command timed out after {} seconds",
|
|
timeout_secs
|
|
)),
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl BashTool {
|
|
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
|
|
let mut cmd = Command::new("bash");
|
|
cmd.args(["-c", command])
|
|
.stdout(Stdio::piped())
|
|
.stderr(Stdio::piped())
|
|
.current_dir(cwd);
|
|
|
|
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
|
|
|
let mut stdout = Vec::new();
|
|
let mut stderr = Vec::new();
|
|
|
|
if let Some(ref mut out) = child.stdout {
|
|
out.read_to_end(&mut stdout)
|
|
.await
|
|
.map_err(|e| format!("Failed to read stdout: {}", e))?;
|
|
}
|
|
|
|
if let Some(ref mut err) = child.stderr {
|
|
err.read_to_end(&mut stderr)
|
|
.await
|
|
.map_err(|e| format!("Failed to read stderr: {}", e))?;
|
|
}
|
|
|
|
let status = child
|
|
.wait()
|
|
.await
|
|
.map_err(|e| format!("Failed to wait: {}", e))?;
|
|
|
|
let mut output = String::new();
|
|
|
|
if !stdout.is_empty() {
|
|
let stdout_str = String::from_utf8_lossy(&stdout);
|
|
output.push_str(&stdout_str);
|
|
}
|
|
|
|
if !stderr.is_empty() {
|
|
let stderr_str = String::from_utf8_lossy(&stderr);
|
|
if !stderr_str.trim().is_empty() {
|
|
if !output.is_empty() {
|
|
output.push_str("\n");
|
|
}
|
|
output.push_str("STDERR:\n");
|
|
output.push_str(&stderr_str);
|
|
}
|
|
}
|
|
|
|
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
|
|
|
|
Ok(self.truncate_output(&output))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_simple_command() {
|
|
let tool = BashTool::new();
|
|
let result = tool
|
|
.execute(json!({ "command": "echo 'Hello World'" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(result.success);
|
|
assert!(result.output.contains("Hello World"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_pwd_command() {
|
|
let tool = BashTool::new();
|
|
let result = tool
|
|
.execute(json!({ "command": "pwd" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(result.success);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_ls_command() {
|
|
let tool = BashTool::new();
|
|
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
|
|
|
assert!(result.success);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dangerous_rm() {
|
|
let tool = BashTool::new();
|
|
let result = tool
|
|
.execute(json!({ "command": "rm -rf /" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("blocked"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dangerous_fork_bomb() {
|
|
let tool = BashTool::new();
|
|
let result = tool
|
|
.execute(json!({ "command": ":(){ :|:& };:" }))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("blocked"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_missing_command() {
|
|
let tool = BashTool::new();
|
|
let result = tool.execute(json!({})).await.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("command"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_timeout() {
|
|
let tool = BashTool::new();
|
|
let result = tool
|
|
.execute(json!({
|
|
"command": "sleep 10",
|
|
"timeout": 1
|
|
}))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("timed out"));
|
|
}
|
|
}
|