feat: 添加平台抽象层,支持跨平台兼容性;更新多个模块以使用临时目录和平台特定路径

This commit is contained in:
oudecheng 2026-05-09 16:59:58 +08:00
parent e0a7f67dab
commit f4758f8513
11 changed files with 537 additions and 72 deletions

View File

@ -210,6 +210,12 @@ mod tests {
#[tokio::test]
async fn init_registers_wechat_channel_by_instance_name() {
let file = tempfile::NamedTempFile::new().unwrap();
// 使用临时目录确保跨平台兼容
let temp_dir = tempfile::tempdir().unwrap();
let cred_path = temp_dir.path().join("wechat-creds.json");
// JSON 中的路径需要转义反斜杠
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
std::fs::write(
file.path(),
r#"{
@ -236,10 +242,10 @@ mod tests {
"wechat_main": {
"type": "wechat",
"enabled": true,
"cred_path": "/tmp/wechat-creds.json"
"cred_path": "<CRED_PATH>"
}
}
}"#,
}"#.replace("<CRED_PATH>", &cred_path_json),
)
.unwrap();

View File

@ -1303,6 +1303,12 @@ mod tests {
#[test]
fn test_tagged_wechat_channel_config_loads() {
let file = tempfile::NamedTempFile::new().unwrap();
// 使用临时文件路径确保跨平台兼容
let temp_dir = tempfile::tempdir().unwrap();
let cred_path = temp_dir.path().join("wechat-creds.json");
// JSON 中的路径需要转义反斜杠
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
std::fs::write(
file.path(),
r#"{
@ -1330,12 +1336,12 @@ mod tests {
"type": "wechat",
"enabled": true,
"base_url": "https://ilinkai.weixin.qq.com",
"cred_path": "/tmp/wechat-creds.json",
"cred_path": "<CRED_PATH>",
"force_login": true,
"allow_from": ["wxid_1"]
}
}
}"#,
}"#.replace("<CRED_PATH>", &cred_path_json),
)
.unwrap();
@ -1344,7 +1350,7 @@ mod tests {
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
assert!(config.channels["wechat_main"].enabled());
assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json");
assert_eq!(wechat.cred_path, cred_path.display().to_string());
assert!(wechat.force_login);
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
}

View File

@ -20,14 +20,18 @@ mod tests {
#[test]
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
// 使用临时目录确保跨平台兼容
let temp_dir = tempfile::tempdir().unwrap();
let media_a = temp_dir.path().join("a.png");
let media_b = temp_dir.path().join("b.pdf");
let media_refs = vec![media_a.display().to_string(), media_b.display().to_string()];
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
assert_eq!(
enriched,
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
);
// 验证 JSON 格式正确
assert!(enriched.starts_with("hello\n\nmedia_refs_json: "));
assert!(enriched.contains("a.png"));
assert!(enriched.contains("b.pdf"));
}
#[test]

View File

@ -2,6 +2,7 @@ use std::fs;
use std::path::{Path, PathBuf};
use crate::agent::AgentError;
use crate::platform::atomic_rename;
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
@ -85,7 +86,9 @@ fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> {
let temp_path = path.with_extension("md.tmp");
fs::write(&temp_path, normalized)
.map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
fs::rename(&temp_path, path)
// 使用平台抽象的原子重命名
atomic_rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
Ok(())
}

View File

@ -114,7 +114,11 @@ mod tests {
&context,
SessionSendRequest {
text: Some("hello".to_string()),
attachments: vec![MediaItem::new("/tmp/demo.png", "image")],
// 使用临时目录确保跨平台兼容
attachments: vec![MediaItem::new(
&std::env::temp_dir().join("demo.png").display().to_string(),
"image"
)],
},
)
.await

View File

@ -9,6 +9,7 @@ pub mod domain;
pub mod gateway;
pub mod logging;
pub mod observability;
pub mod platform;
pub mod protocol;
pub mod providers;
pub mod scheduler;

251
src/platform/mod.rs Normal file
View File

