use std::time::Duration; use async_trait::async_trait; use reqwest::header::HeaderMap; use serde_json::json; use crate::tools::traits::{Tool, ToolResult}; pub struct HttpRequestTool { allowed_domains: Vec, max_response_size: usize, timeout_secs: u64, allow_private_hosts: bool, } impl HttpRequestTool { pub fn new( allowed_domains: Vec, max_response_size: usize, timeout_secs: u64, allow_private_hosts: bool, ) -> Self { Self { allowed_domains: normalize_domains(allowed_domains), max_response_size, timeout_secs, allow_private_hosts, } } fn validate_url(&self, url: &str) -> Result { let url = url.trim(); if url.is_empty() { return Err("URL cannot be empty".to_string()); } if url.chars().any(char::is_whitespace) { return Err("URL cannot contain whitespace".to_string()); } if !url.starts_with("http://") && !url.starts_with("https://") { return Err("Only http:// and https:// URLs are allowed".to_string()); } let host = extract_host(url)?; if !self.allow_private_hosts && is_private_host(&host) { return Err(format!("Blocked local/private host: {}", host)); } if !host_matches_allowlist(&host, &self.allowed_domains) { return Err(format!( "Host '{}' is not in allowed_domains", host )); } Ok(url.to_string()) } fn validate_method(&self, method: &str) -> Result { match method.to_uppercase().as_str() { "GET" => Ok(reqwest::Method::GET), "POST" => Ok(reqwest::Method::POST), "PUT" => Ok(reqwest::Method::PUT), "DELETE" => Ok(reqwest::Method::DELETE), "PATCH" => Ok(reqwest::Method::PATCH), _ => Err(format!( "Unsupported HTTP method: {}. Supported: GET, POST, PUT, DELETE, PATCH", method )), } } fn parse_headers(&self, headers: &serde_json::Value) -> HeaderMap { let mut header_map = HeaderMap::new(); if let Some(obj) = headers.as_object() { for (key, value) in obj { if let Some(str_val) = value.as_str() { if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) { if let Ok(val) = reqwest::header::HeaderValue::from_str(str_val) { header_map.insert(name, val); } } } } } header_map } fn truncate_response(&self, text: &str) -> String { if self.max_response_size == 0 { return text.to_string(); } if text.len() > self.max_response_size { format!( "{}\n\n... [Response truncated due to size limit] ...", &text[..self.max_response_size] ) } else { text.to_string() } } } fn normalize_domains(domains: Vec) -> Vec { let mut normalized: Vec = domains .into_iter() .filter_map(|d| normalize_domain(&d)) .collect(); normalized.sort_unstable(); normalized.dedup(); normalized } fn normalize_domain(raw: &str) -> Option { let mut d = raw.trim().to_lowercase(); if d.is_empty() { return None; } if let Some(stripped) = d.strip_prefix("https://") { d = stripped.to_string(); } else if let Some(stripped) = d.strip_prefix("http://") { d = stripped.to_string(); } if let Some((host, _)) = d.split_once('/') { d = host.to_string(); } d = d.trim_start_matches('.').trim_end_matches('.').to_string(); if let Some((host, _)) = d.split_once(':') { d = host.to_string(); } if d.is_empty() || d.chars().any(char::is_whitespace) { return None; } Some(d) } fn extract_host(url: &str) -> Result { let rest = url .strip_prefix("http://") .or_else(|| url.strip_prefix("https://")) .ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?; let authority = rest .split(['/', '?', '#']) .next() .ok_or_else(|| "Invalid URL".to_string())?; if authority.is_empty() { return Err("URL must include a host".to_string()); } if authority.contains('@') { return Err("URL userinfo is not allowed".to_string()); } if authority.starts_with('[') { return Err("IPv6 hosts are not supported".to_string()); } let host = authority .split(':') .next() .unwrap_or_default() .trim() .trim_end_matches('.') .to_lowercase(); if host.is_empty() { return Err("URL must include a valid host".to_string()); } Ok(host) } fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { if allowed_domains.iter().any(|domain| domain == "*") { return true; } allowed_domains.iter().any(|domain| { host == domain || host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.')) }) } fn is_private_host(host: &str) -> bool { // Check localhost if host == "localhost" || host.ends_with(".localhost") { return true; } // Check .local TLD if host.rsplit('.').next().is_some_and(|label| label == "local") { return true; } // Try to parse as IP if let Ok(ip) = host.parse::() { return is_private_ip(&ip); } false } fn is_private_ip(ip: &std::net::IpAddr) -> bool { match ip { std::net::IpAddr::V4(v4) => { v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.is_broadcast() || v4.is_multicast() } std::net::IpAddr::V6(v6) => { v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() } } } #[async_trait] impl Tool for HttpRequestTool { fn name(&self) -> &str { "http_request" } fn description(&self) -> &str { "Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH methods. Security: domain allowlist, no local/private hosts." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "url": { "type": "string", "description": "HTTP or HTTPS URL to request" }, "method": { "type": "string", "description": "HTTP method (GET, POST, PUT, DELETE, PATCH)", "default": "GET" }, "headers": { "type": "object", "description": "Optional HTTP headers as key-value pairs" }, "body": { "type": "string", "description": "Optional request body" } }, "required": ["url"] }) } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let url = match args.get("url").and_then(|v| v.as_str()) { Some(u) => u, None => { return Ok(ToolResult { success: false, output: String::new(), error: Some("Missing required parameter: url".to_string()), }); } }; let method_str = args .get("method") .and_then(|v| v.as_str()) .unwrap_or("GET"); let headers_val = args.get("headers").cloned().unwrap_or(json!({})); let body = args.get("body").and_then(|v| v.as_str()); let url = match self.validate_url(url) { Ok(u) => u, Err(e) => { return Ok(ToolResult { success: false, output: String::new(), error: Some(e), }); } }; let method = match self.validate_method(method_str) { Ok(m) => m, Err(e) => { return Ok(ToolResult { success: false, output: String::new(), error: Some(e), }); } }; let headers = self.parse_headers(&headers_val); let client = match reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .build() { Ok(c) => c, Err(e) => { return Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Failed to create HTTP client: {}", e)), }); } }; let mut request = client.request(method, &url).headers(headers); if let Some(body_str) = body { request = request.body(body_str.to_string()); } match request.send().await { Ok(response) => { let status = response.status(); let status_code = status.as_u16(); let response_text = response .text() .await .map(|t| self.truncate_response(&t)) .unwrap_or_else(|_| "[Failed to read response body]".to_string()); let output = format!( "Status: {} {}\n\nResponse Body:\n{}", status_code, status.canonical_reason().unwrap_or("Unknown"), response_text ); Ok(ToolResult { success: status.is_success(), output, error: if status.is_client_error() || status.is_server_error() { Some(format!("HTTP {}", status_code)) } else { None }, }) } Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!("HTTP request failed: {}", e)), }), } } } #[cfg(test)] mod tests { use super::*; fn test_tool(domains: Vec<&str>) -> HttpRequestTool { HttpRequestTool::new( domains.into_iter().map(String::from).collect(), 1_000_000, 30, false, ) } #[tokio::test] async fn test_validate_url_success() { let tool = test_tool(vec!["example.com"]); let result = tool.validate_url("https://example.com/docs"); assert!(result.is_ok()); } #[tokio::test] async fn test_validate_url_rejects_private() { let tool = test_tool(vec!["example.com"]); let result = tool.validate_url("https://localhost:8080"); assert!(result.is_err()); assert!(result.unwrap_err().contains("local/private")); } #[tokio::test] async fn test_validate_url_rejects_whitespace() { let tool = test_tool(vec!["example.com"]); let result = tool.validate_url("https://example.com/hello world"); assert!(result.is_err()); assert!(result.unwrap_err().contains("whitespace")); } #[tokio::test] async fn test_validate_url_requires_allowlist() { let tool = HttpRequestTool::new(vec![], 1_000_000, 30, false); let result = tool.validate_url("https://example.com"); assert!(result.is_err()); assert!(result.unwrap_err().contains("allowed_domains")); } #[tokio::test] async fn test_validate_method() { let tool = test_tool(vec!["example.com"]); assert!(tool.validate_method("GET").is_ok()); assert!(tool.validate_method("POST").is_ok()); assert!(tool.validate_method("PUT").is_ok()); assert!(tool.validate_method("DELETE").is_ok()); assert!(tool.validate_method("PATCH").is_ok()); assert!(tool.validate_method("INVALID").is_err()); } #[tokio::test] async fn test_blocks_loopback() { assert!(is_private_host("127.0.0.1")); assert!(is_private_host("localhost")); } #[tokio::test] async fn test_blocks_private_ranges() { assert!(is_private_host("10.0.0.1")); assert!(is_private_host("172.16.0.1")); assert!(is_private_host("192.168.1.1")); } #[tokio::test] async fn test_blocks_local_tld() { assert!(is_private_host("service.local")); } }