use std::time::Duration; use async_trait::async_trait; use reqwest::header::HeaderMap; use serde_json::json; use crate::text::take_prefix_chars; use crate::tools::traits::{Tool, ToolResult}; pub struct WebFetchTool { max_response_size: usize, timeout_secs: u64, user_agent: String, } impl WebFetchTool { pub fn new(max_response_size: usize, timeout_secs: u64) -> Self { Self { max_response_size, timeout_secs, user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(), } } fn validate_url(&self, url: &str) -> Result { let url = url.trim(); if url.is_empty() { return Err("URL cannot be empty".to_string()); } if url.chars().any(char::is_whitespace) { return Err("URL cannot contain whitespace".to_string()); } if !url.starts_with("http://") && !url.starts_with("https://") { return Err("Only http:// and https:// URLs are allowed".to_string()); } let host = extract_host(url)?; if is_private_host(&host) { return Err(format!("Blocked local/private host: {}", host)); } Ok(url.to_string()) } fn truncate_response(&self, text: &str) -> String { if self.max_response_size == 0 { return text.to_string(); } if text.chars().count() > self.max_response_size { format!( "{}\n\n... [Response truncated due to size limit] ...", take_prefix_chars(text, self.max_response_size) ) } else { text.to_string() } } async fn fetch_content(&self, url: &str) -> Result { let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; let mut headers = HeaderMap::new(); headers.insert( reqwest::header::USER_AGENT, self.user_agent.parse().unwrap(), ); let response = client .get(url) .headers(headers) .send() .await .map_err(|e| format!("Request failed: {}", e))?; let content_type = response .headers() .get(reqwest::header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .unwrap_or(""); // Handle HTML content if content_type.contains("text/html") { let html = response .text() .await .map_err(|e| format!("Failed to read response: {}", e))?; return Ok(self.extract_text_from_html(&html)); } // Handle JSON content if content_type.contains("application/json") { let text = response .text() .await .map_err(|e| format!("Failed to read response: {}", e))?; // Pretty print JSON if let Ok(parsed) = serde_json::from_str::(&text) { return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text)); } return Ok(text); } // For other content types, return raw text response .text() .await .map_err(|e| format!("Failed to read response: {}", e)) } fn extract_text_from_html(&self, html: &str) -> String { let mut text = html.to_string(); // Remove script and style tags with content using simple replacements text = strip_tag(&text, "script"); text = strip_tag(&text, "style"); // Remove all HTML tags text = strip_all_tags(&text); // Decode HTML entities text = self.decode_html_entities(&text); // Clean up whitespace let mut cleaned = String::new(); let mut last_was_space = true; for c in text.chars() { if c.is_whitespace() { if !last_was_space { cleaned.push(' '); last_was_space = true; } } else { cleaned.push(c); last_was_space = false; } } cleaned.trim().to_string() } fn decode_html_entities(&self, text: &str) -> String { let entities = [ (" ", " "), ("<", "<"), (">", ">"), ("&", "&"), (""", "\""), ("'", "'"), ("—", "—"), ("–", "–"), ("©", "©"), ("®", "®"), ("™", "™"), ]; let mut result = text.to_string(); for (entity, replacement) in entities { result = result.replace(entity, replacement); } result } } fn strip_tag(s: &str, tag_name: &str) -> String { let open = format!("<{}>", tag_name); let close = format!("", tag_name); let mut result = s.to_string(); // Keep removing until no more found (simple approach) while let Some(start) = result.to_lowercase().find(&open) { if let Some(end) = result.to_lowercase()[start..].find(&close) { let end_pos = start + end + close.len(); result = format!("{}{}", &result[..start], &result[end_pos..]); } else { break; } } result } fn strip_all_tags(s: &str) -> String { let mut result = String::new(); let mut in_tag = false; for c in s.chars() { if c == '<' { in_tag = true; } else if c == '>' { in_tag = false; result.push(' '); } else if !in_tag { result.push(c); } } result } fn extract_host(url: &str) -> Result { let rest = url .strip_prefix("http://") .or_else(|| url.strip_prefix("https://")) .ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?; let authority = rest .split(['/', '?', '#']) .next() .ok_or_else(|| "Invalid URL".to_string())?; if authority.is_empty() { return Err("URL must include a host".to_string()); } let host = authority .split(':') .next() .unwrap_or_default() .trim() .to_lowercase(); if host.is_empty() { return Err("URL must include a valid host".to_string()); } Ok(host) } fn is_private_host(host: &str) -> bool { if host == "localhost" || host.ends_with(".localhost") { return true; } if host .rsplit('.') .next() .is_some_and(|label| label == "local") { return true; } if let Ok(ip) = host.parse::() { return match ip { std::net::IpAddr::V4(v4) => { v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() } std::net::IpAddr::V6(v6) => { v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() } }; } false } #[async_trait] impl Tool for WebFetchTool { fn name(&self) -> &str { "web_fetch" } fn description(&self) -> &str { "Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "url": { "type": "string", "description": "URL to fetch" } }, "required": ["url"] }) } fn read_only(&self) -> bool { true } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let url = match args.get("url").and_then(|v| v.as_str()) { Some(u) => u, None => { return Ok(ToolResult { success: false, output: String::new(), error: Some("Missing required parameter: url".to_string()), }); } }; let url = match self.validate_url(url) { Ok(u) => u, Err(e) => { return Ok(ToolResult { success: false, output: String::new(), error: Some(e), }); } }; match self.fetch_content(&url).await { Ok(content) => Ok(ToolResult { success: true, output: self.truncate_response(&content), error: None, }), Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(e), }), } } } #[cfg(test)] mod tests { use super::*; fn test_tool() -> WebFetchTool { WebFetchTool::new(50_000, 30) } #[tokio::test] async fn test_validate_url_success() { let tool = test_tool(); let result = tool.validate_url("https://example.com"); assert!(result.is_ok()); } #[tokio::test] async fn test_validate_url_rejects_private() { let tool = test_tool(); let result = tool.validate_url("https://localhost:8080"); assert!(result.is_err()); assert!(result.unwrap_err().contains("local/private")); } #[tokio::test] async fn test_validate_url_rejects_whitespace() { let tool = test_tool(); let result = tool.validate_url("https://example.com/hello world"); assert!(result.is_err()); assert!(result.unwrap_err().contains("whitespace")); } #[tokio::test] async fn test_extract_text_simple() { let tool = test_tool(); let html = "

Hello World

"; let text = tool.extract_text_from_html(html); assert!(text.contains("Hello World")); assert!(!text.contains("<")); } #[tokio::test] async fn test_extract_text_removes_scripts() { let tool = test_tool(); let html = "

Good

"; let text = tool.extract_text_from_html(html); assert!(text.contains("Good")); assert!(!text.contains("alert")); } #[tokio::test] async fn test_extract_text_removes_styles() { let tool = test_tool(); let html = "

Content

"; let text = tool.extract_text_from_html(html); assert!(text.contains("Content")); assert!(!text.contains("color")); } #[tokio::test] async fn test_truncate_response_handles_multibyte_boundary() { let tool = WebFetchTool::new(3, 30); let text = "a\u{1F642}bc"; let truncated = tool.truncate_response(text); assert_eq!( truncated, "a\u{1F642}b\n\n... [Response truncated due to size limit] ..." ); } }