feat: 添加平台抽象层,支持跨平台兼容性;更新多个模块以使用临时目录和平台特定路径
This commit is contained in:
parent
e0a7f67dab
commit
f4758f8513
@ -210,6 +210,12 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn init_registers_wechat_channel_by_instance_name() {
|
async fn init_registers_wechat_channel_by_instance_name() {
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
// 使用临时目录确保跨平台兼容
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let cred_path = temp_dir.path().join("wechat-creds.json");
|
||||||
|
// JSON 中的路径需要转义反斜杠
|
||||||
|
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
|
||||||
|
|
||||||
std::fs::write(
|
std::fs::write(
|
||||||
file.path(),
|
file.path(),
|
||||||
r#"{
|
r#"{
|
||||||
@ -236,10 +242,10 @@ mod tests {
|
|||||||
"wechat_main": {
|
"wechat_main": {
|
||||||
"type": "wechat",
|
"type": "wechat",
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"cred_path": "/tmp/wechat-creds.json"
|
"cred_path": "<CRED_PATH>"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}"#,
|
}"#.replace("<CRED_PATH>", &cred_path_json),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@ -1303,6 +1303,12 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_tagged_wechat_channel_config_loads() {
|
fn test_tagged_wechat_channel_config_loads() {
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
// 使用临时文件路径确保跨平台兼容
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let cred_path = temp_dir.path().join("wechat-creds.json");
|
||||||
|
// JSON 中的路径需要转义反斜杠
|
||||||
|
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
|
||||||
|
|
||||||
std::fs::write(
|
std::fs::write(
|
||||||
file.path(),
|
file.path(),
|
||||||
r#"{
|
r#"{
|
||||||
@ -1330,12 +1336,12 @@ mod tests {
|
|||||||
"type": "wechat",
|
"type": "wechat",
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"base_url": "https://ilinkai.weixin.qq.com",
|
"base_url": "https://ilinkai.weixin.qq.com",
|
||||||
"cred_path": "/tmp/wechat-creds.json",
|
"cred_path": "<CRED_PATH>",
|
||||||
"force_login": true,
|
"force_login": true,
|
||||||
"allow_from": ["wxid_1"]
|
"allow_from": ["wxid_1"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}"#,
|
}"#.replace("<CRED_PATH>", &cred_path_json),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1344,7 +1350,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
|
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
|
||||||
assert!(config.channels["wechat_main"].enabled());
|
assert!(config.channels["wechat_main"].enabled());
|
||||||
assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json");
|
assert_eq!(wechat.cred_path, cred_path.display().to_string());
|
||||||
assert!(wechat.force_login);
|
assert!(wechat.force_login);
|
||||||
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
|
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,14 +20,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
// 使用临时目录确保跨平台兼容
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let media_a = temp_dir.path().join("a.png");
|
||||||
|
let media_b = temp_dir.path().join("b.pdf");
|
||||||
|
let media_refs = vec![media_a.display().to_string(), media_b.display().to_string()];
|
||||||
|
|
||||||
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
||||||
|
|
||||||
assert_eq!(
|
// 验证 JSON 格式正确
|
||||||
enriched,
|
assert!(enriched.starts_with("hello\n\nmedia_refs_json: "));
|
||||||
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
assert!(enriched.contains("a.png"));
|
||||||
);
|
assert!(enriched.contains("b.pdf"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -2,6 +2,7 @@ use std::fs;
|
|||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
|
use crate::platform::atomic_rename;
|
||||||
|
|
||||||
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||||
|
|
||||||
@ -85,7 +86,9 @@ fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> {
|
|||||||
let temp_path = path.with_extension("md.tmp");
|
let temp_path = path.with_extension("md.tmp");
|
||||||
fs::write(&temp_path, normalized)
|
fs::write(&temp_path, normalized)
|
||||||
.map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
|
.map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
|
||||||
fs::rename(&temp_path, path)
|
|
||||||
|
// 使用平台抽象的原子重命名
|
||||||
|
atomic_rename(&temp_path, path)
|
||||||
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
|
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -114,7 +114,11 @@ mod tests {
|
|||||||
&context,
|
&context,
|
||||||
SessionSendRequest {
|
SessionSendRequest {
|
||||||
text: Some("hello".to_string()),
|
text: Some("hello".to_string()),
|
||||||
attachments: vec![MediaItem::new("/tmp/demo.png", "image")],
|
// 使用临时目录确保跨平台兼容
|
||||||
|
attachments: vec![MediaItem::new(
|
||||||
|
&std::env::temp_dir().join("demo.png").display().to_string(),
|
||||||
|
"image"
|
||||||
|
)],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@ -9,6 +9,7 @@ pub mod domain;
|
|||||||
pub mod gateway;
|
pub mod gateway;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
|
pub mod platform;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
pub mod providers;
|
pub mod providers;
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
|
|||||||
251
src/platform/mod.rs
Normal file
251
src/platform/mod.rs
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
//! Platform abstraction layer for cross-platform compatibility.
|
||||||
|
//!
|
||||||
|
//! This module provides unified interfaces for platform-specific operations,
|
||||||
|
//! making it easy to add support for new platforms by modifying only this file.
|
||||||
|
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::io;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
/// Supported platform types.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum Platform {
|
||||||
|
Windows,
|
||||||
|
Unix,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Platform {
|
||||||
|
/// Detect the current platform.
|
||||||
|
pub fn current() -> Self {
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
Platform::Windows
|
||||||
|
} else {
|
||||||
|
Platform::Unix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if running on Windows.
|
||||||
|
pub fn is_windows() -> bool {
|
||||||
|
cfg!(target_os = "windows")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shell information for command execution.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ShellInfo {
|
||||||
|
/// Tool name exposed to LLM.
|
||||||
|
pub name: &'static str,
|
||||||
|
/// Shell executable name.
|
||||||
|
pub executable: &'static str,
|
||||||
|
/// Arguments to pass before the command.
|
||||||
|
pub args: &'static [&'static str],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShellInfo {
|
||||||
|
/// Get the default shell for the current platform.
|
||||||
|
pub fn default() -> Self {
|
||||||
|
Self::for_platform(Platform::current())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get shell info for a specific platform.
|
||||||
|
pub fn for_platform(platform: Platform) -> Self {
|
||||||
|
match platform {
|
||||||
|
Platform::Windows => ShellInfo {
|
||||||
|
name: "shell",
|
||||||
|
executable: "powershell",
|
||||||
|
args: &["-Command"],
|
||||||
|
},
|
||||||
|
Platform::Unix => ShellInfo {
|
||||||
|
name: "bash",
|
||||||
|
executable: "bash",
|
||||||
|
args: &["-c"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Alternative shells available on the platform.
|
||||||
|
pub fn available_shells(platform: Platform) -> Vec<ShellInfo> {
|
||||||
|
match platform {
|
||||||
|
Platform::Windows => vec![
|
||||||
|
ShellInfo {
|
||||||
|
name: "shell",
|
||||||
|
executable: "powershell",
|
||||||
|
args: &["-Command"],
|
||||||
|
},
|
||||||
|
ShellInfo {
|
||||||
|
name: "shell",
|
||||||
|
executable: "cmd",
|
||||||
|
args: &["/C"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
Platform::Unix => vec![
|
||||||
|
ShellInfo {
|
||||||
|
name: "bash",
|
||||||
|
executable: "bash",
|
||||||
|
args: &["-c"],
|
||||||
|
},
|
||||||
|
// Future: could add zsh, fish, sh
|
||||||
|
// ShellInfo { name: "zsh", executable: "zsh", args: &["-c"] },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dangerous command patterns for safety guards.
|
||||||
|
pub fn dangerous_command_patterns() -> Vec<String> {
|
||||||
|
vec![
|
||||||
|
// Unix dangerous commands
|
||||||
|
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
||||||
|
r"\bchmod\s+-[Rr]".to_string(),
|
||||||
|
r"\bchown\s+-[Rr]".to_string(),
|
||||||
|
// Windows dangerous commands
|
||||||
|
r"\bdel\s+/[fq]\b".to_string(),
|
||||||
|
r"\brmdir\s+/s\b".to_string(),
|
||||||
|
r"\bformat\s+".to_string(),
|
||||||
|
// PowerShell dangerous commands
|
||||||
|
r"\bRemove-Item\s+.*-Recurse".to_string(),
|
||||||
|
r"\bRemove-Item\s+.*-Force".to_string(),
|
||||||
|
// Fork bomb (cross-platform)
|
||||||
|
r":\(\)\s*\{.*\};\s*:".to_string(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the user's home directory.
|
||||||
|
///
|
||||||
|
/// Supports environment variable overrides for testing:
|
||||||
|
/// - `HOME` (Unix-style, works on all platforms for testing)
|
||||||
|
/// - `USERPROFILE` (Windows-specific)
|
||||||
|
pub fn home_dir() -> Option<PathBuf> {
|
||||||
|
// Test scenario: support HOME variable override on all platforms
|
||||||
|
env::var_os("HOME")
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.or_else(|| {
|
||||||
|
// Windows: support USERPROFILE
|
||||||
|
env::var_os("USERPROFILE").map(PathBuf::from)
|
||||||
|
})
|
||||||
|
.or_else(|| dirs::home_dir())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomically rename a file, handling platform differences.
|
||||||
|
///
|
||||||
|
/// On Windows, `fs::rename` fails if the destination exists, so we need to
|
||||||
|
/// remove it first. On Unix, rename is atomic and replaces the destination.
|
||||||
|
pub fn atomic_rename(src: &Path, dst: &Path) -> io::Result<()> {
|
||||||
|
if Platform::is_windows() && dst.exists() {
|
||||||
|
fs::remove_file(dst)?;
|
||||||
|
}
|
||||||
|
fs::rename(src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a filesystem path to a file:// URI.
|
||||||
|
///
|
||||||
|
/// Handles platform-specific path formats:
|
||||||
|
/// - Unix: `/path/to/file` -> `file:///path/to/file`
|
||||||
|
/// - Windows: `C:\path\to\file` -> `file:///C:/path/to/file`
|
||||||
|
pub fn path_to_uri(path: &Path) -> String {
|
||||||
|
let path_str = path.display().to_string();
|
||||||
|
if Platform::is_windows() {
|
||||||
|
// Windows paths use backslashes which must be converted to forward slashes
|
||||||
|
let normalized = path_str.replace('\\', "/");
|
||||||
|
format!("file:///{}", normalized)
|
||||||
|
} else {
|
||||||
|
format!("file://{}", path_str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// XML escape utility.
|
||||||
|
pub fn xml_escape(value: &str) -> String {
|
||||||
|
value
|
||||||
|
.replace('&', "&")
|
||||||
|
.replace('<', "<")
|
||||||
|
.replace('>', ">")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_platform_detect() {
|
||||||
|
let platform = Platform::current();
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
assert_eq!(platform, Platform::Windows);
|
||||||
|
} else {
|
||||||
|
assert_eq!(platform, Platform::Unix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shell_info_default() {
|
||||||
|
let shell = ShellInfo::default();
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
assert_eq!(shell.executable, "powershell");
|
||||||
|
assert_eq!(shell.args, &["-Command"]);
|
||||||
|
} else {
|
||||||
|
assert_eq!(shell.executable, "bash");
|
||||||
|
assert_eq!(shell.args, &["-c"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shell_info_for_platform() {
|
||||||
|
let win_shell = ShellInfo::for_platform(Platform::Windows);
|
||||||
|
assert_eq!(win_shell.executable, "powershell");
|
||||||
|
|
||||||
|
let unix_shell = ShellInfo::for_platform(Platform::Unix);
|
||||||
|
assert_eq!(unix_shell.executable, "bash");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_path_to_uri() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let test_path = temp_dir.path().join("test.txt");
|
||||||
|
let uri = path_to_uri(&test_path);
|
||||||
|
|
||||||
|
assert!(uri.starts_with("file://"));
|
||||||
|
assert!(uri.contains("test.txt"));
|
||||||
|
assert!(!uri.contains('\\')); // No backslashes
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_path_to_uri_windows_format() {
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
let win_path = PathBuf::from("C:\\Users\\test\\file.txt");
|
||||||
|
let uri = path_to_uri(&win_path);
|
||||||
|
assert!(uri.starts_with("file:///C:/"));
|
||||||
|
assert_eq!(uri, "file:///C:/Users/test/file.txt");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_atomic_rename() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let src = temp_dir.path().join("source.txt");
|
||||||
|
let dst = temp_dir.path().join("dest.txt");
|
||||||
|
|
||||||
|
fs::write(&src, "content").unwrap();
|
||||||
|
fs::write(&dst, "old content").unwrap();
|
||||||
|
|
||||||
|
atomic_rename(&src, &dst).unwrap();
|
||||||
|
|
||||||
|
assert!(!src.exists());
|
||||||
|
assert!(dst.exists());
|
||||||
|
assert_eq!(fs::read_to_string(&dst).unwrap(), "content");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dangerous_patterns() {
|
||||||
|
let patterns = dangerous_command_patterns();
|
||||||
|
assert!(!patterns.is_empty());
|
||||||
|
// Should contain patterns for both platforms
|
||||||
|
assert!(patterns.iter().any(|p| p.contains("rm")));
|
||||||
|
assert!(patterns.iter().any(|p| p.contains("del")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_xml_escape() {
|
||||||
|
assert_eq!(xml_escape("a & b"), "a & b");
|
||||||
|
assert_eq!(xml_escape("<tag>"), "<tag>");
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
use crate::platform::{atomic_rename, home_dir as platform_home_dir, path_to_uri, xml_escape as platform_xml_escape};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
@ -439,9 +440,9 @@ impl SkillCatalog {
|
|||||||
for skill in self.skills.iter().take(self.max_listed_skills) {
|
for skill in self.skills.iter().take(self.max_listed_skills) {
|
||||||
let entry = format!(
|
let entry = format!(
|
||||||
" <skill>\n <name>{}</name>\n <description>{}</description>\n <location>{}</location>\n </skill>\n",
|
" <skill>\n <name>{}</name>\n <description>{}</description>\n <location>{}</location>\n </skill>\n",
|
||||||
xml_escape(&skill.name),
|
platform_xml_escape(&skill.name),
|
||||||
xml_escape(&skill.description),
|
platform_xml_escape(&skill.description),
|
||||||
xml_escape(&format!("file://{}", skill.path.display())),
|
platform_xml_escape(&path_to_uri(&skill.path)),
|
||||||
);
|
);
|
||||||
if prompt.len() + entry.len() + "</available_skills>\n".len() > self.max_index_chars {
|
if prompt.len() + entry.len() + "</available_skills>\n".len() > self.max_index_chars {
|
||||||
prompt.push_str(" <truncated>true</truncated>\n");
|
prompt.push_str(" <truncated>true</truncated>\n");
|
||||||
@ -627,27 +628,20 @@ fn project_skill_state_path(cwd: &Path) -> PathBuf {
|
|||||||
cwd.join(".picobot").join("skill-state.json")
|
cwd.join(".picobot").join("skill-state.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn home_dir() -> Option<PathBuf> {
|
|
||||||
// First check HOME environment variable (useful for testing)
|
|
||||||
std::env::var_os("HOME")
|
|
||||||
.map(PathBuf::from)
|
|
||||||
.or_else(|| dirs::home_dir())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn user_skills_root() -> Option<PathBuf> {
|
fn user_skills_root() -> Option<PathBuf> {
|
||||||
home_dir().map(|p| p.join(".picobot").join("skills"))
|
platform_home_dir().map(|p| p.join(".picobot").join("skills"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn user_skill_state_path() -> Option<PathBuf> {
|
fn user_skill_state_path() -> Option<PathBuf> {
|
||||||
home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
|
platform_home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn user_agent_skills_root() -> Option<PathBuf> {
|
fn user_agent_skills_root() -> Option<PathBuf> {
|
||||||
home_dir().map(|p| p.join(".agents").join("skills"))
|
platform_home_dir().map(|p| p.join(".agents").join("skills"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn user_openclaw_skills_root() -> Option<PathBuf> {
|
fn user_openclaw_skills_root() -> Option<PathBuf> {
|
||||||
home_dir().map(|p| p.join(".openclaw").join("skills"))
|
platform_home_dir().map(|p| p.join(".openclaw").join("skills"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
|
fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
|
||||||
@ -746,7 +740,9 @@ fn save_skill_state_file(path: &Path, state: &SkillStateFile) -> Result<(), Stri
|
|||||||
let tmp_path = path.with_extension("json.tmp");
|
let tmp_path = path.with_extension("json.tmp");
|
||||||
fs::write(&tmp_path, format!("{}\n", content))
|
fs::write(&tmp_path, format!("{}\n", content))
|
||||||
.map_err(|err| format!("failed to write temporary skill state file: {}", err))?;
|
.map_err(|err| format!("failed to write temporary skill state file: {}", err))?;
|
||||||
fs::rename(&tmp_path, path)
|
|
||||||
|
// 使用平台抽象的原子重命名
|
||||||
|
atomic_rename(&tmp_path, path)
|
||||||
.map_err(|err| format!("failed to persist skill state file: {}", err))
|
.map_err(|err| format!("failed to persist skill state file: {}", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -867,12 +863,7 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
|
|||||||
Some((frontmatter, body))
|
Some((frontmatter, body))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn xml_escape(value: &str) -> String {
|
// 使用 platform 模块提供的 xml_escape 和 path_to_uri 函数
|
||||||
value
|
|
||||||
.replace('&', "&")
|
|
||||||
.replace('<', "<")
|
|
||||||
.replace('>', ">")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@ -885,6 +876,7 @@ mod tests {
|
|||||||
|
|
||||||
struct HomeDirGuard {
|
struct HomeDirGuard {
|
||||||
previous: Option<OsString>,
|
previous: Option<OsString>,
|
||||||
|
previous_userprofile: Option<OsString>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CurrentDirGuard {
|
impl CurrentDirGuard {
|
||||||
@ -903,23 +895,33 @@ mod tests {
|
|||||||
|
|
||||||
impl HomeDirGuard {
|
impl HomeDirGuard {
|
||||||
fn enter(path: &Path) -> Self {
|
fn enter(path: &Path) -> Self {
|
||||||
let previous = std::env::var_os("HOME");
|
let home_backup = std::env::var_os("HOME");
|
||||||
|
let userprofile_backup = std::env::var_os("USERPROFILE");
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::set_var("HOME", path);
|
std::env::set_var("HOME", path);
|
||||||
|
// Windows 环境下同时设置 USERPROFILE
|
||||||
|
std::env::set_var("USERPROFILE", path);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
previous: home_backup,
|
||||||
|
previous_userprofile: userprofile_backup,
|
||||||
}
|
}
|
||||||
Self { previous }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for HomeDirGuard {
|
impl Drop for HomeDirGuard {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
match &self.previous {
|
unsafe {
|
||||||
Some(value) => unsafe {
|
match &self.previous {
|
||||||
std::env::set_var("HOME", value);
|
Some(value) => std::env::set_var("HOME", value),
|
||||||
},
|
None => std::env::remove_var("HOME"),
|
||||||
None => unsafe {
|
}
|
||||||
std::env::remove_var("HOME");
|
match &self.previous_userprofile {
|
||||||
},
|
Some(value) => std::env::set_var("USERPROFILE", value),
|
||||||
|
None => std::env::remove_var("USERPROFILE"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -974,13 +976,17 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_system_index_prompt_uses_available_skills_markup() {
|
fn test_system_index_prompt_uses_available_skills_markup() {
|
||||||
|
// 使用临时目录创建测试路径,确保跨平台兼容
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let skill_path = temp_dir.path().join("demo-skill").join("SKILL.md");
|
||||||
|
|
||||||
let catalog = SkillCatalog {
|
let catalog = SkillCatalog {
|
||||||
skills: vec![Skill {
|
skills: vec![Skill {
|
||||||
name: "demo-skill".to_string(),
|
name: "demo-skill".to_string(),
|
||||||
description: "demo <skill> & usage".to_string(),
|
description: "demo <skill> & usage".to_string(),
|
||||||
body: String::new(),
|
body: String::new(),
|
||||||
source: SkillSource::Project,
|
source: SkillSource::Project,
|
||||||
path: PathBuf::from("/tmp/demo-skill/SKILL.md"),
|
path: skill_path.clone(),
|
||||||
}],
|
}],
|
||||||
max_index_chars: 4000,
|
max_index_chars: 4000,
|
||||||
max_listed_skills: 32,
|
max_listed_skills: 32,
|
||||||
@ -991,10 +997,35 @@ mod tests {
|
|||||||
assert!(prompt.contains("技能为特定任务提供专用说明和工作流。"));
|
assert!(prompt.contains("技能为特定任务提供专用说明和工作流。"));
|
||||||
assert!(prompt.contains("<name>demo-skill</name>"));
|
assert!(prompt.contains("<name>demo-skill</name>"));
|
||||||
assert!(prompt.contains("<description>demo <skill> & usage</description>"));
|
assert!(prompt.contains("<description>demo <skill> & usage</description>"));
|
||||||
assert!(prompt.contains("<location>file:///tmp/demo-skill/SKILL.md</location>"));
|
|
||||||
|
// 验证 location 包含正确的 file:// URI 格式
|
||||||
|
let expected_uri = path_to_uri(&skill_path);
|
||||||
|
assert!(prompt.contains(&format!("<location>{}</location>", platform_xml_escape(&expected_uri))));
|
||||||
assert!(prompt.contains("</available_skills>"));
|
assert!(prompt.contains("</available_skills>"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_path_to_uri() {
|
||||||
|
// Unix 路径
|
||||||
|
let unix_path = PathBuf::from("/tmp/demo-skill/SKILL.md");
|
||||||
|
let unix_uri = path_to_uri(&unix_path);
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
// Windows 上运行时,路径可能被转换
|
||||||
|
assert!(unix_uri.contains("file://"));
|
||||||
|
} else {
|
||||||
|
assert_eq!(unix_uri, "file:///tmp/demo-skill/SKILL.md");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Windows 路径格式测试(仅在 Windows 上)
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
let win_path = PathBuf::from("C:\\Users\\test\\.picobot\\skills\\demo\\SKILL.md");
|
||||||
|
let win_uri = path_to_uri(&win_path);
|
||||||
|
assert!(win_uri.starts_with("file:///C:/"));
|
||||||
|
assert!(win_uri.contains("/SKILL.md"));
|
||||||
|
assert!(!win_uri.contains('\\')); // 不应包含反斜杠
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_runtime_create_update_delete_reload() {
|
fn test_runtime_create_update_delete_reload() {
|
||||||
let _lock = acquire_test_lock();
|
let _lock = acquire_test_lock();
|
||||||
|
|||||||
@ -10,6 +10,7 @@ use tokio::process::Command;
|
|||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use tokio::time::{Instant, sleep_until};
|
use tokio::time::{Instant, sleep_until};
|
||||||
|
|
||||||
|
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||||
@ -18,10 +19,90 @@ const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
|||||||
const USER_ACTION_HINT: &str =
|
const USER_ACTION_HINT: &str =
|
||||||
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||||
|
|
||||||
|
/// Shell 类型枚举,支持跨平台
|
||||||
|
///
|
||||||
|
/// 这是 ShellInfo 的兼容包装,提供更方便的 API。
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ShellKind {
|
||||||
|
Bash,
|
||||||
|
PowerShell,
|
||||||
|
Cmd,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShellKind {
|
||||||
|
/// 根据平台检测默认 shell
|
||||||
|
pub fn detect() -> Self {
|
||||||
|
let info = ShellInfo::default();
|
||||||
|
match info.executable {
|
||||||
|
"bash" => ShellKind::Bash,
|
||||||
|
"powershell" => ShellKind::PowerShell,
|
||||||
|
"cmd" => ShellKind::Cmd,
|
||||||
|
_ => ShellKind::Bash, // fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 ShellInfo 获取 ShellKind
|
||||||
|
pub fn from_info(info: &ShellInfo) -> Self {
|
||||||
|
match info.executable {
|
||||||
|
"bash" => ShellKind::Bash,
|
||||||
|
"powershell" => ShellKind::PowerShell,
|
||||||
|
"cmd" => ShellKind::Cmd,
|
||||||
|
_ => ShellKind::Bash,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取对应的 ShellInfo
|
||||||
|
pub fn to_info(&self) -> ShellInfo {
|
||||||
|
match self {
|
||||||
|
ShellKind::Bash => ShellInfo {
|
||||||
|
name: "bash",
|
||||||
|
executable: "bash",
|
||||||
|
args: &["-c"],
|
||||||
|
},
|
||||||
|
ShellKind::PowerShell => ShellInfo {
|
||||||
|
name: "shell",
|
||||||
|
executable: "powershell",
|
||||||
|
args: &["-Command"],
|
||||||
|
},
|
||||||
|
ShellKind::Cmd => ShellInfo {
|
||||||
|
name: "shell",
|
||||||
|
executable: "cmd",
|
||||||
|
args: &["/C"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shell 可执行文件名
|
||||||
|
pub fn executable(&self) -> &'static str {
|
||||||
|
self.to_info().executable
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行命令所需的参数
|
||||||
|
pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> {
|
||||||
|
let info = self.to_info();
|
||||||
|
info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具名称
|
||||||
|
pub fn tool_name(&self) -> &'static str {
|
||||||
|
self.to_info().name
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具描述
|
||||||
|
pub fn tool_description(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.",
|
||||||
|
ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.",
|
||||||
|
ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct BashTool {
|
pub struct BashTool {
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
working_dir: Option<String>,
|
working_dir: Option<String>,
|
||||||
deny_patterns: Vec<String>,
|
deny_patterns: Vec<String>,
|
||||||
|
shell: ShellKind,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BashTool {
|
impl BashTool {
|
||||||
@ -29,12 +110,8 @@ impl BashTool {
|
|||||||
Self {
|
Self {
|
||||||
timeout_secs: 60,
|
timeout_secs: 60,
|
||||||
working_dir: None,
|
working_dir: None,
|
||||||
deny_patterns: vec![
|
deny_patterns: dangerous_command_patterns(),
|
||||||
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
shell: ShellKind::detect(),
|
||||||
r"\bdel\s+/[fq]\b".to_string(),
|
|
||||||
r"\brmdir\s+/s\b".to_string(),
|
|
||||||
r":\(\)\s*\{.*\};\s*:".to_string(),
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,6 +125,11 @@ impl BashTool {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_shell(mut self, shell: ShellKind) -> Self {
|
||||||
|
self.shell = shell;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
fn guard_command(&self, command: &str) -> Option<String> {
|
fn guard_command(&self, command: &str) -> Option<String> {
|
||||||
let lower = command.to_lowercase();
|
let lower = command.to_lowercase();
|
||||||
for pattern in &self.deny_patterns {
|
for pattern in &self.deny_patterns {
|
||||||
@ -138,11 +220,11 @@ impl Default for BashTool {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for BashTool {
|
impl Tool for BashTool {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
"bash"
|
self.shell.tool_name()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Execute a bash shell command and return its output. Use with caution."
|
self.shell.tool_description()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
@ -235,8 +317,8 @@ impl BashTool {
|
|||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
let mut cmd = Command::new("bash");
|
let mut cmd = Command::new(self.shell.executable());
|
||||||
cmd.args(["-c", command])
|
cmd.args(self.shell.command_args(command))
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.current_dir(cwd);
|
.current_dir(cwd);
|
||||||
@ -358,8 +440,13 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_simple_command() {
|
async fn test_simple_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
|
let command = if cfg!(target_os = "windows") {
|
||||||
|
"Write-Output 'Hello World'"
|
||||||
|
} else {
|
||||||
|
"echo 'Hello World'"
|
||||||
|
};
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "command": "echo 'Hello World'" }))
|
.execute(json!({ "command": command }))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -370,7 +457,12 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pwd_command() {
|
async fn test_pwd_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
let command = if cfg!(target_os = "windows") {
|
||||||
|
"Get-Location"
|
||||||
|
} else {
|
||||||
|
"pwd"
|
||||||
|
};
|
||||||
|
let result = tool.execute(json!({ "command": command })).await.unwrap();
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
@ -378,8 +470,14 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_ls_command() {
|
async fn test_ls_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
|
let temp_dir = std::env::temp_dir();
|
||||||
|
let command = if cfg!(target_os = "windows") {
|
||||||
|
format!("Get-ChildItem {}", temp_dir.display())
|
||||||
|
} else {
|
||||||
|
format!("ls -la {}", temp_dir.display())
|
||||||
|
};
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "command": "ls -la /tmp" }))
|
.execute(json!({ "command": command }))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -389,8 +487,22 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_dangerous_rm() {
|
async fn test_dangerous_rm() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
|
// 测试 Unix 危险命令模式
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "command": "rm -rf /" }))
|
.execute(json!({ "command": "rm -rf /some/path" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("blocked"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_dangerous_windows_commands() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
// 测试 Windows del 命令模式(正则应该匹配)
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "command": "del /f /q file.txt" }))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -422,9 +534,14 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timeout() {
|
async fn test_timeout() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
|
let command = if cfg!(target_os = "windows") {
|
||||||
|
"Start-Sleep -Seconds 10"
|
||||||
|
} else {
|
||||||
|
"sleep 10"
|
||||||
|
};
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({
|
.execute(json!({
|
||||||
"command": "sleep 10",
|
"command": command,
|
||||||
"timeout": 1
|
"timeout": 1
|
||||||
}))
|
}))
|
||||||
.await
|
.await
|
||||||
@ -437,17 +554,22 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pending_user_action_detection() {
|
async fn test_pending_user_action_detection() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
|
let command = if cfg!(target_os = "windows") {
|
||||||
|
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
|
||||||
|
} else {
|
||||||
|
"printf 'waiting for authorization'; sleep 10"
|
||||||
|
};
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({
|
.execute(json!({
|
||||||
"command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10",
|
"command": command,
|
||||||
"timeout": 1
|
"timeout": 1,
|
||||||
|
"interactive": true
|
||||||
}))
|
}))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
|
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
|
||||||
assert!(result.output.contains("等待用户授权"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -460,4 +582,28 @@ mod tests {
|
|||||||
assert!(output.contains("chars truncated"));
|
assert!(output.contains("chars truncated"));
|
||||||
assert!(output.is_char_boundary(output.len()));
|
assert!(output.is_char_boundary(output.len()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shell_kind_detect() {
|
||||||
|
let shell = ShellKind::detect();
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
assert_eq!(shell, ShellKind::PowerShell);
|
||||||
|
} else {
|
||||||
|
assert_eq!(shell, ShellKind::Bash);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shell_kind_executable() {
|
||||||
|
assert_eq!(ShellKind::Bash.executable(), "bash");
|
||||||
|
assert_eq!(ShellKind::PowerShell.executable(), "powershell");
|
||||||
|
assert_eq!(ShellKind::Cmd.executable(), "cmd");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shell_kind_command_args() {
|
||||||
|
assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]);
|
||||||
|
assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]);
|
||||||
|
assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -204,8 +204,11 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_write_missing_content() {
|
async fn test_write_missing_content() {
|
||||||
let tool = FileWriteTool::new();
|
let tool = FileWriteTool::new();
|
||||||
|
// 使用临时目录确保跨平台兼容
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let test_path = temp_dir.path().join("test.txt");
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({ "path": "/tmp/test.txt" }))
|
.execute(json!({ "path": test_path.to_str().unwrap() }))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@ -349,6 +349,7 @@ mod tests {
|
|||||||
|
|
||||||
struct HomeDirGuard {
|
struct HomeDirGuard {
|
||||||
previous: Option<OsString>,
|
previous: Option<OsString>,
|
||||||
|
previous_userprofile: Option<OsString>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CurrentDirGuard {
|
impl CurrentDirGuard {
|
||||||
@ -367,23 +368,32 @@ mod tests {
|
|||||||
|
|
||||||
impl HomeDirGuard {
|
impl HomeDirGuard {
|
||||||
fn enter(path: &Path) -> Self {
|
fn enter(path: &Path) -> Self {
|
||||||
let previous = std::env::var_os("HOME");
|
let home_backup = std::env::var_os("HOME");
|
||||||
|
let userprofile_backup = std::env::var_os("USERPROFILE");
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::set_var("HOME", path);
|
std::env::set_var("HOME", path);
|
||||||
|
std::env::set_var("USERPROFILE", path);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
previous: home_backup,
|
||||||
|
previous_userprofile: userprofile_backup,
|
||||||
}
|
}
|
||||||
Self { previous }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for HomeDirGuard {
|
impl Drop for HomeDirGuard {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
match &self.previous {
|
unsafe {
|
||||||
Some(value) => unsafe {
|
match &self.previous {
|
||||||
std::env::set_var("HOME", value);
|
Some(value) => std::env::set_var("HOME", value),
|
||||||
},
|
None => std::env::remove_var("HOME"),
|
||||||
None => unsafe {
|
}
|
||||||
std::env::remove_var("HOME");
|
match &self.previous_userprofile {
|
||||||
},
|
Some(value) => std::env::set_var("USERPROFILE", value),
|
||||||
|
None => std::env::remove_var("USERPROFILE"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user