diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs new file mode 100644 index 0000000..cc70cce --- /dev/null +++ b/src/tools/http_request.rs @@ -0,0 +1,444 @@ +use std::time::Duration; + +use async_trait::async_trait; +use reqwest::header::HeaderMap; +use serde_json::json; + +use crate::tools::traits::{Tool, ToolResult}; + +pub struct HttpRequestTool { + allowed_domains: Vec, + max_response_size: usize, + timeout_secs: u64, + allow_private_hosts: bool, +} + +impl HttpRequestTool { + pub fn new( + allowed_domains: Vec, + max_response_size: usize, + timeout_secs: u64, + allow_private_hosts: bool, + ) -> Self { + Self { + allowed_domains: normalize_domains(allowed_domains), + max_response_size, + timeout_secs, + allow_private_hosts, + } + } + + fn validate_url(&self, url: &str) -> Result { + let url = url.trim(); + + if url.is_empty() { + return Err("URL cannot be empty".to_string()); + } + + if url.chars().any(char::is_whitespace) { + return Err("URL cannot contain whitespace".to_string()); + } + + if !url.starts_with("http://") && !url.starts_with("https://") { + return Err("Only http:// and https:// URLs are allowed".to_string()); + } + + let host = extract_host(url)?; + + if !self.allow_private_hosts && is_private_host(&host) { + return Err(format!("Blocked local/private host: {}", host)); + } + + if !host_matches_allowlist(&host, &self.allowed_domains) { + return Err(format!( + "Host '{}' is not in allowed_domains", + host + )); + } + + Ok(url.to_string()) + } + + fn validate_method(&self, method: &str) -> Result { + match method.to_uppercase().as_str() { + "GET" => Ok(reqwest::Method::GET), + "POST" => Ok(reqwest::Method::POST), + "PUT" => Ok(reqwest::Method::PUT), + "DELETE" => Ok(reqwest::Method::DELETE), + "PATCH" => Ok(reqwest::Method::PATCH), + _ => Err(format!( + "Unsupported HTTP method: {}. Supported: GET, POST, PUT, DELETE, PATCH", + method + )), + } + } + + fn parse_headers(&self, headers: &serde_json::Value) -> HeaderMap { + let mut header_map = HeaderMap::new(); + + if let Some(obj) = headers.as_object() { + for (key, value) in obj { + if let Some(str_val) = value.as_str() { + if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) { + if let Ok(val) = + reqwest::header::HeaderValue::from_str(str_val) + { + header_map.insert(name, val); + } + } + } + } + } + + header_map + } + + fn truncate_response(&self, text: &str) -> String { + if self.max_response_size == 0 { + return text.to_string(); + } + + if text.len() > self.max_response_size { + format!( + "{}\n\n... [Response truncated due to size limit] ...", + &text[..self.max_response_size] + ) + } else { + text.to_string() + } + } +} + +fn normalize_domains(domains: Vec) -> Vec { + let mut normalized: Vec = domains + .into_iter() + .filter_map(|d| normalize_domain(&d)) + .collect(); + normalized.sort_unstable(); + normalized.dedup(); + normalized +} + +fn normalize_domain(raw: &str) -> Option { + let mut d = raw.trim().to_lowercase(); + if d.is_empty() { + return None; + } + + if let Some(stripped) = d.strip_prefix("https://") { + d = stripped.to_string(); + } else if let Some(stripped) = d.strip_prefix("http://") { + d = stripped.to_string(); + } + + if let Some((host, _)) = d.split_once('/') { + d = host.to_string(); + } + + d = d.trim_start_matches('.').trim_end_matches('.').to_string(); + + if let Some((host, _)) = d.split_once(':') { + d = host.to_string(); + } + + if d.is_empty() || d.chars().any(char::is_whitespace) { + return None; + } + + Some(d) +} + +fn extract_host(url: &str) -> Result { + let rest = url + .strip_prefix("http://") + .or_else(|| url.strip_prefix("https://")) + .ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?; + + let authority = rest + .split(['/', '?', '#']) + .next() + .ok_or_else(|| "Invalid URL".to_string())?; + + if authority.is_empty() { + return Err("URL must include a host".to_string()); + } + + if authority.contains('@') { + return Err("URL userinfo is not allowed".to_string()); + } + + if authority.starts_with('[') { + return Err("IPv6 hosts are not supported".to_string()); + } + + let host = authority + .split(':') + .next() + .unwrap_or_default() + .trim() + .trim_end_matches('.') + .to_lowercase(); + + if host.is_empty() { + return Err("URL must include a valid host".to_string()); + } + + Ok(host) +} + +fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { + if allowed_domains.iter().any(|domain| domain == "*") { + return true; + } + + allowed_domains.iter().any(|domain| { + host == domain + || host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.')) + }) +} + +fn is_private_host(host: &str) -> bool { + // Check localhost + if host == "localhost" || host.ends_with(".localhost") { + return true; + } + + // Check .local TLD + if host.rsplit('.').next().is_some_and(|label| label == "local") { + return true; + } + + // Try to parse as IP + if let Ok(ip) = host.parse::() { + return is_private_ip(&ip); + } + + false +} + +fn is_private_ip(ip: &std::net::IpAddr) -> bool { + match ip { + std::net::IpAddr::V4(v4) => { + v4.is_loopback() + || v4.is_private() + || v4.is_link_local() + || v4.is_unspecified() + || v4.is_broadcast() + || v4.is_multicast() + } + std::net::IpAddr::V6(v6) => { + v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() + } + } +} + +#[async_trait] +impl Tool for HttpRequestTool { + fn name(&self) -> &str { + "http_request" + } + + fn description(&self) -> &str { + "Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH methods. Security: domain allowlist, no local/private hosts." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "HTTP or HTTPS URL to request" + }, + "method": { + "type": "string", + "description": "HTTP method (GET, POST, PUT, DELETE, PATCH)", + "default": "GET" + }, + "headers": { + "type": "object", + "description": "Optional HTTP headers as key-value pairs" + }, + "body": { + "type": "string", + "description": "Optional request body" + } + }, + "required": ["url"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let url = match args.get("url").and_then(|v| v.as_str()) { + Some(u) => u, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: url".to_string()), + }); + } + }; + + let method_str = args + .get("method") + .and_then(|v| v.as_str()) + .unwrap_or("GET"); + + let headers_val = args.get("headers").cloned().unwrap_or(json!({})); + let body = args.get("body").and_then(|v| v.as_str()); + + let url = match self.validate_url(url) { + Ok(u) => u, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }); + } + }; + + let method = match self.validate_method(method_str) { + Ok(m) => m, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }); + } + }; + + let headers = self.parse_headers(&headers_val); + + let client = match reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .build() + { + Ok(c) => c, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to create HTTP client: {}", e)), + }); + } + }; + + let mut request = client.request(method, &url).headers(headers); + + if let Some(body_str) = body { + request = request.body(body_str.to_string()); + } + + match request.send().await { + Ok(response) => { + let status = response.status(); + let status_code = status.as_u16(); + + let response_text = response + .text() + .await + .map(|t| self.truncate_response(&t)) + .unwrap_or_else(|_| "[Failed to read response body]".to_string()); + + let output = format!( + "Status: {} {}\n\nResponse Body:\n{}", + status_code, + status.canonical_reason().unwrap_or("Unknown"), + response_text + ); + + Ok(ToolResult { + success: status.is_success(), + output, + error: if status.is_client_error() || status.is_server_error() { + Some(format!("HTTP {}", status_code)) + } else { + None + }, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("HTTP request failed: {}", e)), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_tool(domains: Vec<&str>) -> HttpRequestTool { + HttpRequestTool::new( + domains.into_iter().map(String::from).collect(), + 1_000_000, + 30, + false, + ) + } + + #[tokio::test] + async fn test_validate_url_success() { + let tool = test_tool(vec!["example.com"]); + let result = tool.validate_url("https://example.com/docs"); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_validate_url_rejects_private() { + let tool = test_tool(vec!["example.com"]); + let result = tool.validate_url("https://localhost:8080"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("local/private")); + } + + #[tokio::test] + async fn test_validate_url_rejects_whitespace() { + let tool = test_tool(vec!["example.com"]); + let result = tool.validate_url("https://example.com/hello world"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("whitespace")); + } + + #[tokio::test] + async fn test_validate_url_requires_allowlist() { + let tool = HttpRequestTool::new(vec![], 1_000_000, 30, false); + let result = tool.validate_url("https://example.com"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("allowed_domains")); + } + + #[tokio::test] + async fn test_validate_method() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_method("GET").is_ok()); + assert!(tool.validate_method("POST").is_ok()); + assert!(tool.validate_method("PUT").is_ok()); + assert!(tool.validate_method("DELETE").is_ok()); + assert!(tool.validate_method("PATCH").is_ok()); + assert!(tool.validate_method("INVALID").is_err()); + } + + #[tokio::test] + async fn test_blocks_loopback() { + assert!(is_private_host("127.0.0.1")); + assert!(is_private_host("localhost")); + } + + #[tokio::test] + async fn test_blocks_private_ranges() { + assert!(is_private_host("10.0.0.1")); + assert!(is_private_host("172.16.0.1")); + assert!(is_private_host("192.168.1.1")); + } + + #[tokio::test] + async fn test_blocks_local_tld() { + assert!(is_private_host("service.local")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index ecb4849..59f3df6 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -3,6 +3,7 @@ pub mod calculator; pub mod file_edit; pub mod file_read; pub mod file_write; +pub mod http_request; pub mod registry; pub mod schema; pub mod traits; @@ -12,6 +13,7 @@ pub use calculator::CalculatorTool; pub use file_edit::FileEditTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; +pub use http_request::HttpRequestTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use traits::{Tool, ToolResult};