feat(tools): add PtyTool with PtyManager for persistent PTY sessions
This commit is contained in:
parent
2a69021e27
commit
014538eedc
608
src/tools/pty.rs
Normal file
608
src/tools/pty.rs
Normal file
@ -0,0 +1,608 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io::Write;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
const MAX_OUTPUT_LINES: usize = 50_000;
|
||||
const MAX_CHARS_PER_LINE: usize = 2_000;
|
||||
const MAX_SESSIONS: usize = 10;
|
||||
|
||||
fn guard_command(command: &str) -> Option<String> {
|
||||
let deny_patterns: &[&str] = &[
|
||||
r"\brm\s+-[rf]{1,2}\b",
|
||||
r"\bdel\s+/[fq]\b",
|
||||
r"\brmdir\s+/s\b",
|
||||
r":\(\)\s*\{.*\};\s*:",
|
||||
];
|
||||
let lower = command.to_lowercase();
|
||||
for pattern in deny_patterns {
|
||||
if regex::Regex::new(pattern)
|
||||
.ok()
|
||||
.map(|re| re.is_match(&lower))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Some(format!(
|
||||
"Command blocked by safety guard (dangerous pattern: {})",
|
||||
pattern
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn unescape_control_chars(s: &str) -> String {
|
||||
let mut result = String::with_capacity(s.len());
|
||||
let mut chars = s.chars().peekable();
|
||||
while let Some(c) = chars.next() {
|
||||
if c == '\\' {
|
||||
match chars.next() {
|
||||
Some('n') => result.push('\n'),
|
||||
Some('r') => result.push('\r'),
|
||||
Some('t') => result.push('\t'),
|
||||
Some('x') => {
|
||||
let hex: String = chars.by_ref().take(2).collect();
|
||||
if hex.len() == 2 {
|
||||
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
|
||||
result.push(byte as char);
|
||||
} else {
|
||||
result.push_str(&format!("\\x{}", hex));
|
||||
}
|
||||
} else {
|
||||
result.push_str(&format!("\\x{}", hex));
|
||||
}
|
||||
}
|
||||
Some('\\') => result.push('\\'),
|
||||
Some(other) => {
|
||||
result.push('\\');
|
||||
result.push(other);
|
||||
}
|
||||
None => result.push('\\'),
|
||||
}
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn truncate_command(cmd: &str, max_len: usize) -> String {
|
||||
let cmd = cmd.trim();
|
||||
let first_arg = cmd.split_whitespace().next().unwrap_or(cmd);
|
||||
if first_arg.len() > max_len {
|
||||
format!("{}...", &first_arg[..max_len])
|
||||
} else {
|
||||
first_arg.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn status_str(status: &SessionStatus) -> &str {
|
||||
match status {
|
||||
SessionStatus::Running => "running",
|
||||
SessionStatus::Exited(_) => "exited",
|
||||
SessionStatus::Killed => "killed",
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum SessionStatus {
|
||||
Running,
|
||||
Exited(i32),
|
||||
Killed,
|
||||
}
|
||||
|
||||
struct PtySession {
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
#[allow(dead_code)]
|
||||
command: String,
|
||||
#[allow(dead_code)]
|
||||
started_at: Instant,
|
||||
status: SessionStatus,
|
||||
child: Arc<Mutex<Option<Box<dyn portable_pty::Child + Send + Sync>>>>,
|
||||
writer: Arc<Mutex<Option<Box<dyn Write + Send>>>>,
|
||||
output_buffer: VecDeque<String>,
|
||||
output_total_lines: usize,
|
||||
}
|
||||
|
||||
impl PtySession {
|
||||
fn new(
|
||||
id: String,
|
||||
command: String,
|
||||
child: Box<dyn portable_pty::Child + Send + Sync>,
|
||||
writer: Box<dyn Write + Send>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
command,
|
||||
started_at: Instant::now(),
|
||||
status: SessionStatus::Running,
|
||||
child: Arc::new(Mutex::new(Some(child))),
|
||||
writer: Arc::new(Mutex::new(Some(writer))),
|
||||
output_buffer: VecDeque::new(),
|
||||
output_total_lines: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_line(&mut self, line: String) {
|
||||
let line = if line.len() > MAX_CHARS_PER_LINE {
|
||||
format!("{}...<truncated>", &line[..MAX_CHARS_PER_LINE])
|
||||
} else {
|
||||
line
|
||||
};
|
||||
self.output_total_lines += 1;
|
||||
if self.output_buffer.len() >= MAX_OUTPUT_LINES {
|
||||
self.output_buffer.pop_front();
|
||||
}
|
||||
self.output_buffer.push_back(line);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PtyManager {
|
||||
sessions: Mutex<HashMap<String, Arc<Mutex<PtySession>>>>,
|
||||
}
|
||||
|
||||
impl PtyManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cleanup_all(&self) {
|
||||
let sessions: Vec<Arc<Mutex<PtySession>>> = {
|
||||
let mut guard = self.sessions.lock().unwrap();
|
||||
guard.drain().map(|(_, s)| s).collect()
|
||||
};
|
||||
for session in sessions {
|
||||
let mut guard = session.lock().unwrap();
|
||||
let mut child_guard = guard.child.lock().unwrap();
|
||||
if let Some(ref mut child) = *child_guard {
|
||||
let _ = child.kill();
|
||||
}
|
||||
*child_guard = None;
|
||||
guard.status = SessionStatus::Killed;
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn(&self, command: &str) -> Result<String, String> {
|
||||
if let Some(reason) = guard_command(command) {
|
||||
return Err(reason);
|
||||
}
|
||||
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
if sessions.len() >= MAX_SESSIONS {
|
||||
return Err(format!(
|
||||
"Max sessions ({}) reached, kill some sessions first",
|
||||
MAX_SESSIONS
|
||||
));
|
||||
}
|
||||
|
||||
let session_id = format!("pty_{}", crate::util::short_id());
|
||||
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
|
||||
|
||||
let pty_system = portable_pty::native_pty_system();
|
||||
let pty_pair = pty_system
|
||||
.openpty(portable_pty::PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(|e| format!("Failed to open PTY: {}", e))?;
|
||||
|
||||
let mut cmd = portable_pty::CommandBuilder::new("bash");
|
||||
cmd.args(&["-c", command]);
|
||||
cmd.cwd(cwd);
|
||||
|
||||
let child = pty_pair
|
||||
.slave
|
||||
.spawn_command(cmd)
|
||||
.map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||
|
||||
let pid = child.process_id().unwrap_or(0);
|
||||
|
||||
let writer = pty_pair
|
||||
.master
|
||||
.take_writer()
|
||||
.map_err(|e| format!("Failed to take writer: {}", e))?;
|
||||
|
||||
let reader = pty_pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(|e| format!("Failed to clone reader: {}", e))?;
|
||||
|
||||
let session_id_clone = session_id.clone();
|
||||
let session = PtySession::new(session_id_clone, command.to_string(), child, writer);
|
||||
let session = Arc::new(Mutex::new(session));
|
||||
sessions.insert(session_id.clone(), session.clone());
|
||||
|
||||
let session_for_reader = session.clone();
|
||||
let child_for_reader = session_for_reader.lock().unwrap().child.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut reader = reader;
|
||||
let mut buf = [0u8; 4096];
|
||||
let mut partial = String::new();
|
||||
loop {
|
||||
use std::io::Read;
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let s = String::from_utf8_lossy(&buf[..n]);
|
||||
partial.push_str(&s);
|
||||
let lines: Vec<&str> = partial.split('\n').collect();
|
||||
if lines.len() <= 1 {
|
||||
continue;
|
||||
}
|
||||
let complete = lines.len() - 1;
|
||||
let mut guard = session_for_reader.lock().unwrap();
|
||||
for line in lines[..complete].iter() {
|
||||
guard.push_line(line.to_string());
|
||||
}
|
||||
partial = lines[complete].to_string();
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
let mut guard = session_for_reader.lock().unwrap();
|
||||
if !partial.is_empty() {
|
||||
guard.push_line(partial);
|
||||
}
|
||||
let exit_code = {
|
||||
let mut cg = child_for_reader.lock().unwrap();
|
||||
if let Some(ref mut c) = *cg {
|
||||
c.wait().map(|s| s.exit_code() as i32).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1));
|
||||
});
|
||||
|
||||
Ok(format!("session_id: {}, pid: {}", session_id, pid))
|
||||
}
|
||||
|
||||
fn write(&self, session_id: &str, data: &str) -> Result<String, String> {
|
||||
let unescaped = unescape_control_chars(data);
|
||||
let byte_count = unescaped.len();
|
||||
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
let mut guard = session.lock().unwrap();
|
||||
if guard.status != SessionStatus::Running {
|
||||
return Err("Session is not running".to_string());
|
||||
}
|
||||
|
||||
let writer = guard.writer.clone();
|
||||
drop(guard);
|
||||
drop(sessions);
|
||||
|
||||
let mut writer_guard = writer.lock().unwrap();
|
||||
match *writer_guard {
|
||||
Some(ref mut w) => {
|
||||
w.write_all(unescaped.as_bytes())
|
||||
.map_err(|e| format!("Write error: {}", e))?;
|
||||
w.flush().map_err(|e| format!("Flush error: {}", e))?;
|
||||
}
|
||||
None => return Err("Writer not available".to_string()),
|
||||
}
|
||||
|
||||
Ok(format!("OK, wrote {} bytes", byte_count))
|
||||
}
|
||||
|
||||
fn read(
|
||||
&self,
|
||||
session_id: &str,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
) -> Result<String, String> {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
let guard = session.lock().unwrap();
|
||||
|
||||
let total = guard.output_total_lines;
|
||||
let buffer_len = guard.output_buffer.len();
|
||||
let start = 0_usize.max(offset);
|
||||
let skip_old = total.saturating_sub(buffer_len);
|
||||
let view_start = start.saturating_sub(skip_old);
|
||||
|
||||
let lines: Vec<String> = guard
|
||||
.output_buffer
|
||||
.iter()
|
||||
.skip(view_start)
|
||||
.take(limit)
|
||||
.enumerate()
|
||||
.map(|(i, line)| format!("{}: {}", start + i, line))
|
||||
.collect();
|
||||
|
||||
let displayed = start + lines.len();
|
||||
let has_more = displayed < total;
|
||||
|
||||
let mut output = format!(
|
||||
"# Lines {}-{} (共 {} 行{})\n",
|
||||
start,
|
||||
displayed.saturating_sub(1),
|
||||
total,
|
||||
if has_more { ",还有更多" } else { "" }
|
||||
);
|
||||
output.push_str(&lines.join("\n"));
|
||||
if has_more {
|
||||
output.push_str(&format!(
|
||||
"\n[还有 {} 行未显示,用 offset={} 继续读取]",
|
||||
total.saturating_sub(displayed),
|
||||
displayed
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn kill(&self, session_id: &str) -> Result<String, String> {
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
let mut guard = session.lock().unwrap();
|
||||
|
||||
let mut child_guard = guard.child.lock().unwrap();
|
||||
if let Some(ref mut child) = *child_guard {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
}
|
||||
*child_guard = None;
|
||||
guard.status = SessionStatus::Killed;
|
||||
drop(child_guard);
|
||||
drop(guard);
|
||||
sessions.remove(session_id);
|
||||
|
||||
Ok(format!("Session {} killed", session_id))
|
||||
}
|
||||
|
||||
fn list(&self) -> String {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
if sessions.is_empty() {
|
||||
return "No active PTY sessions".to_string();
|
||||
}
|
||||
|
||||
let mut lines: Vec<String> = sessions
|
||||
.iter()
|
||||
.map(|(id, session)| {
|
||||
let guard = session.lock().unwrap();
|
||||
let age = Instant::now().duration_since(guard.started_at);
|
||||
let age_str = if age.as_secs() < 60 {
|
||||
format!("{}s ago", age.as_secs())
|
||||
} else {
|
||||
format!("{}m ago", age.as_secs() / 60)
|
||||
};
|
||||
format!(
|
||||
"{:<14} {:<12} {:<10} {:<6} lines {}",
|
||||
id,
|
||||
truncate_command(&guard.command, 10),
|
||||
status_str(&guard.status),
|
||||
guard.output_total_lines,
|
||||
age_str,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
lines.sort();
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PtyManager {
|
||||
fn drop(&mut self) {
|
||||
self.cleanup_all();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PtyTool {
|
||||
pty_manager: Arc<PtyManager>,
|
||||
}
|
||||
|
||||
impl PtyTool {
|
||||
pub fn new(pty_manager: Arc<PtyManager>) -> Self {
|
||||
Self { pty_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for PtyTool {
|
||||
fn name(&self) -> &str {
|
||||
"pty"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["spawn", "write", "read", "kill", "list"],
|
||||
"description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "会话ID (write/read/kill 需要)"
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的命令 (spawn 需要)"
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)",
|
||||
"minimum": 0
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "读取的最大行数 (read 可选,默认 500)",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
fn exclusive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match action {
|
||||
"spawn" => {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: command".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
match self.pty_manager.spawn(command) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"write" => {
|
||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: session_id".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let data = match args.get("data").and_then(|v| v.as_str()) {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: data".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
match self.pty_manager.write(session_id, data) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"read" => {
|
||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: session_id".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let offset = args
|
||||
.get("offset")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as usize;
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(500) as usize;
|
||||
match self.pty_manager.read(session_id, offset, limit) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"kill" => {
|
||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: session_id".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
match self.pty_manager.kill(session_id) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"list" => {
|
||||
let output = self.pty_manager.list();
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action: {}", action)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user