@ -0,0 +1,251 @@
//! Platform abstraction layer for cross-platform compatibility.
//!
//! This module provides unified interfaces for platform-specific operations,
//! making it easy to add support for new platforms by modifying only this file.
use std::env;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
/// Supported platform types.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Platform {
Windows,
Unix,
}
impl Platform {
/// Detect the current platform.
pub fn current() -> Self {
if cfg!(target_os = "windows") {
Platform::Windows
} else {
Platform::Unix
}
}
/// Check if running on Windows.
pub fn is_windows() -> bool {
cfg!(target_os = "windows")
}
}
/// Shell information for command execution.
#[derive(Debug, Clone)]
pub struct ShellInfo {
/// Tool name exposed to LLM.
pub name: &'static str,
/// Shell executable name.
pub executable: &'static str,
/// Arguments to pass before the command.
pub args: &'static [&'static str],
}
impl ShellInfo {
/// Get the default shell for the current platform.
pub fn default() -> Self {
Self::for_platform(Platform::current())
}
/// Get shell info for a specific platform.
pub fn for_platform(platform: Platform) -> Self {
match platform {
Platform::Windows => ShellInfo {
name: "shell",
executable: "powershell",
args: &["-Command"],
},
Platform::Unix => ShellInfo {
name: "bash",
executable: "bash",
args: &["-c"],
},
}
}
/// Alternative shells available on the platform.
pub fn available_shells(platform: Platform) -> Vec<ShellInfo> {
match platform {
Platform::Windows => vec![
ShellInfo {
name: "shell",
executable: "powershell",
args: &["-Command"],
},
ShellInfo {
name: "shell",
executable: "cmd",
args: &["/C"],
},
],
Platform::Unix => vec![
ShellInfo {
name: "bash",
executable: "bash",
args: &["-c"],
},
// Future: could add zsh, fish, sh
// ShellInfo { name: "zsh", executable: "zsh", args: &["-c"] },
],
}
}
}
/// Dangerous command patterns for safety guards.
pub fn dangerous_command_patterns() -> Vec<String> {
vec![
// Unix dangerous commands
r"\brm\s+-[rf]{1,2}\b".to_string(),
r"\bchmod\s+-[Rr]".to_string(),
r"\bchown\s+-[Rr]".to_string(),
// Windows dangerous commands
r"\bdel\s+/[fq]\b".to_string(),
r"\brmdir\s+/s\b".to_string(),
r"\bformat\s+".to_string(),
// PowerShell dangerous commands
r"\bRemove-Item\s+.*-Recurse".to_string(),
r"\bRemove-Item\s+.*-Force".to_string(),
// Fork bomb (cross-platform)
r":\(\)\s*\{.*\};\s*:".to_string(),
]
}
/// Get the user's home directory.
///
/// Supports environment variable overrides for testing:
/// - `HOME` (Unix-style, works on all platforms for testing)
/// - `USERPROFILE` (Windows-specific)
pub fn home_dir() -> Option<PathBuf> {
// Test scenario: support HOME variable override on all platforms
env::var_os("HOME")
.map(PathBuf::from)
.or_else(|| {
// Windows: support USERPROFILE
env::var_os("USERPROFILE").map(PathBuf::from)
})
.or_else(|| dirs::home_dir())
}
/// Atomically rename a file, handling platform differences.
///
/// On Windows, `fs::rename` fails if the destination exists, so we need to
/// remove it first. On Unix, rename is atomic and replaces the destination.
pub fn atomic_rename(src: &Path, dst: &Path) -> io::Result<()> {
if Platform::is_windows() && dst.exists() {
fs::remove_file(dst)?;
}
fs::rename(src, dst)
}
/// Convert a filesystem path to a file:// URI.
///
/// Handles platform-specific path formats:
/// - Unix: `/path/to/file` -> `file:///path/to/file`
/// - Windows: `C:\path\to\file` -> `file:///C:/path/to/file`
pub fn path_to_uri(path: &Path) -> String {
let path_str = path.display().to_string();
if Platform::is_windows() {
// Windows paths use backslashes which must be converted to forward slashes
let normalized = path_str.replace('\\', "/");
format!("file:///{}", normalized)
} else {
format!("file://{}", path_str)
}
}
/// XML escape utility.
pub fn xml_escape(value: &str) -> String {
value
.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_platform_detect() {
let platform = Platform::current();
if cfg!(target_os = "windows") {
assert_eq!(platform, Platform::Windows);
} else {
assert_eq!(platform, Platform::Unix);
}
}
#[test]
fn test_shell_info_default() {
let shell = ShellInfo::default();
if cfg!(target_os = "windows") {
assert_eq!(shell.executable, "powershell");
assert_eq!(shell.args, &["-Command"]);
} else {
assert_eq!(shell.executable, "bash");
assert_eq!(shell.args, &["-c"]);
}
}
#[test]
fn test_shell_info_for_platform() {
let win_shell = ShellInfo::for_platform(Platform::Windows);
assert_eq!(win_shell.executable, "powershell");
let unix_shell = ShellInfo::for_platform(Platform::Unix);
assert_eq!(unix_shell.executable, "bash");
}
#[test]
fn test_path_to_uri() {
let temp_dir = tempfile::tempdir().unwrap();
let test_path = temp_dir.path().join("test.txt");
let uri = path_to_uri(&test_path);
assert!(uri.starts_with("file://"));
assert!(uri.contains("test.txt"));
assert!(!uri.contains('\\')); // No backslashes
}
#[test]
fn test_path_to_uri_windows_format() {
if cfg!(target_os = "windows") {
let win_path = PathBuf::from("C:\\Users\\test\\file.txt");
let uri = path_to_uri(&win_path);
assert!(uri.starts_with("file:///C:/"));
assert_eq!(uri, "file:///C:/Users/test/file.txt");
}
}
#[test]
fn test_atomic_rename() {
let temp_dir = tempfile::tempdir().unwrap();
let src = temp_dir.path().join("source.txt");
let dst = temp_dir.path().join("dest.txt");
fs::write(&src, "content").unwrap();
fs::write(&dst, "old content").unwrap();
atomic_rename(&src, &dst).unwrap();
assert!(!src.exists());
assert!(dst.exists());
assert_eq!(fs::read_to_string(&dst).unwrap(), "content");
}
#[test]
fn test_dangerous_patterns() {
let patterns = dangerous_command_patterns();
assert!(!patterns.is_empty());
// Should contain patterns for both platforms
assert!(patterns.iter().any(|p| p.contains("rm")));
assert!(patterns.iter().any(|p| p.contains("del")));
}
#[test]
fn test_xml_escape() {
assert_eq!(xml_escape("a & b"), "a &amp; b");
assert_eq!(xml_escape("<tag>"), "&lt;tag&gt;");
}
}

