feat(tools): add http_request tool with security features
- HTTP client with GET/POST/PUT/DELETE/PATCH support - Domain allowlist for security - SSRF protection (blocks private IPs, localhost) - Response size limit and truncation - Timeout control - Includes 8 unit tests
This commit is contained in:
parent
68e3663c2f
commit
1581732ef9
444
src/tools/http_request.rs
Normal file
444
src/tools/http_request.rs
Normal file
@ -0,0 +1,444 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
}
|
||||
|
||||
impl HttpRequestTool {
|
||||
pub fn new(
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
allowed_domains: normalize_domains(allowed_domains),
|
||||
max_response_size,
|
||||
timeout_secs,
|
||||
allow_private_hosts,
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_url(&self, url: &str) -> Result<String, String> {
|
||||
let url = url.trim();
|
||||
|
||||
if url.is_empty() {
|
||||
return Err("URL cannot be empty".to_string());
|
||||
}
|
||||
|
||||
if url.chars().any(char::is_whitespace) {
|
||||
return Err("URL cannot contain whitespace".to_string());
|
||||
}
|
||||
|
||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||
}
|
||||
|
||||
let host = extract_host(url)?;
|
||||
|
||||
if !self.allow_private_hosts && is_private_host(&host) {
|
||||
return Err(format!("Blocked local/private host: {}", host));
|
||||
}
|
||||
|
||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||
return Err(format!(
|
||||
"Host '{}' is not in allowed_domains",
|
||||
host
|
||||
));
|
||||
}
|
||||
|
||||
Ok(url.to_string())
|
||||
}
|
||||
|
||||
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
|
||||
match method.to_uppercase().as_str() {
|
||||
"GET" => Ok(reqwest::Method::GET),
|
||||
"POST" => Ok(reqwest::Method::POST),
|
||||
"PUT" => Ok(reqwest::Method::PUT),
|
||||
"DELETE" => Ok(reqwest::Method::DELETE),
|
||||
"PATCH" => Ok(reqwest::Method::PATCH),
|
||||
_ => Err(format!(
|
||||
"Unsupported HTTP method: {}. Supported: GET, POST, PUT, DELETE, PATCH",
|
||||
method
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_headers(&self, headers: &serde_json::Value) -> HeaderMap {
|
||||
let mut header_map = HeaderMap::new();
|
||||
|
||||
if let Some(obj) = headers.as_object() {
|
||||
for (key, value) in obj {
|
||||
if let Some(str_val) = value.as_str() {
|
||||
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) =
|
||||
reqwest::header::HeaderValue::from_str(str_val)
|
||||
{
|
||||
header_map.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
header_map
|
||||
}
|
||||
|
||||
fn truncate_response(&self, text: &str) -> String {
|
||||
if self.max_response_size == 0 {
|
||||
return text.to_string();
|
||||
}
|
||||
|
||||
if text.len() > self.max_response_size {
|
||||
format!(
|
||||
"{}\n\n... [Response truncated due to size limit] ...",
|
||||
&text[..self.max_response_size]
|
||||
)
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
|
||||
let mut normalized: Vec<String> = domains
|
||||
.into_iter()
|
||||
.filter_map(|d| normalize_domain(&d))
|
||||
.collect();
|
||||
normalized.sort_unstable();
|
||||
normalized.dedup();
|
||||
normalized
|
||||
}
|
||||
|
||||
fn normalize_domain(raw: &str) -> Option<String> {
|
||||
let mut d = raw.trim().to_lowercase();
|
||||
if d.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(stripped) = d.strip_prefix("https://") {
|
||||
d = stripped.to_string();
|
||||
} else if let Some(stripped) = d.strip_prefix("http://") {
|
||||
d = stripped.to_string();
|
||||
}
|
||||
|
||||
if let Some((host, _)) = d.split_once('/') {
|
||||
d = host.to_string();
|
||||
}
|
||||
|
||||
d = d.trim_start_matches('.').trim_end_matches('.').to_string();
|
||||
|
||||
if let Some((host, _)) = d.split_once(':') {
|
||||
d = host.to_string();
|
||||
}
|
||||
|
||||
if d.is_empty() || d.chars().any(char::is_whitespace) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(d)
|
||||
}
|
||||
|
||||
fn extract_host(url: &str) -> Result<String, String> {
|
||||
let rest = url
|
||||
.strip_prefix("http://")
|
||||
.or_else(|| url.strip_prefix("https://"))
|
||||
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||
|
||||
let authority = rest
|
||||
.split(['/', '?', '#'])
|
||||
.next()
|
||||
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||
|
||||
if authority.is_empty() {
|
||||
return Err("URL must include a host".to_string());
|
||||
}
|
||||
|
||||
if authority.contains('@') {
|
||||
return Err("URL userinfo is not allowed".to_string());
|
||||
}
|
||||
|
||||
if authority.starts_with('[') {
|
||||
return Err("IPv6 hosts are not supported".to_string());
|
||||
}
|
||||
|
||||
let host = authority
|
||||
.split(':')
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.trim_end_matches('.')
|
||||
.to_lowercase();
|
||||
|
||||
if host.is_empty() {
|
||||
return Err("URL must include a valid host".to_string());
|
||||
}
|
||||
|
||||
Ok(host)
|
||||
}
|
||||
|
||||
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||
if allowed_domains.iter().any(|domain| domain == "*") {
|
||||
return true;
|
||||
}
|
||||
|
||||
allowed_domains.iter().any(|domain| {
|
||||
host == domain
|
||||
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
||||
})
|
||||
}
|
||||
|
||||
fn is_private_host(host: &str) -> bool {
|
||||
// Check localhost
|
||||
if host == "localhost" || host.ends_with(".localhost") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check .local TLD
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try to parse as IP
|
||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||
return is_private_ip(&ip);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||
match ip {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
v4.is_loopback()
|
||||
|| v4.is_private()
|
||||
|| v4.is_link_local()
|
||||
|| v4.is_unspecified()
|
||||
|| v4.is_broadcast()
|
||||
|| v4.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for HttpRequestTool {
|
||||
fn name(&self) -> &str {
|
||||
"http_request"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH methods. Security: domain allowlist, no local/private hosts."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP or HTTPS URL to request"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
|
||||
"default": "GET"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Optional HTTP headers as key-value pairs"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Optional request body"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let url = match args.get("url").and_then(|v| v.as_str()) {
|
||||
Some(u) => u,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: url".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let method_str = args
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
|
||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||
let body = args.get("body").and_then(|v| v.as_str());
|
||||
|
||||
let url = match self.validate_url(url) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let method = match self.validate_method(method_str) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let headers = self.parse_headers(&headers_val);
|
||||
|
||||
let client = match reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_secs))
|
||||
.build()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to create HTTP client: {}", e)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let mut request = client.request(method, &url).headers(headers);
|
||||
|
||||
if let Some(body_str) = body {
|
||||
request = request.body(body_str.to_string());
|
||||
}
|
||||
|
||||
match request.send().await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
let status_code = status.as_u16();
|
||||
|
||||
let response_text = response
|
||||
.text()
|
||||
.await
|
||||
.map(|t| self.truncate_response(&t))
|
||||
.unwrap_or_else(|_| "[Failed to read response body]".to_string());
|
||||
|
||||
let output = format!(
|
||||
"Status: {} {}\n\nResponse Body:\n{}",
|
||||
status_code,
|
||||
status.canonical_reason().unwrap_or("Unknown"),
|
||||
response_text
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: status.is_success(),
|
||||
output,
|
||||
error: if status.is_client_error() || status.is_server_error() {
|
||||
Some(format!("HTTP {}", status_code))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("HTTP request failed: {}", e)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_tool(domains: Vec<&str>) -> HttpRequestTool {
|
||||
HttpRequestTool::new(
|
||||
domains.into_iter().map(String::from).collect(),
|
||||
1_000_000,
|
||||
30,
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_success() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
let result = tool.validate_url("https://example.com/docs");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_rejects_private() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
let result = tool.validate_url("https://localhost:8080");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("local/private"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_rejects_whitespace() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
let result = tool.validate_url("https://example.com/hello world");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("whitespace"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_requires_allowlist() {
|
||||
let tool = HttpRequestTool::new(vec![], 1_000_000, 30, false);
|
||||
let result = tool.validate_url("https://example.com");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("allowed_domains"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_method() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
assert!(tool.validate_method("GET").is_ok());
|
||||
assert!(tool.validate_method("POST").is_ok());
|
||||
assert!(tool.validate_method("PUT").is_ok());
|
||||
assert!(tool.validate_method("DELETE").is_ok());
|
||||
assert!(tool.validate_method("PATCH").is_ok());
|
||||
assert!(tool.validate_method("INVALID").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_blocks_loopback() {
|
||||
assert!(is_private_host("127.0.0.1"));
|
||||
assert!(is_private_host("localhost"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_blocks_private_ranges() {
|
||||
assert!(is_private_host("10.0.0.1"));
|
||||
assert!(is_private_host("172.16.0.1"));
|
||||
assert!(is_private_host("192.168.1.1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_blocks_local_tld() {
|
||||
assert!(is_private_host("service.local"));
|
||||
}
|
||||
}
|
||||
@ -3,6 +3,7 @@ pub mod calculator;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod http_request;
|
||||
pub mod registry;
|
||||
pub mod schema;
|
||||
pub mod traits;
|
||||
@ -12,6 +13,7 @@ pub use calculator::CalculatorTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use traits::{Tool, ToolResult};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user