feat(tools): add file_write tool with directory creation

- Write content to file, creating parent directories if needed
- Overwrites existing files
- Includes 5 unit tests
This commit is contained in:
xiaoski 2026-04-07 23:44:45 +08:00
parent a9e7aabed4
commit 16b052bd21
2 changed files with 244 additions and 0 deletions

242
src/tools/file_write.rs Normal file
View File

@ -0,0 +1,242 @@
use std::path::Path;
use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct FileWriteTool {
allowed_dir: Option<String>,
}
impl FileWriteTool {
pub fn new() -> Self {
Self { allowed_dir: None }
}
pub fn with_allowed_dir(dir: String) -> Self {
Self {
allowed_dir: Some(dir),
}
}
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
let p = Path::new(path);
let resolved = if p.is_absolute() {
p.to_path_buf()
} else {
std::env::current_dir()
.map_err(|e| format!("Failed to get current directory: {}", e))?
.join(p)
};
// Check directory restriction
if let Some(ref allowed) = self.allowed_dir {
let allowed_path = Path::new(allowed);
if !resolved.starts_with(allowed_path) {
return Err(format!(
"Path '{}' is outside allowed directory '{}'",
path, allowed
));
}
}
Ok(resolved)
}
}
impl Default for FileWriteTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileWriteTool {
fn name(&self) -> &str {
"file_write"
}
fn description(&self) -> &str {
"Write content to a file at the given path. Creates parent directories if needed."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to write to"
},
"content": {
"type": "string",
"description": "The content to write"
}
},
"required": ["path", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path = match args.get("path").and_then(|v| v.as_str()) {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: path".to_string()),
});
}
};
let content = match args.get("content").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: content".to_string()),
});
}
};
let resolved = match self.resolve_path(path) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
// Create parent directories if needed
if let Some(parent) = resolved.parent() {
if !parent.exists() {
if let Err(e) = std::fs::create_dir_all(parent) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create parent directory: {}", e)),
});
}
}
}
match std::fs::write(&resolved, content) {
Ok(_) => Ok(ToolResult {
success: true,
output: format!(
"Successfully wrote {} bytes to {}",
content.len(),
resolved.display()
),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to write file: {}", e)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_write_simple_file() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.txt");
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": file_path.to_str().unwrap(),
"content": "Hello, World!"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Successfully wrote"));
// Verify content
let read_content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(read_content, "Hello, World!");
}
#[tokio::test]
async fn test_write_creates_parent_dirs() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("subdir1/subdir2/test.txt");
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": file_path.to_str().unwrap(),
"content": "Nested content"
}))
.await
.unwrap();
assert!(result.success);
// Verify content
let read_content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(read_content, "Nested content");
}
#[tokio::test]
async fn test_write_missing_path() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("path"));
}
#[tokio::test]
async fn test_write_missing_content() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({ "path": "/tmp/test.txt" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("content"));
}
#[tokio::test]
async fn test_overwrite_file() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.txt");
// Write initial content
std::fs::write(&file_path, "Initial content").unwrap();
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": file_path.to_str().unwrap(),
"content": "New content"
}))
.await
.unwrap();
assert!(result.success);
// Verify overwritten
let read_content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(read_content, "New content");
}
}

View File

@ -1,11 +1,13 @@
pub mod calculator; pub mod calculator;
pub mod file_read; pub mod file_read;
pub mod file_write;
pub mod registry; pub mod registry;
pub mod schema; pub mod schema;
pub mod traits; pub mod traits;
pub use calculator::CalculatorTool; pub use calculator::CalculatorTool;
pub use file_read::FileReadTool; pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
pub use registry::ToolRegistry; pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr}; pub use schema::{CleaningStrategy, SchemaCleanr};
pub use traits::{Tool, ToolResult}; pub use traits::{Tool, ToolResult};