From 16b052bd21cd79c7d229ed1ff0bb57d682e7d591 Mon Sep 17 00:00:00 2001 From: xiaoski Date: Tue, 7 Apr 2026 23:44:45 +0800 Subject: [PATCH] feat(tools): add file_write tool with directory creation - Write content to file, creating parent directories if needed - Overwrites existing files - Includes 5 unit tests --- src/tools/file_write.rs | 242 ++++++++++++++++++++++++++++++++++++++++ src/tools/mod.rs | 2 + 2 files changed, 244 insertions(+) create mode 100644 src/tools/file_write.rs diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs new file mode 100644 index 0000000..3472c70 --- /dev/null +++ b/src/tools/file_write.rs @@ -0,0 +1,242 @@ +use std::path::Path; + +use async_trait::async_trait; +use serde_json::json; + +use crate::tools::traits::{Tool, ToolResult}; + +pub struct FileWriteTool { + allowed_dir: Option, +} + +impl FileWriteTool { + pub fn new() -> Self { + Self { allowed_dir: None } + } + + pub fn with_allowed_dir(dir: String) -> Self { + Self { + allowed_dir: Some(dir), + } + } + + fn resolve_path(&self, path: &str) -> Result { + let p = Path::new(path); + let resolved = if p.is_absolute() { + p.to_path_buf() + } else { + std::env::current_dir() + .map_err(|e| format!("Failed to get current directory: {}", e))? + .join(p) + }; + + // Check directory restriction + if let Some(ref allowed) = self.allowed_dir { + let allowed_path = Path::new(allowed); + if !resolved.starts_with(allowed_path) { + return Err(format!( + "Path '{}' is outside allowed directory '{}'", + path, allowed + )); + } + } + + Ok(resolved) + } +} + +impl Default for FileWriteTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Tool for FileWriteTool { + fn name(&self) -> &str { + "file_write" + } + + fn description(&self) -> &str { + "Write content to a file at the given path. Creates parent directories if needed." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The file path to write to" + }, + "content": { + "type": "string", + "description": "The content to write" + } + }, + "required": ["path", "content"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path = match args.get("path").and_then(|v| v.as_str()) { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: path".to_string()), + }); + } + }; + + let content = match args.get("content").and_then(|v| v.as_str()) { + Some(c) => c, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: content".to_string()), + }); + } + }; + + let resolved = match self.resolve_path(path) { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }); + } + }; + + // Create parent directories if needed + if let Some(parent) = resolved.parent() { + if !parent.exists() { + if let Err(e) = std::fs::create_dir_all(parent) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to create parent directory: {}", e)), + }); + } + } + } + + match std::fs::write(&resolved, content) { + Ok(_) => Ok(ToolResult { + success: true, + output: format!( + "Successfully wrote {} bytes to {}", + content.len(), + resolved.display() + ), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to write file: {}", e)), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_write_simple_file() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + let tool = FileWriteTool::new(); + let result = tool + .execute(json!({ + "path": file_path.to_str().unwrap(), + "content": "Hello, World!" + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("Successfully wrote")); + + // Verify content + let read_content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(read_content, "Hello, World!"); + } + + #[tokio::test] + async fn test_write_creates_parent_dirs() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("subdir1/subdir2/test.txt"); + + let tool = FileWriteTool::new(); + let result = tool + .execute(json!({ + "path": file_path.to_str().unwrap(), + "content": "Nested content" + })) + .await + .unwrap(); + + assert!(result.success); + + // Verify content + let read_content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(read_content, "Nested content"); + } + + #[tokio::test] + async fn test_write_missing_path() { + let tool = FileWriteTool::new(); + let result = tool + .execute(json!({ "content": "Hello" })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("path")); + } + + #[tokio::test] + async fn test_write_missing_content() { + let tool = FileWriteTool::new(); + let result = tool + .execute(json!({ "path": "/tmp/test.txt" })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("content")); + } + + #[tokio::test] + async fn test_overwrite_file() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + // Write initial content + std::fs::write(&file_path, "Initial content").unwrap(); + + let tool = FileWriteTool::new(); + let result = tool + .execute(json!({ + "path": file_path.to_str().unwrap(), + "content": "New content" + })) + .await + .unwrap(); + + assert!(result.success); + + // Verify overwritten + let read_content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(read_content, "New content"); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index a17c8c6..8ff4b68 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,11 +1,13 @@ pub mod calculator; pub mod file_read; +pub mod file_write; pub mod registry; pub mod schema; pub mod traits; pub use calculator::CalculatorTool; pub use file_read::FileReadTool; +pub use file_write::FileWriteTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use traits::{Tool, ToolResult};