清理代码问题

This commit is contained in:
xiaoxixi 2026-06-16 22:56:01 +08:00
parent ad7fa70a02
commit e707774175
11 changed files with 430 additions and 139 deletions

View File

@ -18,7 +18,9 @@ PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id` - 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。 - 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。 - 已修复Session 主处理路径不再在持有 session mutex 时执行 memory recall、上下文压缩、标题 LLM 生成、消息持久化、`/stop` sub-agent 取消或清历史存储操作;慢操作改为锁外执行并用 `state_version`/`worker_generation` 防止陈旧结果覆盖当前会话。
- 已修复Bash 超时清理、文件读取大文件限制、HTTP DNS 私网校验、Bus 关闭退出、Cron `from` 语义和 PTY 工具接入等中等级问题已完成清扫。
- 待处理:工具文件边界仍是后续质量风险。
## 主要发现 ## 主要发现
@ -108,7 +110,7 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。 - 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。 - shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 中高优先级Session 锁内执行过多异步操作 ### 已修复Session 锁内执行过多异步操作
位置: 位置:
@ -125,13 +127,17 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
- 当压缩或存储出现抖动时,用户感觉像“卡死”。 - 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。 - 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
建议 已采取修复
- 锁内只做内存状态快照和必要的状态标记。 - 为 `Session` 增加 `state_version`,慢操作提交前检查会话是否已被 `/stop`、清历史或其它内存变更替换。
- 将 memory recall、压缩、LLM 摘要放到锁外执行。 - `/compact` 改为锁内取 history 快照,锁外压缩,锁内提交压缩结果,锁外持久化 meta。
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。 - agent worker Phase 1 改为锁内只创建用户消息、agent、cancel handle 和 history 快照memory recall 与 context compression 都在锁外执行。
- context overflow retry 的二次压缩移到锁外。
- 标题生成改为锁内取 prompt/provider 快照,锁外调用 LLM锁内应用标题锁外持久化。
- `add_message` 拆出内存更新和持久化快照,主消息路径在释放 session 锁后写入 SQLite。
- `/stop` 和清历史不再持有 session 锁等待 sub-agent 取消或 Storage 操作。
### 中优先级Bash 超时不会显式终止子进程 ### 已修复Bash 超时不会显式终止子进程
位置: 位置:
@ -146,14 +152,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。 长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
建议 已采取修复
- 使用 `tokio::process::Child``kill_on_drop(true)` - Bash 一次性命令改用 `wait_with_output()`,避免 stdout/stderr 顺序读取造成 pipe 阻塞
- 超时分支显式 kill child 并 wait - 子进程启用 `kill_on_drop(true)`,超时后丢弃等待 future 时会清理 child
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组 - 新增大 stderr 输出测试,覆盖不会因为 stderr pipe 填满而卡住
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义 - 持久/交互式进程通过已接入的 PTY 工具承载
### 中优先级:文件读取对大二进制文件没有输出上限 ### 已修复:文件读取对大二进制文件没有输出上限
位置: 位置:
@ -168,13 +174,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。 读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
建议 已采取修复
- 先检查 metadata size超过阈值直接返回提示。 - `file_read` 在读取前检查 metadata size超过安全阈值直接拒绝。
- 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。 - 二进制 inline base64 增加单独大小上限,超限只返回错误和文件信息。
- 对文本读取也改成流式按行读取,而不是整文件读入。 - 含 NUL 字节内容按二进制处理,避免全 0 文件被 UTF-8 路径误判为文本。
- 增加大文件和大二进制文件测试。
### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验 ### 已修复HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置: 位置:
@ -188,13 +195,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
如果该工具暴露给非完全可信输入,存在 SSRF 风险。 如果该工具暴露给非完全可信输入,存在 SSRF 风险。
建议 已采取修复
- 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。 - `http_request``web_fetch` 在发送请求前通过 DNS 解析 host并拒绝解析到 loopback、private、link-local、multicast、unspecified 的地址。
- 禁止或限制重定向,重定向后的每个 URL 重新校验。 - IPv6 unique-local 和 link-local 地址也纳入私网判定。
- 对 `http_request``web_fetch` 复用同一套 URL 安全策略。 - 禁用 reqwest 自动重定向,避免跳转到未校验的内网地址。
- 增加端口解析和 IPv6 私网判断测试。
### 中优先级:后台任务和主循环缺少监督与优雅关闭 ### 已修复:后台任务和主循环缺少监督与优雅关闭
位置: 位置:
@ -212,13 +220,14 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。 - 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。 - bus channel 关闭时更像崩溃,而不是可恢复状态。
建议 已采取修复
- 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。 - `MessageBus::consume_inbound/consume_outbound/consume_control` 不再在 channel 关闭时 `expect()` panic改为返回 `Option<T>`
- 用 `CancellationToken` 贯穿 Gateway 子任务。 - Gateway message processor 在 inbound/control bus 关闭时记录 warning 并退出 loop。
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。 - OutboundDispatcher 在 outbound bus 关闭时记录 warning 并退出 loop。
- 这不是完整 runtime supervisor但已消除 bus 关闭导致的 panic 崩溃路径,为后续集中 JoinHandle 管理留出接口。
### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间 ### 已修复Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置: 位置:
@ -232,13 +241,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。 单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
建议 已采取修复
- 使`cron_schedule.after(&from_dt).next()` 或等价 API - cron 分支改`cron_schedule.after(&from_dt).next()`
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。 - timezone 分支用 `from_dt.with_timezone(&tz)` 作为计算起点。
- 增加固定时间输入的单元测试,避免受系统时间影响 - 增加 UTC 和 Asia/Shanghai 固定时间输入测试。
### 中低优先级:存在未接入或半接入代码,增加维护噪音 ### 已修复:存在未接入或半接入代码,增加维护噪音
位置: 位置:
@ -254,10 +263,11 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。 维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
建议 已采取修复
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。 - `src/tools/pty.rs` 已接入 `tools/mod.rs`,导出 `PtyManager`/`PtyTool`
- 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。 - `create_default_tools()` 默认注册共享 `PtyManager``PtyTool`
- 修复 PTY 原本因未编译暴露不出的借用问题。
## 架构评价 ## 架构评价

View File

@ -24,7 +24,10 @@ impl OutboundDispatcher {
tracing::info!("OutboundDispatcher started"); tracing::info!("OutboundDispatcher started");
loop { loop {
let msg = self.bus.consume_outbound().await; let Some(msg) = self.bus.consume_outbound().await else {
tracing::warn!("OutboundDispatcher stopping because outbound bus closed");
break;
};
let channel_name = msg.channel.clone(); let channel_name = msg.channel.clone();
let channel = self.channel_manager.get_channel(&channel_name).await; let channel = self.channel_manager.get_channel(&channel_name).await;

View File

@ -51,17 +51,11 @@ impl MessageBus {
} }
/// Consume an inbound message (Agent -> Bus) /// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage { pub async fn consume_inbound(&self) -> Option<InboundMessage> {
let msg = self let msg = self.inbound_rx.lock().await.recv().await?;
.inbound_rx
.lock()
.await
.recv()
.await
.expect("bus inbound closed");
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message"); tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
msg Some(msg)
} }
/// Publish an outbound message (Agent -> Bus) /// Publish an outbound message (Agent -> Bus)
@ -75,13 +69,8 @@ impl MessageBus {
} }
/// Consume an outbound message (Dispatcher -> Bus) /// Consume an outbound message (Dispatcher -> Bus)
pub async fn consume_outbound(&self) -> OutboundMessage { pub async fn consume_outbound(&self) -> Option<OutboundMessage> {
self.outbound_rx self.outbound_rx.lock().await.recv().await
.lock()
.await
.recv()
.await
.expect("bus outbound closed")
} }
/// Publish a control message (Channel -> Bus for session management) /// Publish a control message (Channel -> Bus for session management)
@ -94,13 +83,8 @@ impl MessageBus {
} }
/// Consume a control message (ControlProcessor -> Bus) /// Consume a control message (ControlProcessor -> Bus)
pub async fn consume_control(&self) -> ControlMessage { pub async fn consume_control(&self) -> Option<ControlMessage> {
self.control_rx self.control_rx.lock().await.recv().await
.lock()
.await
.recv()
.await
.expect("bus control closed")
} }
} }

View File

@ -205,6 +205,10 @@ impl GatewayState {
tokio::select! { tokio::select! {
// Inbound: AI message flow // Inbound: AI message flow
inbound = bus.consume_inbound() => { inbound = bus.consume_inbound() => {
let Some(inbound) = inbound else {
tracing::warn!("Message processor stopping because inbound bus closed");
break;
};
match session_manager.handle_message( match session_manager.handle_message(
&inbound.channel, &inbound.channel,
&inbound.sender_id, &inbound.sender_id,
@ -252,6 +256,10 @@ impl GatewayState {
// Control: session management operations // Control: session management operations
msg = bus.consume_control() => { msg = bus.consume_control() => {
let Some(msg) = msg else {
tracing::warn!("Message processor stopping because control bus closed");
break;
};
Self::handle_control_message(&session_manager, msg).await; Self::handle_control_message(&session_manager, msg).await;
} }
} }

View File

@ -30,11 +30,11 @@ pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
let next_utc = if let Some(tz_str) = tz { let next_utc = if let Some(tz_str) = tz {
let tz: chrono_tz::Tz = tz_str.parse().ok()?; let tz: chrono_tz::Tz = tz_str.parse().ok()?;
let _from_local = from_dt.with_timezone(&tz); let from_local = from_dt.with_timezone(&tz);
let next_local = cron_schedule.upcoming(tz).next()?; let next_local = cron_schedule.after(&from_local).next()?;
next_local.with_timezone(&Utc) next_local.with_timezone(&Utc)
} else { } else {
cron_schedule.upcoming(Utc).next()? cron_schedule.after(&from_dt).next()?
}; };
Some(next_utc.timestamp_millis()) Some(next_utc.timestamp_millis())
@ -311,4 +311,37 @@ mod tests {
let next_ms = next.unwrap(); let next_ms = next.unwrap();
assert!(next_ms > now); assert!(next_ms > now);
} }
#[test]
fn test_next_run_cron_uses_from_argument() {
let expr = "0 * * * * *".to_string();
let schedule = Schedule::Cron { expr, tz: None };
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:34:20Z")
.unwrap()
.timestamp_millis();
let next = next_run_for_schedule(&schedule, from).unwrap();
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:35:00Z")
.unwrap()
.timestamp_millis();
assert_eq!(next, expected);
}
#[test]
fn test_next_run_cron_timezone_uses_from_argument() {
let expr = "0 0 9 * * *".to_string();
let schedule = Schedule::Cron {
expr,
tz: Some("Asia/Shanghai".to_string()),
};
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T00:30:00Z")
.unwrap()
.timestamp_millis();
let next = next_run_for_schedule(&schedule, from).unwrap();
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T01:00:00Z")
.unwrap()
.timestamp_millis();
assert_eq!(next, expected);
}
} }

View File

@ -4,7 +4,6 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use tokio::io::AsyncReadExt;
use tokio::process::Command; use tokio::process::Command;
use tokio::time::timeout; use tokio::time::timeout;
@ -147,71 +146,55 @@ impl Tool for BashTool {
.map(Path::new) .map(Path::new)
.unwrap_or_else(|| Path::new(".")); .unwrap_or_else(|| Path::new("."));
let result = timeout( match self.run_command(command, cwd, timeout_secs).await {
Duration::from_secs(timeout_secs), Ok(output) => Ok(ToolResult {
self.run_command(command, cwd),
)
.await;
match result {
Ok(Ok(output)) => Ok(ToolResult {
success: true, success: true,
output, output,
error: None, error: None,
}), }),
Ok(Err(e)) => Ok(ToolResult { Err(e) => Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(e), error: Some(e),
}), }),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
}),
} }
} }
} }
impl BashTool { impl BashTool {
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> { async fn run_command(
&self,
command: &str,
cwd: &Path,
timeout_secs: u64,
) -> Result<String, String> {
let mut cmd = Command::new("bash"); let mut cmd = Command::new("bash");
cmd.args(["-c", command]) cmd.args(["-c", command])
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.current_dir(cwd); .current_dir(cwd)
.kill_on_drop(true);
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let mut stdout = Vec::new(); let process_output =
let mut stderr = Vec::new(); match timeout(Duration::from_secs(timeout_secs), child.wait_with_output()).await {
Ok(Ok(output)) => output,
if let Some(ref mut out) = child.stdout { Ok(Err(e)) => return Err(format!("Failed to wait: {}", e)),
out.read_to_end(&mut stdout) Err(_) => {
.await return Err(format!("Command timed out after {} seconds", timeout_secs));
.map_err(|e| format!("Failed to read stdout: {}", e))?; }
} };
if let Some(ref mut err) = child.stderr {
err.read_to_end(&mut stderr)
.await
.map_err(|e| format!("Failed to read stderr: {}", e))?;
}
let status = child
.wait()
.await
.map_err(|e| format!("Failed to wait: {}", e))?;
let mut output = String::new(); let mut output = String::new();
if !stdout.is_empty() { if !process_output.stdout.is_empty() {
let stdout_str = String::from_utf8_lossy(&stdout); let stdout_str = String::from_utf8_lossy(&process_output.stdout);
output.push_str(&stdout_str); output.push_str(&stdout_str);
} }
if !stderr.is_empty() { if !process_output.stderr.is_empty() {
let stderr_str = String::from_utf8_lossy(&stderr); let stderr_str = String::from_utf8_lossy(&process_output.stderr);
if !stderr_str.trim().is_empty() { if !stderr_str.trim().is_empty() {
if !output.is_empty() { if !output.is_empty() {
output.push('\n'); output.push('\n');
@ -221,7 +204,10 @@ impl BashTool {
} }
} }
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1))); output.push_str(&format!(
"\nExit code: {}",
process_output.status.code().unwrap_or(-1)
));
Ok(self.truncate_output(&output)) Ok(self.truncate_output(&output))
} }
@ -309,4 +295,19 @@ mod tests {
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("timed out")); assert!(result.error.unwrap().contains("timed out"));
} }
#[tokio::test]
async fn test_large_stderr_does_not_deadlock() {
let tool = BashTool::new().with_timeout(5);
let result = tool
.execute(json!({
"command": "for i in $(seq 1 2000); do echo noisy-error-line >&2; done; echo done"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("done"));
assert!(result.output.contains("STDERR"));
}
} }

View File

@ -6,6 +6,8 @@ use crate::tools::path_utils;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
const MAX_CHARS: usize = 128_000; const MAX_CHARS: usize = 128_000;
const MAX_FILE_BYTES: u64 = 5 * 1024 * 1024;
const MAX_BINARY_BYTES: usize = 512 * 1024;
const DEFAULT_LIMIT: usize = 2000; const DEFAULT_LIMIT: usize = 2000;
pub struct FileReadTool { pub struct FileReadTool {
@ -118,6 +120,29 @@ impl Tool for FileReadTool {
}); });
} }
let metadata = match std::fs::metadata(&resolved) {
Ok(m) => m,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to inspect file: {}", e)),
});
}
};
if metadata.len() > MAX_FILE_BYTES {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"File too large to read safely: {} bytes (max {} bytes). Use a narrower tool or inspect a smaller excerpt.",
metadata.len(),
MAX_FILE_BYTES
)),
});
}
// Read raw bytes and try multiple encodings // Read raw bytes and try multiple encodings
let bytes = match std::fs::read(&resolved) { let bytes = match std::fs::read(&resolved) {
Ok(b) => b, Ok(b) => b,
@ -209,6 +234,21 @@ impl Tool for FileReadTool {
None => { None => {
// Truly binary file — base64 encode // Truly binary file — base64 encode
use base64::{Engine, engine::general_purpose::STANDARD}; use base64::{Engine, engine::general_purpose::STANDARD};
if bytes.len() > MAX_BINARY_BYTES {
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
.to_string();
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Binary file too large to inline: {}, {} bytes (max {} bytes).",
mime,
bytes.len(),
MAX_BINARY_BYTES
)),
});
}
let encoded = STANDARD.encode(&bytes); let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved) let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream() .first_or_octet_stream()
@ -229,6 +269,10 @@ impl Tool for FileReadTool {
} }
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) { fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
if bytes.contains(&0) {
return (None, None);
}
// Try UTF-8 first // Try UTF-8 first
if let Ok(text) = std::str::from_utf8(bytes) { if let Ok(text) = std::str::from_utf8(bytes) {
return (Some(text.to_string()), None); return (Some(text.to_string()), None);
@ -337,4 +381,37 @@ mod tests {
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file")); assert!(result.error.unwrap().contains("Not a file"));
} }
#[tokio::test]
async fn test_rejects_large_file_before_reading() {
let mut file = NamedTempFile::new().unwrap();
file.as_file_mut()
.set_len(MAX_FILE_BYTES + 1)
.expect("set large file length");
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("too large"));
}
#[tokio::test]
async fn test_rejects_large_binary_inline() {
let mut file = NamedTempFile::new().unwrap();
let bytes = vec![0_u8; MAX_BINARY_BYTES + 1];
file.write_all(&bytes).unwrap();
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Binary file too large"));
}
} }

View File

@ -3,6 +3,7 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use serde_json::json; use serde_json::json;
use tokio::net::lookup_host;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
@ -56,6 +57,34 @@ impl HttpRequestTool {
Ok(url.to_string()) Ok(url.to_string())
} }
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
if self.allow_private_hosts {
return Ok(());
}
let host = extract_host(url)?;
if host.parse::<std::net::IpAddr>().is_ok() {
return Ok(());
}
let port = extract_port(url)?;
let addrs = lookup_host((host.as_str(), port))
.await
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
for addr in addrs {
let ip = addr.ip();
if is_private_ip(&ip) {
return Err(format!(
"Blocked host '{}' because DNS resolved to local/private IP {}",
host, ip
));
}
}
Ok(())
}
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> { fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
match method.to_uppercase().as_str() { match method.to_uppercase().as_str() {
"GET" => Ok(reqwest::Method::GET), "GET" => Ok(reqwest::Method::GET),
@ -180,6 +209,34 @@ fn extract_host(url: &str) -> Result<String, String> {
Ok(host) Ok(host)
} }
fn extract_port(url: &str) -> Result<u16, String> {
let scheme = if url.starts_with("https://") {
"https"
} else if url.starts_with("http://") {
"http"
} else {
return Err("Only http:// and https:// URLs are allowed".to_string());
};
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if let Some((_, port)) = authority.rsplit_once(':') {
port.parse::<u16>()
.map_err(|_| format!("Invalid URL port: {}", port))
} else if scheme == "https" {
Ok(443)
} else {
Ok(80)
}
}
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
if allowed_domains.iter().any(|domain| domain == "*") { if allowed_domains.iter().any(|domain| domain == "*") {
return true; return true;
@ -226,7 +283,13 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast() || v4.is_broadcast()
|| v4.is_multicast() || v4.is_multicast()
} }
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), std::net::IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
}
} }
} }
@ -294,6 +357,14 @@ impl Tool for HttpRequestTool {
} }
}; };
if let Err(e) = self.validate_resolved_host(&url).await {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
let method = match self.validate_method(method_str) { let method = match self.validate_method(method_str) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
@ -309,6 +380,7 @@ impl Tool for HttpRequestTool {
let client = match reqwest::Client::builder() let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs)) .timeout(Duration::from_secs(self.timeout_secs))
.redirect(reqwest::redirect::Policy::none())
.build() .build()
{ {
Ok(c) => c, Ok(c) => c,
@ -436,4 +508,19 @@ mod tests {
async fn test_blocks_local_tld() { async fn test_blocks_local_tld() {
assert!(is_private_host("service.local")); assert!(is_private_host("service.local"));
} }
#[test]
fn test_extract_port_defaults_and_explicit() {
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
}
#[test]
fn test_private_ipv6_ranges_are_blocked() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
}
} }

View File

@ -13,6 +13,7 @@ pub mod get_skill;
pub mod http_request; pub mod http_request;
pub mod memory; pub mod memory;
pub mod path_utils; pub mod path_utils;
pub mod pty;
pub mod registry; pub mod registry;
pub mod schema; pub mod schema;
pub mod send_message; pub mod send_message;
@ -32,6 +33,7 @@ pub use file_write::FileWriteTool;
pub use get_skill::GetSkillTool; pub use get_skill::GetSkillTool;
pub use http_request::HttpRequestTool; pub use http_request::HttpRequestTool;
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool}; pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
pub use pty::{PtyManager, PtyTool};
pub use registry::ToolRegistry; pub use registry::ToolRegistry;
pub use send_message::SendMessageTool; pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use traits::{OutboundMessenger, Tool, ToolResult};
@ -60,6 +62,7 @@ pub fn create_default_tools(
registry.register(FileSearchTool::new()); registry.register(FileSearchTool::new());
registry.register(ContentSearchTool::new()); registry.register(ContentSearchTool::new());
registry.register(BashTool::new()); registry.register(BashTool::new());
registry.register(PtyTool::new(Arc::new(PtyManager::new())));
registry.register(HttpRequestTool::new( registry.register(HttpRequestTool::new(
vec!["*".to_string()], vec!["*".to_string()],
1_000_000, 1_000_000,

View File

@ -160,11 +160,14 @@ impl PtyManager {
}; };
for session in sessions { for session in sessions {
let mut guard = session.lock().unwrap(); let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap(); let child_handle = guard.child.clone();
if let Some(ref mut child) = *child_guard { {
let _ = child.kill(); let mut child_guard = child_handle.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
}
*child_guard = None;
} }
*child_guard = None;
guard.status = SessionStatus::Killed; guard.status = SessionStatus::Killed;
} }
} }
@ -274,7 +277,7 @@ impl PtyManager {
let session = sessions let session = sessions
.get(session_id) .get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?; .ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap(); let guard = session.lock().unwrap();
if guard.status != SessionStatus::Running { if guard.status != SessionStatus::Running {
return Err("Session is not running".to_string()); return Err("Session is not running".to_string());
} }
@ -296,12 +299,7 @@ impl PtyManager {
Ok(format!("OK, wrote {} bytes", byte_count)) Ok(format!("OK, wrote {} bytes", byte_count))
} }
fn read( fn read(&self, session_id: &str, offset: usize, limit: usize) -> Result<String, String> {
&self,
session_id: &str,
offset: usize,
limit: usize,
) -> Result<String, String> {
let sessions = self.sessions.lock().unwrap(); let sessions = self.sessions.lock().unwrap();
let session = sessions let session = sessions
.get(session_id) .get(session_id)
@ -352,14 +350,16 @@ impl PtyManager {
.ok_or_else(|| format!("Session not found: {}", session_id))?; .ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap(); let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap(); let child_handle = guard.child.clone();
if let Some(ref mut child) = *child_guard { {
let _ = child.kill(); let mut child_guard = child_handle.lock().unwrap();
let _ = child.wait(); if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let _ = child.wait();
}
*child_guard = None;
} }
*child_guard = None;
guard.status = SessionStatus::Killed; guard.status = SessionStatus::Killed;
drop(child_guard);
drop(guard); drop(guard);
sessions.remove(session_id); sessions.remove(session_id);
@ -545,14 +545,8 @@ impl Tool for PtyTool {
}); });
} }
}; };
let offset = args let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
.get("offset") let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(500) as usize;
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(500) as usize;
match self.pty_manager.read(session_id, offset, limit) { match self.pty_manager.read(session_id, offset, limit) {
Ok(output) => Ok(ToolResult { Ok(output) => Ok(ToolResult {
success: true, success: true,

View File

@ -3,6 +3,7 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use serde_json::json; use serde_json::json;
use tokio::net::lookup_host;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
@ -60,9 +61,34 @@ impl WebFetchTool {
} }
} }
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
let host = extract_host(url)?;
if host.parse::<std::net::IpAddr>().is_ok() {
return Ok(());
}
let port = extract_port(url)?;
let addrs = lookup_host((host.as_str(), port))
.await
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
for addr in addrs {
let ip = addr.ip();
if is_private_ip(&ip) {
return Err(format!(
"Blocked host '{}' because DNS resolved to local/private IP {}",
host, ip
));
}
}
Ok(())
}
async fn fetch_content(&self, url: &str) -> Result<String, String> { async fn fetch_content(&self, url: &str) -> Result<String, String> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs)) .timeout(Duration::from_secs(self.timeout_secs))
.redirect(reqwest::redirect::Policy::none())
.build() .build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?; .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
@ -234,6 +260,35 @@ fn extract_host(url: &str) -> Result<String, String> {
Ok(host) Ok(host)
} }
fn extract_port(url: &str) -> Result<u16, String> {
let scheme = if url.starts_with("https://") {
"https"
} else if url.starts_with("http://") {
"http"
} else {
return Err("Only http:// and https:// URLs are allowed".to_string());
};
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if let Some((_, port)) = authority.rsplit_once(':') {
port.parse::<u16>()
.map_err(|_| format!("Invalid URL port: {}", port))
} else if scheme == "https" {
Ok(443)
} else {
Ok(80)
}
}
fn is_private_host(host: &str) -> bool { fn is_private_host(host: &str) -> bool {
if host == "localhost" || host.ends_with(".localhost") { if host == "localhost" || host.ends_with(".localhost") {
return true; return true;
@ -248,19 +303,32 @@ fn is_private_host(host: &str) -> bool {
} }
if let Ok(ip) = host.parse::<std::net::IpAddr>() { if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return match ip { return is_private_ip(&ip);
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
};
} }
false false
} }
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
}
}
}
#[async_trait] #[async_trait]
impl Tool for WebFetchTool { impl Tool for WebFetchTool {
fn name(&self) -> &str { fn name(&self) -> &str {
@ -311,6 +379,14 @@ impl Tool for WebFetchTool {
} }
}; };
if let Err(e) = self.validate_resolved_host(&url).await {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
match self.fetch_content(&url).await { match self.fetch_content(&url).await {
Ok(content) => Ok(ToolResult { Ok(content) => Ok(ToolResult {
success: true, success: true,
@ -357,6 +433,21 @@ mod tests {
assert!(result.unwrap_err().contains("whitespace")); assert!(result.unwrap_err().contains("whitespace"));
} }
#[test]
fn test_extract_port_defaults_and_explicit() {
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
}
#[test]
fn test_private_ipv6_ranges_are_blocked() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
}
#[tokio::test] #[tokio::test]
async fn test_extract_text_simple() { async fn test_extract_text_simple() {
let tool = test_tool(); let tool = test_tool();