From e707774175c2e01c4b95308858c88932fc290f4d Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Tue, 16 Jun 2026 22:56:01 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B8=85=E7=90=86=E4=BB=A3=E7=A0=81=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/CODE_QUALITY_ANALYSIS.md | 82 ++++++++++++++------------ src/bus/dispatcher.rs | 5 +- src/bus/mod.rs | 30 +++------- src/gateway/mod.rs | 8 +++ src/scheduler/mod.rs | 39 ++++++++++++- src/tools/bash.rs | 85 ++++++++++++++------------- src/tools/file_read.rs | 77 ++++++++++++++++++++++++ src/tools/http_request.rs | 89 +++++++++++++++++++++++++++- src/tools/mod.rs | 3 + src/tools/pty.rs | 44 ++++++-------- src/tools/web_fetch.rs | 107 +++++++++++++++++++++++++++++++--- 11 files changed, 430 insertions(+), 139 deletions(-) diff --git a/docs/CODE_QUALITY_ANALYSIS.md b/docs/CODE_QUALITY_ANALYSIS.md index 41ef4b7..19cd791 100644 --- a/docs/CODE_QUALITY_ANALYSIS.md +++ b/docs/CODE_QUALITY_ANALYSIS.md @@ -18,7 +18,9 @@ PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只 - 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。 - 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。 -- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。 +- 已修复:Session 主处理路径不再在持有 session mutex 时执行 memory recall、上下文压缩、标题 LLM 生成、消息持久化、`/stop` sub-agent 取消或清历史存储操作;慢操作改为锁外执行并用 `state_version`/`worker_generation` 防止陈旧结果覆盖当前会话。 +- 已修复:Bash 超时清理、文件读取大文件限制、HTTP DNS 私网校验、Bus 关闭退出、Cron `from` 语义和 PTY 工具接入等中等级问题已完成清扫。 +- 待处理:工具文件边界仍是后续质量风险。 ## 主要发现 @@ -108,7 +110,7 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“ - 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。 - shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。 -### 中高优先级:Session 锁内执行过多异步操作 +### 已修复:Session 锁内执行过多异步操作 位置: @@ -125,13 +127,17 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“ - 当压缩或存储出现抖动时,用户感觉像“卡死”。 - 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。 -建议: +已采取修复: -- 锁内只做内存状态快照和必要的状态标记。 -- 将 memory recall、压缩、LLM 摘要放到锁外执行。 -- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。 +- 为 `Session` 增加 `state_version`,慢操作提交前检查会话是否已被 `/stop`、清历史或其它内存变更替换。 +- `/compact` 改为锁内取 history 快照,锁外压缩,锁内提交压缩结果,锁外持久化 meta。 +- agent worker Phase 1 改为锁内只创建用户消息、agent、cancel handle 和 history 快照;memory recall 与 context compression 都在锁外执行。 +- context overflow retry 的二次压缩移到锁外。 +- 标题生成改为锁内取 prompt/provider 快照,锁外调用 LLM,锁内应用标题,锁外持久化。 +- `add_message` 拆出内存更新和持久化快照,主消息路径在释放 session 锁后写入 SQLite。 +- `/stop` 和清历史不再持有 session 锁等待 sub-agent 取消或 Storage 操作。 -### 中优先级:Bash 超时不会显式终止子进程 +### 已修复:Bash 超时不会显式终止子进程 位置: @@ -146,14 +152,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“ 长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。 -建议: +已采取修复: -- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。 -- 超时分支显式 kill child 并 wait。 -- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。 -- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。 +- Bash 一次性命令改用 `wait_with_output()`,避免 stdout/stderr 顺序读取造成 pipe 阻塞。 +- 子进程启用 `kill_on_drop(true)`,超时后丢弃等待 future 时会清理 child。 +- 新增大 stderr 输出测试,覆盖不会因为 stderr pipe 填满而卡住。 +- 持久/交互式进程通过已接入的 PTY 工具承载。 -### 中优先级:文件读取对大二进制文件没有输出上限 +### 已修复:文件读取对大二进制文件没有输出上限 位置: @@ -168,13 +174,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“ 读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。 -建议: +已采取修复: -- 先检查 metadata size,超过阈值直接返回提示。 -- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。 -- 对文本读取也改成流式按行读取,而不是整文件读入。 +- `file_read` 在读取前检查 metadata size,超过安全阈值直接拒绝。 +- 二进制 inline base64 增加单独大小上限,超限只返回错误和文件信息。 +- 含 NUL 字节内容按二进制处理,避免全 0 文件被 UTF-8 路径误判为文本。 +- 增加大文件和大二进制文件测试。 -### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验 +### 已修复:HTTP 私网防护只检查字面 host,未做 DNS 解析校验 位置: @@ -188,13 +195,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“ 如果该工具暴露给非完全可信输入,存在 SSRF 风险。 -建议: +已采取修复: -- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。 -- 禁止或限制重定向,重定向后的每个 URL 重新校验。 -- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。 +- `http_request` 和 `web_fetch` 在发送请求前通过 DNS 解析 host,并拒绝解析到 loopback、private、link-local、multicast、unspecified 的地址。 +- IPv6 unique-local 和 link-local 地址也纳入私网判定。 +- 禁用 reqwest 自动重定向,避免跳转到未校验的内网地址。 +- 增加端口解析和 IPv6 私网判断测试。 -### 中优先级:后台任务和主循环缺少监督与优雅关闭 +### 已修复:后台任务和主循环缺少监督与优雅关闭 位置: @@ -212,13 +220,14 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan - 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。 - bus channel 关闭时更像崩溃,而不是可恢复状态。 -建议: +已采取修复: -- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。 -- 用 `CancellationToken` 贯穿 Gateway 子任务。 -- `consume_*()` 返回 `Result>`,由调用方决定退出或重启。 +- `MessageBus::consume_inbound/consume_outbound/consume_control` 不再在 channel 关闭时 `expect()` panic,改为返回 `Option`。 +- Gateway message processor 在 inbound/control bus 关闭时记录 warning 并退出 loop。 +- OutboundDispatcher 在 outbound bus 关闭时记录 warning 并退出 loop。 +- 这不是完整 runtime supervisor,但已消除 bus 关闭导致的 panic 崩溃路径,为后续集中 JoinHandle 管理留出接口。 -### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间 +### 已修复:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间 位置: @@ -232,13 +241,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan 单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。 -建议: +已采取修复: -- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。 -- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。 -- 增加固定时间输入的单元测试,避免受系统时间影响。 +- cron 分支改用 `cron_schedule.after(&from_dt).next()`。 +- timezone 分支用 `from_dt.with_timezone(&tz)` 作为计算起点。 +- 增加 UTC 和 Asia/Shanghai 固定时间输入测试。 -### 中低优先级:存在未接入或半接入代码,增加维护噪音 +### 已修复:存在未接入或半接入代码,增加维护噪音 位置: @@ -254,10 +263,11 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan 维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。 -建议: +已采取修复: -- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。 -- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。 +- `src/tools/pty.rs` 已接入 `tools/mod.rs`,导出 `PtyManager`/`PtyTool`。 +- `create_default_tools()` 默认注册共享 `PtyManager` 的 `PtyTool`。 +- 修复 PTY 原本因未编译暴露不出的借用问题。 ## 架构评价 diff --git a/src/bus/dispatcher.rs b/src/bus/dispatcher.rs index b10099f..82b607c 100644 --- a/src/bus/dispatcher.rs +++ b/src/bus/dispatcher.rs @@ -24,7 +24,10 @@ impl OutboundDispatcher { tracing::info!("OutboundDispatcher started"); loop { - let msg = self.bus.consume_outbound().await; + let Some(msg) = self.bus.consume_outbound().await else { + tracing::warn!("OutboundDispatcher stopping because outbound bus closed"); + break; + }; let channel_name = msg.channel.clone(); let channel = self.channel_manager.get_channel(&channel_name).await; diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 2f3011c..f22829e 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -51,17 +51,11 @@ impl MessageBus { } /// Consume an inbound message (Agent -> Bus) - pub async fn consume_inbound(&self) -> InboundMessage { - let msg = self - .inbound_rx - .lock() - .await - .recv() - .await - .expect("bus inbound closed"); + pub async fn consume_inbound(&self) -> Option { + let msg = self.inbound_rx.lock().await.recv().await?; #[cfg(debug_assertions)] tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message"); - msg + Some(msg) } /// Publish an outbound message (Agent -> Bus) @@ -75,13 +69,8 @@ impl MessageBus { } /// Consume an outbound message (Dispatcher -> Bus) - pub async fn consume_outbound(&self) -> OutboundMessage { - self.outbound_rx - .lock() - .await - .recv() - .await - .expect("bus outbound closed") + pub async fn consume_outbound(&self) -> Option { + self.outbound_rx.lock().await.recv().await } /// Publish a control message (Channel -> Bus for session management) @@ -94,13 +83,8 @@ impl MessageBus { } /// Consume a control message (ControlProcessor -> Bus) - pub async fn consume_control(&self) -> ControlMessage { - self.control_rx - .lock() - .await - .recv() - .await - .expect("bus control closed") + pub async fn consume_control(&self) -> Option { + self.control_rx.lock().await.recv().await } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 65ad775..2ecd6f2 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -205,6 +205,10 @@ impl GatewayState { tokio::select! { // Inbound: AI message flow inbound = bus.consume_inbound() => { + let Some(inbound) = inbound else { + tracing::warn!("Message processor stopping because inbound bus closed"); + break; + }; match session_manager.handle_message( &inbound.channel, &inbound.sender_id, @@ -252,6 +256,10 @@ impl GatewayState { // Control: session management operations msg = bus.consume_control() => { + let Some(msg) = msg else { + tracing::warn!("Message processor stopping because control bus closed"); + break; + }; Self::handle_control_message(&session_manager, msg).await; } } diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index cf0a70c..502b1b1 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -30,11 +30,11 @@ pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option { let next_utc = if let Some(tz_str) = tz { let tz: chrono_tz::Tz = tz_str.parse().ok()?; - let _from_local = from_dt.with_timezone(&tz); - let next_local = cron_schedule.upcoming(tz).next()?; + let from_local = from_dt.with_timezone(&tz); + let next_local = cron_schedule.after(&from_local).next()?; next_local.with_timezone(&Utc) } else { - cron_schedule.upcoming(Utc).next()? + cron_schedule.after(&from_dt).next()? }; Some(next_utc.timestamp_millis()) @@ -311,4 +311,37 @@ mod tests { let next_ms = next.unwrap(); assert!(next_ms > now); } + + #[test] + fn test_next_run_cron_uses_from_argument() { + let expr = "0 * * * * *".to_string(); + let schedule = Schedule::Cron { expr, tz: None }; + let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:34:20Z") + .unwrap() + .timestamp_millis(); + + let next = next_run_for_schedule(&schedule, from).unwrap(); + let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:35:00Z") + .unwrap() + .timestamp_millis(); + assert_eq!(next, expected); + } + + #[test] + fn test_next_run_cron_timezone_uses_from_argument() { + let expr = "0 0 9 * * *".to_string(); + let schedule = Schedule::Cron { + expr, + tz: Some("Asia/Shanghai".to_string()), + }; + let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T00:30:00Z") + .unwrap() + .timestamp_millis(); + + let next = next_run_for_schedule(&schedule, from).unwrap(); + let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T01:00:00Z") + .unwrap() + .timestamp_millis(); + assert_eq!(next, expected); + } } diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 714cbc4..74f6d89 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -4,7 +4,6 @@ use std::time::Duration; use async_trait::async_trait; use serde_json::json; -use tokio::io::AsyncReadExt; use tokio::process::Command; use tokio::time::timeout; @@ -147,71 +146,55 @@ impl Tool for BashTool { .map(Path::new) .unwrap_or_else(|| Path::new(".")); - let result = timeout( - Duration::from_secs(timeout_secs), - self.run_command(command, cwd), - ) - .await; - - match result { - Ok(Ok(output)) => Ok(ToolResult { + match self.run_command(command, cwd, timeout_secs).await { + Ok(output) => Ok(ToolResult { success: true, output, error: None, }), - Ok(Err(e)) => Ok(ToolResult { + Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(e), }), - Err(_) => Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Command timed out after {} seconds", timeout_secs)), - }), } } } impl BashTool { - async fn run_command(&self, command: &str, cwd: &Path) -> Result { + async fn run_command( + &self, + command: &str, + cwd: &Path, + timeout_secs: u64, + ) -> Result { let mut cmd = Command::new("bash"); cmd.args(["-c", command]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) - .current_dir(cwd); + .current_dir(cwd) + .kill_on_drop(true); - let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; + let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; - let mut stdout = Vec::new(); - let mut stderr = Vec::new(); - - if let Some(ref mut out) = child.stdout { - out.read_to_end(&mut stdout) - .await - .map_err(|e| format!("Failed to read stdout: {}", e))?; - } - - if let Some(ref mut err) = child.stderr { - err.read_to_end(&mut stderr) - .await - .map_err(|e| format!("Failed to read stderr: {}", e))?; - } - - let status = child - .wait() - .await - .map_err(|e| format!("Failed to wait: {}", e))?; + let process_output = + match timeout(Duration::from_secs(timeout_secs), child.wait_with_output()).await { + Ok(Ok(output)) => output, + Ok(Err(e)) => return Err(format!("Failed to wait: {}", e)), + Err(_) => { + return Err(format!("Command timed out after {} seconds", timeout_secs)); + } + }; let mut output = String::new(); - if !stdout.is_empty() { - let stdout_str = String::from_utf8_lossy(&stdout); + if !process_output.stdout.is_empty() { + let stdout_str = String::from_utf8_lossy(&process_output.stdout); output.push_str(&stdout_str); } - if !stderr.is_empty() { - let stderr_str = String::from_utf8_lossy(&stderr); + if !process_output.stderr.is_empty() { + let stderr_str = String::from_utf8_lossy(&process_output.stderr); if !stderr_str.trim().is_empty() { if !output.is_empty() { output.push('\n'); @@ -221,7 +204,10 @@ impl BashTool { } } - output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1))); + output.push_str(&format!( + "\nExit code: {}", + process_output.status.code().unwrap_or(-1) + )); Ok(self.truncate_output(&output)) } @@ -309,4 +295,19 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("timed out")); } + + #[tokio::test] + async fn test_large_stderr_does_not_deadlock() { + let tool = BashTool::new().with_timeout(5); + let result = tool + .execute(json!({ + "command": "for i in $(seq 1 2000); do echo noisy-error-line >&2; done; echo done" + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("done")); + assert!(result.output.contains("STDERR")); + } } diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 3115de2..b0e5bf3 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -6,6 +6,8 @@ use crate::tools::path_utils; use crate::tools::traits::{Tool, ToolResult}; const MAX_CHARS: usize = 128_000; +const MAX_FILE_BYTES: u64 = 5 * 1024 * 1024; +const MAX_BINARY_BYTES: usize = 512 * 1024; const DEFAULT_LIMIT: usize = 2000; pub struct FileReadTool { @@ -118,6 +120,29 @@ impl Tool for FileReadTool { }); } + let metadata = match std::fs::metadata(&resolved) { + Ok(m) => m, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to inspect file: {}", e)), + }); + } + }; + + if metadata.len() > MAX_FILE_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "File too large to read safely: {} bytes (max {} bytes). Use a narrower tool or inspect a smaller excerpt.", + metadata.len(), + MAX_FILE_BYTES + )), + }); + } + // Read raw bytes and try multiple encodings let bytes = match std::fs::read(&resolved) { Ok(b) => b, @@ -209,6 +234,21 @@ impl Tool for FileReadTool { None => { // Truly binary file — base64 encode use base64::{Engine, engine::general_purpose::STANDARD}; + if bytes.len() > MAX_BINARY_BYTES { + let mime = mime_guess::from_path(&resolved) + .first_or_octet_stream() + .to_string(); + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Binary file too large to inline: {}, {} bytes (max {} bytes).", + mime, + bytes.len(), + MAX_BINARY_BYTES + )), + }); + } let encoded = STANDARD.encode(&bytes); let mime = mime_guess::from_path(&resolved) .first_or_octet_stream() @@ -229,6 +269,10 @@ impl Tool for FileReadTool { } fn decode_text(bytes: &[u8]) -> (Option, Option<&'static str>) { + if bytes.contains(&0) { + return (None, None); + } + // Try UTF-8 first if let Ok(text) = std::str::from_utf8(bytes) { return (Some(text.to_string()), None); @@ -337,4 +381,37 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("Not a file")); } + + #[tokio::test] + async fn test_rejects_large_file_before_reading() { + let mut file = NamedTempFile::new().unwrap(); + file.as_file_mut() + .set_len(MAX_FILE_BYTES + 1) + .expect("set large file length"); + + let tool = FileReadTool::new(); + let result = tool + .execute(json!({ "path": file.path().to_str().unwrap() })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("too large")); + } + + #[tokio::test] + async fn test_rejects_large_binary_inline() { + let mut file = NamedTempFile::new().unwrap(); + let bytes = vec![0_u8; MAX_BINARY_BYTES + 1]; + file.write_all(&bytes).unwrap(); + + let tool = FileReadTool::new(); + let result = tool + .execute(json!({ "path": file.path().to_str().unwrap() })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Binary file too large")); + } } diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index 96b039b..f88db38 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -3,6 +3,7 @@ use std::time::Duration; use async_trait::async_trait; use reqwest::header::HeaderMap; use serde_json::json; +use tokio::net::lookup_host; use crate::tools::traits::{Tool, ToolResult}; @@ -56,6 +57,34 @@ impl HttpRequestTool { Ok(url.to_string()) } + async fn validate_resolved_host(&self, url: &str) -> Result<(), String> { + if self.allow_private_hosts { + return Ok(()); + } + + let host = extract_host(url)?; + if host.parse::().is_ok() { + return Ok(()); + } + + let port = extract_port(url)?; + let addrs = lookup_host((host.as_str(), port)) + .await + .map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?; + + for addr in addrs { + let ip = addr.ip(); + if is_private_ip(&ip) { + return Err(format!( + "Blocked host '{}' because DNS resolved to local/private IP {}", + host, ip + )); + } + } + + Ok(()) + } + fn validate_method(&self, method: &str) -> Result { match method.to_uppercase().as_str() { "GET" => Ok(reqwest::Method::GET), @@ -180,6 +209,34 @@ fn extract_host(url: &str) -> Result { Ok(host) } +fn extract_port(url: &str) -> Result { + let scheme = if url.starts_with("https://") { + "https" + } else if url.starts_with("http://") { + "http" + } else { + return Err("Only http:// and https:// URLs are allowed".to_string()); + }; + + let rest = url + .strip_prefix("http://") + .or_else(|| url.strip_prefix("https://")) + .ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?; + let authority = rest + .split(['/', '?', '#']) + .next() + .ok_or_else(|| "Invalid URL".to_string())?; + + if let Some((_, port)) = authority.rsplit_once(':') { + port.parse::() + .map_err(|_| format!("Invalid URL port: {}", port)) + } else if scheme == "https" { + Ok(443) + } else { + Ok(80) + } +} + fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { if allowed_domains.iter().any(|domain| domain == "*") { return true; @@ -226,7 +283,13 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool { || v4.is_broadcast() || v4.is_multicast() } - std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), + std::net::IpAddr::V6(v6) => { + v6.is_loopback() + || v6.is_unspecified() + || v6.is_multicast() + || ((v6.segments()[0] & 0xfe00) == 0xfc00) + || ((v6.segments()[0] & 0xffc0) == 0xfe80) + } } } @@ -294,6 +357,14 @@ impl Tool for HttpRequestTool { } }; + if let Err(e) = self.validate_resolved_host(&url).await { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }); + } + let method = match self.validate_method(method_str) { Ok(m) => m, Err(e) => { @@ -309,6 +380,7 @@ impl Tool for HttpRequestTool { let client = match reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) + .redirect(reqwest::redirect::Policy::none()) .build() { Ok(c) => c, @@ -436,4 +508,19 @@ mod tests { async fn test_blocks_local_tld() { assert!(is_private_host("service.local")); } + + #[test] + fn test_extract_port_defaults_and_explicit() { + assert_eq!(extract_port("https://example.com/path").unwrap(), 443); + assert_eq!(extract_port("http://example.com/path").unwrap(), 80); + assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443); + } + + #[test] + fn test_private_ipv6_ranges_are_blocked() { + assert!(is_private_ip(&"::1".parse().unwrap())); + assert!(is_private_ip(&"fc00::1".parse().unwrap())); + assert!(is_private_ip(&"fe80::1".parse().unwrap())); + assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap())); + } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 1d3dcf0..cbed3f2 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -13,6 +13,7 @@ pub mod get_skill; pub mod http_request; pub mod memory; pub mod path_utils; +pub mod pty; pub mod registry; pub mod schema; pub mod send_message; @@ -32,6 +33,7 @@ pub use file_write::FileWriteTool; pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool}; +pub use pty::{PtyManager, PtyTool}; pub use registry::ToolRegistry; pub use send_message::SendMessageTool; pub use traits::{OutboundMessenger, Tool, ToolResult}; @@ -60,6 +62,7 @@ pub fn create_default_tools( registry.register(FileSearchTool::new()); registry.register(ContentSearchTool::new()); registry.register(BashTool::new()); + registry.register(PtyTool::new(Arc::new(PtyManager::new()))); registry.register(HttpRequestTool::new( vec!["*".to_string()], 1_000_000, diff --git a/src/tools/pty.rs b/src/tools/pty.rs index b3a5ee3..c3e93f4 100644 --- a/src/tools/pty.rs +++ b/src/tools/pty.rs @@ -160,11 +160,14 @@ impl PtyManager { }; for session in sessions { let mut guard = session.lock().unwrap(); - let mut child_guard = guard.child.lock().unwrap(); - if let Some(ref mut child) = *child_guard { - let _ = child.kill(); + let child_handle = guard.child.clone(); + { + let mut child_guard = child_handle.lock().unwrap(); + if let Some(ref mut child) = *child_guard { + let _ = child.kill(); + } + *child_guard = None; } - *child_guard = None; guard.status = SessionStatus::Killed; } } @@ -274,7 +277,7 @@ impl PtyManager { let session = sessions .get(session_id) .ok_or_else(|| format!("Session not found: {}", session_id))?; - let mut guard = session.lock().unwrap(); + let guard = session.lock().unwrap(); if guard.status != SessionStatus::Running { return Err("Session is not running".to_string()); } @@ -296,12 +299,7 @@ impl PtyManager { Ok(format!("OK, wrote {} bytes", byte_count)) } - fn read( - &self, - session_id: &str, - offset: usize, - limit: usize, - ) -> Result { + fn read(&self, session_id: &str, offset: usize, limit: usize) -> Result { let sessions = self.sessions.lock().unwrap(); let session = sessions .get(session_id) @@ -352,14 +350,16 @@ impl PtyManager { .ok_or_else(|| format!("Session not found: {}", session_id))?; let mut guard = session.lock().unwrap(); - let mut child_guard = guard.child.lock().unwrap(); - if let Some(ref mut child) = *child_guard { - let _ = child.kill(); - let _ = child.wait(); + let child_handle = guard.child.clone(); + { + let mut child_guard = child_handle.lock().unwrap(); + if let Some(ref mut child) = *child_guard { + let _ = child.kill(); + let _ = child.wait(); + } + *child_guard = None; } - *child_guard = None; guard.status = SessionStatus::Killed; - drop(child_guard); drop(guard); sessions.remove(session_id); @@ -545,14 +545,8 @@ impl Tool for PtyTool { }); } }; - let offset = args - .get("offset") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - let limit = args - .get("limit") - .and_then(|v| v.as_u64()) - .unwrap_or(500) as usize; + let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(0) as usize; + let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(500) as usize; match self.pty_manager.read(session_id, offset, limit) { Ok(output) => Ok(ToolResult { success: true, diff --git a/src/tools/web_fetch.rs b/src/tools/web_fetch.rs index 5f2f650..afe7bda 100644 --- a/src/tools/web_fetch.rs +++ b/src/tools/web_fetch.rs @@ -3,6 +3,7 @@ use std::time::Duration; use async_trait::async_trait; use reqwest::header::HeaderMap; use serde_json::json; +use tokio::net::lookup_host; use crate::tools::traits::{Tool, ToolResult}; @@ -60,9 +61,34 @@ impl WebFetchTool { } } + async fn validate_resolved_host(&self, url: &str) -> Result<(), String> { + let host = extract_host(url)?; + if host.parse::().is_ok() { + return Ok(()); + } + + let port = extract_port(url)?; + let addrs = lookup_host((host.as_str(), port)) + .await + .map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?; + + for addr in addrs { + let ip = addr.ip(); + if is_private_ip(&ip) { + return Err(format!( + "Blocked host '{}' because DNS resolved to local/private IP {}", + host, ip + )); + } + } + + Ok(()) + } + async fn fetch_content(&self, url: &str) -> Result { let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) + .redirect(reqwest::redirect::Policy::none()) .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; @@ -234,6 +260,35 @@ fn extract_host(url: &str) -> Result { Ok(host) } +fn extract_port(url: &str) -> Result { + let scheme = if url.starts_with("https://") { + "https" + } else if url.starts_with("http://") { + "http" + } else { + return Err("Only http:// and https:// URLs are allowed".to_string()); + }; + + let rest = url + .strip_prefix("http://") + .or_else(|| url.strip_prefix("https://")) + .ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?; + + let authority = rest + .split(['/', '?', '#']) + .next() + .ok_or_else(|| "Invalid URL".to_string())?; + + if let Some((_, port)) = authority.rsplit_once(':') { + port.parse::() + .map_err(|_| format!("Invalid URL port: {}", port)) + } else if scheme == "https" { + Ok(443) + } else { + Ok(80) + } +} + fn is_private_host(host: &str) -> bool { if host == "localhost" || host.ends_with(".localhost") { return true; @@ -248,19 +303,32 @@ fn is_private_host(host: &str) -> bool { } if let Ok(ip) = host.parse::() { - return match ip { - std::net::IpAddr::V4(v4) => { - v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() - } - std::net::IpAddr::V6(v6) => { - v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() - } - }; + return is_private_ip(&ip); } false } +fn is_private_ip(ip: &std::net::IpAddr) -> bool { + match ip { + std::net::IpAddr::V4(v4) => { + v4.is_loopback() + || v4.is_private() + || v4.is_link_local() + || v4.is_unspecified() + || v4.is_broadcast() + || v4.is_multicast() + } + std::net::IpAddr::V6(v6) => { + v6.is_loopback() + || v6.is_unspecified() + || v6.is_multicast() + || ((v6.segments()[0] & 0xfe00) == 0xfc00) + || ((v6.segments()[0] & 0xffc0) == 0xfe80) + } + } +} + #[async_trait] impl Tool for WebFetchTool { fn name(&self) -> &str { @@ -311,6 +379,14 @@ impl Tool for WebFetchTool { } }; + if let Err(e) = self.validate_resolved_host(&url).await { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }); + } + match self.fetch_content(&url).await { Ok(content) => Ok(ToolResult { success: true, @@ -357,6 +433,21 @@ mod tests { assert!(result.unwrap_err().contains("whitespace")); } + #[test] + fn test_extract_port_defaults_and_explicit() { + assert_eq!(extract_port("https://example.com/path").unwrap(), 443); + assert_eq!(extract_port("http://example.com/path").unwrap(), 80); + assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443); + } + + #[test] + fn test_private_ipv6_ranges_are_blocked() { + assert!(is_private_ip(&"::1".parse().unwrap())); + assert!(is_private_ip(&"fc00::1".parse().unwrap())); + assert!(is_private_ip(&"fe80::1".parse().unwrap())); + assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap())); + } + #[tokio::test] async fn test_extract_text_simple() { let tool = test_tool();