diff --git a/src/tools/pty.rs b/src/tools/pty.rs new file mode 100644 index 0000000..b3a5ee3 --- /dev/null +++ b/src/tools/pty.rs @@ -0,0 +1,608 @@ +use std::collections::{HashMap, VecDeque}; +use std::io::Write; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use async_trait::async_trait; +use serde_json::json; + +use crate::tools::traits::{Tool, ToolResult}; + +const MAX_OUTPUT_LINES: usize = 50_000; +const MAX_CHARS_PER_LINE: usize = 2_000; +const MAX_SESSIONS: usize = 10; + +fn guard_command(command: &str) -> Option { + let deny_patterns: &[&str] = &[ + r"\brm\s+-[rf]{1,2}\b", + r"\bdel\s+/[fq]\b", + r"\brmdir\s+/s\b", + r":\(\)\s*\{.*\};\s*:", + ]; + let lower = command.to_lowercase(); + for pattern in deny_patterns { + if regex::Regex::new(pattern) + .ok() + .map(|re| re.is_match(&lower)) + .unwrap_or(false) + { + return Some(format!( + "Command blocked by safety guard (dangerous pattern: {})", + pattern + )); + } + } + None +} + +fn unescape_control_chars(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut chars = s.chars().peekable(); + while let Some(c) = chars.next() { + if c == '\\' { + match chars.next() { + Some('n') => result.push('\n'), + Some('r') => result.push('\r'), + Some('t') => result.push('\t'), + Some('x') => { + let hex: String = chars.by_ref().take(2).collect(); + if hex.len() == 2 { + if let Ok(byte) = u8::from_str_radix(&hex, 16) { + result.push(byte as char); + } else { + result.push_str(&format!("\\x{}", hex)); + } + } else { + result.push_str(&format!("\\x{}", hex)); + } + } + Some('\\') => result.push('\\'), + Some(other) => { + result.push('\\'); + result.push(other); + } + None => result.push('\\'), + } + } else { + result.push(c); + } + } + result +} + +fn truncate_command(cmd: &str, max_len: usize) -> String { + let cmd = cmd.trim(); + let first_arg = cmd.split_whitespace().next().unwrap_or(cmd); + if first_arg.len() > max_len { + format!("{}...", &first_arg[..max_len]) + } else { + first_arg.to_string() + } +} + +fn status_str(status: &SessionStatus) -> &str { + match status { + SessionStatus::Running => "running", + SessionStatus::Exited(_) => "exited", + SessionStatus::Killed => "killed", + } +} + +#[derive(Debug, Clone, PartialEq)] +enum SessionStatus { + Running, + Exited(i32), + Killed, +} + +struct PtySession { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + command: String, + #[allow(dead_code)] + started_at: Instant, + status: SessionStatus, + child: Arc>>>, + writer: Arc>>>, + output_buffer: VecDeque, + output_total_lines: usize, +} + +impl PtySession { + fn new( + id: String, + command: String, + child: Box, + writer: Box, + ) -> Self { + Self { + id, + command, + started_at: Instant::now(), + status: SessionStatus::Running, + child: Arc::new(Mutex::new(Some(child))), + writer: Arc::new(Mutex::new(Some(writer))), + output_buffer: VecDeque::new(), + output_total_lines: 0, + } + } + + fn push_line(&mut self, line: String) { + let line = if line.len() > MAX_CHARS_PER_LINE { + format!("{}...", &line[..MAX_CHARS_PER_LINE]) + } else { + line + }; + self.output_total_lines += 1; + if self.output_buffer.len() >= MAX_OUTPUT_LINES { + self.output_buffer.pop_front(); + } + self.output_buffer.push_back(line); + } +} + +pub struct PtyManager { + sessions: Mutex>>>, +} + +impl PtyManager { + pub fn new() -> Self { + Self { + sessions: Mutex::new(HashMap::new()), + } + } + + pub fn cleanup_all(&self) { + let sessions: Vec>> = { + let mut guard = self.sessions.lock().unwrap(); + guard.drain().map(|(_, s)| s).collect() + }; + for session in sessions { + let mut guard = session.lock().unwrap(); + let mut child_guard = guard.child.lock().unwrap(); + if let Some(ref mut child) = *child_guard { + let _ = child.kill(); + } + *child_guard = None; + guard.status = SessionStatus::Killed; + } + } + + fn spawn(&self, command: &str) -> Result { + if let Some(reason) = guard_command(command) { + return Err(reason); + } + + let mut sessions = self.sessions.lock().unwrap(); + if sessions.len() >= MAX_SESSIONS { + return Err(format!( + "Max sessions ({}) reached, kill some sessions first", + MAX_SESSIONS + )); + } + + let session_id = format!("pty_{}", crate::util::short_id()); + let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")); + + let pty_system = portable_pty::native_pty_system(); + let pty_pair = pty_system + .openpty(portable_pty::PtySize { + rows: 24, + cols: 80, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(|e| format!("Failed to open PTY: {}", e))?; + + let mut cmd = portable_pty::CommandBuilder::new("bash"); + cmd.args(&["-c", command]); + cmd.cwd(cwd); + + let child = pty_pair + .slave + .spawn_command(cmd) + .map_err(|e| format!("Failed to spawn: {}", e))?; + + let pid = child.process_id().unwrap_or(0); + + let writer = pty_pair + .master + .take_writer() + .map_err(|e| format!("Failed to take writer: {}", e))?; + + let reader = pty_pair + .master + .try_clone_reader() + .map_err(|e| format!("Failed to clone reader: {}", e))?; + + let session_id_clone = session_id.clone(); + let session = PtySession::new(session_id_clone, command.to_string(), child, writer); + let session = Arc::new(Mutex::new(session)); + sessions.insert(session_id.clone(), session.clone()); + + let session_for_reader = session.clone(); + let child_for_reader = session_for_reader.lock().unwrap().child.clone(); + tokio::task::spawn_blocking(move || { + let mut reader = reader; + let mut buf = [0u8; 4096]; + let mut partial = String::new(); + loop { + use std::io::Read; + match reader.read(&mut buf) { + Ok(0) => break, + Ok(n) => { + let s = String::from_utf8_lossy(&buf[..n]); + partial.push_str(&s); + let lines: Vec<&str> = partial.split('\n').collect(); + if lines.len() <= 1 { + continue; + } + let complete = lines.len() - 1; + let mut guard = session_for_reader.lock().unwrap(); + for line in lines[..complete].iter() { + guard.push_line(line.to_string()); + } + partial = lines[complete].to_string(); + } + Err(_) => break, + } + } + let mut guard = session_for_reader.lock().unwrap(); + if !partial.is_empty() { + guard.push_line(partial); + } + let exit_code = { + let mut cg = child_for_reader.lock().unwrap(); + if let Some(ref mut c) = *cg { + c.wait().map(|s| s.exit_code() as i32).ok() + } else { + None + } + }; + guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1)); + }); + + Ok(format!("session_id: {}, pid: {}", session_id, pid)) + } + + fn write(&self, session_id: &str, data: &str) -> Result { + let unescaped = unescape_control_chars(data); + let byte_count = unescaped.len(); + + let sessions = self.sessions.lock().unwrap(); + let session = sessions + .get(session_id) + .ok_or_else(|| format!("Session not found: {}", session_id))?; + let mut guard = session.lock().unwrap(); + if guard.status != SessionStatus::Running { + return Err("Session is not running".to_string()); + } + + let writer = guard.writer.clone(); + drop(guard); + drop(sessions); + + let mut writer_guard = writer.lock().unwrap(); + match *writer_guard { + Some(ref mut w) => { + w.write_all(unescaped.as_bytes()) + .map_err(|e| format!("Write error: {}", e))?; + w.flush().map_err(|e| format!("Flush error: {}", e))?; + } + None => return Err("Writer not available".to_string()), + } + + Ok(format!("OK, wrote {} bytes", byte_count)) + } + + fn read( + &self, + session_id: &str, + offset: usize, + limit: usize, + ) -> Result { + let sessions = self.sessions.lock().unwrap(); + let session = sessions + .get(session_id) + .ok_or_else(|| format!("Session not found: {}", session_id))?; + let guard = session.lock().unwrap(); + + let total = guard.output_total_lines; + let buffer_len = guard.output_buffer.len(); + let start = 0_usize.max(offset); + let skip_old = total.saturating_sub(buffer_len); + let view_start = start.saturating_sub(skip_old); + + let lines: Vec = guard + .output_buffer + .iter() + .skip(view_start) + .take(limit) + .enumerate() + .map(|(i, line)| format!("{}: {}", start + i, line)) + .collect(); + + let displayed = start + lines.len(); + let has_more = displayed < total; + + let mut output = format!( + "# Lines {}-{} (共 {} 行{})\n", + start, + displayed.saturating_sub(1), + total, + if has_more { ",还有更多" } else { "" } + ); + output.push_str(&lines.join("\n")); + if has_more { + output.push_str(&format!( + "\n[还有 {} 行未显示,用 offset={} 继续读取]", + total.saturating_sub(displayed), + displayed + )); + } + + Ok(output) + } + + fn kill(&self, session_id: &str) -> Result { + let mut sessions = self.sessions.lock().unwrap(); + let session = sessions + .get(session_id) + .ok_or_else(|| format!("Session not found: {}", session_id))?; + let mut guard = session.lock().unwrap(); + + let mut child_guard = guard.child.lock().unwrap(); + if let Some(ref mut child) = *child_guard { + let _ = child.kill(); + let _ = child.wait(); + } + *child_guard = None; + guard.status = SessionStatus::Killed; + drop(child_guard); + drop(guard); + sessions.remove(session_id); + + Ok(format!("Session {} killed", session_id)) + } + + fn list(&self) -> String { + let sessions = self.sessions.lock().unwrap(); + if sessions.is_empty() { + return "No active PTY sessions".to_string(); + } + + let mut lines: Vec = sessions + .iter() + .map(|(id, session)| { + let guard = session.lock().unwrap(); + let age = Instant::now().duration_since(guard.started_at); + let age_str = if age.as_secs() < 60 { + format!("{}s ago", age.as_secs()) + } else { + format!("{}m ago", age.as_secs() / 60) + }; + format!( + "{:<14} {:<12} {:<10} {:<6} lines {}", + id, + truncate_command(&guard.command, 10), + status_str(&guard.status), + guard.output_total_lines, + age_str, + ) + }) + .collect(); + lines.sort(); + lines.join("\n") + } +} + +impl Drop for PtyManager { + fn drop(&mut self) { + self.cleanup_all(); + } +} + +pub struct PtyTool { + pty_manager: Arc, +} + +impl PtyTool { + pub fn new(pty_manager: Arc) -> Self { + Self { pty_manager } + } +} + +#[async_trait] +impl Tool for PtyTool { + fn name(&self) -> &str { + "pty" + } + + fn description(&self) -> &str { + "管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["spawn", "write", "read", "kill", "list"], + "description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话" + }, + "session_id": { + "type": "string", + "description": "会话ID (write/read/kill 需要)" + }, + "command": { + "type": "string", + "description": "要执行的命令 (spawn 需要)" + }, + "data": { + "type": "string", + "description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)" + }, + "offset": { + "type": "integer", + "description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)", + "minimum": 0 + }, + "limit": { + "type": "integer", + "description": "读取的最大行数 (read 可选,默认 500)", + "minimum": 1 + } + }, + "required": ["action"] + }) + } + + fn exclusive(&self) -> bool { + true + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = match args.get("action").and_then(|v| v.as_str()) { + Some(a) => a, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: action".to_string()), + }); + } + }; + + match action { + "spawn" => { + let command = match args.get("command").and_then(|v| v.as_str()) { + Some(c) => c, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: command".to_string()), + }); + } + }; + match self.pty_manager.spawn(command) { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + } + } + "write" => { + let session_id = match args.get("session_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: session_id".to_string()), + }); + } + }; + let data = match args.get("data").and_then(|v| v.as_str()) { + Some(d) => d, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: data".to_string()), + }); + } + }; + match self.pty_manager.write(session_id, data) { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + } + } + "read" => { + let session_id = match args.get("session_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: session_id".to_string()), + }); + } + }; + let offset = args + .get("offset") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let limit = args + .get("limit") + .and_then(|v| v.as_u64()) + .unwrap_or(500) as usize; + match self.pty_manager.read(session_id, offset, limit) { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + } + } + "kill" => { + let session_id = match args.get("session_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: session_id".to_string()), + }); + } + }; + match self.pty_manager.kill(session_id) { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + } + } + "list" => { + let output = self.pty_manager.list(); + Ok(ToolResult { + success: true, + output, + error: None, + }) + } + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown action: {}", action)), + }), + } + } +}