feat(tools): add http_request tool with security features

- HTTP client with GET/POST/PUT/DELETE/PATCH support
- Domain allowlist for security
- SSRF protection (blocks private IPs, localhost)
- Response size limit and truncation
- Timeout control
- Includes 8 unit tests
This commit is contained in:
xiaoski 2026-04-07 23:49:15 +08:00
parent 68e3663c2f
commit 1581732ef9
2 changed files with 446 additions and 0 deletions

444
src/tools/http_request.rs Normal file
View File

@ -0,0 +1,444 @@
use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct HttpRequestTool {
allowed_domains: Vec<String>,
max_response_size: usize,
timeout_secs: u64,
allow_private_hosts: bool,
}
impl HttpRequestTool {
pub fn new(
allowed_domains: Vec<String>,
max_response_size: usize,
timeout_secs: u64,
allow_private_hosts: bool,
) -> Self {
Self {
allowed_domains: normalize_domains(allowed_domains),
max_response_size,
timeout_secs,
allow_private_hosts,
}
}
fn validate_url(&self, url: &str) -> Result<String, String> {
let url = url.trim();
if url.is_empty() {
return Err("URL cannot be empty".to_string());
}
if url.chars().any(char::is_whitespace) {
return Err("URL cannot contain whitespace".to_string());
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err("Only http:// and https:// URLs are allowed".to_string());
}
let host = extract_host(url)?;
if !self.allow_private_hosts && is_private_host(&host) {
return Err(format!("Blocked local/private host: {}", host));
}
if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!(
"Host '{}' is not in allowed_domains",
host
));
}
Ok(url.to_string())
}
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
match method.to_uppercase().as_str() {
"GET" => Ok(reqwest::Method::GET),
"POST" => Ok(reqwest::Method::POST),
"PUT" => Ok(reqwest::Method::PUT),
"DELETE" => Ok(reqwest::Method::DELETE),
"PATCH" => Ok(reqwest::Method::PATCH),
_ => Err(format!(
"Unsupported HTTP method: {}. Supported: GET, POST, PUT, DELETE, PATCH",
method
)),
}
}
fn parse_headers(&self, headers: &serde_json::Value) -> HeaderMap {
let mut header_map = HeaderMap::new();
if let Some(obj) = headers.as_object() {
for (key, value) in obj {
if let Some(str_val) = value.as_str() {
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) =
reqwest::header::HeaderValue::from_str(str_val)
{
header_map.insert(name, val);
}
}
}
}
}
header_map
}
fn truncate_response(&self, text: &str) -> String {
if self.max_response_size == 0 {
return text.to_string();
}
if text.len() > self.max_response_size {
format!(
"{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size]
)
} else {
text.to_string()
}
}
}
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
let mut normalized: Vec<String> = domains
.into_iter()
.filter_map(|d| normalize_domain(&d))
.collect();
normalized.sort_unstable();
normalized.dedup();
normalized
}
fn normalize_domain(raw: &str) -> Option<String> {
let mut d = raw.trim().to_lowercase();
if d.is_empty() {
return None;
}
if let Some(stripped) = d.strip_prefix("https://") {
d = stripped.to_string();
} else if let Some(stripped) = d.strip_prefix("http://") {
d = stripped.to_string();
}
if let Some((host, _)) = d.split_once('/') {
d = host.to_string();
}
d = d.trim_start_matches('.').trim_end_matches('.').to_string();
if let Some((host, _)) = d.split_once(':') {
d = host.to_string();
}
if d.is_empty() || d.chars().any(char::is_whitespace) {
return None;
}
Some(d)
}
fn extract_host(url: &str) -> Result<String, String> {
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if authority.is_empty() {
return Err("URL must include a host".to_string());
}
if authority.contains('@') {
return Err("URL userinfo is not allowed".to_string());
}
if authority.starts_with('[') {
return Err("IPv6 hosts are not supported".to_string());
}
let host = authority
.split(':')
.next()
.unwrap_or_default()
.trim()
.trim_end_matches('.')
.to_lowercase();
if host.is_empty() {
return Err("URL must include a valid host".to_string());
}
Ok(host)
}
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
if allowed_domains.iter().any(|domain| domain == "*") {
return true;
}
allowed_domains.iter().any(|domain| {
host == domain
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
})
}
fn is_private_host(host: &str) -> bool {
// Check localhost
if host == "localhost" || host.ends_with(".localhost") {
return true;
}
// Check .local TLD
if host.rsplit('.').next().is_some_and(|label| label == "local") {
return true;
}
// Try to parse as IP
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return is_private_ip(&ip);
}
false
}
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
}
}
#[async_trait]
impl Tool for HttpRequestTool {
fn name(&self) -> &str {
"http_request"
}
fn description(&self) -> &str {
"Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH methods. Security: domain allowlist, no local/private hosts."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "HTTP or HTTPS URL to request"
},
"method": {
"type": "string",
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
"default": "GET"
},
"headers": {
"type": "object",
"description": "Optional HTTP headers as key-value pairs"
},
"body": {
"type": "string",
"description": "Optional request body"
}
},
"required": ["url"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let url = match args.get("url").and_then(|v| v.as_str()) {
Some(u) => u,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: url".to_string()),
});
}
};
let method_str = args
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str());
let url = match self.validate_url(url) {
Ok(u) => u,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
let method = match self.validate_method(method_str) {
Ok(m) => m,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
let headers = self.parse_headers(&headers_val);
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.build()
{
Ok(c) => c,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create HTTP client: {}", e)),
});
}
};
let mut request = client.request(method, &url).headers(headers);
if let Some(body_str) = body {
request = request.body(body_str.to_string());
}
match request.send().await {
Ok(response) => {
let status = response.status();
let status_code = status.as_u16();
let response_text = response
.text()
.await
.map(|t| self.truncate_response(&t))
.unwrap_or_else(|_| "[Failed to read response body]".to_string());
let output = format!(
"Status: {} {}\n\nResponse Body:\n{}",
status_code,
status.canonical_reason().unwrap_or("Unknown"),
response_text
);
Ok(ToolResult {
success: status.is_success(),
output,
error: if status.is_client_error() || status.is_server_error() {
Some(format!("HTTP {}", status_code))
} else {
None
},
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("HTTP request failed: {}", e)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tool(domains: Vec<&str>) -> HttpRequestTool {
HttpRequestTool::new(
domains.into_iter().map(String::from).collect(),
1_000_000,
30,
false,
)
}
#[tokio::test]
async fn test_validate_url_success() {
let tool = test_tool(vec!["example.com"]);
let result = tool.validate_url("https://example.com/docs");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_url_rejects_private() {
let tool = test_tool(vec!["example.com"]);
let result = tool.validate_url("https://localhost:8080");
assert!(result.is_err());
assert!(result.unwrap_err().contains("local/private"));
}
#[tokio::test]
async fn test_validate_url_rejects_whitespace() {
let tool = test_tool(vec!["example.com"]);
let result = tool.validate_url("https://example.com/hello world");
assert!(result.is_err());
assert!(result.unwrap_err().contains("whitespace"));
}
#[tokio::test]
async fn test_validate_url_requires_allowlist() {
let tool = HttpRequestTool::new(vec![], 1_000_000, 30, false);
let result = tool.validate_url("https://example.com");
assert!(result.is_err());
assert!(result.unwrap_err().contains("allowed_domains"));
}
#[tokio::test]
async fn test_validate_method() {
let tool = test_tool(vec!["example.com"]);
assert!(tool.validate_method("GET").is_ok());
assert!(tool.validate_method("POST").is_ok());
assert!(tool.validate_method("PUT").is_ok());
assert!(tool.validate_method("DELETE").is_ok());
assert!(tool.validate_method("PATCH").is_ok());
assert!(tool.validate_method("INVALID").is_err());
}
#[tokio::test]
async fn test_blocks_loopback() {
assert!(is_private_host("127.0.0.1"));
assert!(is_private_host("localhost"));
}
#[tokio::test]
async fn test_blocks_private_ranges() {
assert!(is_private_host("10.0.0.1"));
assert!(is_private_host("172.16.0.1"));
assert!(is_private_host("192.168.1.1"));
}
#[tokio::test]
async fn test_blocks_local_tld() {
assert!(is_private_host("service.local"));
}
}

View File

@ -3,6 +3,7 @@ pub mod calculator;
pub mod file_edit; pub mod file_edit;
pub mod file_read; pub mod file_read;
pub mod file_write; pub mod file_write;
pub mod http_request;
pub mod registry; pub mod registry;
pub mod schema; pub mod schema;
pub mod traits; pub mod traits;
@ -12,6 +13,7 @@ pub use calculator::CalculatorTool;
pub use file_edit::FileEditTool; pub use file_edit::FileEditTool;
pub use file_read::FileReadTool; pub use file_read::FileReadTool;
pub use file_write::FileWriteTool; pub use file_write::FileWriteTool;
pub use http_request::HttpRequestTool;
pub use registry::ToolRegistry; pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr}; pub use schema::{CleaningStrategy, SchemaCleanr};
pub use traits::{Tool, ToolResult}; pub use traits::{Tool, ToolResult};