use std::path::Path; use std::process::Stdio; use std::time::Duration; use async_trait::async_trait; use serde_json::json; use tokio::io::AsyncReadExt; use tokio::process::Command; use tokio::time::timeout; use crate::tools::traits::{Tool, ToolResult}; const MAX_TIMEOUT_SECS: u64 = 600; const MAX_OUTPUT_CHARS: usize = 50_000; pub struct BashTool { timeout_secs: u64, working_dir: Option, deny_patterns: Vec, } impl BashTool { pub fn new() -> Self { Self { timeout_secs: 60, working_dir: None, deny_patterns: vec![ r"\brm\s+-[rf]{1,2}\b".to_string(), r"\bdel\s+/[fq]\b".to_string(), r"\brmdir\s+/s\b".to_string(), r":\(\)\s*\{.*\};\s*:".to_string(), ], } } pub fn with_timeout(mut self, timeout_secs: u64) -> Self { self.timeout_secs = timeout_secs; self } pub fn with_working_dir(mut self, dir: String) -> Self { self.working_dir = Some(dir); self } fn guard_command(&self, command: &str) -> Option { let lower = command.to_lowercase(); for pattern in &self.deny_patterns { if regex::Regex::new(pattern) .ok() .map(|re| re.is_match(&lower)) .unwrap_or(false) { return Some(format!( "Command blocked by safety guard (dangerous pattern: {})", pattern )); } } None } fn truncate_output(&self, output: &str) -> String { if output.len() <= MAX_OUTPUT_CHARS { return output.to_string(); } let half = MAX_OUTPUT_CHARS / 2; format!( "{}...\n\n(... {} chars truncated ...)\n\n{}", &output[..half], output.len() - MAX_OUTPUT_CHARS, &output[output.len() - half..] ) } } impl Default for BashTool { fn default() -> Self { Self::new() } } #[async_trait] impl Tool for BashTool { fn name(&self) -> &str { "bash" } fn description(&self) -> &str { "Execute a bash shell command and return its output. Use with caution." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "command": { "type": "string", "description": "The shell command to execute" }, "timeout": { "type": "integer", "description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS), "minimum": 1, "maximum": MAX_TIMEOUT_SECS } }, "required": ["command"] }) } fn exclusive(&self) -> bool { true // Shell commands should not run concurrently } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let command = match args.get("command").and_then(|v| v.as_str()) { Some(c) => c, None => { return Ok(ToolResult { success: false, output: String::new(), error: Some("Missing required parameter: command".to_string()), }); } }; // Safety check if let Some(error) = self.guard_command(command) { return Ok(ToolResult { success: false, output: String::new(), error: Some(error), }); } let timeout_secs = args .get("timeout") .and_then(|v| v.as_u64()) .unwrap_or(self.timeout_secs) .min(MAX_TIMEOUT_SECS); let cwd = self .working_dir .as_ref() .map(|d| Path::new(d)) .unwrap_or_else(|| Path::new(".")); let result = timeout( Duration::from_secs(timeout_secs), self.run_command(command, cwd), ) .await; match result { Ok(Ok(output)) => Ok(ToolResult { success: true, output, error: None, }), Ok(Err(e)) => Ok(ToolResult { success: false, output: String::new(), error: Some(e), }), Err(_) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!( "Command timed out after {} seconds", timeout_secs )), }), } } } impl BashTool { async fn run_command(&self, command: &str, cwd: &Path) -> Result { let mut cmd = Command::new("bash"); cmd.args(["-c", command]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .current_dir(cwd); let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; let mut stdout = Vec::new(); let mut stderr = Vec::new(); if let Some(ref mut out) = child.stdout { out.read_to_end(&mut stdout) .await .map_err(|e| format!("Failed to read stdout: {}", e))?; } if let Some(ref mut err) = child.stderr { err.read_to_end(&mut stderr) .await .map_err(|e| format!("Failed to read stderr: {}", e))?; } let status = child .wait() .await .map_err(|e| format!("Failed to wait: {}", e))?; let mut output = String::new(); if !stdout.is_empty() { let stdout_str = String::from_utf8_lossy(&stdout); output.push_str(&stdout_str); } if !stderr.is_empty() { let stderr_str = String::from_utf8_lossy(&stderr); if !stderr_str.trim().is_empty() { if !output.is_empty() { output.push_str("\n"); } output.push_str("STDERR:\n"); output.push_str(&stderr_str); } } output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1))); Ok(self.truncate_output(&output)) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_simple_command() { let tool = BashTool::new(); let result = tool .execute(json!({ "command": "echo 'Hello World'" })) .await .unwrap(); assert!(result.success); assert!(result.output.contains("Hello World")); } #[tokio::test] async fn test_pwd_command() { let tool = BashTool::new(); let result = tool .execute(json!({ "command": "pwd" })) .await .unwrap(); assert!(result.success); } #[tokio::test] async fn test_ls_command() { let tool = BashTool::new(); let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap(); assert!(result.success); } #[tokio::test] async fn test_dangerous_rm() { let tool = BashTool::new(); let result = tool .execute(json!({ "command": "rm -rf /" })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("blocked")); } #[tokio::test] async fn test_dangerous_fork_bomb() { let tool = BashTool::new(); let result = tool .execute(json!({ "command": ":(){ :|:& };:" })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("blocked")); } #[tokio::test] async fn test_missing_command() { let tool = BashTool::new(); let result = tool.execute(json!({})).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("command")); } #[tokio::test] async fn test_timeout() { let tool = BashTool::new(); let result = tool .execute(json!({ "command": "sleep 10", "timeout": 1 })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("timed out")); } }