View File

@ -1,3 +1,4 @@
use crate::platform::{atomic_rename, home_dir as platform_home_dir, path_to_uri, xml_escape as platform_xml_escape};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::{HashMap, HashSet};
@ -439,9 +440,9 @@ impl SkillCatalog {
for skill in self.skills.iter().take(self.max_listed_skills) {
let entry = format!(
" <skill>\n <name>{}</name>\n <description>{}</description>\n <location>{}</location>\n </skill>\n",
xml_escape(&skill.name),
xml_escape(&skill.description),
xml_escape(&format!("file://{}", skill.path.display())),
platform_xml_escape(&skill.name),
platform_xml_escape(&skill.description),
platform_xml_escape(&path_to_uri(&skill.path)),
);
if prompt.len() + entry.len() + "</available_skills>\n".len() > self.max_index_chars {
prompt.push_str(" <truncated>true</truncated>\n");
@ -627,27 +628,20 @@ fn project_skill_state_path(cwd: &Path) -> PathBuf {
cwd.join(".picobot").join("skill-state.json")
}
fn home_dir() -> Option<PathBuf> {
// First check HOME environment variable (useful for testing)
std::env::var_os("HOME")
.map(PathBuf::from)
.or_else(|| dirs::home_dir())
}
fn user_skills_root() -> Option<PathBuf> {
home_dir().map(|p| p.join(".picobot").join("skills"))
platform_home_dir().map(|p| p.join(".picobot").join("skills"))
}
fn user_skill_state_path() -> Option<PathBuf> {
home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
platform_home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
}
fn user_agent_skills_root() -> Option<PathBuf> {
home_dir().map(|p| p.join(".agents").join("skills"))
platform_home_dir().map(|p| p.join(".agents").join("skills"))
}
fn user_openclaw_skills_root() -> Option<PathBuf> {
home_dir().map(|p| p.join(".openclaw").join("skills"))
platform_home_dir().map(|p| p.join(".openclaw").join("skills"))
}
fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
@ -746,7 +740,9 @@ fn save_skill_state_file(path: &Path, state: &SkillStateFile) -> Result<(), Stri
let tmp_path = path.with_extension("json.tmp");
fs::write(&tmp_path, format!("{}\n", content))
.map_err(|err| format!("failed to write temporary skill state file: {}", err))?;
fs::rename(&tmp_path, path)
// 使用平台抽象的原子重命名
atomic_rename(&tmp_path, path)
.map_err(|err| format!("failed to persist skill state file: {}", err))
}
@ -867,12 +863,7 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
Some((frontmatter, body))
}
fn xml_escape(value: &str) -> String {
value
.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
}
// 使用 platform 模块提供的 xml_escape 和 path_to_uri 函数
#[cfg(test)]
mod tests {
@ -885,6 +876,7 @@ mod tests {
struct HomeDirGuard {
previous: Option<OsString>,
previous_userprofile: Option<OsString>,
}
impl CurrentDirGuard {
@ -903,23 +895,33 @@ mod tests {
impl HomeDirGuard {
fn enter(path: &Path) -> Self {
let previous = std::env::var_os("HOME");
let home_backup = std::env::var_os("HOME");
let userprofile_backup = std::env::var_os("USERPROFILE");
unsafe {
std::env::set_var("HOME", path);
// Windows 环境下同时设置 USERPROFILE
std::env::set_var("USERPROFILE", path);
}
Self {
previous: home_backup,
previous_userprofile: userprofile_backup,
}
Self { previous }
}
}
impl Drop for HomeDirGuard {
fn drop(&mut self) {
unsafe {
match &self.previous {
Some(value) => unsafe {
std::env::set_var("HOME", value);
},
None => unsafe {
std::env::remove_var("HOME");
},
Some(value) => std::env::set_var("HOME", value),
None => std::env::remove_var("HOME"),
}
match &self.previous_userprofile {
Some(value) => std::env::set_var("USERPROFILE", value),
None => std::env::remove_var("USERPROFILE"),
}
}
}
}
@ -974,13 +976,17 @@ mod tests {
#[test]
fn test_system_index_prompt_uses_available_skills_markup() {
// 使用临时目录创建测试路径,确保跨平台兼容
let temp_dir = tempfile::tempdir().unwrap();
let skill_path = temp_dir.path().join("demo-skill").join("SKILL.md");
let catalog = SkillCatalog {
skills: vec![Skill {
name: "demo-skill".to_string(),
description: "demo <skill> & usage".to_string(),
body: String::new(),
source: SkillSource::Project,
path: PathBuf::from("/tmp/demo-skill/SKILL.md"),
path: skill_path.clone(),
}],
max_index_chars: 4000,
max_listed_skills: 32,
@ -991,10 +997,35 @@ mod tests {
assert!(prompt.contains("技能为特定任务提供专用说明和工作流。"));
assert!(prompt.contains("<name>demo-skill</name>"));
assert!(prompt.contains("<description>demo &lt;skill&gt; &amp; usage</description>"));
assert!(prompt.contains("<location>file:///tmp/demo-skill/SKILL.md</location>"));
// 验证 location 包含正确的 file:// URI 格式
let expected_uri = path_to_uri(&skill_path);
assert!(prompt.contains(&format!("<location>{}</location>", platform_xml_escape(&expected_uri))));
assert!(prompt.contains("</available_skills>"));
}
#[test]
fn test_path_to_uri() {
// Unix 路径
let unix_path = PathBuf::from("/tmp/demo-skill/SKILL.md");
let unix_uri = path_to_uri(&unix_path);
if cfg!(target_os = "windows") {
// Windows 上运行时,路径可能被转换
assert!(unix_uri.contains("file://"));
} else {
assert_eq!(unix_uri, "file:///tmp/demo-skill/SKILL.md");
}
// Windows 路径格式测试(仅在 Windows 上)
if cfg!(target_os = "windows") {
let win_path = PathBuf::from("C:\\Users\\test\\.picobot\\skills\\demo\\SKILL.md");
let win_uri = path_to_uri(&win_path);
assert!(win_uri.starts_with("file:///C:/"));
assert!(win_uri.contains("/SKILL.md"));
assert!(!win_uri.contains('\\')); // 不应包含反斜杠
}
}
#[test]
fn test_runtime_create_update_delete_reload() {
let _lock = acquire_test_lock();

View File

@ -10,6 +10,7 @@ use tokio::process::Command;
use tokio::sync::{Mutex, mpsc};
use tokio::time::{Instant, sleep_until};
use crate::platform::{ShellInfo, dangerous_command_patterns};
use crate::tools::traits::{Tool, ToolResult};
const MAX_TIMEOUT_SECS: u64 = 600;
@ -18,10 +19,90 @@ const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const USER_ACTION_HINT: &str =
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
/// Shell 类型枚举,支持跨平台
///
/// 这是 ShellInfo 的兼容包装,提供更方便的 API。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShellKind {
Bash,
PowerShell,
Cmd,
}
impl ShellKind {
/// 根据平台检测默认 shell
pub fn detect() -> Self {
let info = ShellInfo::default();
match info.executable {
"bash" => ShellKind::Bash,
"powershell" => ShellKind::PowerShell,
"cmd" => ShellKind::Cmd,
_ => ShellKind::Bash, // fallback
}
}
/// 从 ShellInfo 获取 ShellKind
pub fn from_info(info: &ShellInfo) -> Self {
match info.executable {
"bash" => ShellKind::Bash,
"powershell" => ShellKind::PowerShell,
"cmd" => ShellKind::Cmd,
_ => ShellKind::Bash,
}
}
/// 获取对应的 ShellInfo
pub fn to_info(&self) -> ShellInfo {
match self {
ShellKind::Bash => ShellInfo {
name: "bash",
executable: "bash",
args: &["-c"],
},
ShellKind::PowerShell => ShellInfo {
name: "shell",
executable: "powershell",
args: &["-Command"],
},
ShellKind::Cmd => ShellInfo {
name: "shell",
executable: "cmd",
args: &["/C"],
},
}
}
/// Shell 可执行文件名
pub fn executable(&self) -> &'static str {
self.to_info().executable
}
/// 执行命令所需的参数
pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> {
let info = self.to_info();
info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect()
}
/// 工具名称
pub fn tool_name(&self) -> &'static str {
self.to_info().name
}
/// 工具描述
pub fn tool_description(&self) -> &'static str {
match self {
ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.",
ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.",
ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.",
}
}
}
pub struct BashTool {
timeout_secs: u64,
working_dir: Option<String>,
deny_patterns: Vec<String>,
shell: ShellKind,
}
impl BashTool {
@ -29,12 +110,8 @@ impl BashTool {
Self {
timeout_secs: 60,
working_dir: None,
deny_patterns: vec![
r"\brm\s+-[rf]{1,2}\b".to_string(),
r"\bdel\s+/[fq]\b".to_string(),
r"\brmdir\s+/s\b".to_string(),
r":\(\)\s*\{.*\};\s*:".to_string(),
],
deny_patterns: dangerous_command_patterns(),
shell: ShellKind::detect(),
}
}
@ -48,6 +125,11 @@ impl BashTool {
self
}
pub fn with_shell(mut self, shell: ShellKind) -> Self {
self.shell = shell;
self
}
fn guard_command(&self, command: &str) -> Option<String> {
let lower = command.to_lowercase();
for pattern in &self.deny_patterns {
@ -138,11 +220,11 @@ impl Default for BashTool {
#[async_trait]
impl Tool for BashTool {
fn name(&self) -> &str {
"bash"
self.shell.tool_name()
}
fn description(&self) -> &str {
"Execute a bash shell command and return its output. Use with caution."
self.shell.tool_description()
}
fn parameters_schema(&self) -> serde_json::Value {
@ -235,8 +317,8 @@ impl BashTool {
timeout_secs: u64,
interactive: bool,
) -> Result<String, String> {
let mut cmd = Command::new("bash");
cmd.args(["-c", command])
let mut cmd = Command::new(self.shell.executable());
cmd.args(self.shell.command_args(command))
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(cwd);
@ -358,8 +440,13 @@ mod tests {
#[tokio::test]
async fn test_simple_command() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Write-Output 'Hello World'"
} else {
"echo 'Hello World'"
};
let result = tool
.execute(json!({ "command": "echo 'Hello World'" }))
.execute(json!({ "command": command }))
.await
.unwrap();
@ -370,7 +457,12 @@ mod tests {
#[tokio::test]
async fn test_pwd_command() {
let tool = BashTool::new();
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
let command = if cfg!(target_os = "windows") {
"Get-Location"
} else {
"pwd"
};
let result = tool.execute(json!({ "command": command })).await.unwrap();
assert!(result.success);
}
@ -378,8 +470,14 @@ mod tests {
#[tokio::test]
async fn test_ls_command() {
let tool = BashTool::new();
let temp_dir = std::env::temp_dir();
let command = if cfg!(target_os = "windows") {
format!("Get-ChildItem {}", temp_dir.display())
} else {
format!("ls -la {}", temp_dir.display())
};
let result = tool
.execute(json!({ "command": "ls -la /tmp" }))
.execute(json!({ "command": command }))
.await
.unwrap();
@ -389,8 +487,22 @@ mod tests {
#[tokio::test]
async fn test_dangerous_rm() {
let tool = BashTool::new();
// 测试 Unix 危险命令模式
let result = tool
.execute(json!({ "command": "rm -rf /" }))
.execute(json!({ "command": "rm -rf /some/path" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("blocked"));
}
#[tokio::test]
async fn test_dangerous_windows_commands() {
let tool = BashTool::new();
// 测试 Windows del 命令模式(正则应该匹配)
let result = tool
.execute(json!({ "command": "del /f /q file.txt" }))
.await
.unwrap();
@ -422,9 +534,14 @@ mod tests {
#[tokio::test]
async fn test_timeout() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Start-Sleep -Seconds 10"
} else {
"sleep 10"
};
let result = tool
.execute(json!({
"command": "sleep 10",
"command": command,
"timeout": 1
}))
.await
@ -437,17 +554,22 @@ mod tests {
#[tokio::test]
async fn test_pending_user_action_detection() {
let tool = BashTool::new();
let command = if cfg!(target_os = "windows") {
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
} else {
"printf 'waiting for authorization'; sleep 10"
};
let result = tool
.execute(json!({
"command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10",
"timeout": 1
"command": command,
"timeout": 1,
"interactive": true
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
assert!(result.output.contains("等待用户授权"));
}
#[test]
@ -460,4 +582,28 @@ mod tests {
assert!(output.contains("chars truncated"));
assert!(output.is_char_boundary(output.len()));
}
#[test]
fn test_shell_kind_detect() {
let shell = ShellKind::detect();
if cfg!(target_os = "windows") {
assert_eq!(shell, ShellKind::PowerShell);
} else {
assert_eq!(shell, ShellKind::Bash);
}
}
#[test]
fn test_shell_kind_executable() {
assert_eq!(ShellKind::Bash.executable(), "bash");
assert_eq!(ShellKind::PowerShell.executable(), "powershell");
assert_eq!(ShellKind::Cmd.executable(), "cmd");
}
#[test]
fn test_shell_kind_command_args() {
assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]);
assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]);
assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]);
}
}

View File

@ -204,8 +204,11 @@ mod tests {
#[tokio::test]
async fn test_write_missing_content() {
let tool = FileWriteTool::new();
// 使用临时目录确保跨平台兼容
let temp_dir = tempfile::tempdir().unwrap();
let test_path = temp_dir.path().join("test.txt");
let result = tool
.execute(json!({ "path": "/tmp/test.txt" }))
.execute(json!({ "path": test_path.to_str().unwrap() }))
.await
.unwrap();

View File

@ -349,6 +349,7 @@ mod tests {
struct HomeDirGuard {
previous: Option<OsString>,
previous_userprofile: Option<OsString>,
}
impl CurrentDirGuard {
@ -367,23 +368,32 @@ mod tests {
impl HomeDirGuard {
fn enter(path: &Path) -> Self {
let previous = std::env::var_os("HOME");
let home_backup = std::env::var_os("HOME");
let userprofile_backup = std::env::var_os("USERPROFILE");
unsafe {
std::env::set_var("HOME", path);
std::env::set_var("USERPROFILE", path);
}
Self {
previous: home_backup,
previous_userprofile: userprofile_backup,
}
Self { previous }
}
}
impl Drop for HomeDirGuard {
fn drop(&mut self) {
unsafe {
match &self.previous {
Some(value) => unsafe {
std::env::set_var("HOME", value);
},
None => unsafe {
std::env::remove_var("HOME");
},
Some(value) => std::env::set_var("HOME", value),
None => std::env::remove_var("HOME"),
}
match &self.previous_userprofile {
Some(value) => std::env::set_var("USERPROFILE", value),
None => std::env::remove_var("USERPROFILE"),
}
}
}
}