diff --git a/Cargo.toml b/Cargo.toml index 82c9ba6..a435293 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,3 +37,4 @@ chrono = "0.4" hostname = "0.3" sqlx = { version = "0.8", features = ["sqlite", "macros", "chrono", "runtime-tokio"] } jieba-rs = "0.9" +which = "7" diff --git a/src/tools/content_search.rs b/src/tools/content_search.rs new file mode 100644 index 0000000..22b68e4 --- /dev/null +++ b/src/tools/content_search.rs @@ -0,0 +1,460 @@ +use std::path::Path; +use std::process::Stdio; + +use async_trait::async_trait; +use serde_json::json; +use tokio::process::Command; +use tokio::time::timeout; + +use crate::tools::traits::{Tool, ToolResult}; + +const MAX_RESULTS: usize = 100; +const MAX_OUTPUT_CHARS: usize = 50_000; +const TIMEOUT_SECS: u64 = 60; + +pub struct ContentSearchTool; + +impl ContentSearchTool { + pub fn new() -> Self { + Self + } + + fn resolve_dir(&self, dir: Option<&str>) -> String { + match dir { + Some(d) if !d.is_empty() => d.to_string(), + _ => ".".to_string(), + } + } + + fn truncate_output(&self, lines: &[String]) -> String { + let mut output = String::new(); + for line in lines { + if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { + output.push_str(&format!( + "\n... ({} chars truncated, {} matches omitted) ...", + output.len(), + lines.len() + )); + break; + } + if !output.is_empty() { + output.push('\n'); + } + output.push_str(line); + } + output + } +} + +impl Default for ContentSearchTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Tool for ContentSearchTool { + fn name(&self) -> &str { + "content_search" + } + + fn description(&self) -> &str { + "Search file contents by regex or text pattern. Uses ripgrep (rg) if available, falls back to grep, then pure Rust." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Regex or text pattern to search for in file contents" + }, + "dir": { + "type": "string", + "description": "Directory to search in (default: current working directory)" + }, + "file_pattern": { + "type": "string", + "description": "Optional glob to restrict which files to search (e.g. '*.rs', '*.{rs,toml}')" + }, + "case_sensitive": { + "type": "boolean", + "description": "Whether to match case-sensitively (default: false)" + }, + "context_lines": { + "type": "integer", + "description": "Number of context lines to show before and after each match (default: 0)" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of matching lines to return (default: 100)" + } + }, + "required": ["pattern"] + }) + } + + fn read_only(&self) -> bool { + true + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let pattern = match args.get("pattern").and_then(|v| v.as_str()) { + Some(p) if !p.is_empty() => p, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: pattern".to_string()), + }); + } + }; + + let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); + let file_pattern = args.get("file_pattern").and_then(|v| v.as_str()); + let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false); + let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize; + let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; + + let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await; + + match result { + Ok(lines) => { + let count = lines.len(); + let mut output = self.truncate_output(&lines); + output.push_str(&format!("\n\n---\n共 {} 条匹配", count)); + Ok(ToolResult { success: true, output, error: None }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +impl ContentSearchTool { + async fn run_search( + &self, + pattern: &str, + dir: &str, + file_pattern: Option<&str>, + case_sensitive: bool, + context_lines: usize, + max_results: usize, + ) -> anyhow::Result> { + if which::which("rg").is_ok() { + match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { + Ok(lines) => return Ok(lines), + Err(e) => tracing::warn!("rg failed: {}, falling back", e), + } + } + + if which::which("grep").is_ok() { + match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { + Ok(lines) if !lines.is_empty() => return Ok(lines), + Ok(_) => {}, + Err(e) => tracing::warn!("grep failed: {}, falling back", e), + } + } + + tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."); + self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await + } + + async fn search_with_rg( + &self, + pattern: &str, + dir: &str, + file_pattern: Option<&str>, + case_sensitive: bool, + context_lines: usize, + max_results: usize, + ) -> anyhow::Result> { + let mut cmd = Command::new("rg"); + cmd.arg("-n") + .arg("--no-heading") + .arg("--color").arg("never") + .arg("--max-count").arg(max_results.to_string()) + .arg(pattern) + .arg(dir) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + if !case_sensitive { + cmd.arg("-i"); + } + if context_lines > 0 { + cmd.arg("-C").arg(context_lines.to_string()); + } + if let Some(fp) = file_pattern { + cmd.arg("--glob").arg(fp); + } + + let output = timeout( + std::time::Duration::from_secs(TIMEOUT_SECS), + cmd.output(), + ) + .await + .map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??; + + if !output.status.success() && output.status.code() != Some(1) { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow::anyhow!("rg error: {}", stderr.trim())); + } + + let text = String::from_utf8_lossy(&output.stdout); + let lines: Vec = text.lines() + .take(max_results) + .map(|l| l.to_string()) + .collect(); + Ok(lines) + } + + async fn search_with_grep( + &self, + pattern: &str, + dir: &str, + file_pattern: Option<&str>, + case_sensitive: bool, + context_lines: usize, + max_results: usize, + ) -> anyhow::Result> { + let mut cmd = Command::new("grep"); + cmd.arg("-rn") + .arg("-E") + .arg("--color=never") + .arg("--binary-files=without-match") + .arg(pattern) + .arg(dir) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + if !case_sensitive { + cmd.arg("-i"); + } + if context_lines > 0 { + cmd.arg("-C").arg(context_lines.to_string()); + } + if let Some(fp) = file_pattern { + cmd.arg("--include").arg(fp); + } + + let output = timeout( + std::time::Duration::from_secs(TIMEOUT_SECS), + cmd.output(), + ) + .await + .map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??; + + let text = String::from_utf8_lossy(&output.stdout); + let lines: Vec = text.lines() + .take(max_results) + .map(|l| l.to_string()) + .collect(); + Ok(lines) + } + + async fn search_with_rust( + &self, + pattern: &str, + dir: &str, + file_pattern: Option<&str>, + case_sensitive: bool, + _context_lines: usize, + max_results: usize, + ) -> anyhow::Result> { + let re = if case_sensitive { + regex::Regex::new(pattern) + } else { + regex::RegexBuilder::new(pattern) + .case_insensitive(true) + .build() + } + .map_err(|e| anyhow::anyhow!("Invalid regex pattern '{}': {}", pattern, e))?; + + let file_re = file_pattern.map(|fp| { + let re_str = glob_to_regex(fp); + if case_sensitive { + regex::Regex::new(&re_str) + } else { + regex::RegexBuilder::new(&re_str).case_insensitive(true).build() + } + }); + + let file_re = match file_re { + Some(Ok(r)) => Some(r), + Some(Err(e)) => return Err(anyhow::anyhow!("Invalid file pattern: {}", e)), + None => None, + }; + + let mut results = Vec::new(); + grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?; + + Ok(results) + } +} + +fn glob_to_regex(glob: &str) -> String { + let mut regex = String::from("^"); + let chars: Vec = glob.chars().collect(); + let mut i = 0; + while i < chars.len() { + match chars[i] { + '*' => { + if i + 1 < chars.len() && chars[i + 1] == '*' { + regex.push_str(".*"); + i += 1; + } else { + regex.push_str("[^/]*"); + } + } + '?' => regex.push_str("[^/]"), + '.' | '+' | '(' | ')' | '[' | ']' | '{' | '}' | '^' | '$' | '|' | '\\' => { + regex.push('\\'); + regex.push(chars[i]); + } + c => regex.push(c), + } + i += 1; + } + regex.push('$'); + regex +} + +fn grep_dir( + base: &Path, + current: &Path, + re: ®ex::Regex, + file_re: Option<®ex::Regex>, + results: &mut Vec, + max: usize, +) -> anyhow::Result<()> { + if results.len() >= max { + return Ok(()); + } + + let entries = match std::fs::read_dir(current) { + Ok(e) => e, + Err(_) => return Ok(()), + }; + + for entry in entries.flatten() { + let path = entry.path(); + let rel = match path.strip_prefix(base) { + Ok(r) => r, + Err(_) => continue, + }; + + if path.is_dir() { + if let Some(name) = rel.file_name().and_then(|n| n.to_str()) { + if name.starts_with('.') && name.len() > 1 { + continue; + } + } + grep_dir(base, &path, re, file_re, results, max)?; + } else if path.is_file() { + if let Some(file_re) = file_re { + if let Some(name) = rel.file_name().and_then(|n| n.to_str()) { + if !file_re.is_match(name) { + continue; + } + } + } + + if let Ok(content) = std::fs::read_to_string(&path) { + for (line_num, line) in content.lines().enumerate() { + if re.is_match(line) { + results.push(format!( + "{}:{}:{}", + rel.to_string_lossy(), + line_num + 1, + line + )); + if results.len() >= max { + return Ok(()); + } + } + } + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[tokio::test] + async fn test_content_search_rust_fallback() { + let dir = TempDir::new().unwrap(); + fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap(); + fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap(); + fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap(); + + let tool = ContentSearchTool::new(); + let result = tool + .execute(json!({ + "pattern": "let.*=.*42", + "dir": dir.path().to_str().unwrap() + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("main.rs")); + assert!(result.output.contains("lib.rs")); + assert!(!result.output.contains("README.md")); + assert!(result.output.contains("共 2 条匹配")); + } + + #[tokio::test] + async fn test_content_search_file_filter() { + let dir = TempDir::new().unwrap(); + fs::write(dir.path().join("main.rs"), "fn main() {}").unwrap(); + fs::write(dir.path().join("config.toml"), "name = \"test\"").unwrap(); + + let tool = ContentSearchTool::new(); + let result = tool + .execute(json!({ + "pattern": "test", + "dir": dir.path().to_str().unwrap(), + "file_pattern": "*.toml" + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("config.toml")); + assert!(!result.output.contains("main.rs")); + } + + #[tokio::test] + async fn test_content_search_max_results() { + let dir = TempDir::new().unwrap(); + let mut content = String::new(); + for i in 0..10 { + content.push_str(&format!("match line {}\n", i)); + } + fs::write(dir.path().join("data.txt"), &content).unwrap(); + + let tool = ContentSearchTool::new(); + let result = tool + .execute(json!({ + "pattern": "match line", + "dir": dir.path().to_str().unwrap(), + "max_results": 3 + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("共 3 条匹配")); + } +} diff --git a/src/tools/file_search.rs b/src/tools/file_search.rs new file mode 100644 index 0000000..1e90bbe --- /dev/null +++ b/src/tools/file_search.rs @@ -0,0 +1,375 @@ +use std::path::Path; +use std::process::Stdio; + +use async_trait::async_trait; +use serde_json::json; +use tokio::process::Command; +use tokio::time::timeout; + +use crate::tools::traits::{Tool, ToolResult}; + +const MAX_RESULTS: usize = 200; +const MAX_OUTPUT_CHARS: usize = 50_000; +const TIMEOUT_SECS: u64 = 60; + +pub struct FileSearchTool; + +impl FileSearchTool { + pub fn new() -> Self { + Self + } + + fn resolve_dir(&self, dir: Option<&str>) -> String { + match dir { + Some(d) if !d.is_empty() => d.to_string(), + _ => ".".to_string(), + } + } + + fn truncate_output(&self, lines: &[String]) -> String { + let mut output = String::new(); + for line in lines { + if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { + output.push_str(&format!("\n... ({} chars truncated) ...", output.len())); + break; + } + if !output.is_empty() { + output.push('\n'); + } + output.push_str(line); + } + output + } +} + +impl Default for FileSearchTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Tool for FileSearchTool { + fn name(&self) -> &str { + "file_search" + } + + fn description(&self) -> &str { + "Search for files by glob pattern (e.g. '*.rs', 'test_*.rs'). Uses fd if available, falls back to find, then pure Rust." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "File glob pattern to search for (e.g. *.rs, test_*.rs, src/**/*.py)" + }, + "dir": { + "type": "string", + "description": "Directory to search in (default: current working directory)" + }, + "case_sensitive": { + "type": "boolean", + "description": "Whether to match case-sensitively (default: true)" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return (default: 200)" + } + }, + "required": ["pattern"] + }) + } + + fn read_only(&self) -> bool { + true + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let pattern = match args.get("pattern").and_then(|v| v.as_str()) { + Some(p) if !p.is_empty() => p, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: pattern".to_string()), + }); + } + }; + + let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); + let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true); + let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; + + let result = self.run_search(pattern, &dir, case_sensitive, max_results).await; + + match result { + Ok(lines) => { + let count = lines.len(); + let mut output = self.truncate_output(&lines); + output.push_str(&format!("\n\n---\n共 {} 个文件", count)); + Ok(ToolResult { success: true, output, error: None }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +impl FileSearchTool { + async fn run_search( + &self, + pattern: &str, + dir: &str, + case_sensitive: bool, + max_results: usize, + ) -> anyhow::Result> { + if which::which("fd").is_ok() { + match self.search_with_fd(pattern, dir, case_sensitive, max_results).await { + Ok(lines) if !lines.is_empty() => return Ok(lines), + Ok(_) => {}, + Err(e) => tracing::warn!("fd failed: {}, falling back", e), + } + } + + if which::which("find").is_ok() { + match self.search_with_find(pattern, dir, max_results).await { + Ok(lines) if !lines.is_empty() => return Ok(lines), + Ok(_) => {}, + Err(e) => tracing::warn!("find failed: {}, falling back", e), + } + } + + tracing::warn!("No fd/find available, using built-in file search (slower)"); + self.search_with_rust(pattern, dir, case_sensitive, max_results).await + } + + async fn search_with_fd( + &self, + pattern: &str, + dir: &str, + case_sensitive: bool, + max_results: usize, + ) -> anyhow::Result> { + let mut cmd = Command::new("fd"); + cmd.arg("--search-path").arg(dir) + .arg("--glob").arg(pattern) + .arg("--color").arg("never") + .arg("--strip-cwd-prefix") + .arg("--max-results").arg(max_results.to_string()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + if !case_sensitive { + cmd.arg("--ignore-case"); + } + + let output = timeout( + std::time::Duration::from_secs(TIMEOUT_SECS), + cmd.output(), + ) + .await + .map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow::anyhow!("fd error: {}", stderr.trim())); + } + + let text = String::from_utf8_lossy(&output.stdout); + let lines: Vec = text.lines() + .filter(|l| !l.is_empty()) + .map(|l| l.to_string()) + .collect(); + Ok(lines) + } + + async fn search_with_find( + &self, + pattern: &str, + dir: &str, + max_results: usize, + ) -> anyhow::Result> { + let limit_str = max_results.to_string(); + let mut cmd = Command::new("sh"); + cmd.arg("-c") + .arg(format!( + "find '{}' -name '{}' -not -path '*/.*' 2>/dev/null | head -n {}", + dir, pattern, limit_str + )) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let output = timeout( + std::time::Duration::from_secs(TIMEOUT_SECS), + cmd.output(), + ) + .await + .map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??; + + let text = String::from_utf8_lossy(&output.stdout); + let lines: Vec = text.lines() + .filter(|l| !l.is_empty()) + .map(|l| { + let p = Path::new(l); + p.to_string_lossy().to_string() + }) + .collect(); + Ok(lines) + } + + async fn search_with_rust( + &self, + pattern: &str, + dir: &str, + case_sensitive: bool, + max_results: usize, + ) -> anyhow::Result> { + let regex_str = glob_to_regex(pattern); + let re = if case_sensitive { + regex::Regex::new(®ex_str) + } else { + regex::RegexBuilder::new(®ex_str) + .case_insensitive(true) + .build() + } + .map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?; + + let mut results = Vec::new(); + walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?; + Ok(results) + } +} + +fn glob_to_regex(glob: &str) -> String { + let mut regex = String::from("^"); + let chars: Vec = glob.chars().collect(); + let mut i = 0; + while i < chars.len() { + match chars[i] { + '*' => { + if i + 1 < chars.len() && chars[i + 1] == '*' { + regex.push_str(".*"); + i += 1; + } else { + regex.push_str("[^/]*"); + } + } + '?' => regex.push_str("[^/]"), + '.' | '+' | '(' | ')' | '[' | ']' | '{' | '}' | '^' | '$' | '|' | '\\' => { + regex.push('\\'); + regex.push(chars[i]); + } + c => regex.push(c), + } + i += 1; + } + regex.push('$'); + regex +} + +fn walk_dir( + base: &Path, + current: &Path, + re: ®ex::Regex, + results: &mut Vec, + max: usize, +) -> anyhow::Result<()> { + if results.len() >= max { + return Ok(()); + } + + let entries = match std::fs::read_dir(current) { + Ok(e) => e, + Err(_) => return Ok(()), + }; + + for entry in entries.flatten() { + let path = entry.path(); + let rel = match path.strip_prefix(base) { + Ok(r) => r, + Err(_) => continue, + }; + + if path.is_dir() { + if let Some(name) = rel.file_name().and_then(|n| n.to_str()) { + if name.starts_with('.') && name.len() > 1 { + continue; + } + } + walk_dir(base, &path, re, results, max)?; + } else if path.is_file() { + if let Some(name) = rel.file_name().and_then(|n| n.to_str()) { + if re.is_match(name) { + results.push(rel.to_string_lossy().to_string()); + } + } + if results.len() >= max { + return Ok(()); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[tokio::test] + async fn test_file_search_rust_fallback() { + let dir = TempDir::new().unwrap(); + fs::write(dir.path().join("main.rs"), "fn main() {}").unwrap(); + fs::write(dir.path().join("lib.rs"), "pub fn foo() {}").unwrap(); + fs::write(dir.path().join("test.rs"), "#[test] fn t() {}").unwrap(); + fs::write(dir.path().join("README.md"), "# Readme").unwrap(); + fs::create_dir(dir.path().join("src")).unwrap(); + fs::write(dir.path().join("src/nested.rs"), "fn nested() {}").unwrap(); + + let tool = FileSearchTool::new(); + let result = tool + .execute(json!({ + "pattern": "*.rs", + "dir": dir.path().to_str().unwrap() + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("main.rs")); + assert!(result.output.contains("lib.rs")); + assert!(result.output.contains("test.rs")); + assert!(result.output.contains("nested.rs")); + assert!(!result.output.contains("README.md")); + assert!(result.output.contains("共 4 个文件")); + } + + #[tokio::test] + async fn test_file_search_max_results() { + let dir = TempDir::new().unwrap(); + for i in 0..5 { + fs::write(dir.path().join(format!("file_{}.rs", i)), "").unwrap(); + } + + let tool = FileSearchTool::new(); + let result = tool + .execute(json!({ + "pattern": "*.rs", + "dir": dir.path().to_str().unwrap(), + "max_results": 3 + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("共 3 个文件")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 6f40f83..91c2225 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,9 +1,11 @@ pub mod bash; pub mod calculator; pub mod chat_manager; +pub mod content_search; pub mod cron; pub mod file_edit; pub mod file_read; +pub mod file_search; pub mod file_write; pub mod get_skill; pub mod http_request; @@ -17,8 +19,10 @@ pub mod web_fetch; pub use bash::BashTool; pub use calculator::CalculatorTool; pub use chat_manager::ChatManagerTool; +pub use content_search::ContentSearchTool; pub use file_edit::FileEditTool; pub use file_read::FileReadTool; +pub use file_search::FileSearchTool; pub use file_write::FileWriteTool; pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; @@ -45,6 +49,8 @@ pub fn create_default_tools( registry.register(FileReadTool::new()); registry.register(FileWriteTool::new()); registry.register(FileEditTool::new()); + registry.register(FileSearchTool::new()); + registry.register(ContentSearchTool::new()); registry.register(BashTool::new()); registry.register(HttpRequestTool::new( vec!["*".to_string()],