diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 59f3df6..035c8f7 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -7,6 +7,7 @@ pub mod http_request; pub mod registry; pub mod schema; pub mod traits; +pub mod web_fetch; pub use bash::BashTool; pub use calculator::CalculatorTool; @@ -17,3 +18,4 @@ pub use http_request::HttpRequestTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use traits::{Tool, ToolResult}; +pub use web_fetch::WebFetchTool; diff --git a/src/tools/web_fetch.rs b/src/tools/web_fetch.rs new file mode 100644 index 0000000..3982bde --- /dev/null +++ b/src/tools/web_fetch.rs @@ -0,0 +1,411 @@ +use std::time::Duration; + +use async_trait::async_trait; +use reqwest::header::HeaderMap; +use serde_json::json; + +use crate::tools::traits::{Tool, ToolResult}; + +pub struct WebFetchTool { + max_response_size: usize, + timeout_secs: u64, + user_agent: String, +} + +impl WebFetchTool { + pub fn new(max_response_size: usize, timeout_secs: u64) -> Self { + Self { + max_response_size, + timeout_secs, + user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(), + } + } + + fn validate_url(&self, url: &str) -> Result { + let url = url.trim(); + + if url.is_empty() { + return Err("URL cannot be empty".to_string()); + } + + if url.chars().any(char::is_whitespace) { + return Err("URL cannot contain whitespace".to_string()); + } + + if !url.starts_with("http://") && !url.starts_with("https://") { + return Err("Only http:// and https:// URLs are allowed".to_string()); + } + + let host = extract_host(url)?; + + if is_private_host(&host) { + return Err(format!("Blocked local/private host: {}", host)); + } + + Ok(url.to_string()) + } + + fn truncate_response(&self, text: &str) -> String { + if self.max_response_size == 0 { + return text.to_string(); + } + + if text.len() > self.max_response_size { + format!( + "{}\n\n... [Response truncated due to size limit] ...", + &text[..self.max_response_size] + ) + } else { + text.to_string() + } + } + + async fn fetch_content(&self, url: &str) -> Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::USER_AGENT, + self.user_agent.parse().unwrap(), + ); + + let response = client + .get(url) + .headers(headers) + .send() + .await + .map_err(|e| format!("Request failed: {}", e))?; + + let content_type = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + // Handle HTML content + if content_type.contains("text/html") { + let html = response + .text() + .await + .map_err(|e| format!("Failed to read response: {}", e))?; + return Ok(self.extract_text_from_html(&html)); + } + + // Handle JSON content + if content_type.contains("application/json") { + let text = response + .text() + .await + .map_err(|e| format!("Failed to read response: {}", e))?; + // Pretty print JSON + if let Ok(parsed) = serde_json::from_str::(&text) { + return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text)); + } + return Ok(text); + } + + // For other content types, return raw text + response + .text() + .await + .map_err(|e| format!("Failed to read response: {}", e)) + } + + fn extract_text_from_html(&self, html: &str) -> String { + let mut text = html.to_string(); + + // Remove script and style tags with content using simple replacements + text = strip_tag(&text, "script"); + text = strip_tag(&text, "style"); + + // Remove all HTML tags + text = strip_all_tags(&text); + + // Decode HTML entities + text = self.decode_html_entities(&text); + + // Clean up whitespace + let mut cleaned = String::new(); + let mut last_was_space = true; + for c in text.chars() { + if c.is_whitespace() { + if !last_was_space { + cleaned.push(' '); + last_was_space = true; + } + } else { + cleaned.push(c); + last_was_space = false; + } + } + + cleaned.trim().to_string() + } + + fn decode_html_entities(&self, text: &str) -> String { + let entities = [ + (" ", " "), + ("<", "<"), + (">", ">"), + ("&", "&"), + (""", "\""), + ("'", "'"), + ("—", "—"), + ("–", "–"), + ("©", "©"), + ("®", "®"), + ("™", "™"), + ]; + + let mut result = text.to_string(); + for (entity, replacement) in entities { + result = result.replace(entity, replacement); + } + + result + } +} + +fn strip_tag(s: &str, tag_name: &str) -> String { + let open = format!("<{}>", tag_name); + let close = format!("", tag_name); + let mut result = s.to_string(); + + // Keep removing until no more found (simple approach) + while let Some(start) = result.to_lowercase().find(&open) { + if let Some(end) = result.to_lowercase()[start..].find(&close) { + let end_pos = start + end + close.len(); + result = format!("{}{}", &result[..start], &result[end_pos..]); + } else { + break; + } + } + + result +} + +fn strip_all_tags(s: &str) -> String { + let mut result = String::new(); + let mut in_tag = false; + + for c in s.chars() { + if c == '<' { + in_tag = true; + } else if c == '>' { + in_tag = false; + result.push(' '); + } else if !in_tag { + result.push(c); + } + } + + result +} + +fn extract_html_entity(s: &str) -> Option<(char, usize)> { + let s_lower = s.to_lowercase(); + + let entities = [ + (" ", ' '), + ("<", '<'), + (">", '>'), + ("&", '&'), + (""", '"'), + ("'", '\''), + ("—", '—'), + ("–", '–'), + ("©", '©'), + ("®", '®'), + ("™", '™'), + ]; + + for (entity, replacement) in entities { + if s_lower.starts_with(&entity.to_lowercase()) { + return Some((replacement, entity.len())); + } + } + + // Handle numeric entities + if s_lower.starts_with("&#x") || s_lower.starts_with("&#") { + // Skip for now + } + + None +} + +fn extract_host(url: &str) -> Result { + let rest = url + .strip_prefix("http://") + .or_else(|| url.strip_prefix("https://")) + .ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?; + + let authority = rest + .split(['/', '?', '#']) + .next() + .ok_or_else(|| "Invalid URL".to_string())?; + + if authority.is_empty() { + return Err("URL must include a host".to_string()); + } + + let host = authority + .split(':') + .next() + .unwrap_or_default() + .trim() + .to_lowercase(); + + if host.is_empty() { + return Err("URL must include a valid host".to_string()); + } + + Ok(host) +} + +fn is_private_host(host: &str) -> bool { + if host == "localhost" || host.ends_with(".localhost") { + return true; + } + + if host.rsplit('.').next().is_some_and(|label| label == "local") { + return true; + } + + if let Ok(ip) = host.parse::() { + return match ip { + std::net::IpAddr::V4(v4) => { + v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() + } + std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), + }; + } + + false +} + +#[async_trait] +impl Tool for WebFetchTool { + fn name(&self) -> &str { + "web_fetch" + } + + fn description(&self) -> &str { + "Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL to fetch" + } + }, + "required": ["url"] + }) + } + + fn read_only(&self) -> bool { + true + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let url = match args.get("url").and_then(|v| v.as_str()) { + Some(u) => u, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: url".to_string()), + }); + } + }; + + let url = match self.validate_url(url) { + Ok(u) => u, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }); + } + }; + + match self.fetch_content(&url).await { + Ok(content) => Ok(ToolResult { + success: true, + output: self.truncate_response(&content), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_tool() -> WebFetchTool { + WebFetchTool::new(50_000, 30) + } + + #[tokio::test] + async fn test_validate_url_success() { + let tool = test_tool(); + let result = tool.validate_url("https://example.com"); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_validate_url_rejects_private() { + let tool = test_tool(); + let result = tool.validate_url("https://localhost:8080"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("local/private")); + } + + #[tokio::test] + async fn test_validate_url_rejects_whitespace() { + let tool = test_tool(); + let result = tool.validate_url("https://example.com/hello world"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("whitespace")); + } + + #[tokio::test] + async fn test_extract_text_simple() { + let tool = test_tool(); + let html = "

Hello World

"; + let text = tool.extract_text_from_html(html); + assert!(text.contains("Hello World")); + assert!(!text.contains("<")); + } + + #[tokio::test] + async fn test_extract_text_removes_scripts() { + let tool = test_tool(); + let html = "

Good

"; + let text = tool.extract_text_from_html(html); + assert!(text.contains("Good")); + assert!(!text.contains("alert")); + } + + #[tokio::test] + async fn test_extract_text_removes_styles() { + let tool = test_tool(); + let html = "

Content

"; + let text = tool.extract_text_from_html(html); + assert!(text.contains("Content")); + assert!(!text.contains("color")); + } +}