feat(tools): add web_fetch tool for HTML content extraction
- Fetch URL and extract readable text - HTML to plain text conversion - Removes scripts, styles, and HTML tags - Decodes HTML entities - JSON pretty printing - SSRF protection - Includes 6 unit tests
This commit is contained in:
parent
1581732ef9
commit
8936e70a12
@ -7,6 +7,7 @@ pub mod http_request;
|
|||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
pub mod web_fetch;
|
||||||
|
|
||||||
pub use bash::BashTool;
|
pub use bash::BashTool;
|
||||||
pub use calculator::CalculatorTool;
|
pub use calculator::CalculatorTool;
|
||||||
@ -17,3 +18,4 @@ pub use http_request::HttpRequestTool;
|
|||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use traits::{Tool, ToolResult};
|
pub use traits::{Tool, ToolResult};
|
||||||
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|||||||
411
src/tools/web_fetch.rs
Normal file
411
src/tools/web_fetch.rs
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::header::HeaderMap;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
pub struct WebFetchTool {
|
||||||
|
max_response_size: usize,
|
||||||
|
timeout_secs: u64,
|
||||||
|
user_agent: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebFetchTool {
|
||||||
|
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
max_response_size,
|
||||||
|
timeout_secs,
|
||||||
|
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_url(&self, url: &str) -> Result<String, String> {
|
||||||
|
let url = url.trim();
|
||||||
|
|
||||||
|
if url.is_empty() {
|
||||||
|
return Err("URL cannot be empty".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.chars().any(char::is_whitespace) {
|
||||||
|
return Err("URL cannot contain whitespace".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
|
||||||
|
if is_private_host(&host) {
|
||||||
|
return Err(format!("Blocked local/private host: {}", host));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(url.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_response(&self, text: &str) -> String {
|
||||||
|
if self.max_response_size == 0 {
|
||||||
|
return text.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if text.len() > self.max_response_size {
|
||||||
|
format!(
|
||||||
|
"{}\n\n... [Response truncated due to size limit] ...",
|
||||||
|
&text[..self.max_response_size]
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
text.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
reqwest::header::USER_AGENT,
|
||||||
|
self.user_agent.parse().unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.get(url)
|
||||||
|
.headers(headers)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Request failed: {}", e))?;
|
||||||
|
|
||||||
|
let content_type = response
|
||||||
|
.headers()
|
||||||
|
.get(reqwest::header::CONTENT_TYPE)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
// Handle HTML content
|
||||||
|
if content_type.contains("text/html") {
|
||||||
|
let html = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||||||
|
return Ok(self.extract_text_from_html(&html));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle JSON content
|
||||||
|
if content_type.contains("application/json") {
|
||||||
|
let text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||||||
|
// Pretty print JSON
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||||
|
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
|
||||||
|
}
|
||||||
|
return Ok(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other content types, return raw text
|
||||||
|
response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read response: {}", e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_from_html(&self, html: &str) -> String {
|
||||||
|
let mut text = html.to_string();
|
||||||
|
|
||||||
|
// Remove script and style tags with content using simple replacements
|
||||||
|
text = strip_tag(&text, "script");
|
||||||
|
text = strip_tag(&text, "style");
|
||||||
|
|
||||||
|
// Remove all HTML tags
|
||||||
|
text = strip_all_tags(&text);
|
||||||
|
|
||||||
|
// Decode HTML entities
|
||||||
|
text = self.decode_html_entities(&text);
|
||||||
|
|
||||||
|
// Clean up whitespace
|
||||||
|
let mut cleaned = String::new();
|
||||||
|
let mut last_was_space = true;
|
||||||
|
for c in text.chars() {
|
||||||
|
if c.is_whitespace() {
|
||||||
|
if !last_was_space {
|
||||||
|
cleaned.push(' ');
|
||||||
|
last_was_space = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cleaned.push(c);
|
||||||
|
last_was_space = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned.trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_html_entities(&self, text: &str) -> String {
|
||||||
|
let entities = [
|
||||||
|
(" ", " "),
|
||||||
|
("<", "<"),
|
||||||
|
(">", ">"),
|
||||||
|
("&", "&"),
|
||||||
|
(""", "\""),
|
||||||
|
("'", "'"),
|
||||||
|
("—", "—"),
|
||||||
|
("–", "–"),
|
||||||
|
("©", "©"),
|
||||||
|
("®", "®"),
|
||||||
|
("™", "™"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut result = text.to_string();
|
||||||
|
for (entity, replacement) in entities {
|
||||||
|
result = result.replace(entity, replacement);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_tag(s: &str, tag_name: &str) -> String {
|
||||||
|
let open = format!("<{}>", tag_name);
|
||||||
|
let close = format!("</{}>", tag_name);
|
||||||
|
let mut result = s.to_string();
|
||||||
|
|
||||||
|
// Keep removing until no more found (simple approach)
|
||||||
|
while let Some(start) = result.to_lowercase().find(&open) {
|
||||||
|
if let Some(end) = result.to_lowercase()[start..].find(&close) {
|
||||||
|
let end_pos = start + end + close.len();
|
||||||
|
result = format!("{}{}", &result[..start], &result[end_pos..]);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_all_tags(s: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut in_tag = false;
|
||||||
|
|
||||||
|
for c in s.chars() {
|
||||||
|
if c == '<' {
|
||||||
|
in_tag = true;
|
||||||
|
} else if c == '>' {
|
||||||
|
in_tag = false;
|
||||||
|
result.push(' ');
|
||||||
|
} else if !in_tag {
|
||||||
|
result.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
|
||||||
|
let s_lower = s.to_lowercase();
|
||||||
|
|
||||||
|
let entities = [
|
||||||
|
(" ", ' '),
|
||||||
|
("<", '<'),
|
||||||
|
(">", '>'),
|
||||||
|
("&", '&'),
|
||||||
|
(""", '"'),
|
||||||
|
("'", '\''),
|
||||||
|
("—", '—'),
|
||||||
|
("–", '–'),
|
||||||
|
("©", '©'),
|
||||||
|
("®", '®'),
|
||||||
|
("™", '™'),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (entity, replacement) in entities {
|
||||||
|
if s_lower.starts_with(&entity.to_lowercase()) {
|
||||||
|
return Some((replacement, entity.len()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle numeric entities
|
||||||
|
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
|
||||||
|
// Skip for now
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_host(url: &str) -> Result<String, String> {
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if authority.is_empty() {
|
||||||
|
return Err("URL must include a host".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = authority
|
||||||
|
.split(':')
|
||||||
|
.next()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.trim()
|
||||||
|
.to_lowercase();
|
||||||
|
|
||||||
|
if host.is_empty() {
|
||||||
|
return Err("URL must include a valid host".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_private_host(host: &str) -> bool {
|
||||||
|
if host == "localhost" || host.ends_with(".localhost") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
|
return match ip {
|
||||||
|
std::net::IpAddr::V4(v4) => {
|
||||||
|
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||||
|
}
|
||||||
|
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for WebFetchTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"web_fetch"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL to fetch"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let url = match args.get("url").and_then(|v| v.as_str()) {
|
||||||
|
Some(u) => u,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: url".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = match self.validate_url(url) {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.fetch_content(&url).await {
|
||||||
|
Ok(content) => Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: self.truncate_response(&content),
|
||||||
|
error: None,
|
||||||
|
}),
|
||||||
|
Err(e) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn test_tool() -> WebFetchTool {
|
||||||
|
WebFetchTool::new(50_000, 30)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_success() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let result = tool.validate_url("https://example.com");
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_rejects_private() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let result = tool.validate_url("https://localhost:8080");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("local/private"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_rejects_whitespace() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let result = tool.validate_url("https://example.com/hello world");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("whitespace"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extract_text_simple() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let html = "<html><body><p>Hello World</p></body></html>";
|
||||||
|
let text = tool.extract_text_from_html(html);
|
||||||
|
assert!(text.contains("Hello World"));
|
||||||
|
assert!(!text.contains("<"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extract_text_removes_scripts() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
|
||||||
|
let text = tool.extract_text_from_html(html);
|
||||||
|
assert!(text.contains("Good"));
|
||||||
|
assert!(!text.contains("alert"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extract_text_removes_styles() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
|
||||||
|
let text = tool.extract_text_from_html(html);
|
||||||
|
assert!(text.contains("Content"));
|
||||||
|
assert!(!text.contains("color"));
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user