feat(tools): add PtyTool with PtyManager for persistent PTY sessions

This commit is contained in:
xiaoski 2026-05-29 17:03:19 +08:00
parent 2a69021e27
commit 014538eedc

608
src/tools/pty.rs Normal file
View File

@ -0,0 +1,608 @@
use std::collections::{HashMap, VecDeque};
use std::io::Write;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
const MAX_OUTPUT_LINES: usize = 50_000;
const MAX_CHARS_PER_LINE: usize = 2_000;
const MAX_SESSIONS: usize = 10;
fn guard_command(command: &str) -> Option<String> {
let deny_patterns: &[&str] = &[
r"\brm\s+-[rf]{1,2}\b",
r"\bdel\s+/[fq]\b",
r"\brmdir\s+/s\b",
r":\(\)\s*\{.*\};\s*:",
];
let lower = command.to_lowercase();
for pattern in deny_patterns {
if regex::Regex::new(pattern)
.ok()
.map(|re| re.is_match(&lower))
.unwrap_or(false)
{
return Some(format!(
"Command blocked by safety guard (dangerous pattern: {})",
pattern
));
}
}
None
}
fn unescape_control_chars(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '\\' {
match chars.next() {
Some('n') => result.push('\n'),
Some('r') => result.push('\r'),
Some('t') => result.push('\t'),
Some('x') => {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
} else {
result.push_str(&format!("\\x{}", hex));
}
} else {
result.push_str(&format!("\\x{}", hex));
}
}
Some('\\') => result.push('\\'),
Some(other) => {
result.push('\\');
result.push(other);
}
None => result.push('\\'),
}
} else {
result.push(c);
}
}
result
}
fn truncate_command(cmd: &str, max_len: usize) -> String {
let cmd = cmd.trim();
let first_arg = cmd.split_whitespace().next().unwrap_or(cmd);
if first_arg.len() > max_len {
format!("{}...", &first_arg[..max_len])
} else {
first_arg.to_string()
}
}
fn status_str(status: &SessionStatus) -> &str {
match status {
SessionStatus::Running => "running",
SessionStatus::Exited(_) => "exited",
SessionStatus::Killed => "killed",
}
}
#[derive(Debug, Clone, PartialEq)]
enum SessionStatus {
Running,
Exited(i32),
Killed,
}
struct PtySession {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
command: String,
#[allow(dead_code)]
started_at: Instant,
status: SessionStatus,
child: Arc<Mutex<Option<Box<dyn portable_pty::Child + Send + Sync>>>>,
writer: Arc<Mutex<Option<Box<dyn Write + Send>>>>,
output_buffer: VecDeque<String>,
output_total_lines: usize,
}
impl PtySession {
fn new(
id: String,
command: String,
child: Box<dyn portable_pty::Child + Send + Sync>,
writer: Box<dyn Write + Send>,
) -> Self {
Self {
id,
command,
started_at: Instant::now(),
status: SessionStatus::Running,
child: Arc::new(Mutex::new(Some(child))),
writer: Arc::new(Mutex::new(Some(writer))),
output_buffer: VecDeque::new(),
output_total_lines: 0,
}
}
fn push_line(&mut self, line: String) {
let line = if line.len() > MAX_CHARS_PER_LINE {
format!("{}...<truncated>", &line[..MAX_CHARS_PER_LINE])
} else {
line
};
self.output_total_lines += 1;
if self.output_buffer.len() >= MAX_OUTPUT_LINES {
self.output_buffer.pop_front();
}
self.output_buffer.push_back(line);
}
}
pub struct PtyManager {
sessions: Mutex<HashMap<String, Arc<Mutex<PtySession>>>>,
}
impl PtyManager {
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
}
}
pub fn cleanup_all(&self) {
let sessions: Vec<Arc<Mutex<PtySession>>> = {
let mut guard = self.sessions.lock().unwrap();
guard.drain().map(|(_, s)| s).collect()
};
for session in sessions {
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
}
*child_guard = None;
guard.status = SessionStatus::Killed;
}
}
fn spawn(&self, command: &str) -> Result<String, String> {
if let Some(reason) = guard_command(command) {
return Err(reason);
}
let mut sessions = self.sessions.lock().unwrap();
if sessions.len() >= MAX_SESSIONS {
return Err(format!(
"Max sessions ({}) reached, kill some sessions first",
MAX_SESSIONS
));
}
let session_id = format!("pty_{}", crate::util::short_id());
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
let pty_system = portable_pty::native_pty_system();
let pty_pair = pty_system
.openpty(portable_pty::PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| format!("Failed to open PTY: {}", e))?;
let mut cmd = portable_pty::CommandBuilder::new("bash");
cmd.args(&["-c", command]);
cmd.cwd(cwd);
let child = pty_pair
.slave
.spawn_command(cmd)
.map_err(|e| format!("Failed to spawn: {}", e))?;
let pid = child.process_id().unwrap_or(0);
let writer = pty_pair
.master
.take_writer()
.map_err(|e| format!("Failed to take writer: {}", e))?;
let reader = pty_pair
.master
.try_clone_reader()
.map_err(|e| format!("Failed to clone reader: {}", e))?;
let session_id_clone = session_id.clone();
let session = PtySession::new(session_id_clone, command.to_string(), child, writer);
let session = Arc::new(Mutex::new(session));
sessions.insert(session_id.clone(), session.clone());
let session_for_reader = session.clone();
let child_for_reader = session_for_reader.lock().unwrap().child.clone();
tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; 4096];
let mut partial = String::new();
loop {
use std::io::Read;
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let s = String::from_utf8_lossy(&buf[..n]);
partial.push_str(&s);
let lines: Vec<&str> = partial.split('\n').collect();
if lines.len() <= 1 {
continue;
}
let complete = lines.len() - 1;
let mut guard = session_for_reader.lock().unwrap();
for line in lines[..complete].iter() {
guard.push_line(line.to_string());
}
partial = lines[complete].to_string();
}
Err(_) => break,
}
}
let mut guard = session_for_reader.lock().unwrap();
if !partial.is_empty() {
guard.push_line(partial);
}
let exit_code = {
let mut cg = child_for_reader.lock().unwrap();
if let Some(ref mut c) = *cg {
c.wait().map(|s| s.exit_code() as i32).ok()
} else {
None
}
};
guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1));
});
Ok(format!("session_id: {}, pid: {}", session_id, pid))
}
fn write(&self, session_id: &str, data: &str) -> Result<String, String> {
let unescaped = unescape_control_chars(data);
let byte_count = unescaped.len();
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
if guard.status != SessionStatus::Running {
return Err("Session is not running".to_string());
}
let writer = guard.writer.clone();
drop(guard);
drop(sessions);
let mut writer_guard = writer.lock().unwrap();
match *writer_guard {
Some(ref mut w) => {
w.write_all(unescaped.as_bytes())
.map_err(|e| format!("Write error: {}", e))?;
w.flush().map_err(|e| format!("Flush error: {}", e))?;
}
None => return Err("Writer not available".to_string()),
}
Ok(format!("OK, wrote {} bytes", byte_count))
}
fn read(
&self,
session_id: &str,
offset: usize,
limit: usize,
) -> Result<String, String> {
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let guard = session.lock().unwrap();
let total = guard.output_total_lines;
let buffer_len = guard.output_buffer.len();
let start = 0_usize.max(offset);
let skip_old = total.saturating_sub(buffer_len);
let view_start = start.saturating_sub(skip_old);
let lines: Vec<String> = guard
.output_buffer
.iter()
.skip(view_start)
.take(limit)
.enumerate()
.map(|(i, line)| format!("{}: {}", start + i, line))
.collect();
let displayed = start + lines.len();
let has_more = displayed < total;
let mut output = format!(
"# Lines {}-{} (共 {} 行{})\n",
start,
displayed.saturating_sub(1),
total,
if has_more { ",还有更多" } else { "" }
);
output.push_str(&lines.join("\n"));
if has_more {
output.push_str(&format!(
"\n[还有 {} 行未显示,用 offset={} 继续读取]",
total.saturating_sub(displayed),
displayed
));
}
Ok(output)
}
fn kill(&self, session_id: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let _ = child.wait();
}
*child_guard = None;
guard.status = SessionStatus::Killed;
drop(child_guard);
drop(guard);
sessions.remove(session_id);
Ok(format!("Session {} killed", session_id))
}
fn list(&self) -> String {
let sessions = self.sessions.lock().unwrap();
if sessions.is_empty() {
return "No active PTY sessions".to_string();
}
let mut lines: Vec<String> = sessions
.iter()
.map(|(id, session)| {
let guard = session.lock().unwrap();
let age = Instant::now().duration_since(guard.started_at);
let age_str = if age.as_secs() < 60 {
format!("{}s ago", age.as_secs())
} else {
format!("{}m ago", age.as_secs() / 60)
};
format!(
"{:<14} {:<12} {:<10} {:<6} lines {}",
id,
truncate_command(&guard.command, 10),
status_str(&guard.status),
guard.output_total_lines,
age_str,
)
})
.collect();
lines.sort();
lines.join("\n")
}
}
impl Drop for PtyManager {
fn drop(&mut self) {
self.cleanup_all();
}
}
pub struct PtyTool {
pty_manager: Arc<PtyManager>,
}
impl PtyTool {
pub fn new(pty_manager: Arc<PtyManager>) -> Self {
Self { pty_manager }
}
}
#[async_trait]
impl Tool for PtyTool {
fn name(&self) -> &str {
"pty"
}
fn description(&self) -> &str {
"管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["spawn", "write", "read", "kill", "list"],
"description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话"
},
"session_id": {
"type": "string",
"description": "会话ID (write/read/kill 需要)"
},
"command": {
"type": "string",
"description": "要执行的命令 (spawn 需要)"
},
"data": {
"type": "string",
"description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)"
},
"offset": {
"type": "integer",
"description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)",
"minimum": 0
},
"limit": {
"type": "integer",
"description": "读取的最大行数 (read 可选,默认 500)",
"minimum": 1
}
},
"required": ["action"]
})
}
fn exclusive(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|v| v.as_str()) {
Some(a) => a,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: action".to_string()),
});
}
};
match action {
"spawn" => {
let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: command".to_string()),
});
}
};
match self.pty_manager.spawn(command) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"write" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
let data = match args.get("data").and_then(|v| v.as_str()) {
Some(d) => d,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: data".to_string()),
});
}
};
match self.pty_manager.write(session_id, data) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"read" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
let offset = args
.get("offset")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(500) as usize;
match self.pty_manager.read(session_id, offset, limit) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"kill" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
match self.pty_manager.kill(session_id) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"list" => {
let output = self.pty_manager.list();
Ok(ToolResult {
success: true,
output,
error: None,
})
}
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {}", action)),
}),
}
}
}