diff --git a/src/tools/bash.rs b/src/tools/bash.rs new file mode 100644 index 0000000..d2c1e85 --- /dev/null +++ b/src/tools/bash.rs @@ -0,0 +1,315 @@ +use std::path::Path; +use std::process::Stdio; +use std::time::Duration; + +use async_trait::async_trait; +use serde_json::json; +use tokio::io::AsyncReadExt; +use tokio::process::Command; +use tokio::time::timeout; + +use crate::tools::traits::{Tool, ToolResult}; + +const MAX_TIMEOUT_SECS: u64 = 600; +const MAX_OUTPUT_CHARS: usize = 50_000; + +pub struct BashTool { + timeout_secs: u64, + working_dir: Option, + deny_patterns: Vec, +} + +impl BashTool { + pub fn new() -> Self { + Self { + timeout_secs: 60, + working_dir: None, + deny_patterns: vec![ + r"\brm\s+-[rf]{1,2}\b".to_string(), + r"\bdel\s+/[fq]\b".to_string(), + r"\brmdir\s+/s\b".to_string(), + r":\(\)\s*\{.*\};\s*:".to_string(), + ], + } + } + + pub fn with_timeout(mut self, timeout_secs: u64) -> Self { + self.timeout_secs = timeout_secs; + self + } + + pub fn with_working_dir(mut self, dir: String) -> Self { + self.working_dir = Some(dir); + self + } + + fn guard_command(&self, command: &str) -> Option { + let lower = command.to_lowercase(); + for pattern in &self.deny_patterns { + if regex::Regex::new(pattern) + .ok() + .map(|re| re.is_match(&lower)) + .unwrap_or(false) + { + return Some(format!( + "Command blocked by safety guard (dangerous pattern: {})", + pattern + )); + } + } + None + } + + fn truncate_output(&self, output: &str) -> String { + if output.len() <= MAX_OUTPUT_CHARS { + return output.to_string(); + } + + let half = MAX_OUTPUT_CHARS / 2; + format!( + "{}...\n\n(... {} chars truncated ...)\n\n{}", + &output[..half], + output.len() - MAX_OUTPUT_CHARS, + &output[output.len() - half..] + ) + } +} + +impl Default for BashTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Tool for BashTool { + fn name(&self) -> &str { + "bash" + } + + fn description(&self) -> &str { + "Execute a bash shell command and return its output. Use with caution." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + }, + "timeout": { + "type": "integer", + "description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS), + "minimum": 1, + "maximum": MAX_TIMEOUT_SECS + } + }, + "required": ["command"] + }) + } + + fn exclusive(&self) -> bool { + true // Shell commands should not run concurrently + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let command = match args.get("command").and_then(|v| v.as_str()) { + Some(c) => c, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: command".to_string()), + }); + } + }; + + // Safety check + if let Some(error) = self.guard_command(command) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + + let timeout_secs = args + .get("timeout") + .and_then(|v| v.as_u64()) + .unwrap_or(self.timeout_secs) + .min(MAX_TIMEOUT_SECS); + + let cwd = self + .working_dir + .as_ref() + .map(|d| Path::new(d)) + .unwrap_or_else(|| Path::new(".")); + + let result = timeout( + Duration::from_secs(timeout_secs), + self.run_command(command, cwd), + ) + .await; + + match result { + Ok(Ok(output)) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Ok(Err(e)) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + Err(_) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Command timed out after {} seconds", + timeout_secs + )), + }), + } + } +} + +impl BashTool { + async fn run_command(&self, command: &str, cwd: &Path) -> Result { + let mut cmd = Command::new("bash"); + cmd.args(["-c", command]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .current_dir(cwd); + + let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; + + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + + if let Some(ref mut out) = child.stdout { + out.read_to_end(&mut stdout) + .await + .map_err(|e| format!("Failed to read stdout: {}", e))?; + } + + if let Some(ref mut err) = child.stderr { + err.read_to_end(&mut stderr) + .await + .map_err(|e| format!("Failed to read stderr: {}", e))?; + } + + let status = child + .wait() + .await + .map_err(|e| format!("Failed to wait: {}", e))?; + + let mut output = String::new(); + + if !stdout.is_empty() { + let stdout_str = String::from_utf8_lossy(&stdout); + output.push_str(&stdout_str); + } + + if !stderr.is_empty() { + let stderr_str = String::from_utf8_lossy(&stderr); + if !stderr_str.trim().is_empty() { + if !output.is_empty() { + output.push_str("\n"); + } + output.push_str("STDERR:\n"); + output.push_str(&stderr_str); + } + } + + output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1))); + + Ok(self.truncate_output(&output)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_simple_command() { + let tool = BashTool::new(); + let result = tool + .execute(json!({ "command": "echo 'Hello World'" })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("Hello World")); + } + + #[tokio::test] + async fn test_pwd_command() { + let tool = BashTool::new(); + let result = tool + .execute(json!({ "command": "pwd" })) + .await + .unwrap(); + + assert!(result.success); + } + + #[tokio::test] + async fn test_ls_command() { + let tool = BashTool::new(); + let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap(); + + assert!(result.success); + } + + #[tokio::test] + async fn test_dangerous_rm() { + let tool = BashTool::new(); + let result = tool + .execute(json!({ "command": "rm -rf /" })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("blocked")); + } + + #[tokio::test] + async fn test_dangerous_fork_bomb() { + let tool = BashTool::new(); + let result = tool + .execute(json!({ "command": ":(){ :|:& };:" })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("blocked")); + } + + #[tokio::test] + async fn test_missing_command() { + let tool = BashTool::new(); + let result = tool.execute(json!({})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("command")); + } + + #[tokio::test] + async fn test_timeout() { + let tool = BashTool::new(); + let result = tool + .execute(json!({ + "command": "sleep 10", + "timeout": 1 + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("timed out")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index cd94be7..ecb4849 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,3 +1,4 @@ +pub mod bash; pub mod calculator; pub mod file_edit; pub mod file_read; @@ -6,6 +7,7 @@ pub mod registry; pub mod schema; pub mod traits; +pub use bash::BashTool; pub use calculator::CalculatorTool; pub use file_edit::FileEditTool; pub use file_read::FileReadTool;