feat: 添加平台抽象层,支持跨平台兼容性;更新多个模块以使用临时目录和平台特定路径
This commit is contained in:
parent
e0a7f67dab
commit
f4758f8513
@ -210,6 +210,12 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn init_registers_wechat_channel_by_instance_name() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
// 使用临时目录确保跨平台兼容
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let cred_path = temp_dir.path().join("wechat-creds.json");
|
||||
// JSON 中的路径需要转义反斜杠
|
||||
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
|
||||
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
@ -236,10 +242,10 @@ mod tests {
|
||||
"wechat_main": {
|
||||
"type": "wechat",
|
||||
"enabled": true,
|
||||
"cred_path": "/tmp/wechat-creds.json"
|
||||
"cred_path": "<CRED_PATH>"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
}"#.replace("<CRED_PATH>", &cred_path_json),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@ -1303,6 +1303,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_tagged_wechat_channel_config_loads() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
// 使用临时文件路径确保跨平台兼容
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let cred_path = temp_dir.path().join("wechat-creds.json");
|
||||
// JSON 中的路径需要转义反斜杠
|
||||
let cred_path_json = cred_path.display().to_string().replace('\\', "\\\\");
|
||||
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
@ -1330,12 +1336,12 @@ mod tests {
|
||||
"type": "wechat",
|
||||
"enabled": true,
|
||||
"base_url": "https://ilinkai.weixin.qq.com",
|
||||
"cred_path": "/tmp/wechat-creds.json",
|
||||
"cred_path": "<CRED_PATH>",
|
||||
"force_login": true,
|
||||
"allow_from": ["wxid_1"]
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
}"#.replace("<CRED_PATH>", &cred_path_json),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1344,7 +1350,7 @@ mod tests {
|
||||
|
||||
assert_eq!(config.channels["wechat_main"].kind(), "wechat");
|
||||
assert!(config.channels["wechat_main"].enabled());
|
||||
assert_eq!(wechat.cred_path, "/tmp/wechat-creds.json");
|
||||
assert_eq!(wechat.cred_path, cred_path.display().to_string());
|
||||
assert!(wechat.force_login);
|
||||
assert_eq!(wechat.allow_from, vec!["wxid_1"]);
|
||||
}
|
||||
|
||||
@ -20,14 +20,18 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||
// 使用临时目录确保跨平台兼容
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let media_a = temp_dir.path().join("a.png");
|
||||
let media_b = temp_dir.path().join("b.pdf");
|
||||
let media_refs = vec![media_a.display().to_string(), media_b.display().to_string()];
|
||||
|
||||
let enriched = enrich_user_content_with_media_refs("hello", &media_refs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
enriched,
|
||||
"hello\n\nmedia_refs_json: [\"/tmp/a.png\",\"/tmp/b.pdf\"]"
|
||||
);
|
||||
// 验证 JSON 格式正确
|
||||
assert!(enriched.starts_with("hello\n\nmedia_refs_json: "));
|
||||
assert!(enriched.contains("a.png"));
|
||||
assert!(enriched.contains("b.pdf"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -2,6 +2,7 @@ use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::platform::atomic_rename;
|
||||
|
||||
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||
|
||||
@ -85,7 +86,9 @@ fn write_prompt_file(path: &Path, content: &str) -> Result<(), AgentError> {
|
||||
let temp_path = path.with_extension("md.tmp");
|
||||
fs::write(&temp_path, normalized)
|
||||
.map_err(|err| AgentError::Other(format!("write prompt temp file error: {}", err)))?;
|
||||
fs::rename(&temp_path, path)
|
||||
|
||||
// 使用平台抽象的原子重命名
|
||||
atomic_rename(&temp_path, path)
|
||||
.map_err(|err| AgentError::Other(format!("replace prompt file error: {}", err)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -114,7 +114,11 @@ mod tests {
|
||||
&context,
|
||||
SessionSendRequest {
|
||||
text: Some("hello".to_string()),
|
||||
attachments: vec![MediaItem::new("/tmp/demo.png", "image")],
|
||||
// 使用临时目录确保跨平台兼容
|
||||
attachments: vec![MediaItem::new(
|
||||
&std::env::temp_dir().join("demo.png").display().to_string(),
|
||||
"image"
|
||||
)],
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@ -9,6 +9,7 @@ pub mod domain;
|
||||
pub mod gateway;
|
||||
pub mod logging;
|
||||
pub mod observability;
|
||||
pub mod platform;
|
||||
pub mod protocol;
|
||||
pub mod providers;
|
||||
pub mod scheduler;
|
||||
|
||||
251
src/platform/mod.rs
Normal file
251
src/platform/mod.rs
Normal file
@ -0,0 +1,251 @@
|
||||
//! Platform abstraction layer for cross-platform compatibility.
|
||||
//!
|
||||
//! This module provides unified interfaces for platform-specific operations,
|
||||
//! making it easy to add support for new platforms by modifying only this file.
|
||||
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Supported platform types.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Platform {
|
||||
Windows,
|
||||
Unix,
|
||||
}
|
||||
|
||||
impl Platform {
|
||||
/// Detect the current platform.
|
||||
pub fn current() -> Self {
|
||||
if cfg!(target_os = "windows") {
|
||||
Platform::Windows
|
||||
} else {
|
||||
Platform::Unix
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if running on Windows.
|
||||
pub fn is_windows() -> bool {
|
||||
cfg!(target_os = "windows")
|
||||
}
|
||||
}
|
||||
|
||||
/// Shell information for command execution.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShellInfo {
|
||||
/// Tool name exposed to LLM.
|
||||
pub name: &'static str,
|
||||
/// Shell executable name.
|
||||
pub executable: &'static str,
|
||||
/// Arguments to pass before the command.
|
||||
pub args: &'static [&'static str],
|
||||
}
|
||||
|
||||
impl ShellInfo {
|
||||
/// Get the default shell for the current platform.
|
||||
pub fn default() -> Self {
|
||||
Self::for_platform(Platform::current())
|
||||
}
|
||||
|
||||
/// Get shell info for a specific platform.
|
||||
pub fn for_platform(platform: Platform) -> Self {
|
||||
match platform {
|
||||
Platform::Windows => ShellInfo {
|
||||
name: "shell",
|
||||
executable: "powershell",
|
||||
args: &["-Command"],
|
||||
},
|
||||
Platform::Unix => ShellInfo {
|
||||
name: "bash",
|
||||
executable: "bash",
|
||||
args: &["-c"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Alternative shells available on the platform.
|
||||
pub fn available_shells(platform: Platform) -> Vec<ShellInfo> {
|
||||
match platform {
|
||||
Platform::Windows => vec![
|
||||
ShellInfo {
|
||||
name: "shell",
|
||||
executable: "powershell",
|
||||
args: &["-Command"],
|
||||
},
|
||||
ShellInfo {
|
||||
name: "shell",
|
||||
executable: "cmd",
|
||||
args: &["/C"],
|
||||
},
|
||||
],
|
||||
Platform::Unix => vec![
|
||||
ShellInfo {
|
||||
name: "bash",
|
||||
executable: "bash",
|
||||
args: &["-c"],
|
||||
},
|
||||
// Future: could add zsh, fish, sh
|
||||
// ShellInfo { name: "zsh", executable: "zsh", args: &["-c"] },
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dangerous command patterns for safety guards.
|
||||
pub fn dangerous_command_patterns() -> Vec<String> {
|
||||
vec![
|
||||
// Unix dangerous commands
|
||||
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
||||
r"\bchmod\s+-[Rr]".to_string(),
|
||||
r"\bchown\s+-[Rr]".to_string(),
|
||||
// Windows dangerous commands
|
||||
r"\bdel\s+/[fq]\b".to_string(),
|
||||
r"\brmdir\s+/s\b".to_string(),
|
||||
r"\bformat\s+".to_string(),
|
||||
// PowerShell dangerous commands
|
||||
r"\bRemove-Item\s+.*-Recurse".to_string(),
|
||||
r"\bRemove-Item\s+.*-Force".to_string(),
|
||||
// Fork bomb (cross-platform)
|
||||
r":\(\)\s*\{.*\};\s*:".to_string(),
|
||||
]
|
||||
}
|
||||
|
||||
/// Get the user's home directory.
|
||||
///
|
||||
/// Supports environment variable overrides for testing:
|
||||
/// - `HOME` (Unix-style, works on all platforms for testing)
|
||||
/// - `USERPROFILE` (Windows-specific)
|
||||
pub fn home_dir() -> Option<PathBuf> {
|
||||
// Test scenario: support HOME variable override on all platforms
|
||||
env::var_os("HOME")
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| {
|
||||
// Windows: support USERPROFILE
|
||||
env::var_os("USERPROFILE").map(PathBuf::from)
|
||||
})
|
||||
.or_else(|| dirs::home_dir())
|
||||
}
|
||||
|
||||
/// Atomically rename a file, handling platform differences.
|
||||
///
|
||||
/// On Windows, `fs::rename` fails if the destination exists, so we need to
|
||||
/// remove it first. On Unix, rename is atomic and replaces the destination.
|
||||
pub fn atomic_rename(src: &Path, dst: &Path) -> io::Result<()> {
|
||||
if Platform::is_windows() && dst.exists() {
|
||||
fs::remove_file(dst)?;
|
||||
}
|
||||
fs::rename(src, dst)
|
||||
}
|
||||
|
||||
/// Convert a filesystem path to a file:// URI.
|
||||
///
|
||||
/// Handles platform-specific path formats:
|
||||
/// - Unix: `/path/to/file` -> `file:///path/to/file`
|
||||
/// - Windows: `C:\path\to\file` -> `file:///C:/path/to/file`
|
||||
pub fn path_to_uri(path: &Path) -> String {
|
||||
let path_str = path.display().to_string();
|
||||
if Platform::is_windows() {
|
||||
// Windows paths use backslashes which must be converted to forward slashes
|
||||
let normalized = path_str.replace('\\', "/");
|
||||
format!("file:///{}", normalized)
|
||||
} else {
|
||||
format!("file://{}", path_str)
|
||||
}
|
||||
}
|
||||
|
||||
/// XML escape utility.
|
||||
pub fn xml_escape(value: &str) -> String {
|
||||
value
|
||||
.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_platform_detect() {
|
||||
let platform = Platform::current();
|
||||
if cfg!(target_os = "windows") {
|
||||
assert_eq!(platform, Platform::Windows);
|
||||
} else {
|
||||
assert_eq!(platform, Platform::Unix);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_info_default() {
|
||||
let shell = ShellInfo::default();
|
||||
if cfg!(target_os = "windows") {
|
||||
assert_eq!(shell.executable, "powershell");
|
||||
assert_eq!(shell.args, &["-Command"]);
|
||||
} else {
|
||||
assert_eq!(shell.executable, "bash");
|
||||
assert_eq!(shell.args, &["-c"]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_info_for_platform() {
|
||||
let win_shell = ShellInfo::for_platform(Platform::Windows);
|
||||
assert_eq!(win_shell.executable, "powershell");
|
||||
|
||||
let unix_shell = ShellInfo::for_platform(Platform::Unix);
|
||||
assert_eq!(unix_shell.executable, "bash");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_to_uri() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let test_path = temp_dir.path().join("test.txt");
|
||||
let uri = path_to_uri(&test_path);
|
||||
|
||||
assert!(uri.starts_with("file://"));
|
||||
assert!(uri.contains("test.txt"));
|
||||
assert!(!uri.contains('\\')); // No backslashes
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_to_uri_windows_format() {
|
||||
if cfg!(target_os = "windows") {
|
||||
let win_path = PathBuf::from("C:\\Users\\test\\file.txt");
|
||||
let uri = path_to_uri(&win_path);
|
||||
assert!(uri.starts_with("file:///C:/"));
|
||||
assert_eq!(uri, "file:///C:/Users/test/file.txt");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_rename() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let src = temp_dir.path().join("source.txt");
|
||||
let dst = temp_dir.path().join("dest.txt");
|
||||
|
||||
fs::write(&src, "content").unwrap();
|
||||
fs::write(&dst, "old content").unwrap();
|
||||
|
||||
atomic_rename(&src, &dst).unwrap();
|
||||
|
||||
assert!(!src.exists());
|
||||
assert!(dst.exists());
|
||||
assert_eq!(fs::read_to_string(&dst).unwrap(), "content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dangerous_patterns() {
|
||||
let patterns = dangerous_command_patterns();
|
||||
assert!(!patterns.is_empty());
|
||||
// Should contain patterns for both platforms
|
||||
assert!(patterns.iter().any(|p| p.contains("rm")));
|
||||
assert!(patterns.iter().any(|p| p.contains("del")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xml_escape() {
|
||||
assert_eq!(xml_escape("a & b"), "a & b");
|
||||
assert_eq!(xml_escape("<tag>"), "<tag>");
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
use crate::platform::{atomic_rename, home_dir as platform_home_dir, path_to_uri, xml_escape as platform_xml_escape};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -439,9 +440,9 @@ impl SkillCatalog {
|
||||
for skill in self.skills.iter().take(self.max_listed_skills) {
|
||||
let entry = format!(
|
||||
" <skill>\n <name>{}</name>\n <description>{}</description>\n <location>{}</location>\n </skill>\n",
|
||||
xml_escape(&skill.name),
|
||||
xml_escape(&skill.description),
|
||||
xml_escape(&format!("file://{}", skill.path.display())),
|
||||
platform_xml_escape(&skill.name),
|
||||
platform_xml_escape(&skill.description),
|
||||
platform_xml_escape(&path_to_uri(&skill.path)),
|
||||
);
|
||||
if prompt.len() + entry.len() + "</available_skills>\n".len() > self.max_index_chars {
|
||||
prompt.push_str(" <truncated>true</truncated>\n");
|
||||
@ -627,27 +628,20 @@ fn project_skill_state_path(cwd: &Path) -> PathBuf {
|
||||
cwd.join(".picobot").join("skill-state.json")
|
||||
}
|
||||
|
||||
fn home_dir() -> Option<PathBuf> {
|
||||
// First check HOME environment variable (useful for testing)
|
||||
std::env::var_os("HOME")
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| dirs::home_dir())
|
||||
}
|
||||
|
||||
fn user_skills_root() -> Option<PathBuf> {
|
||||
home_dir().map(|p| p.join(".picobot").join("skills"))
|
||||
platform_home_dir().map(|p| p.join(".picobot").join("skills"))
|
||||
}
|
||||
|
||||
fn user_skill_state_path() -> Option<PathBuf> {
|
||||
home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
|
||||
platform_home_dir().map(|p| p.join(".picobot").join("skill-state.json"))
|
||||
}
|
||||
|
||||
fn user_agent_skills_root() -> Option<PathBuf> {
|
||||
home_dir().map(|p| p.join(".agents").join("skills"))
|
||||
platform_home_dir().map(|p| p.join(".agents").join("skills"))
|
||||
}
|
||||
|
||||
fn user_openclaw_skills_root() -> Option<PathBuf> {
|
||||
home_dir().map(|p| p.join(".openclaw").join("skills"))
|
||||
platform_home_dir().map(|p| p.join(".openclaw").join("skills"))
|
||||
}
|
||||
|
||||
fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
|
||||
@ -746,7 +740,9 @@ fn save_skill_state_file(path: &Path, state: &SkillStateFile) -> Result<(), Stri
|
||||
let tmp_path = path.with_extension("json.tmp");
|
||||
fs::write(&tmp_path, format!("{}\n", content))
|
||||
.map_err(|err| format!("failed to write temporary skill state file: {}", err))?;
|
||||
fs::rename(&tmp_path, path)
|
||||
|
||||
// 使用平台抽象的原子重命名
|
||||
atomic_rename(&tmp_path, path)
|
||||
.map_err(|err| format!("failed to persist skill state file: {}", err))
|
||||
}
|
||||
|
||||
@ -867,12 +863,7 @@ fn split_frontmatter(content: &str) -> Option<(&str, &str)> {
|
||||
Some((frontmatter, body))
|
||||
}
|
||||
|
||||
fn xml_escape(value: &str) -> String {
|
||||
value
|
||||
.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
}
|
||||
// 使用 platform 模块提供的 xml_escape 和 path_to_uri 函数
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@ -885,6 +876,7 @@ mod tests {
|
||||
|
||||
struct HomeDirGuard {
|
||||
previous: Option<OsString>,
|
||||
previous_userprofile: Option<OsString>,
|
||||
}
|
||||
|
||||
impl CurrentDirGuard {
|
||||
@ -903,23 +895,33 @@ mod tests {
|
||||
|
||||
impl HomeDirGuard {
|
||||
fn enter(path: &Path) -> Self {
|
||||
let previous = std::env::var_os("HOME");
|
||||
let home_backup = std::env::var_os("HOME");
|
||||
let userprofile_backup = std::env::var_os("USERPROFILE");
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("HOME", path);
|
||||
// Windows 环境下同时设置 USERPROFILE
|
||||
std::env::set_var("USERPROFILE", path);
|
||||
}
|
||||
|
||||
Self {
|
||||
previous: home_backup,
|
||||
previous_userprofile: userprofile_backup,
|
||||
}
|
||||
Self { previous }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HomeDirGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
match &self.previous {
|
||||
Some(value) => unsafe {
|
||||
std::env::set_var("HOME", value);
|
||||
},
|
||||
None => unsafe {
|
||||
std::env::remove_var("HOME");
|
||||
},
|
||||
Some(value) => std::env::set_var("HOME", value),
|
||||
None => std::env::remove_var("HOME"),
|
||||
}
|
||||
match &self.previous_userprofile {
|
||||
Some(value) => std::env::set_var("USERPROFILE", value),
|
||||
None => std::env::remove_var("USERPROFILE"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -974,13 +976,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_system_index_prompt_uses_available_skills_markup() {
|
||||
// 使用临时目录创建测试路径,确保跨平台兼容
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let skill_path = temp_dir.path().join("demo-skill").join("SKILL.md");
|
||||
|
||||
let catalog = SkillCatalog {
|
||||
skills: vec![Skill {
|
||||
name: "demo-skill".to_string(),
|
||||
description: "demo <skill> & usage".to_string(),
|
||||
body: String::new(),
|
||||
source: SkillSource::Project,
|
||||
path: PathBuf::from("/tmp/demo-skill/SKILL.md"),
|
||||
path: skill_path.clone(),
|
||||
}],
|
||||
max_index_chars: 4000,
|
||||
max_listed_skills: 32,
|
||||
@ -991,10 +997,35 @@ mod tests {
|
||||
assert!(prompt.contains("技能为特定任务提供专用说明和工作流。"));
|
||||
assert!(prompt.contains("<name>demo-skill</name>"));
|
||||
assert!(prompt.contains("<description>demo <skill> & usage</description>"));
|
||||
assert!(prompt.contains("<location>file:///tmp/demo-skill/SKILL.md</location>"));
|
||||
|
||||
// 验证 location 包含正确的 file:// URI 格式
|
||||
let expected_uri = path_to_uri(&skill_path);
|
||||
assert!(prompt.contains(&format!("<location>{}</location>", platform_xml_escape(&expected_uri))));
|
||||
assert!(prompt.contains("</available_skills>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_to_uri() {
|
||||
// Unix 路径
|
||||
let unix_path = PathBuf::from("/tmp/demo-skill/SKILL.md");
|
||||
let unix_uri = path_to_uri(&unix_path);
|
||||
if cfg!(target_os = "windows") {
|
||||
// Windows 上运行时,路径可能被转换
|
||||
assert!(unix_uri.contains("file://"));
|
||||
} else {
|
||||
assert_eq!(unix_uri, "file:///tmp/demo-skill/SKILL.md");
|
||||
}
|
||||
|
||||
// Windows 路径格式测试(仅在 Windows 上)
|
||||
if cfg!(target_os = "windows") {
|
||||
let win_path = PathBuf::from("C:\\Users\\test\\.picobot\\skills\\demo\\SKILL.md");
|
||||
let win_uri = path_to_uri(&win_path);
|
||||
assert!(win_uri.starts_with("file:///C:/"));
|
||||
assert!(win_uri.contains("/SKILL.md"));
|
||||
assert!(!win_uri.contains('\\')); // 不应包含反斜杠
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_runtime_create_update_delete_reload() {
|
||||
let _lock = acquire_test_lock();
|
||||
|
||||
@ -10,6 +10,7 @@ use tokio::process::Command;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::time::{Instant, sleep_until};
|
||||
|
||||
use crate::platform::{ShellInfo, dangerous_command_patterns};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||
@ -18,10 +19,90 @@ const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||
const USER_ACTION_HINT: &str =
|
||||
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||
|
||||
/// Shell 类型枚举,支持跨平台
|
||||
///
|
||||
/// 这是 ShellInfo 的兼容包装,提供更方便的 API。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ShellKind {
|
||||
Bash,
|
||||
PowerShell,
|
||||
Cmd,
|
||||
}
|
||||
|
||||
impl ShellKind {
|
||||
/// 根据平台检测默认 shell
|
||||
pub fn detect() -> Self {
|
||||
let info = ShellInfo::default();
|
||||
match info.executable {
|
||||
"bash" => ShellKind::Bash,
|
||||
"powershell" => ShellKind::PowerShell,
|
||||
"cmd" => ShellKind::Cmd,
|
||||
_ => ShellKind::Bash, // fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 ShellInfo 获取 ShellKind
|
||||
pub fn from_info(info: &ShellInfo) -> Self {
|
||||
match info.executable {
|
||||
"bash" => ShellKind::Bash,
|
||||
"powershell" => ShellKind::PowerShell,
|
||||
"cmd" => ShellKind::Cmd,
|
||||
_ => ShellKind::Bash,
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取对应的 ShellInfo
|
||||
pub fn to_info(&self) -> ShellInfo {
|
||||
match self {
|
||||
ShellKind::Bash => ShellInfo {
|
||||
name: "bash",
|
||||
executable: "bash",
|
||||
args: &["-c"],
|
||||
},
|
||||
ShellKind::PowerShell => ShellInfo {
|
||||
name: "shell",
|
||||
executable: "powershell",
|
||||
args: &["-Command"],
|
||||
},
|
||||
ShellKind::Cmd => ShellInfo {
|
||||
name: "shell",
|
||||
executable: "cmd",
|
||||
args: &["/C"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Shell 可执行文件名
|
||||
pub fn executable(&self) -> &'static str {
|
||||
self.to_info().executable
|
||||
}
|
||||
|
||||
/// 执行命令所需的参数
|
||||
pub fn command_args<'a>(&self, command: &'a str) -> Vec<&'a str> {
|
||||
let info = self.to_info();
|
||||
info.args.iter().map(|s| *s).chain(std::iter::once(command)).collect()
|
||||
}
|
||||
|
||||
/// 工具名称
|
||||
pub fn tool_name(&self) -> &'static str {
|
||||
self.to_info().name
|
||||
}
|
||||
|
||||
/// 工具描述
|
||||
pub fn tool_description(&self) -> &'static str {
|
||||
match self {
|
||||
ShellKind::Bash => "Execute a bash shell command and return its output. Use with caution.",
|
||||
ShellKind::PowerShell => "Execute a PowerShell command and return its output. Use with caution.",
|
||||
ShellKind::Cmd => "Execute a cmd shell command and return its output. Use with caution.",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BashTool {
|
||||
timeout_secs: u64,
|
||||
working_dir: Option<String>,
|
||||
deny_patterns: Vec<String>,
|
||||
shell: ShellKind,
|
||||
}
|
||||
|
||||
impl BashTool {
|
||||
@ -29,12 +110,8 @@ impl BashTool {
|
||||
Self {
|
||||
timeout_secs: 60,
|
||||
working_dir: None,
|
||||
deny_patterns: vec![
|
||||
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
||||
r"\bdel\s+/[fq]\b".to_string(),
|
||||
r"\brmdir\s+/s\b".to_string(),
|
||||
r":\(\)\s*\{.*\};\s*:".to_string(),
|
||||
],
|
||||
deny_patterns: dangerous_command_patterns(),
|
||||
shell: ShellKind::detect(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,6 +125,11 @@ impl BashTool {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_shell(mut self, shell: ShellKind) -> Self {
|
||||
self.shell = shell;
|
||||
self
|
||||
}
|
||||
|
||||
fn guard_command(&self, command: &str) -> Option<String> {
|
||||
let lower = command.to_lowercase();
|
||||
for pattern in &self.deny_patterns {
|
||||
@ -138,11 +220,11 @@ impl Default for BashTool {
|
||||
#[async_trait]
|
||||
impl Tool for BashTool {
|
||||
fn name(&self) -> &str {
|
||||
"bash"
|
||||
self.shell.tool_name()
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Execute a bash shell command and return its output. Use with caution."
|
||||
self.shell.tool_description()
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -235,8 +317,8 @@ impl BashTool {
|
||||
timeout_secs: u64,
|
||||
interactive: bool,
|
||||
) -> Result<String, String> {
|
||||
let mut cmd = Command::new("bash");
|
||||
cmd.args(["-c", command])
|
||||
let mut cmd = Command::new(self.shell.executable());
|
||||
cmd.args(self.shell.command_args(command))
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(cwd);
|
||||
@ -358,8 +440,13 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_simple_command() {
|
||||
let tool = BashTool::new();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Write-Output 'Hello World'"
|
||||
} else {
|
||||
"echo 'Hello World'"
|
||||
};
|
||||
let result = tool
|
||||
.execute(json!({ "command": "echo 'Hello World'" }))
|
||||
.execute(json!({ "command": command }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -370,7 +457,12 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_pwd_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Get-Location"
|
||||
} else {
|
||||
"pwd"
|
||||
};
|
||||
let result = tool.execute(json!({ "command": command })).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
@ -378,8 +470,14 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_ls_command() {
|
||||
let tool = BashTool::new();
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
format!("Get-ChildItem {}", temp_dir.display())
|
||||
} else {
|
||||
format!("ls -la {}", temp_dir.display())
|
||||
};
|
||||
let result = tool
|
||||
.execute(json!({ "command": "ls -la /tmp" }))
|
||||
.execute(json!({ "command": command }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -389,8 +487,22 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_rm() {
|
||||
let tool = BashTool::new();
|
||||
// 测试 Unix 危险命令模式
|
||||
let result = tool
|
||||
.execute(json!({ "command": "rm -rf /" }))
|
||||
.execute(json!({ "command": "rm -rf /some/path" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("blocked"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dangerous_windows_commands() {
|
||||
let tool = BashTool::new();
|
||||
// 测试 Windows del 命令模式(正则应该匹配)
|
||||
let result = tool
|
||||
.execute(json!({ "command": "del /f /q file.txt" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -422,9 +534,14 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_timeout() {
|
||||
let tool = BashTool::new();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Start-Sleep -Seconds 10"
|
||||
} else {
|
||||
"sleep 10"
|
||||
};
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"command": "sleep 10",
|
||||
"command": command,
|
||||
"timeout": 1
|
||||
}))
|
||||
.await
|
||||
@ -437,17 +554,22 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_pending_user_action_detection() {
|
||||
let tool = BashTool::new();
|
||||
let command = if cfg!(target_os = "windows") {
|
||||
"Write-Host 'waiting for authorization'; Start-Sleep -Seconds 10"
|
||||
} else {
|
||||
"printf 'waiting for authorization'; sleep 10"
|
||||
};
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10",
|
||||
"timeout": 1
|
||||
"command": command,
|
||||
"timeout": 1,
|
||||
"interactive": true
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
|
||||
assert!(result.output.contains("等待用户授权"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -460,4 +582,28 @@ mod tests {
|
||||
assert!(output.contains("chars truncated"));
|
||||
assert!(output.is_char_boundary(output.len()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_kind_detect() {
|
||||
let shell = ShellKind::detect();
|
||||
if cfg!(target_os = "windows") {
|
||||
assert_eq!(shell, ShellKind::PowerShell);
|
||||
} else {
|
||||
assert_eq!(shell, ShellKind::Bash);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_kind_executable() {
|
||||
assert_eq!(ShellKind::Bash.executable(), "bash");
|
||||
assert_eq!(ShellKind::PowerShell.executable(), "powershell");
|
||||
assert_eq!(ShellKind::Cmd.executable(), "cmd");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shell_kind_command_args() {
|
||||
assert_eq!(ShellKind::Bash.command_args("echo hello"), vec!["-c" as &str, "echo hello"]);
|
||||
assert_eq!(ShellKind::PowerShell.command_args("echo hello"), vec!["-Command" as &str, "echo hello"]);
|
||||
assert_eq!(ShellKind::Cmd.command_args("echo hello"), vec!["/C" as &str, "echo hello"]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,8 +204,11 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_write_missing_content() {
|
||||
let tool = FileWriteTool::new();
|
||||
// 使用临时目录确保跨平台兼容
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let test_path = temp_dir.path().join("test.txt");
|
||||
let result = tool
|
||||
.execute(json!({ "path": "/tmp/test.txt" }))
|
||||
.execute(json!({ "path": test_path.to_str().unwrap() }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@ -349,6 +349,7 @@ mod tests {
|
||||
|
||||
struct HomeDirGuard {
|
||||
previous: Option<OsString>,
|
||||
previous_userprofile: Option<OsString>,
|
||||
}
|
||||
|
||||
impl CurrentDirGuard {
|
||||
@ -367,23 +368,32 @@ mod tests {
|
||||
|
||||
impl HomeDirGuard {
|
||||
fn enter(path: &Path) -> Self {
|
||||
let previous = std::env::var_os("HOME");
|
||||
let home_backup = std::env::var_os("HOME");
|
||||
let userprofile_backup = std::env::var_os("USERPROFILE");
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("HOME", path);
|
||||
std::env::set_var("USERPROFILE", path);
|
||||
}
|
||||
|
||||
Self {
|
||||
previous: home_backup,
|
||||
previous_userprofile: userprofile_backup,
|
||||
}
|
||||
Self { previous }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HomeDirGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
match &self.previous {
|
||||
Some(value) => unsafe {
|
||||
std::env::set_var("HOME", value);
|
||||
},
|
||||
None => unsafe {
|
||||
std::env::remove_var("HOME");
|
||||
},
|
||||
Some(value) => std::env::set_var("HOME", value),
|
||||
None => std::env::remove_var("HOME"),
|
||||
}
|
||||
match &self.previous_userprofile {
|
||||
Some(value) => std::env::set_var("USERPROFILE", value),
|
||||
None => std::env::remove_var("USERPROFILE"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user