清理代码问题
This commit is contained in:
parent
ad7fa70a02
commit
e707774175
@ -18,7 +18,9 @@ PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只
|
|||||||
|
|
||||||
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
||||||
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
||||||
- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
|
- 已修复:Session 主处理路径不再在持有 session mutex 时执行 memory recall、上下文压缩、标题 LLM 生成、消息持久化、`/stop` sub-agent 取消或清历史存储操作;慢操作改为锁外执行并用 `state_version`/`worker_generation` 防止陈旧结果覆盖当前会话。
|
||||||
|
- 已修复:Bash 超时清理、文件读取大文件限制、HTTP DNS 私网校验、Bus 关闭退出、Cron `from` 语义和 PTY 工具接入等中等级问题已完成清扫。
|
||||||
|
- 待处理:工具文件边界仍是后续质量风险。
|
||||||
|
|
||||||
## 主要发现
|
## 主要发现
|
||||||
|
|
||||||
@ -108,7 +110,7 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
||||||
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
||||||
|
|
||||||
### 中高优先级:Session 锁内执行过多异步操作
|
### 已修复:Session 锁内执行过多异步操作
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -125,13 +127,17 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
||||||
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 锁内只做内存状态快照和必要的状态标记。
|
- 为 `Session` 增加 `state_version`,慢操作提交前检查会话是否已被 `/stop`、清历史或其它内存变更替换。
|
||||||
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
|
- `/compact` 改为锁内取 history 快照,锁外压缩,锁内提交压缩结果,锁外持久化 meta。
|
||||||
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
|
- agent worker Phase 1 改为锁内只创建用户消息、agent、cancel handle 和 history 快照;memory recall 与 context compression 都在锁外执行。
|
||||||
|
- context overflow retry 的二次压缩移到锁外。
|
||||||
|
- 标题生成改为锁内取 prompt/provider 快照,锁外调用 LLM,锁内应用标题,锁外持久化。
|
||||||
|
- `add_message` 拆出内存更新和持久化快照,主消息路径在释放 session 锁后写入 SQLite。
|
||||||
|
- `/stop` 和清历史不再持有 session 锁等待 sub-agent 取消或 Storage 操作。
|
||||||
|
|
||||||
### 中优先级:Bash 超时不会显式终止子进程
|
### 已修复:Bash 超时不会显式终止子进程
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -146,14 +152,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
|
|
||||||
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。
|
- Bash 一次性命令改用 `wait_with_output()`,避免 stdout/stderr 顺序读取造成 pipe 阻塞。
|
||||||
- 超时分支显式 kill child 并 wait。
|
- 子进程启用 `kill_on_drop(true)`,超时后丢弃等待 future 时会清理 child。
|
||||||
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
|
- 新增大 stderr 输出测试,覆盖不会因为 stderr pipe 填满而卡住。
|
||||||
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
|
- 持久/交互式进程通过已接入的 PTY 工具承载。
|
||||||
|
|
||||||
### 中优先级:文件读取对大二进制文件没有输出上限
|
### 已修复:文件读取对大二进制文件没有输出上限
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -168,13 +174,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
|
|
||||||
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 先检查 metadata size,超过阈值直接返回提示。
|
- `file_read` 在读取前检查 metadata size,超过安全阈值直接拒绝。
|
||||||
- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。
|
- 二进制 inline base64 增加单独大小上限,超限只返回错误和文件信息。
|
||||||
- 对文本读取也改成流式按行读取,而不是整文件读入。
|
- 含 NUL 字节内容按二进制处理,避免全 0 文件被 UTF-8 路径误判为文本。
|
||||||
|
- 增加大文件和大二进制文件测试。
|
||||||
|
|
||||||
### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
### 已修复:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -188,13 +195,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
|
|
||||||
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。
|
- `http_request` 和 `web_fetch` 在发送请求前通过 DNS 解析 host,并拒绝解析到 loopback、private、link-local、multicast、unspecified 的地址。
|
||||||
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
|
- IPv6 unique-local 和 link-local 地址也纳入私网判定。
|
||||||
- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。
|
- 禁用 reqwest 自动重定向,避免跳转到未校验的内网地址。
|
||||||
|
- 增加端口解析和 IPv6 私网判断测试。
|
||||||
|
|
||||||
### 中优先级:后台任务和主循环缺少监督与优雅关闭
|
### 已修复:后台任务和主循环缺少监督与优雅关闭
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -212,13 +220,14 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
|
|||||||
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
||||||
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。
|
- `MessageBus::consume_inbound/consume_outbound/consume_control` 不再在 channel 关闭时 `expect()` panic,改为返回 `Option<T>`。
|
||||||
- 用 `CancellationToken` 贯穿 Gateway 子任务。
|
- Gateway message processor 在 inbound/control bus 关闭时记录 warning 并退出 loop。
|
||||||
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
|
- OutboundDispatcher 在 outbound bus 关闭时记录 warning 并退出 loop。
|
||||||
|
- 这不是完整 runtime supervisor,但已消除 bus 关闭导致的 panic 崩溃路径,为后续集中 JoinHandle 管理留出接口。
|
||||||
|
|
||||||
### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
### 已修复:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -232,13 +241,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
|
|||||||
|
|
||||||
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
|
- cron 分支改用 `cron_schedule.after(&from_dt).next()`。
|
||||||
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
|
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为计算起点。
|
||||||
- 增加固定时间输入的单元测试,避免受系统时间影响。
|
- 增加 UTC 和 Asia/Shanghai 固定时间输入测试。
|
||||||
|
|
||||||
### 中低优先级:存在未接入或半接入代码,增加维护噪音
|
### 已修复:存在未接入或半接入代码,增加维护噪音
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -254,10 +263,11 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
|
|||||||
|
|
||||||
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
|
- `src/tools/pty.rs` 已接入 `tools/mod.rs`,导出 `PtyManager`/`PtyTool`。
|
||||||
- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。
|
- `create_default_tools()` 默认注册共享 `PtyManager` 的 `PtyTool`。
|
||||||
|
- 修复 PTY 原本因未编译暴露不出的借用问题。
|
||||||
|
|
||||||
## 架构评价
|
## 架构评价
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,10 @@ impl OutboundDispatcher {
|
|||||||
tracing::info!("OutboundDispatcher started");
|
tracing::info!("OutboundDispatcher started");
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let msg = self.bus.consume_outbound().await;
|
let Some(msg) = self.bus.consume_outbound().await else {
|
||||||
|
tracing::warn!("OutboundDispatcher stopping because outbound bus closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
let channel_name = msg.channel.clone();
|
let channel_name = msg.channel.clone();
|
||||||
let channel = self.channel_manager.get_channel(&channel_name).await;
|
let channel = self.channel_manager.get_channel(&channel_name).await;
|
||||||
|
|||||||
@ -51,17 +51,11 @@ impl MessageBus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Consume an inbound message (Agent -> Bus)
|
/// Consume an inbound message (Agent -> Bus)
|
||||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
pub async fn consume_inbound(&self) -> Option<InboundMessage> {
|
||||||
let msg = self
|
let msg = self.inbound_rx.lock().await.recv().await?;
|
||||||
.inbound_rx
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.expect("bus inbound closed");
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
|
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
|
||||||
msg
|
Some(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Publish an outbound message (Agent -> Bus)
|
/// Publish an outbound message (Agent -> Bus)
|
||||||
@ -75,13 +69,8 @@ impl MessageBus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Consume an outbound message (Dispatcher -> Bus)
|
/// Consume an outbound message (Dispatcher -> Bus)
|
||||||
pub async fn consume_outbound(&self) -> OutboundMessage {
|
pub async fn consume_outbound(&self) -> Option<OutboundMessage> {
|
||||||
self.outbound_rx
|
self.outbound_rx.lock().await.recv().await
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.expect("bus outbound closed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Publish a control message (Channel -> Bus for session management)
|
/// Publish a control message (Channel -> Bus for session management)
|
||||||
@ -94,13 +83,8 @@ impl MessageBus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Consume a control message (ControlProcessor -> Bus)
|
/// Consume a control message (ControlProcessor -> Bus)
|
||||||
pub async fn consume_control(&self) -> ControlMessage {
|
pub async fn consume_control(&self) -> Option<ControlMessage> {
|
||||||
self.control_rx
|
self.control_rx.lock().await.recv().await
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.expect("bus control closed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -205,6 +205,10 @@ impl GatewayState {
|
|||||||
tokio::select! {
|
tokio::select! {
|
||||||
// Inbound: AI message flow
|
// Inbound: AI message flow
|
||||||
inbound = bus.consume_inbound() => {
|
inbound = bus.consume_inbound() => {
|
||||||
|
let Some(inbound) = inbound else {
|
||||||
|
tracing::warn!("Message processor stopping because inbound bus closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
match session_manager.handle_message(
|
match session_manager.handle_message(
|
||||||
&inbound.channel,
|
&inbound.channel,
|
||||||
&inbound.sender_id,
|
&inbound.sender_id,
|
||||||
@ -252,6 +256,10 @@ impl GatewayState {
|
|||||||
|
|
||||||
// Control: session management operations
|
// Control: session management operations
|
||||||
msg = bus.consume_control() => {
|
msg = bus.consume_control() => {
|
||||||
|
let Some(msg) = msg else {
|
||||||
|
tracing::warn!("Message processor stopping because control bus closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
Self::handle_control_message(&session_manager, msg).await;
|
Self::handle_control_message(&session_manager, msg).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,11 +30,11 @@ pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
|
|||||||
|
|
||||||
let next_utc = if let Some(tz_str) = tz {
|
let next_utc = if let Some(tz_str) = tz {
|
||||||
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
|
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
|
||||||
let _from_local = from_dt.with_timezone(&tz);
|
let from_local = from_dt.with_timezone(&tz);
|
||||||
let next_local = cron_schedule.upcoming(tz).next()?;
|
let next_local = cron_schedule.after(&from_local).next()?;
|
||||||
next_local.with_timezone(&Utc)
|
next_local.with_timezone(&Utc)
|
||||||
} else {
|
} else {
|
||||||
cron_schedule.upcoming(Utc).next()?
|
cron_schedule.after(&from_dt).next()?
|
||||||
};
|
};
|
||||||
|
|
||||||
Some(next_utc.timestamp_millis())
|
Some(next_utc.timestamp_millis())
|
||||||
@ -311,4 +311,37 @@ mod tests {
|
|||||||
let next_ms = next.unwrap();
|
let next_ms = next.unwrap();
|
||||||
assert!(next_ms > now);
|
assert!(next_ms > now);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_cron_uses_from_argument() {
|
||||||
|
let expr = "0 * * * * *".to_string();
|
||||||
|
let schedule = Schedule::Cron { expr, tz: None };
|
||||||
|
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:34:20Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
|
||||||
|
let next = next_run_for_schedule(&schedule, from).unwrap();
|
||||||
|
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:35:00Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
assert_eq!(next, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_cron_timezone_uses_from_argument() {
|
||||||
|
let expr = "0 0 9 * * *".to_string();
|
||||||
|
let schedule = Schedule::Cron {
|
||||||
|
expr,
|
||||||
|
tz: Some("Asia/Shanghai".to_string()),
|
||||||
|
};
|
||||||
|
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T00:30:00Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
|
||||||
|
let next = next_run_for_schedule(&schedule, from).unwrap();
|
||||||
|
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T01:00:00Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
assert_eq!(next, expected);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,6 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::io::AsyncReadExt;
|
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
@ -147,71 +146,55 @@ impl Tool for BashTool {
|
|||||||
.map(Path::new)
|
.map(Path::new)
|
||||||
.unwrap_or_else(|| Path::new("."));
|
.unwrap_or_else(|| Path::new("."));
|
||||||
|
|
||||||
let result = timeout(
|
match self.run_command(command, cwd, timeout_secs).await {
|
||||||
Duration::from_secs(timeout_secs),
|
Ok(output) => Ok(ToolResult {
|
||||||
self.run_command(command, cwd),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(Ok(output)) => Ok(ToolResult {
|
|
||||||
success: true,
|
success: true,
|
||||||
output,
|
output,
|
||||||
error: None,
|
error: None,
|
||||||
}),
|
}),
|
||||||
Ok(Err(e)) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(e),
|
error: Some(e),
|
||||||
}),
|
}),
|
||||||
Err(_) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BashTool {
|
impl BashTool {
|
||||||
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
|
async fn run_command(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
cwd: &Path,
|
||||||
|
timeout_secs: u64,
|
||||||
|
) -> Result<String, String> {
|
||||||
let mut cmd = Command::new("bash");
|
let mut cmd = Command::new("bash");
|
||||||
cmd.args(["-c", command])
|
cmd.args(["-c", command])
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.current_dir(cwd);
|
.current_dir(cwd)
|
||||||
|
.kill_on_drop(true);
|
||||||
|
|
||||||
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||||
|
|
||||||
let mut stdout = Vec::new();
|
let process_output =
|
||||||
let mut stderr = Vec::new();
|
match timeout(Duration::from_secs(timeout_secs), child.wait_with_output()).await {
|
||||||
|
Ok(Ok(output)) => output,
|
||||||
if let Some(ref mut out) = child.stdout {
|
Ok(Err(e)) => return Err(format!("Failed to wait: {}", e)),
|
||||||
out.read_to_end(&mut stdout)
|
Err(_) => {
|
||||||
.await
|
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
||||||
.map_err(|e| format!("Failed to read stdout: {}", e))?;
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
if let Some(ref mut err) = child.stderr {
|
|
||||||
err.read_to_end(&mut stderr)
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to read stderr: {}", e))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let status = child
|
|
||||||
.wait()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to wait: {}", e))?;
|
|
||||||
|
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
|
|
||||||
if !stdout.is_empty() {
|
if !process_output.stdout.is_empty() {
|
||||||
let stdout_str = String::from_utf8_lossy(&stdout);
|
let stdout_str = String::from_utf8_lossy(&process_output.stdout);
|
||||||
output.push_str(&stdout_str);
|
output.push_str(&stdout_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stderr.is_empty() {
|
if !process_output.stderr.is_empty() {
|
||||||
let stderr_str = String::from_utf8_lossy(&stderr);
|
let stderr_str = String::from_utf8_lossy(&process_output.stderr);
|
||||||
if !stderr_str.trim().is_empty() {
|
if !stderr_str.trim().is_empty() {
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
output.push('\n');
|
output.push('\n');
|
||||||
@ -221,7 +204,10 @@ impl BashTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
|
output.push_str(&format!(
|
||||||
|
"\nExit code: {}",
|
||||||
|
process_output.status.code().unwrap_or(-1)
|
||||||
|
));
|
||||||
|
|
||||||
Ok(self.truncate_output(&output))
|
Ok(self.truncate_output(&output))
|
||||||
}
|
}
|
||||||
@ -309,4 +295,19 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("timed out"));
|
assert!(result.error.unwrap().contains("timed out"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_large_stderr_does_not_deadlock() {
|
||||||
|
let tool = BashTool::new().with_timeout(5);
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"command": "for i in $(seq 1 2000); do echo noisy-error-line >&2; done; echo done"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("done"));
|
||||||
|
assert!(result.output.contains("STDERR"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,8 @@ use crate::tools::path_utils;
|
|||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
const MAX_CHARS: usize = 128_000;
|
const MAX_CHARS: usize = 128_000;
|
||||||
|
const MAX_FILE_BYTES: u64 = 5 * 1024 * 1024;
|
||||||
|
const MAX_BINARY_BYTES: usize = 512 * 1024;
|
||||||
const DEFAULT_LIMIT: usize = 2000;
|
const DEFAULT_LIMIT: usize = 2000;
|
||||||
|
|
||||||
pub struct FileReadTool {
|
pub struct FileReadTool {
|
||||||
@ -118,6 +120,29 @@ impl Tool for FileReadTool {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let metadata = match std::fs::metadata(&resolved) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to inspect file: {}", e)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if metadata.len() > MAX_FILE_BYTES {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"File too large to read safely: {} bytes (max {} bytes). Use a narrower tool or inspect a smaller excerpt.",
|
||||||
|
metadata.len(),
|
||||||
|
MAX_FILE_BYTES
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Read raw bytes and try multiple encodings
|
// Read raw bytes and try multiple encodings
|
||||||
let bytes = match std::fs::read(&resolved) {
|
let bytes = match std::fs::read(&resolved) {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
@ -209,6 +234,21 @@ impl Tool for FileReadTool {
|
|||||||
None => {
|
None => {
|
||||||
// Truly binary file — base64 encode
|
// Truly binary file — base64 encode
|
||||||
use base64::{Engine, engine::general_purpose::STANDARD};
|
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||||
|
if bytes.len() > MAX_BINARY_BYTES {
|
||||||
|
let mime = mime_guess::from_path(&resolved)
|
||||||
|
.first_or_octet_stream()
|
||||||
|
.to_string();
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Binary file too large to inline: {}, {} bytes (max {} bytes).",
|
||||||
|
mime,
|
||||||
|
bytes.len(),
|
||||||
|
MAX_BINARY_BYTES
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
let encoded = STANDARD.encode(&bytes);
|
let encoded = STANDARD.encode(&bytes);
|
||||||
let mime = mime_guess::from_path(&resolved)
|
let mime = mime_guess::from_path(&resolved)
|
||||||
.first_or_octet_stream()
|
.first_or_octet_stream()
|
||||||
@ -229,6 +269,10 @@ impl Tool for FileReadTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
||||||
|
if bytes.contains(&0) {
|
||||||
|
return (None, None);
|
||||||
|
}
|
||||||
|
|
||||||
// Try UTF-8 first
|
// Try UTF-8 first
|
||||||
if let Ok(text) = std::str::from_utf8(bytes) {
|
if let Ok(text) = std::str::from_utf8(bytes) {
|
||||||
return (Some(text.to_string()), None);
|
return (Some(text.to_string()), None);
|
||||||
@ -337,4 +381,37 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Not a file"));
|
assert!(result.error.unwrap().contains("Not a file"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rejects_large_file_before_reading() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
file.as_file_mut()
|
||||||
|
.set_len(MAX_FILE_BYTES + 1)
|
||||||
|
.expect("set large file length");
|
||||||
|
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": file.path().to_str().unwrap() }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("too large"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rejects_large_binary_inline() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
let bytes = vec![0_u8; MAX_BINARY_BYTES + 1];
|
||||||
|
file.write_all(&bytes).unwrap();
|
||||||
|
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": file.path().to_str().unwrap() }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("Binary file too large"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use std::time::Duration;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::header::HeaderMap;
|
use reqwest::header::HeaderMap;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tokio::net::lookup_host;
|
||||||
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
@ -56,6 +57,34 @@ impl HttpRequestTool {
|
|||||||
Ok(url.to_string())
|
Ok(url.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
|
||||||
|
if self.allow_private_hosts {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
if host.parse::<std::net::IpAddr>().is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let port = extract_port(url)?;
|
||||||
|
let addrs = lookup_host((host.as_str(), port))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
|
||||||
|
|
||||||
|
for addr in addrs {
|
||||||
|
let ip = addr.ip();
|
||||||
|
if is_private_ip(&ip) {
|
||||||
|
return Err(format!(
|
||||||
|
"Blocked host '{}' because DNS resolved to local/private IP {}",
|
||||||
|
host, ip
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
|
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
|
||||||
match method.to_uppercase().as_str() {
|
match method.to_uppercase().as_str() {
|
||||||
"GET" => Ok(reqwest::Method::GET),
|
"GET" => Ok(reqwest::Method::GET),
|
||||||
@ -180,6 +209,34 @@ fn extract_host(url: &str) -> Result<String, String> {
|
|||||||
Ok(host)
|
Ok(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn extract_port(url: &str) -> Result<u16, String> {
|
||||||
|
let scheme = if url.starts_with("https://") {
|
||||||
|
"https"
|
||||||
|
} else if url.starts_with("http://") {
|
||||||
|
"http"
|
||||||
|
} else {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if let Some((_, port)) = authority.rsplit_once(':') {
|
||||||
|
port.parse::<u16>()
|
||||||
|
.map_err(|_| format!("Invalid URL port: {}", port))
|
||||||
|
} else if scheme == "https" {
|
||||||
|
Ok(443)
|
||||||
|
} else {
|
||||||
|
Ok(80)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||||
if allowed_domains.iter().any(|domain| domain == "*") {
|
if allowed_domains.iter().any(|domain| domain == "*") {
|
||||||
return true;
|
return true;
|
||||||
@ -226,7 +283,13 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
|||||||
|| v4.is_broadcast()
|
|| v4.is_broadcast()
|
||||||
|| v4.is_multicast()
|
|| v4.is_multicast()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback()
|
||||||
|
|| v6.is_unspecified()
|
||||||
|
|| v6.is_multicast()
|
||||||
|
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|
||||||
|
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +357,14 @@ impl Tool for HttpRequestTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Err(e) = self.validate_resolved_host(&url).await {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let method = match self.validate_method(method_str) {
|
let method = match self.validate_method(method_str) {
|
||||||
Ok(m) => m,
|
Ok(m) => m,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -309,6 +380,7 @@ impl Tool for HttpRequestTool {
|
|||||||
|
|
||||||
let client = match reqwest::Client::builder()
|
let client = match reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(self.timeout_secs))
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
.build()
|
.build()
|
||||||
{
|
{
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
@ -436,4 +508,19 @@ mod tests {
|
|||||||
async fn test_blocks_local_tld() {
|
async fn test_blocks_local_tld() {
|
||||||
assert!(is_private_host("service.local"));
|
assert!(is_private_host("service.local"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_port_defaults_and_explicit() {
|
||||||
|
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
|
||||||
|
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
|
||||||
|
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_private_ipv6_ranges_are_blocked() {
|
||||||
|
assert!(is_private_ip(&"::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
|
||||||
|
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,6 +13,7 @@ pub mod get_skill;
|
|||||||
pub mod http_request;
|
pub mod http_request;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod path_utils;
|
pub mod path_utils;
|
||||||
|
pub mod pty;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod send_message;
|
pub mod send_message;
|
||||||
@ -32,6 +33,7 @@ pub use file_write::FileWriteTool;
|
|||||||
pub use get_skill::GetSkillTool;
|
pub use get_skill::GetSkillTool;
|
||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
||||||
|
pub use pty::{PtyManager, PtyTool};
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use send_message::SendMessageTool;
|
pub use send_message::SendMessageTool;
|
||||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||||
@ -60,6 +62,7 @@ pub fn create_default_tools(
|
|||||||
registry.register(FileSearchTool::new());
|
registry.register(FileSearchTool::new());
|
||||||
registry.register(ContentSearchTool::new());
|
registry.register(ContentSearchTool::new());
|
||||||
registry.register(BashTool::new());
|
registry.register(BashTool::new());
|
||||||
|
registry.register(PtyTool::new(Arc::new(PtyManager::new())));
|
||||||
registry.register(HttpRequestTool::new(
|
registry.register(HttpRequestTool::new(
|
||||||
vec!["*".to_string()],
|
vec!["*".to_string()],
|
||||||
1_000_000,
|
1_000_000,
|
||||||
|
|||||||
@ -160,11 +160,14 @@ impl PtyManager {
|
|||||||
};
|
};
|
||||||
for session in sessions {
|
for session in sessions {
|
||||||
let mut guard = session.lock().unwrap();
|
let mut guard = session.lock().unwrap();
|
||||||
let mut child_guard = guard.child.lock().unwrap();
|
let child_handle = guard.child.clone();
|
||||||
if let Some(ref mut child) = *child_guard {
|
{
|
||||||
let _ = child.kill();
|
let mut child_guard = child_handle.lock().unwrap();
|
||||||
|
if let Some(ref mut child) = *child_guard {
|
||||||
|
let _ = child.kill();
|
||||||
|
}
|
||||||
|
*child_guard = None;
|
||||||
}
|
}
|
||||||
*child_guard = None;
|
|
||||||
guard.status = SessionStatus::Killed;
|
guard.status = SessionStatus::Killed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,7 +277,7 @@ impl PtyManager {
|
|||||||
let session = sessions
|
let session = sessions
|
||||||
.get(session_id)
|
.get(session_id)
|
||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
let mut guard = session.lock().unwrap();
|
let guard = session.lock().unwrap();
|
||||||
if guard.status != SessionStatus::Running {
|
if guard.status != SessionStatus::Running {
|
||||||
return Err("Session is not running".to_string());
|
return Err("Session is not running".to_string());
|
||||||
}
|
}
|
||||||
@ -296,12 +299,7 @@ impl PtyManager {
|
|||||||
Ok(format!("OK, wrote {} bytes", byte_count))
|
Ok(format!("OK, wrote {} bytes", byte_count))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read(
|
fn read(&self, session_id: &str, offset: usize, limit: usize) -> Result<String, String> {
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
offset: usize,
|
|
||||||
limit: usize,
|
|
||||||
) -> Result<String, String> {
|
|
||||||
let sessions = self.sessions.lock().unwrap();
|
let sessions = self.sessions.lock().unwrap();
|
||||||
let session = sessions
|
let session = sessions
|
||||||
.get(session_id)
|
.get(session_id)
|
||||||
@ -352,14 +350,16 @@ impl PtyManager {
|
|||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
let mut guard = session.lock().unwrap();
|
let mut guard = session.lock().unwrap();
|
||||||
|
|
||||||
let mut child_guard = guard.child.lock().unwrap();
|
let child_handle = guard.child.clone();
|
||||||
if let Some(ref mut child) = *child_guard {
|
{
|
||||||
let _ = child.kill();
|
let mut child_guard = child_handle.lock().unwrap();
|
||||||
let _ = child.wait();
|
if let Some(ref mut child) = *child_guard {
|
||||||
|
let _ = child.kill();
|
||||||
|
let _ = child.wait();
|
||||||
|
}
|
||||||
|
*child_guard = None;
|
||||||
}
|
}
|
||||||
*child_guard = None;
|
|
||||||
guard.status = SessionStatus::Killed;
|
guard.status = SessionStatus::Killed;
|
||||||
drop(child_guard);
|
|
||||||
drop(guard);
|
drop(guard);
|
||||||
sessions.remove(session_id);
|
sessions.remove(session_id);
|
||||||
|
|
||||||
@ -545,14 +545,8 @@ impl Tool for PtyTool {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let offset = args
|
let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
|
||||||
.get("offset")
|
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(500) as usize;
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(500) as usize;
|
|
||||||
match self.pty_manager.read(session_id, offset, limit) {
|
match self.pty_manager.read(session_id, offset, limit) {
|
||||||
Ok(output) => Ok(ToolResult {
|
Ok(output) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use std::time::Duration;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::header::HeaderMap;
|
use reqwest::header::HeaderMap;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tokio::net::lookup_host;
|
||||||
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
@ -60,9 +61,34 @@ impl WebFetchTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
if host.parse::<std::net::IpAddr>().is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let port = extract_port(url)?;
|
||||||
|
let addrs = lookup_host((host.as_str(), port))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
|
||||||
|
|
||||||
|
for addr in addrs {
|
||||||
|
let ip = addr.ip();
|
||||||
|
if is_private_ip(&ip) {
|
||||||
|
return Err(format!(
|
||||||
|
"Blocked host '{}' because DNS resolved to local/private IP {}",
|
||||||
|
host, ip
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(self.timeout_secs))
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
@ -234,6 +260,35 @@ fn extract_host(url: &str) -> Result<String, String> {
|
|||||||
Ok(host)
|
Ok(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn extract_port(url: &str) -> Result<u16, String> {
|
||||||
|
let scheme = if url.starts_with("https://") {
|
||||||
|
"https"
|
||||||
|
} else if url.starts_with("http://") {
|
||||||
|
"http"
|
||||||
|
} else {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if let Some((_, port)) = authority.rsplit_once(':') {
|
||||||
|
port.parse::<u16>()
|
||||||
|
.map_err(|_| format!("Invalid URL port: {}", port))
|
||||||
|
} else if scheme == "https" {
|
||||||
|
Ok(443)
|
||||||
|
} else {
|
||||||
|
Ok(80)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn is_private_host(host: &str) -> bool {
|
fn is_private_host(host: &str) -> bool {
|
||||||
if host == "localhost" || host.ends_with(".localhost") {
|
if host == "localhost" || host.ends_with(".localhost") {
|
||||||
return true;
|
return true;
|
||||||
@ -248,19 +303,32 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
return match ip {
|
return is_private_ip(&ip);
|
||||||
std::net::IpAddr::V4(v4) => {
|
|
||||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
|
||||||
}
|
|
||||||
std::net::IpAddr::V6(v6) => {
|
|
||||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||||
|
match ip {
|
||||||
|
std::net::IpAddr::V4(v4) => {
|
||||||
|
v4.is_loopback()
|
||||||
|
|| v4.is_private()
|
||||||
|
|| v4.is_link_local()
|
||||||
|
|| v4.is_unspecified()
|
||||||
|
|| v4.is_broadcast()
|
||||||
|
|| v4.is_multicast()
|
||||||
|
}
|
||||||
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback()
|
||||||
|
|| v6.is_unspecified()
|
||||||
|
|| v6.is_multicast()
|
||||||
|
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|
||||||
|
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for WebFetchTool {
|
impl Tool for WebFetchTool {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
@ -311,6 +379,14 @@ impl Tool for WebFetchTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Err(e) = self.validate_resolved_host(&url).await {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
match self.fetch_content(&url).await {
|
match self.fetch_content(&url).await {
|
||||||
Ok(content) => Ok(ToolResult {
|
Ok(content) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
@ -357,6 +433,21 @@ mod tests {
|
|||||||
assert!(result.unwrap_err().contains("whitespace"));
|
assert!(result.unwrap_err().contains("whitespace"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_port_defaults_and_explicit() {
|
||||||
|
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
|
||||||
|
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
|
||||||
|
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_private_ipv6_ranges_are_blocked() {
|
||||||
|
assert!(is_private_ip(&"::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
|
||||||
|
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_extract_text_simple() {
|
async fn test_extract_text_simple() {
|
||||||
let tool = test_tool();
|
let tool = test_tool();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user