feat(tools): add web_fetch tool for HTML content extraction

- Fetch URL and extract readable text
- HTML to plain text conversion
- Removes scripts, styles, and HTML tags
- Decodes HTML entities
- JSON pretty printing
- SSRF protection
- Includes 6 unit tests
This commit is contained in:
xiaoski 2026-04-07 23:52:06 +08:00
parent 1581732ef9
commit 8936e70a12
2 changed files with 413 additions and 0 deletions

View File

@ -7,6 +7,7 @@ pub mod http_request;
pub mod registry;
pub mod schema;
pub mod traits;
pub mod web_fetch;
pub use bash::BashTool;
pub use calculator::CalculatorTool;
@ -17,3 +18,4 @@ pub use http_request::HttpRequestTool;
pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use traits::{Tool, ToolResult};
pub use web_fetch::WebFetchTool;

411
src/tools/web_fetch.rs Normal file
View File

@ -0,0 +1,411 @@
use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct WebFetchTool {
max_response_size: usize,
timeout_secs: u64,
user_agent: String,
}
impl WebFetchTool {
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
Self {
max_response_size,
timeout_secs,
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
}
}
fn validate_url(&self, url: &str) -> Result<String, String> {
let url = url.trim();
if url.is_empty() {
return Err("URL cannot be empty".to_string());
}
if url.chars().any(char::is_whitespace) {
return Err("URL cannot contain whitespace".to_string());
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err("Only http:// and https:// URLs are allowed".to_string());
}
let host = extract_host(url)?;
if is_private_host(&host) {
return Err(format!("Blocked local/private host: {}", host));
}
Ok(url.to_string())
}
fn truncate_response(&self, text: &str) -> String {
if self.max_response_size == 0 {
return text.to_string();
}
if text.len() > self.max_response_size {
format!(
"{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size]
)
} else {
text.to_string()
}
}
async fn fetch_content(&self, url: &str) -> Result<String, String> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
self.user_agent.parse().unwrap(),
);
let response = client
.get(url)
.headers(headers)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Handle HTML content
if content_type.contains("text/html") {
let html = response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))?;
return Ok(self.extract_text_from_html(&html));
}
// Handle JSON content
if content_type.contains("application/json") {
let text = response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))?;
// Pretty print JSON
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
}
return Ok(text);
}
// For other content types, return raw text
response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))
}
fn extract_text_from_html(&self, html: &str) -> String {
let mut text = html.to_string();
// Remove script and style tags with content using simple replacements
text = strip_tag(&text, "script");
text = strip_tag(&text, "style");
// Remove all HTML tags
text = strip_all_tags(&text);
// Decode HTML entities
text = self.decode_html_entities(&text);
// Clean up whitespace
let mut cleaned = String::new();
let mut last_was_space = true;
for c in text.chars() {
if c.is_whitespace() {
if !last_was_space {
cleaned.push(' ');
last_was_space = true;
}
} else {
cleaned.push(c);
last_was_space = false;
}
}
cleaned.trim().to_string()
}
fn decode_html_entities(&self, text: &str) -> String {
let entities = [
("&nbsp;", " "),
("&lt;", "<"),
("&gt;", ">"),
("&amp;", "&"),
("&quot;", "\""),
("&apos;", "'"),
("&mdash;", ""),
("&ndash;", ""),
("&copy;", "©"),
("&reg;", "®"),
("&trade;", ""),
];
let mut result = text.to_string();
for (entity, replacement) in entities {
result = result.replace(entity, replacement);
}
result
}
}
fn strip_tag(s: &str, tag_name: &str) -> String {
let open = format!("<{}>", tag_name);
let close = format!("</{}>", tag_name);
let mut result = s.to_string();
// Keep removing until no more found (simple approach)
while let Some(start) = result.to_lowercase().find(&open) {
if let Some(end) = result.to_lowercase()[start..].find(&close) {
let end_pos = start + end + close.len();
result = format!("{}{}", &result[..start], &result[end_pos..]);
} else {
break;
}
}
result
}
fn strip_all_tags(s: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
for c in s.chars() {
if c == '<' {
in_tag = true;
} else if c == '>' {
in_tag = false;
result.push(' ');
} else if !in_tag {
result.push(c);
}
}
result
}
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
let s_lower = s.to_lowercase();
let entities = [
("&nbsp;", ' '),
("&lt;", '<'),
("&gt;", '>'),
("&amp;", '&'),
("&quot;", '"'),
("&apos;", '\''),
("&mdash;", '—'),
("&ndash;", ''),
("&copy;", '©'),
("&reg;", '®'),
("&trade;", '™'),
];
for (entity, replacement) in entities {
if s_lower.starts_with(&entity.to_lowercase()) {
return Some((replacement, entity.len()));
}
}
// Handle numeric entities
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
// Skip for now
}
None
}
fn extract_host(url: &str) -> Result<String, String> {
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if authority.is_empty() {
return Err("URL must include a host".to_string());
}
let host = authority
.split(':')
.next()
.unwrap_or_default()
.trim()
.to_lowercase();
if host.is_empty() {
return Err("URL must include a valid host".to_string());
}
Ok(host)
}
fn is_private_host(host: &str) -> bool {
if host == "localhost" || host.ends_with(".localhost") {
return true;
}
if host.rsplit('.').next().is_some_and(|label| label == "local") {
return true;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
};
}
false
}
#[async_trait]
impl Tool for WebFetchTool {
fn name(&self) -> &str {
"web_fetch"
}
fn description(&self) -> &str {
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to fetch"
}
},
"required": ["url"]
})
}
fn read_only(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let url = match args.get("url").and_then(|v| v.as_str()) {
Some(u) => u,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: url".to_string()),
});
}
};
let url = match self.validate_url(url) {
Ok(u) => u,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
match self.fetch_content(&url).await {
Ok(content) => Ok(ToolResult {
success: true,
output: self.truncate_response(&content),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tool() -> WebFetchTool {
WebFetchTool::new(50_000, 30)
}
#[tokio::test]
async fn test_validate_url_success() {
let tool = test_tool();
let result = tool.validate_url("https://example.com");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_url_rejects_private() {
let tool = test_tool();
let result = tool.validate_url("https://localhost:8080");
assert!(result.is_err());
assert!(result.unwrap_err().contains("local/private"));
}
#[tokio::test]
async fn test_validate_url_rejects_whitespace() {
let tool = test_tool();
let result = tool.validate_url("https://example.com/hello world");
assert!(result.is_err());
assert!(result.unwrap_err().contains("whitespace"));
}
#[tokio::test]
async fn test_extract_text_simple() {
let tool = test_tool();
let html = "<html><body><p>Hello World</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Hello World"));
assert!(!text.contains("<"));
}
#[tokio::test]
async fn test_extract_text_removes_scripts() {
let tool = test_tool();
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Good"));
assert!(!text.contains("alert"));
}
#[tokio::test]
async fn test_extract_text_removes_styles() {
let tool = test_tool();
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Content"));
assert!(!text.contains("color"));
}
}