use std::path::Path; use async_trait::async_trait; use serde_json::json; use crate::tools::traits::{Tool, ToolResult}; pub struct FileWriteTool { allowed_dir: Option, } impl FileWriteTool { pub fn new() -> Self { Self { allowed_dir: None } } pub fn with_allowed_dir(dir: String) -> Self { Self { allowed_dir: Some(dir), } } fn resolve_path(&self, path: &str) -> Result { let p = Path::new(path); let resolved = if p.is_absolute() { p.to_path_buf() } else { std::env::current_dir() .map_err(|e| format!("Failed to get current directory: {}", e))? .join(p) }; // Check directory restriction if let Some(ref allowed) = self.allowed_dir { let allowed_path = Path::new(allowed); if !resolved.starts_with(allowed_path) { return Err(format!( "Path '{}' is outside allowed directory '{}'", path, allowed )); } } Ok(resolved) } } impl Default for FileWriteTool { fn default() -> Self { Self::new() } } #[async_trait] impl Tool for FileWriteTool { fn name(&self) -> &str { "write" } fn description(&self) -> &str { "Write content to a file at the given path. Creates parent directories if needed." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "path": { "type": "string", "description": "The file path to write to" }, "content": { "type": "string", "description": "The content to write" } }, "required": ["path", "content"] }) } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let path = match args.get("path").and_then(|v| v.as_str()) { Some(p) => p, None => { return Ok(ToolResult { success: false, output: String::new(), error: Some("Missing required parameter: path".to_string()), }); } }; let content = match args.get("content").and_then(|v| v.as_str()) { Some(c) => c, None => { return Ok(ToolResult { success: false, output: String::new(), error: Some("Missing required parameter: content".to_string()), }); } }; let resolved = match self.resolve_path(path) { Ok(p) => p, Err(e) => { return Ok(ToolResult { success: false, output: String::new(), error: Some(e), }); } }; // Create parent directories if needed if let Some(parent) = resolved.parent() { if !parent.exists() { if let Err(e) = std::fs::create_dir_all(parent) { return Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Failed to create parent directory: {}", e)), }); } } } match std::fs::write(&resolved, content) { Ok(_) => Ok(ToolResult { success: true, output: format!( "Successfully wrote {} bytes to {}", content.len(), resolved.display() ), error: None, }), Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Failed to write file: {}", e)), }), } } } #[cfg(test)] mod tests { use super::*; use tempfile::TempDir; #[tokio::test] async fn test_write_simple_file() { let temp_dir = TempDir::new().unwrap(); let file_path = temp_dir.path().join("test.txt"); let tool = FileWriteTool::new(); let result = tool .execute(json!({ "path": file_path.to_str().unwrap(), "content": "Hello, World!" })) .await .unwrap(); assert!(result.success); assert!(result.output.contains("Successfully wrote")); // Verify content let read_content = std::fs::read_to_string(&file_path).unwrap(); assert_eq!(read_content, "Hello, World!"); } #[tokio::test] async fn test_write_creates_parent_dirs() { let temp_dir = TempDir::new().unwrap(); let file_path = temp_dir.path().join("subdir1/subdir2/test.txt"); let tool = FileWriteTool::new(); let result = tool .execute(json!({ "path": file_path.to_str().unwrap(), "content": "Nested content" })) .await .unwrap(); assert!(result.success); // Verify content let read_content = std::fs::read_to_string(&file_path).unwrap(); assert_eq!(read_content, "Nested content"); } #[tokio::test] async fn test_write_missing_path() { let tool = FileWriteTool::new(); let result = tool.execute(json!({ "content": "Hello" })).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("path")); } #[tokio::test] async fn test_write_missing_content() { let tool = FileWriteTool::new(); // 使用临时目录确保跨平台兼容 let temp_dir = tempfile::tempdir().unwrap(); let test_path = temp_dir.path().join("test.txt"); let result = tool .execute(json!({ "path": test_path.to_str().unwrap() })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("content")); } #[tokio::test] async fn test_overwrite_file() { let temp_dir = TempDir::new().unwrap(); let file_path = temp_dir.path().join("test.txt"); // Write initial content std::fs::write(&file_path, "Initial content").unwrap(); let tool = FileWriteTool::new(); let result = tool .execute(json!({ "path": file_path.to_str().unwrap(), "content": "New content" })) .await .unwrap(); assert!(result.success); // Verify overwritten let read_content = std::fs::read_to_string(&file_path).unwrap(); assert_eq!(read_content, "New content"); } }