feat(tools): add web_fetch tool for HTML content extraction
- Fetch URL and extract readable text - HTML to plain text conversion - Removes scripts, styles, and HTML tags - Decodes HTML entities - JSON pretty printing - SSRF protection - Includes 6 unit tests
This commit is contained in:
parent
1581732ef9
commit
8936e70a12
@ -7,6 +7,7 @@ pub mod http_request;
|
||||
pub mod registry;
|
||||
pub mod schema;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
|
||||
pub use bash::BashTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
@ -17,3 +18,4 @@ pub use http_request::HttpRequestTool;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use traits::{Tool, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
411
src/tools/web_fetch.rs
Normal file
411
src/tools/web_fetch.rs
Normal file
@ -0,0 +1,411 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct WebFetchTool {
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
user_agent: String,
|
||||
}
|
||||
|
||||
impl WebFetchTool {
|
||||
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
|
||||
Self {
|
||||
max_response_size,
|
||||
timeout_secs,
|
||||
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_url(&self, url: &str) -> Result<String, String> {
|
||||
let url = url.trim();
|
||||
|
||||
if url.is_empty() {
|
||||
return Err("URL cannot be empty".to_string());
|
||||
}
|
||||
|
||||
if url.chars().any(char::is_whitespace) {
|
||||
return Err("URL cannot contain whitespace".to_string());
|
||||
}
|
||||
|
||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||
}
|
||||
|
||||
let host = extract_host(url)?;
|
||||
|
||||
if is_private_host(&host) {
|
||||
return Err(format!("Blocked local/private host: {}", host));
|
||||
}
|
||||
|
||||
Ok(url.to_string())
|
||||
}
|
||||
|
||||
fn truncate_response(&self, text: &str) -> String {
|
||||
if self.max_response_size == 0 {
|
||||
return text.to_string();
|
||||
}
|
||||
|
||||
if text.len() > self.max_response_size {
|
||||
format!(
|
||||
"{}\n\n... [Response truncated due to size limit] ...",
|
||||
&text[..self.max_response_size]
|
||||
)
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::USER_AGENT,
|
||||
self.user_agent.parse().unwrap(),
|
||||
);
|
||||
|
||||
let response = client
|
||||
.get(url)
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Request failed: {}", e))?;
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Handle HTML content
|
||||
if content_type.contains("text/html") {
|
||||
let html = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||||
return Ok(self.extract_text_from_html(&html));
|
||||
}
|
||||
|
||||
// Handle JSON content
|
||||
if content_type.contains("application/json") {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||||
// Pretty print JSON
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
|
||||
}
|
||||
return Ok(text);
|
||||
}
|
||||
|
||||
// For other content types, return raw text
|
||||
response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read response: {}", e))
|
||||
}
|
||||
|
||||
fn extract_text_from_html(&self, html: &str) -> String {
|
||||
let mut text = html.to_string();
|
||||
|
||||
// Remove script and style tags with content using simple replacements
|
||||
text = strip_tag(&text, "script");
|
||||
text = strip_tag(&text, "style");
|
||||
|
||||
// Remove all HTML tags
|
||||
text = strip_all_tags(&text);
|
||||
|
||||
// Decode HTML entities
|
||||
text = self.decode_html_entities(&text);
|
||||
|
||||
// Clean up whitespace
|
||||
let mut cleaned = String::new();
|
||||
let mut last_was_space = true;
|
||||
for c in text.chars() {
|
||||
if c.is_whitespace() {
|
||||
if !last_was_space {
|
||||
cleaned.push(' ');
|
||||
last_was_space = true;
|
||||
}
|
||||
} else {
|
||||
cleaned.push(c);
|
||||
last_was_space = false;
|
||||
}
|
||||
}
|
||||
|
||||
cleaned.trim().to_string()
|
||||
}
|
||||
|
||||
fn decode_html_entities(&self, text: &str) -> String {
|
||||
let entities = [
|
||||
(" ", " "),
|
||||
("<", "<"),
|
||||
(">", ">"),
|
||||
("&", "&"),
|
||||
(""", "\""),
|
||||
("'", "'"),
|
||||
("—", "—"),
|
||||
("–", "–"),
|
||||
("©", "©"),
|
||||
("®", "®"),
|
||||
("™", "™"),
|
||||
];
|
||||
|
||||
let mut result = text.to_string();
|
||||
for (entity, replacement) in entities {
|
||||
result = result.replace(entity, replacement);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_tag(s: &str, tag_name: &str) -> String {
|
||||
let open = format!("<{}>", tag_name);
|
||||
let close = format!("</{}>", tag_name);
|
||||
let mut result = s.to_string();
|
||||
|
||||
// Keep removing until no more found (simple approach)
|
||||
while let Some(start) = result.to_lowercase().find(&open) {
|
||||
if let Some(end) = result.to_lowercase()[start..].find(&close) {
|
||||
let end_pos = start + end + close.len();
|
||||
result = format!("{}{}", &result[..start], &result[end_pos..]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn strip_all_tags(s: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut in_tag = false;
|
||||
|
||||
for c in s.chars() {
|
||||
if c == '<' {
|
||||
in_tag = true;
|
||||
} else if c == '>' {
|
||||
in_tag = false;
|
||||
result.push(' ');
|
||||
} else if !in_tag {
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
|
||||
let s_lower = s.to_lowercase();
|
||||
|
||||
let entities = [
|
||||
(" ", ' '),
|
||||
("<", '<'),
|
||||
(">", '>'),
|
||||
("&", '&'),
|
||||
(""", '"'),
|
||||
("'", '\''),
|
||||
("—", '—'),
|
||||
("–", '–'),
|
||||
("©", '©'),
|
||||
("®", '®'),
|
||||
("™", '™'),
|
||||
];
|
||||
|
||||
for (entity, replacement) in entities {
|
||||
if s_lower.starts_with(&entity.to_lowercase()) {
|
||||
return Some((replacement, entity.len()));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle numeric entities
|
||||
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
|
||||
// Skip for now
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_host(url: &str) -> Result<String, String> {
|
||||
let rest = url
|
||||
.strip_prefix("http://")
|
||||
.or_else(|| url.strip_prefix("https://"))
|
||||
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||
|
||||
let authority = rest
|
||||
.split(['/', '?', '#'])
|
||||
.next()
|
||||
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||
|
||||
if authority.is_empty() {
|
||||
return Err("URL must include a host".to_string());
|
||||
}
|
||||
|
||||
let host = authority
|
||||
.split(':')
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_lowercase();
|
||||
|
||||
if host.is_empty() {
|
||||
return Err("URL must include a valid host".to_string());
|
||||
}
|
||||
|
||||
Ok(host)
|
||||
}
|
||||
|
||||
fn is_private_host(host: &str) -> bool {
|
||||
if host == "localhost" || host.ends_with(".localhost") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||
return match ip {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
};
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WebFetchTool {
|
||||
fn name(&self) -> &str {
|
||||
"web_fetch"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to fetch"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let url = match args.get("url").and_then(|v| v.as_str()) {
|
||||
Some(u) => u,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: url".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let url = match self.validate_url(url) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match self.fetch_content(&url).await {
|
||||
Ok(content) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: self.truncate_response(&content),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_tool() -> WebFetchTool {
|
||||
WebFetchTool::new(50_000, 30)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_success() {
|
||||
let tool = test_tool();
|
||||
let result = tool.validate_url("https://example.com");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_rejects_private() {
|
||||
let tool = test_tool();
|
||||
let result = tool.validate_url("https://localhost:8080");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("local/private"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_url_rejects_whitespace() {
|
||||
let tool = test_tool();
|
||||
let result = tool.validate_url("https://example.com/hello world");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("whitespace"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_text_simple() {
|
||||
let tool = test_tool();
|
||||
let html = "<html><body><p>Hello World</p></body></html>";
|
||||
let text = tool.extract_text_from_html(html);
|
||||
assert!(text.contains("Hello World"));
|
||||
assert!(!text.contains("<"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_text_removes_scripts() {
|
||||
let tool = test_tool();
|
||||
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
|
||||
let text = tool.extract_text_from_html(html);
|
||||
assert!(text.contains("Good"));
|
||||
assert!(!text.contains("alert"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_text_removes_styles() {
|
||||
let tool = test_tool();
|
||||
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
|
||||
let text = tool.extract_text_from_html(html);
|
||||
assert!(text.contains("Content"));
|
||||
assert!(!text.contains("color"));
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user