PicoBot/src/tools/web_fetch.rs
ooodc 73dab09bfe Refactor code for improved readability and consistency
- Adjusted formatting and indentation in various files for better clarity.
- Consolidated multi-line statements into single lines where appropriate.
- Enhanced error handling messages for better debugging.
- Added a new InboundProcessor struct to handle inbound messages more effectively.
- Updated test cases to ensure they align with the new code structure.
2026-04-28 10:33:31 +08:00

387 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct WebFetchTool {
max_response_size: usize,
timeout_secs: u64,
user_agent: String,
}
impl WebFetchTool {
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
Self {
max_response_size,
timeout_secs,
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
}
}
fn validate_url(&self, url: &str) -> Result<String, String> {
let url = url.trim();
if url.is_empty() {
return Err("URL cannot be empty".to_string());
}
if url.chars().any(char::is_whitespace) {
return Err("URL cannot contain whitespace".to_string());
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err("Only http:// and https:// URLs are allowed".to_string());
}
let host = extract_host(url)?;
if is_private_host(&host) {
return Err(format!("Blocked local/private host: {}", host));
}
Ok(url.to_string())
}
fn truncate_response(&self, text: &str) -> String {
if self.max_response_size == 0 {
return text.to_string();
}
if text.len() > self.max_response_size {
format!(
"{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size]
)
} else {
text.to_string()
}
}
async fn fetch_content(&self, url: &str) -> Result<String, String> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
self.user_agent.parse().unwrap(),
);
let response = client
.get(url)
.headers(headers)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Handle HTML content
if content_type.contains("text/html") {
let html = response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))?;
return Ok(self.extract_text_from_html(&html));
}
// Handle JSON content
if content_type.contains("application/json") {
let text = response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))?;
// Pretty print JSON
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
}
return Ok(text);
}
// For other content types, return raw text
response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))
}
fn extract_text_from_html(&self, html: &str) -> String {
let mut text = html.to_string();
// Remove script and style tags with content using simple replacements
text = strip_tag(&text, "script");
text = strip_tag(&text, "style");
// Remove all HTML tags
text = strip_all_tags(&text);
// Decode HTML entities
text = self.decode_html_entities(&text);
// Clean up whitespace
let mut cleaned = String::new();
let mut last_was_space = true;
for c in text.chars() {
if c.is_whitespace() {
if !last_was_space {
cleaned.push(' ');
last_was_space = true;
}
} else {
cleaned.push(c);
last_was_space = false;
}
}
cleaned.trim().to_string()
}
fn decode_html_entities(&self, text: &str) -> String {
let entities = [
("&nbsp;", " "),
("&lt;", "<"),
("&gt;", ">"),
("&amp;", "&"),
("&quot;", "\""),
("&apos;", "'"),
("&mdash;", ""),
("&ndash;", ""),
("&copy;", "©"),
("&reg;", "®"),
("&trade;", ""),
];
let mut result = text.to_string();
for (entity, replacement) in entities {
result = result.replace(entity, replacement);
}
result
}
}
fn strip_tag(s: &str, tag_name: &str) -> String {
let open = format!("<{}>", tag_name);
let close = format!("</{}>", tag_name);
let mut result = s.to_string();
// Keep removing until no more found (simple approach)
while let Some(start) = result.to_lowercase().find(&open) {
if let Some(end) = result.to_lowercase()[start..].find(&close) {
let end_pos = start + end + close.len();
result = format!("{}{}", &result[..start], &result[end_pos..]);
} else {
break;
}
}
result
}
fn strip_all_tags(s: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
for c in s.chars() {
if c == '<' {
in_tag = true;
} else if c == '>' {
in_tag = false;
result.push(' ');
} else if !in_tag {
result.push(c);
}
}
result
}
fn extract_host(url: &str) -> Result<String, String> {
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if authority.is_empty() {
return Err("URL must include a host".to_string());
}
let host = authority
.split(':')
.next()
.unwrap_or_default()
.trim()
.to_lowercase();
if host.is_empty() {
return Err("URL must include a valid host".to_string());
}
Ok(host)
}
fn is_private_host(host: &str) -> bool {
if host == "localhost" || host.ends_with(".localhost") {
return true;
}
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
};
}
false
}
#[async_trait]
impl Tool for WebFetchTool {
fn name(&self) -> &str {
"web_fetch"
}
fn description(&self) -> &str {
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to fetch"
}
},
"required": ["url"]
})
}
fn read_only(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let url = match args.get("url").and_then(|v| v.as_str()) {
Some(u) => u,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: url".to_string()),
});
}
};
let url = match self.validate_url(url) {
Ok(u) => u,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
match self.fetch_content(&url).await {
Ok(content) => Ok(ToolResult {
success: true,
output: self.truncate_response(&content),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tool() -> WebFetchTool {
WebFetchTool::new(50_000, 30)
}
#[tokio::test]
async fn test_validate_url_success() {
let tool = test_tool();
let result = tool.validate_url("https://example.com");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_url_rejects_private() {
let tool = test_tool();
let result = tool.validate_url("https://localhost:8080");
assert!(result.is_err());
assert!(result.unwrap_err().contains("local/private"));
}
#[tokio::test]
async fn test_validate_url_rejects_whitespace() {
let tool = test_tool();
let result = tool.validate_url("https://example.com/hello world");
assert!(result.is_err());
assert!(result.unwrap_err().contains("whitespace"));
}
#[tokio::test]
async fn test_extract_text_simple() {
let tool = test_tool();
let html = "<html><body><p>Hello World</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Hello World"));
assert!(!text.contains("<"));
}
#[tokio::test]
async fn test_extract_text_removes_scripts() {
let tool = test_tool();
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Good"));
assert!(!text.contains("alert"));
}
#[tokio::test]
async fn test_extract_text_removes_styles() {
let tool = test_tool();
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Content"));
assert!(!text.contains("color"));
}
}