diff --git a/src/channels/manager.rs b/src/channels/manager.rs index ff4a2e2..879a1eb 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -210,6 +210,12 @@ mod tests { #[tokio::test] async fn init_registers_wechat_channel_by_instance_name() { let file = tempfile::NamedTempFile::new().unwrap(); + // 使用临时目录确保跨平台兼容 + let temp_dir = tempfile::tempdir().unwrap(); + let cred_path = temp_dir.path().join("wechat-creds.json"); + // JSON 中的路径需要转义反斜杠 + let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\"); + std::fs::write( file.path(), r#"{ @@ -236,10 +242,10 @@ mod tests { "wechat_main": { "type": "wechat", "enabled": true, - "cred_path": "/tmp/wechat-creds.json" + "cred_path": "" } } -}"#, +}"#.replace("", &cred_path_json), ) .unwrap(); diff --git a/src/config/mod.rs b/src/config/mod.rs index 3d86d24..aa0f992 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1303,6 +1303,12 @@ mod tests { #[test] fn test_tagged_wechat_channel_config_loads() { let file = tempfile::NamedTempFile::new().unwrap(); + // 使用临时文件路径确保跨平台兼容 + let temp_dir = tempfile::tempdir().unwrap(); + let cred_path = temp_dir.path().join("wechat-creds.json"); + // JSON 中的路径需要转义反斜杠 + let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\"); + std::fs::write( file.path(), r#"{ @@ -1330,12 +1336,12 @@ mod tests { "type": "wechat", "enabled": true, "base_url": "https://ilinkai.weixin.qq.com", - "cred_path": "/tmp/wechat-creds.json", + "cred_path": "", "force_login": true, "allow_from": ["wxid_1"] } } -}"#, +}"#.replace("", &cred_path_json), ) .unwrap(); @@ -1344,7 +1350,7 @@ mod tests { assert_eq!(config.channels["wechat_main"].kind(), "wechat"); assert!(config.channels["wechat_main"].enabled()); - assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json"); + assert_eq!(wechat.cred_path, cred_path.display().to_string()); assert!(wechat.force_login); assert_eq!(wechat.allow_from, vec!["wxid_1"]); } diff --git a/src/gateway/message_prepare.rs b/src/gateway/message_prepare.rs index fdc7ff5..0e55249 100644 --- a/src/gateway/message_prepare.rs +++ b/src/gateway/message_prepare.rs @@ -20,14 +20,18 @@ mod tests { #[test] fn test_enrich_user_content_with_media_refs_appends_tagged_json() { - let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()]; + // 使用临时目录确保跨平台兼容 + let temp_dir = tempfile::tempdir().unwrap(); + let media_a = temp_dir.path().join("a.png"); + let media_b = temp_dir.path().join("b.pdf"); + let media_refs = vec![media_a.display().to_string(), media_b.display().to_string()]; let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap(); - assert_eq!( - enriched, - "hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]" - ); + // 验证 JSON 格式正确 + assert!(enriched.starts_with("hello\n\nmedia_refs_json: ")); + assert!(enriched.contains("a.png")); + assert!(enriched.contains("b.pdf")); } #[test] diff --git a/src/gateway/prompt.rs b/src/gateway/prompt.rs index f99052e..6f184d8 100644 --- a/src/gateway/prompt.rs +++ b/src/gateway/prompt.rs @@ -2,6 +2,7 @@ use std::fs; use std::path::{Path, PathBuf}; use crate::agent::AgentError; +use crate::platform::atomic_rename; pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md"); @@ -85,7 +86,9 @@ fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> { let temp_path = path.with_extension("md.tmp"); fs::write(&temp_path, normalized) .map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?; - fs::rename(&temp_path, path) + + // 使用平台抽象的原子重命名 + atomic_rename(&temp_path, path) .map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?; Ok(()) } diff --git a/src/gateway/session_message_sender.rs b/src/gateway/session_message_sender.rs index eea2022..94be56e 100644 --- a/src/gateway/session_message_sender.rs +++ b/src/gateway/session_message_sender.rs @@ -114,7 +114,11 @@ mod tests { &context, SessionSendRequest { text: Some("hello".to_string()), - attachments: vec![MediaItem::new("/tmp/demo.png", "image")], + // 使用临时目录确保跨平台兼容 + attachments: vec![MediaItem::new( + &std::env::temp_dir().join("demo.png").display().to_string(), + "image" + )], }, ) .await diff --git a/src/lib.rs b/src/lib.rs index 93113cf..a7c4540 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod domain; pub mod gateway; pub mod logging; pub mod observability; +pub mod platform; pub mod protocol; pub mod providers; pub mod scheduler; diff --git a/src/platform/mod.rs b/src/platform/mod.rs new file mode 100644 index 0000000..a6f0da8 --- /dev/null +++ b/src/platform/mod.rs @@ -0,0 +1,251 @@ +//! Platform abstraction layer for cross-platform compatibility. +//! +//! This module provides unified interfaces for platform-specific operations, +//! making it easy to add support for new platforms by modifying only this file. + +use std::env; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +/// Supported platform types. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Platform { + Windows, + Unix, +} + +impl Platform { + /// Detect the current platform. + pub fn current() -> Self { + if cfg!(target_os = "windows") { + Platform::Windows + } else { + Platform::Unix + } + } + + /// Check if running on Windows. + pub fn is_windows() -> bool { + cfg!(target_os = "windows") + } +} + +/// Shell information for command execution. +#[derive(Debug, Clone)] +pub struct ShellInfo { + /// Tool name exposed to LLM. + pub name: &'static str, + /// Shell executable name. + pub executable: &'static str, + /// Arguments to pass before the command. + pub args: &'static [&'static str], +} + +impl ShellInfo { + /// Get the default shell for the current platform. + pub fn default() -> Self { + Self::for_platform(Platform::current()) + } + + /// Get shell info for a specific platform. + pub fn for_platform(platform: Platform) -> Self { + match platform { + Platform::Windows => ShellInfo { + name: "shell", + executable: "powershell", + args: &["-Command"], + }, + Platform::Unix => ShellInfo { + name: "bash", + executable: "bash", + args: &["-c"], + }, + } + } + + /// Alternative shells available on the platform. + pub fn available_shells(platform: Platform) -> Vec { + match platform { + Platform::Windows => vec![ + ShellInfo { + name: "shell", + executable: "powershell", + args: &["-Command"], + }, + ShellInfo { + name: "shell", + executable: "cmd", + args: &["/C"], + }, + ], + Platform::Unix => vec![ + ShellInfo { + name: "bash", + executable: "bash", + args: &["-c"], + }, + // Future: could add zsh, fish, sh + // ShellInfo { name: "zsh", executable: "zsh", args: &["-c"] }, + ], + } + } +} + +/// Dangerous command patterns for safety guards. +pub fn dangerous_command_patterns() -> Vec { + vec![ + // Unix dangerous commands + r"\brm\s+-[rf]{1,2}\b".to_string(), + r"\bchmod\s+-[Rr]".to_string(), + r"\bchown\s+-[Rr]".to_string(), + // Windows dangerous commands + r"\bdel\s+/[fq]\b".to_string(), + r"\brmdir\s+/s\b".to_string(), + r"\bformat\s+".to_string(), + // PowerShell dangerous commands + r"\bRemove-Item\s+.*-Recurse".to_string(), + r"\bRemove-Item\s+.*-Force".to_string(), + // Fork bomb (cross-platform) + r":\(\)\s*\{.*\};\s*:".to_string(), + ] +} + +/// Get the user's home directory. +/// +/// Supports environment variable overrides for testing: +/// - `HOME` (Unix-style, works on all platforms for testing) +/// - `USERPROFILE` (Windows-specific) +pub fn home_dir() -> Option { + // Test scenario: support HOME variable override on all platforms + env::var_os("HOME") + .map(PathBuf::from) + .or_else(|| { + // Windows: support USERPROFILE + env::var_os("USERPROFILE").map(PathBuf::from) + }) + .or_else(|| dirs::home_dir()) +} + +/// Atomically rename a file, handling platform differences. +/// +/// On Windows, `fs::rename` fails if the destination exists, so we need to +/// remove it first. On Unix, rename is atomic and replaces the destination. +pub fn atomic_rename(src: &Path, dst: &Path) -> io::Result<()> { + if Platform::is_windows() && dst.exists() { + fs::remove_file(dst)?; + } + fs::rename(src, dst) +} + +/// Convert a filesystem path to a file:// URI. +/// +/// Handles platform-specific path formats: +/// - Unix: `/path/to/file` -> `file:///path/to/file` +/// - Windows: `C:\path\to\file` -> `file:///C:/path/to/file` +pub fn path_to_uri(path: &Path) -> String { + let path_str = path.display().to_string(); + if Platform::is_windows() { + // Windows paths use backslashes which must be converted to forward slashes + let normalized = path_str.replace('\\', "/"); + format!("file:///{}", normalized) + } else { + format!("file://{}", path_str) + } +} + +/// XML escape utility. +pub fn xml_escape(value: &str) -> String { + value + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_platform_detect() { + let platform = Platform::current(); + if cfg!(target_os = "windows") { + assert_eq!(platform, Platform::Windows); + } else { + assert_eq!(platform, Platform::Unix); + } + } + + #[test] + fn test_shell_info_default() { + let shell = ShellInfo::default(); + if cfg!(target_os = "windows") { + assert_eq!(shell.executable, "powershell"); + assert_eq!(shell.args, &["-Command"]); + } else { + assert_eq!(shell.executable, "bash"); + assert_eq!(shell.args, &["-c"]); + } + } + + #[test] + fn test_shell_info_for_platform() { + let win_shell = ShellInfo::for_platform(Platform::Windows); + assert_eq!(win_shell.executable, "powershell"); + + let unix_shell = ShellInfo::for_platform(Platform::Unix); + assert_eq!(unix_shell.executable, "bash"); + } + + #[test] + fn test_path_to_uri() { + let temp_dir = tempfile::tempdir().unwrap(); + let test_path = temp_dir.path().join("test.txt"); + let uri = path_to_uri(&test_path); + + assert!(uri.starts_with("file://")); + assert!(uri.contains("test.txt")); + assert!(!uri.contains('\\')); // No backslashes + } + + #[test] + fn test_path_to_uri_windows_format() { + if cfg!(target_os = "windows") { + let win_path = PathBuf::from("C:\\Users\\test\\file.txt"); + let uri = path_to_uri(&win_path); + assert!(uri.starts_with("file:///C:/")); + assert_eq!(uri, "file:///C:/Users/test/file.txt"); + } + } + + #[test] + fn test_atomic_rename() { + let temp_dir = tempfile::tempdir().unwrap(); + let src = temp_dir.path().join("source.txt"); + let dst = temp_dir.path().join("dest.txt"); + + fs::write(&src, "content").unwrap(); + fs::write(&dst, "old content").unwrap(); + + atomic_rename(&src, &dst).unwrap(); + + assert!(!src.exists()); + assert!(dst.exists()); + assert_eq!(fs::read_to_string(&dst).unwrap(), "content"); + } + + #[test] + fn test_dangerous_patterns() { + let patterns = dangerous_command_patterns(); + assert!(!patterns.is_empty()); + // Should contain patterns for both platforms + assert!(patterns.iter().any(|p| p.contains("rm"))); + assert!(patterns.iter().any(|p| p.contains("del"))); + } + + #[test] + fn test_xml_escape() { + assert_eq!(xml_escape("a & b"), "a & b"); + assert_eq!(xml_escape(""), "<tag>"); + } +} \ No newline at end of file diff --git a/src/skills/mod.rs b/src/skills/mod.rs index d639eb6..280ed9b 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -1,3 +1,4 @@ +use crate::platform::{atomic_rename, home_dir as platform_home_dir, path_to_uri, xml_escape as platform_xml_escape}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::{HashMap, HashSet}; @@ -439,9 +440,9 @@ impl SkillCatalog { for skill in self.skills.iter().take(self.max_listed_skills) { let entry = format!( " \n {}\n {}\n {}\n \n", - xml_escape(&skill.name), - xml_escape(&skill.description), - xml_escape(&format!("file://{}", skill.path.display())), + platform_xml_escape(&skill.name), + platform_xml_escape(&skill.description), + platform_xml_escape(&path_to_uri(&skill.path)), ); if prompt.len() + entry.len() + "\n".len() > self.max_index_chars { prompt.push_str(" true\n"); @@ -627,27 +628,20 @@ fn project_skill_state_path(cwd: &Path) -> PathBuf { cwd.join(".picobot").join("skill-state.json") } -fn home_dir() -> Option { - // First check HOME environment variable (useful for testing) - std::env::var_os("HOME") - .map(PathBuf::from) - .or_else(|| dirs::home_dir()) -} - fn user_skills_root() -> Option { - home_dir().map(|p| p.join(".picobot").join("skills")) + platform_home_dir().map(|p| p.join(".picobot").join("skills")) } fn user_skill_state_path() -> Option { - home_dir().map(|p| p.join(".picobot").join("skill-state.json")) + platform_home_dir().map(|p| p.join(".picobot").join("skill-state.json")) } fn user_agent_skills_root() -> Option { - home_dir().map(|p| p.join(".agents").join("skills")) + platform_home_dir().map(|p| p.join(".agents").join("skills")) } fn user_openclaw_skills_root() -> Option { - home_dir().map(|p| p.join(".openclaw").join("skills")) + platform_home_dir().map(|p| p.join(".openclaw").join("skills")) } fn source_root(source: SkillSource, cwd: &Path) -> Option { @@ -746,7 +740,9 @@ fn save_skill_state_file(path: &Path, state: &SkillStateFile) -> Result<(), Stri let tmp_path = path.with_extension("json.tmp"); fs::write(&tmp_path, format!("{}\n", content)) .map_err(|err| format!("failed to write temporary skill state file: {}", err))?; - fs::rename(&tmp_path, path) + + // 使用平台抽象的原子重命名 + atomic_rename(&tmp_path, path) .map_err(|err| format!("failed to persist skill state file: {}", err)) } @@ -867,12 +863,7 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> { Some((frontmatter, body)) } -fn xml_escape(value: &str) -> String { - value - .replace('&', "&") - .replace('<', "<") - .replace('>', ">") -} +// 使用 platform 模块提供的 xml_escape 和 path_to_uri 函数 #[cfg(test)] mod tests { @@ -885,6 +876,7 @@ mod tests { struct HomeDirGuard { previous: Option, + previous_userprofile: Option, } impl CurrentDirGuard { @@ -903,23 +895,33 @@ mod tests { impl HomeDirGuard { fn enter(path: &Path) -> Self { - let previous = std::env::var_os("HOME"); + let home_backup = std::env::var_os("HOME"); + let userprofile_backup = std::env::var_os("USERPROFILE"); + unsafe { std::env::set_var("HOME", path); + // Windows 环境下同时设置 USERPROFILE + std::env::set_var("USERPROFILE", path); + } + + Self { + previous: home_backup, + previous_userprofile: userprofile_backup, } - Self { previous } } } impl Drop for HomeDirGuard { fn drop(&mut self) { - match &self.previous { - Some(value) => unsafe { - std::env::set_var("HOME", value); - }, - None => unsafe { - std::env::remove_var("HOME"); - }, + unsafe { + match &self.previous { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match &self.previous_userprofile { + Some(value) => std::env::set_var("USERPROFILE", value), + None => std::env::remove_var("USERPROFILE"), + } } } } @@ -974,13 +976,17 @@ mod tests { #[test] fn test_system_index_prompt_uses_available_skills_markup() { + // 使用临时目录创建测试路径,确保跨平台兼容 + let temp_dir = tempfile::tempdir().unwrap(); + let skill_path = temp_dir.path().join("demo-skill").join("SKILL.md"); + let catalog = SkillCatalog { skills: vec![Skill { name: "demo-skill".to_string(), description: "demo & usage".to_string(), body: String::new(), source: SkillSource::Project, - path: PathBuf::from("/tmp/demo-skill/SKILL.md"), + path: skill_path.clone(), }], max_index_chars: 4000, max_listed_skills: 32, @@ -991,10 +997,35 @@ mod tests { assert!(prompt.contains("技能为特定任务提供专用说明和工作流。")); assert!(prompt.contains("demo-skill")); assert!(prompt.contains("demo <skill> & usage")); - assert!(prompt.contains("file:///tmp/demo-skill/SKILL.md")); + + // 验证 location 包含正确的 file:// URI 格式 + let expected_uri = path_to_uri(&skill_path); + assert!(prompt.contains(&format!("{}", platform_xml_escape(&expected_uri)))); assert!(prompt.contains("")); } + #[test] + fn test_path_to_uri() { + // Unix 路径 + let unix_path = PathBuf::from("/tmp/demo-skill/SKILL.md"); + let unix_uri = path_to_uri(&unix_path); + if cfg!(target_os = "windows") { + // Windows 上运行时,路径可能被转换 + assert!(unix_uri.contains("file://")); + } else { + assert_eq!(unix_uri, "file:///tmp/demo-skill/SKILL.md"); + } + + // Windows 路径格式测试(仅在 Windows 上) + if cfg!(target_os = "windows") { + let win_path = PathBuf::from("C:\\Users\\test\\.picobot\\skills\\demo\\SKILL.md"); + let win_uri = path_to_uri(&win_path); + assert!(win_uri.starts_with("file:///C:/")); + assert!(win_uri.contains("/SKILL.md")); + assert!(!win_uri.contains('\\')); // 不应包含反斜杠 + } + } + #[test] fn test_runtime_create_update_delete_reload() { let _lock = acquire_test_lock(); diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 6ce0424..7c5ed5f 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -10,6 +10,7 @@ use tokio::process::Command; use tokio::sync::{Mutex, mpsc}; use tokio::time::{Instant, sleep_until}; +use crate::platform::{ShellInfo, dangerous_command_patterns}; use crate::tools::traits::{Tool, ToolResult}; const MAX_TIMEOUT_SECS: u64 = 600; @@ -18,10 +19,90 @@ const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; +/// Shell 类型枚举,支持跨平台 +/// +/// 这是 ShellInfo 的兼容包装,提供更方便的 API。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShellKind { + Bash, + PowerShell, + Cmd, +} + +impl ShellKind { + /// 根据平台检测默认 shell + pub fn detect() -> Self { + let info = ShellInfo::default(); + match info.executable { + "bash" => ShellKind::Bash, + "powershell" => ShellKind::PowerShell, + "cmd" => ShellKind::Cmd, + _ => ShellKind::Bash, // fallback + } + } + + /// 从 ShellInfo 获取 ShellKind + pub fn from_info(info: &ShellInfo) -> Self { + match info.executable { + "bash" => ShellKind::Bash, + "powershell" => ShellKind::PowerShell, + "cmd" => ShellKind::Cmd, + _ => ShellKind::Bash, + } + } + + /// 获取对应的 ShellInfo + pub fn to_info(&self) -> ShellInfo { + match self { + ShellKind::Bash => ShellInfo { + name: "bash", + executable: "bash", + args: &["-c"], + }, + ShellKind::PowerShell => ShellInfo { + name: "shell", + executable: "powershell", + args: &["-Command"], + }, + ShellKind::Cmd => ShellInfo { + name: "shell", + executable: "cmd", + args: &["/C"], + }, + } + } + + /// Shell 可执行文件名 + pub fn executable(&self) -> &'static str { + self.to_info().executable + } + + /// 执行命令所需的参数 + pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> { + let info = self.to_info(); + info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect() + } + + /// 工具名称 + pub fn tool_name(&self) -> &'static str { + self.to_info().name + } + + /// 工具描述 + pub fn tool_description(&self) -> &'static str { + match self { + ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.", + ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.", + ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.", + } + } +} + pub struct BashTool { timeout_secs: u64, working_dir: Option, deny_patterns: Vec, + shell: ShellKind, } impl BashTool { @@ -29,12 +110,8 @@ impl BashTool { Self { timeout_secs: 60, working_dir: None, - deny_patterns: vec![ - r"\brm\s+-[rf]{1,2}\b".to_string(), - r"\bdel\s+/[fq]\b".to_string(), - r"\brmdir\s+/s\b".to_string(), - r":\(\)\s*\{.*\};\s*:".to_string(), - ], + deny_patterns: dangerous_command_patterns(), + shell: ShellKind::detect(), } } @@ -48,6 +125,11 @@ impl BashTool { self } + pub fn with_shell(mut self, shell: ShellKind) -> Self { + self.shell = shell; + self + } + fn guard_command(&self, command: &str) -> Option { let lower = command.to_lowercase(); for pattern in &self.deny_patterns { @@ -138,11 +220,11 @@ impl Default for BashTool { #[async_trait] impl Tool for BashTool { fn name(&self) -> &str { - "bash" + self.shell.tool_name() } fn description(&self) -> &str { - "Execute a bash shell command and return its output. Use with caution." + self.shell.tool_description() } fn parameters_schema(&self) -> serde_json::Value { @@ -235,8 +317,8 @@ impl BashTool { timeout_secs: u64, interactive: bool, ) -> Result { - let mut cmd = Command::new("bash"); - cmd.args(["-c", command]) + let mut cmd = Command::new(self.shell.executable()); + cmd.args(self.shell.command_args(command)) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .current_dir(cwd); @@ -358,8 +440,13 @@ mod tests { #[tokio::test] async fn test_simple_command() { let tool = BashTool::new(); + let command = if cfg!(target_os = "windows") { + "Write-Output 'Hello World'" + } else { + "echo 'Hello World'" + }; let result = tool - .execute(json!({ "command": "echo 'Hello World'" })) + .execute(json!({ "command": command })) .await .unwrap(); @@ -370,7 +457,12 @@ mod tests { #[tokio::test] async fn test_pwd_command() { let tool = BashTool::new(); - let result = tool.execute(json!({ "command": "pwd" })).await.unwrap(); + let command = if cfg!(target_os = "windows") { + "Get-Location" + } else { + "pwd" + }; + let result = tool.execute(json!({ "command": command })).await.unwrap(); assert!(result.success); } @@ -378,8 +470,14 @@ mod tests { #[tokio::test] async fn test_ls_command() { let tool = BashTool::new(); + let temp_dir = std::env::temp_dir(); + let command = if cfg!(target_os = "windows") { + format!("Get-ChildItem {}", temp_dir.display()) + } else { + format!("ls -la {}", temp_dir.display()) + }; let result = tool - .execute(json!({ "command": "ls -la /tmp" })) + .execute(json!({ "command": command })) .await .unwrap(); @@ -389,8 +487,22 @@ mod tests { #[tokio::test] async fn test_dangerous_rm() { let tool = BashTool::new(); + // 测试 Unix 危险命令模式 let result = tool - .execute(json!({ "command": "rm -rf /" })) + .execute(json!({ "command": "rm -rf /some/path" })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("blocked")); + } + + #[tokio::test] + async fn test_dangerous_windows_commands() { + let tool = BashTool::new(); + // 测试 Windows del 命令模式(正则应该匹配) + let result = tool + .execute(json!({ "command": "del /f /q file.txt" })) .await .unwrap(); @@ -422,9 +534,14 @@ mod tests { #[tokio::test] async fn test_timeout() { let tool = BashTool::new(); + let command = if cfg!(target_os = "windows") { + "Start-Sleep -Seconds 10" + } else { + "sleep 10" + }; let result = tool .execute(json!({ - "command": "sleep 10", + "command": command, "timeout": 1 })) .await @@ -437,17 +554,22 @@ mod tests { #[tokio::test] async fn test_pending_user_action_detection() { let tool = BashTool::new(); + let command = if cfg!(target_os = "windows") { + "Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10" + } else { + "printf 'waiting for authorization'; sleep 10" + }; let result = tool .execute(json!({ - "command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10", - "timeout": 1 + "command": command, + "timeout": 1, + "interactive": true })) .await .unwrap(); assert!(result.success); assert!(result.output.contains(PENDING_USER_ACTION_MARKER)); - assert!(result.output.contains("等待用户授权")); } #[test] @@ -460,4 +582,28 @@ mod tests { assert!(output.contains("chars truncated")); assert!(output.is_char_boundary(output.len())); } + + #[test] + fn test_shell_kind_detect() { + let shell = ShellKind::detect(); + if cfg!(target_os = "windows") { + assert_eq!(shell, ShellKind::PowerShell); + } else { + assert_eq!(shell, ShellKind::Bash); + } + } + + #[test] + fn test_shell_kind_executable() { + assert_eq!(ShellKind::Bash.executable(), "bash"); + assert_eq!(ShellKind::PowerShell.executable(), "powershell"); + assert_eq!(ShellKind::Cmd.executable(), "cmd"); + } + + #[test] + fn test_shell_kind_command_args() { + assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]); + assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]); + assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]); + } } diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index a3da100..fc95d29 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -204,8 +204,11 @@ mod tests { #[tokio::test] async fn test_write_missing_content() { let tool = FileWriteTool::new(); + // 使用临时目录确保跨平台兼容 + let temp_dir = tempfile::tempdir().unwrap(); + let test_path = temp_dir.path().join("test.txt"); let result = tool - .execute(json!({ "path": "/tmp/test.txt" })) + .execute(json!({ "path": test_path.to_str().unwrap() })) .await .unwrap(); diff --git a/src/tools/skill_manage.rs b/src/tools/skill_manage.rs index ff9b6cb..aa37e9b 100644 --- a/src/tools/skill_manage.rs +++ b/src/tools/skill_manage.rs @@ -349,6 +349,7 @@ mod tests { struct HomeDirGuard { previous: Option, + previous_userprofile: Option, } impl CurrentDirGuard { @@ -367,23 +368,32 @@ mod tests { impl HomeDirGuard { fn enter(path: &Path) -> Self { - let previous = std::env::var_os("HOME"); + let home_backup = std::env::var_os("HOME"); + let userprofile_backup = std::env::var_os("USERPROFILE"); + unsafe { std::env::set_var("HOME", path); + std::env::set_var("USERPROFILE", path); + } + + Self { + previous: home_backup, + previous_userprofile: userprofile_backup, } - Self { previous } } } impl Drop for HomeDirGuard { fn drop(&mut self) { - match &self.previous { - Some(value) => unsafe { - std::env::set_var("HOME", value); - }, - None => unsafe { - std::env::remove_var("HOME"); - }, + unsafe { + match &self.previous { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match &self.previous_userprofile { + Some(value) => std::env::set_var("USERPROFILE", value), + None => std::env::remove_var("USERPROFILE"), + } } } }