399 lines
11 KiB
Rust
399 lines
11 KiB
Rust
use std::time::Duration;
|
||
|
||
use async_trait::async_trait;
|
||
use reqwest::header::HeaderMap;
|
||
use serde_json::json;
|
||
|
||
use crate::text::take_prefix_chars;
|
||
use crate::tools::traits::{Tool, ToolResult};
|
||
|
||
pub struct WebFetchTool {
|
||
max_response_size: usize,
|
||
timeout_secs: u64,
|
||
user_agent: String,
|
||
}
|
||
|
||
impl WebFetchTool {
|
||
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
|
||
Self {
|
||
max_response_size,
|
||
timeout_secs,
|
||
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
|
||
}
|
||
}
|
||
|
||
fn validate_url(&self, url: &str) -> Result<String, String> {
|
||
let url = url.trim();
|
||
|
||
if url.is_empty() {
|
||
return Err("URL cannot be empty".to_string());
|
||
}
|
||
|
||
if url.chars().any(char::is_whitespace) {
|
||
return Err("URL cannot contain whitespace".to_string());
|
||
}
|
||
|
||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||
}
|
||
|
||
let host = extract_host(url)?;
|
||
|
||
if is_private_host(&host) {
|
||
return Err(format!("Blocked local/private host: {}", host));
|
||
}
|
||
|
||
Ok(url.to_string())
|
||
}
|
||
|
||
fn truncate_response(&self, text: &str) -> String {
|
||
if self.max_response_size == 0 {
|
||
return text.to_string();
|
||
}
|
||
|
||
if text.chars().count() > self.max_response_size {
|
||
format!(
|
||
"{}\n\n... [Response truncated due to size limit] ...",
|
||
take_prefix_chars(text, self.max_response_size)
|
||
)
|
||
} else {
|
||
text.to_string()
|
||
}
|
||
}
|
||
|
||
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
||
let client = reqwest::Client::builder()
|
||
.timeout(Duration::from_secs(self.timeout_secs))
|
||
.build()
|
||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||
|
||
let mut headers = HeaderMap::new();
|
||
headers.insert(
|
||
reqwest::header::USER_AGENT,
|
||
self.user_agent.parse().unwrap(),
|
||
);
|
||
|
||
let response = client
|
||
.get(url)
|
||
.headers(headers)
|
||
.send()
|
||
.await
|
||
.map_err(|e| format!("Request failed: {}", e))?;
|
||
|
||
let content_type = response
|
||
.headers()
|
||
.get(reqwest::header::CONTENT_TYPE)
|
||
.and_then(|v| v.to_str().ok())
|
||
.unwrap_or("");
|
||
|
||
// Handle HTML content
|
||
if content_type.contains("text/html") {
|
||
let html = response
|
||
.text()
|
||
.await
|
||
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||
return Ok(self.extract_text_from_html(&html));
|
||
}
|
||
|
||
// Handle JSON content
|
||
if content_type.contains("application/json") {
|
||
let text = response
|
||
.text()
|
||
.await
|
||
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||
// Pretty print JSON
|
||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
||
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
|
||
}
|
||
return Ok(text);
|
||
}
|
||
|
||
// For other content types, return raw text
|
||
response
|
||
.text()
|
||
.await
|
||
.map_err(|e| format!("Failed to read response: {}", e))
|
||
}
|
||
|
||
fn extract_text_from_html(&self, html: &str) -> String {
|
||
let mut text = html.to_string();
|
||
|
||
// Remove script and style tags with content using simple replacements
|
||
text = strip_tag(&text, "script");
|
||
text = strip_tag(&text, "style");
|
||
|
||
// Remove all HTML tags
|
||
text = strip_all_tags(&text);
|
||
|
||
// Decode HTML entities
|
||
text = self.decode_html_entities(&text);
|
||
|
||
// Clean up whitespace
|
||
let mut cleaned = String::new();
|
||
let mut last_was_space = true;
|
||
for c in text.chars() {
|
||
if c.is_whitespace() {
|
||
if !last_was_space {
|
||
cleaned.push(' ');
|
||
last_was_space = true;
|
||
}
|
||
} else {
|
||
cleaned.push(c);
|
||
last_was_space = false;
|
||
}
|
||
}
|
||
|
||
cleaned.trim().to_string()
|
||
}
|
||
|
||
fn decode_html_entities(&self, text: &str) -> String {
|
||
let entities = [
|
||
(" ", " "),
|
||
("<", "<"),
|
||
(">", ">"),
|
||
("&", "&"),
|
||
(""", "\""),
|
||
("'", "'"),
|
||
("—", "—"),
|
||
("–", "–"),
|
||
("©", "©"),
|
||
("®", "®"),
|
||
("™", "™"),
|
||
];
|
||
|
||
let mut result = text.to_string();
|
||
for (entity, replacement) in entities {
|
||
result = result.replace(entity, replacement);
|
||
}
|
||
|
||
result
|
||
}
|
||
}
|
||
|
||
fn strip_tag(s: &str, tag_name: &str) -> String {
|
||
let open = format!("<{}>", tag_name);
|
||
let close = format!("</{}>", tag_name);
|
||
let mut result = s.to_string();
|
||
|
||
// Keep removing until no more found (simple approach)
|
||
while let Some(start) = result.to_lowercase().find(&open) {
|
||
if let Some(end) = result.to_lowercase()[start..].find(&close) {
|
||
let end_pos = start + end + close.len();
|
||
result = format!("{}{}", &result[..start], &result[end_pos..]);
|
||
} else {
|
||
break;
|
||
}
|
||
}
|
||
|
||
result
|
||
}
|
||
|
||
fn strip_all_tags(s: &str) -> String {
|
||
let mut result = String::new();
|
||
let mut in_tag = false;
|
||
|
||
for c in s.chars() {
|
||
if c == '<' {
|
||
in_tag = true;
|
||
} else if c == '>' {
|
||
in_tag = false;
|
||
result.push(' ');
|
||
} else if !in_tag {
|
||
result.push(c);
|
||
}
|
||
}
|
||
|
||
result
|
||
}
|
||
|
||
fn extract_host(url: &str) -> Result<String, String> {
|
||
let rest = url
|
||
.strip_prefix("http://")
|
||
.or_else(|| url.strip_prefix("https://"))
|
||
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||
|
||
let authority = rest
|
||
.split(['/', '?', '#'])
|
||
.next()
|
||
.ok_or_else(|| "Invalid URL".to_string())?;
|
||
|
||
if authority.is_empty() {
|
||
return Err("URL must include a host".to_string());
|
||
}
|
||
|
||
let host = authority
|
||
.split(':')
|
||
.next()
|
||
.unwrap_or_default()
|
||
.trim()
|
||
.to_lowercase();
|
||
|
||
if host.is_empty() {
|
||
return Err("URL must include a valid host".to_string());
|
||
}
|
||
|
||
Ok(host)
|
||
}
|
||
|
||
fn is_private_host(host: &str) -> bool {
|
||
if host == "localhost" || host.ends_with(".localhost") {
|
||
return true;
|
||
}
|
||
|
||
if host
|
||
.rsplit('.')
|
||
.next()
|
||
.is_some_and(|label| label == "local")
|
||
{
|
||
return true;
|
||
}
|
||
|
||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||
return match ip {
|
||
std::net::IpAddr::V4(v4) => {
|
||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||
}
|
||
std::net::IpAddr::V6(v6) => {
|
||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||
}
|
||
};
|
||
}
|
||
|
||
false
|
||
}
|
||
|
||
#[async_trait]
|
||
impl Tool for WebFetchTool {
|
||
fn name(&self) -> &str {
|
||
"web_fetch"
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
|
||
}
|
||
|
||
fn parameters_schema(&self) -> serde_json::Value {
|
||
json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"url": {
|
||
"type": "string",
|
||
"description": "URL to fetch"
|
||
}
|
||
},
|
||
"required": ["url"]
|
||
})
|
||
}
|
||
|
||
fn read_only(&self) -> bool {
|
||
true
|
||
}
|
||
|
||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||
let url = match args.get("url").and_then(|v| v.as_str()) {
|
||
Some(u) => u,
|
||
None => {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some("Missing required parameter: url".to_string()),
|
||
});
|
||
}
|
||
};
|
||
|
||
let url = match self.validate_url(url) {
|
||
Ok(u) => u,
|
||
Err(e) => {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(e),
|
||
});
|
||
}
|
||
};
|
||
|
||
match self.fetch_content(&url).await {
|
||
Ok(content) => Ok(ToolResult {
|
||
success: true,
|
||
output: self.truncate_response(&content),
|
||
error: None,
|
||
}),
|
||
Err(e) => Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some(e),
|
||
}),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
fn test_tool() -> WebFetchTool {
|
||
WebFetchTool::new(50_000, 30)
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_validate_url_success() {
|
||
let tool = test_tool();
|
||
let result = tool.validate_url("https://example.com");
|
||
assert!(result.is_ok());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_validate_url_rejects_private() {
|
||
let tool = test_tool();
|
||
let result = tool.validate_url("https://localhost:8080");
|
||
assert!(result.is_err());
|
||
assert!(result.unwrap_err().contains("local/private"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_validate_url_rejects_whitespace() {
|
||
let tool = test_tool();
|
||
let result = tool.validate_url("https://example.com/hello world");
|
||
assert!(result.is_err());
|
||
assert!(result.unwrap_err().contains("whitespace"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_text_simple() {
|
||
let tool = test_tool();
|
||
let html = "<html><body><p>Hello World</p></body></html>";
|
||
let text = tool.extract_text_from_html(html);
|
||
assert!(text.contains("Hello World"));
|
||
assert!(!text.contains("<"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_text_removes_scripts() {
|
||
let tool = test_tool();
|
||
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
|
||
let text = tool.extract_text_from_html(html);
|
||
assert!(text.contains("Good"));
|
||
assert!(!text.contains("alert"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_text_removes_styles() {
|
||
let tool = test_tool();
|
||
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
|
||
let text = tool.extract_text_from_html(html);
|
||
assert!(text.contains("Content"));
|
||
assert!(!text.contains("color"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_truncate_response_handles_multibyte_boundary() {
|
||
let tool = WebFetchTool::new(3, 30);
|
||
let text = "a\u{1F642}bc";
|
||
let truncated = tool.truncate_response(text);
|
||
assert_eq!(
|
||
truncated,
|
||
"a\u{1F642}b\n\n... [Response truncated due to size limit] ..."
|
||
);
|
||
}
|
||
}
|