feat(tools): add bash tool with safety guards
- Execute shell commands with timeout - Safety guards block dangerous commands (rm -rf, fork bombs) - Output truncation for large outputs - Working directory support - Includes 7 unit tests
This commit is contained in:
parent
f3187ceddd
commit
68e3663c2f
315
src/tools/bash.rs
Normal file
315
src/tools/bash.rs
Normal file
@ -0,0 +1,315 @@
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||
|
||||
pub struct BashTool {
|
||||
timeout_secs: u64,
|
||||
working_dir: Option<String>,
|
||||
deny_patterns: Vec<String>,
|
||||
}
|
||||
|
||||
impl BashTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
timeout_secs: 60,
|
||||
working_dir: None,
|
||||
deny_patterns: vec![
|
||||
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
||||
r"\bdel\s+/[fq]\b".to_string(),
|
||||
r"\brmdir\s+/s\b".to_string(),
|
||||
r":\(\)\s*\{.*\};\s*:".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
|
||||
self.timeout_secs = timeout_secs;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_working_dir(mut self, dir: String) -> Self {
|
||||
self.working_dir = Some(dir);
|
||||
self
|
||||
}
|
||||
|
||||
fn guard_command(&self, command: &str) -> Option<String> {
|
||||
let lower = command.to_lowercase();
|
||||
for pattern in &self.deny_patterns {
|
||||
if regex::Regex::new(pattern)
|
||||
.ok()
|
||||
.map(|re| re.is_match(&lower))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Some(format!(
|
||||
"Command blocked by safety guard (dangerous pattern: {})",
|
||||
pattern
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn truncate_output(&self, output: &str) -> String {
|
||||
if output.len() <= MAX_OUTPUT_CHARS {
|
||||
return output.to_string();
|
||||
}
|
||||
|
||||
let half = MAX_OUTPUT_CHARS / 2;
|
||||
format!(
|
||||
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
||||
&output[..half],
|
||||
output.len() - MAX_OUTPUT_CHARS,
|
||||
&output[output.len() - half..]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BashTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for BashTool {
|
||||
fn name(&self) -> &str {
|
||||
"bash"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Execute a bash shell command and return its output. Use with caution."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
|
||||
"minimum": 1,
|
||||
"maximum": MAX_TIMEOUT_SECS
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
fn exclusive(&self) -> bool {
|
||||
true // Shell commands should not run concurrently
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: command".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Safety check
|
||||
if let Some(error) = self.guard_command(command) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let timeout_secs = args
|
||||
.get("timeout")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(self.timeout_secs)
|
||||
.min(MAX_TIMEOUT_SECS);
|
||||
|
||||
let cwd = self
|
||||
.working_dir
|
||||
.as_ref()
|
||||
.map(|d| Path::new(d))
|
||||
.unwrap_or_else(|| Path::new("."));
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
self.run_command(command, cwd),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Ok(Err(e)) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Command timed out after {} seconds",
|
||||
timeout_secs
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BashTool {
|
||||
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
|
||||
let mut cmd = Command::new("bash");
|
||||
cmd.args(["-c", command])
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(cwd);
|
||||
|
||||
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||
|
||||
let mut stdout = Vec::new();
|
||||
let mut stderr = Vec::new();
|
||||
|
||||
if let Some(ref mut out) = child.stdout {
|
||||
out.read_to_end(&mut stdout)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read stdout: {}", e))?;
|
||||
}
|
||||
|
||||
if let Some(ref mut err) = child.stderr {
|
||||
err.read_to_end(&mut stderr)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read stderr: {}", e))?;
|
||||
}
|
||||
|
||||
let status = child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to wait: {}", e))?;
|
||||
|
||||
let mut output = String::new();
|
||||
|
||||
if !stdout.is_empty() {
|
||||
let stdout_str = String::from_utf8_lossy(&stdout);
|
||||
output.push_str(&stdout_str);
|
||||
}
|
||||
|
||||
if !stderr.is_empty() {
|
||||
let stderr_str = String::from_utf8_lossy(&stderr);
|
||||
if !stderr_str.trim().is_empty() {
|
||||
if !output.is_empty() {
|
||||
output.push_str("\n");
|
||||
}
|
||||
output.push_str("STDERR:\n");
|
||||
output.push_str(&stderr_str);
|
||||
}
|
||||
}
|
||||
|
||||
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
|
||||
|
||||
Ok(self.truncate_output(&output))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simple_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "echo 'Hello World'" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Hello World"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pwd_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "pwd" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ls_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_rm() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "rm -rf /" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("blocked"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_fork_bomb() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": ":(){ :|:& };:" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("blocked"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_missing_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("command"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"command": "sleep 10",
|
||||
"timeout": 1
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("timed out"));
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
pub mod bash;
|
||||
pub mod calculator;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
@ -6,6 +7,7 @@ pub mod registry;
|
||||
pub mod schema;
|
||||
pub mod traits;
|
||||
|
||||
pub use bash::BashTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user