Compare commits

..

No commits in common. "fe2bc3dfd35ace3f35c9cec49f9d1297db216eec" and "48c8a51d9a47743fb027d1e123a0f013536f7fa1" have entirely different histories.

20 changed files with 402 additions and 1557 deletions

View File

@ -3,9 +3,7 @@
.gitignore .gitignore
# Build artifacts # Build artifacts
target/* target/
!target/release/
target/release/*
!target/release/picobot !target/release/picobot
# IDE # IDE

View File

@ -1,6 +1,6 @@
[package] [package]
name = "picobot" name = "picobot"
version = "1.1.2" version = "1.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]

View File

@ -55,8 +55,11 @@ RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
&& npm cache clean --force \ && npm cache clean --force \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install himalaya (CLI email client) from the official pre-built binary release # Install himalaya (CLI email client) from local file
RUN curl -sSL https://raw.githubusercontent.com/pimalaya/himalaya/master/install.sh | sh COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
&& chmod +x /usr/local/bin/himalaya \
&& rm -f /tmp/himalaya.tgz
# Install fd (alternative to find) # Install fd (alternative to find)
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \ RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \

View File

@ -18,9 +18,7 @@ PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id` - 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。 - 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 已修复Session 主处理路径不再在持有 session mutex 时执行 memory recall、上下文压缩、标题 LLM 生成、消息持久化、`/stop` sub-agent 取消或清历史存储操作;慢操作改为锁外执行并用 `state_version`/`worker_generation` 防止陈旧结果覆盖当前会话。 - 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
- 已修复Bash 超时清理、文件读取大文件限制、HTTP DNS 私网校验、Bus 关闭退出、Cron `from` 语义和 PTY 工具接入等中等级问题已完成清扫。
- 待处理:工具文件边界仍是后续质量风险。
## 主要发现 ## 主要发现
@ -110,7 +108,7 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。 - 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。 - shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 已修复Session 锁内执行过多异步操作 ### 中高优先级Session 锁内执行过多异步操作
位置: 位置:
@ -127,17 +125,13 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
- 当压缩或存储出现抖动时,用户感觉像“卡死”。 - 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。 - 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
已采取修复 建议
- 为 `Session` 增加 `state_version`,慢操作提交前检查会话是否已被 `/stop`、清历史或其它内存变更替换。 - 锁内只做内存状态快照和必要的状态标记。
- `/compact` 改为锁内取 history 快照,锁外压缩,锁内提交压缩结果,锁外持久化 meta。 - 将 memory recall、压缩、LLM 摘要放到锁外执行。
- agent worker Phase 1 改为锁内只创建用户消息、agent、cancel handle 和 history 快照memory recall 与 context compression 都在锁外执行。 - 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
- context overflow retry 的二次压缩移到锁外。
- 标题生成改为锁内取 prompt/provider 快照,锁外调用 LLM锁内应用标题锁外持久化。
- `add_message` 拆出内存更新和持久化快照,主消息路径在释放 session 锁后写入 SQLite。
- `/stop` 和清历史不再持有 session 锁等待 sub-agent 取消或 Storage 操作。
### 已修复Bash 超时不会显式终止子进程 ### 中优先级Bash 超时不会显式终止子进程
位置: 位置:
@ -152,14 +146,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。 长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
已采取修复 建议
- Bash 一次性命令改用 `wait_with_output()`,避免 stdout/stderr 顺序读取造成 pipe 阻塞 - 使用 `tokio::process::Child``kill_on_drop(true)`
- 子进程启用 `kill_on_drop(true)`,超时后丢弃等待 future 时会清理 child - 超时分支显式 kill child 并 wait
- 新增大 stderr 输出测试,覆盖不会因为 stderr pipe 填满而卡住 - 对 shell 子进程树使用进程组隔离,必要时杀整个进程组
- 持久/交互式进程通过已接入的 PTY 工具承载 - 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义
### 已修复:文件读取对大二进制文件没有输出上限 ### 中优先级:文件读取对大二进制文件没有输出上限
位置: 位置:
@ -174,14 +168,13 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。 读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
已采取修复 建议
- `file_read` 在读取前检查 metadata size超过安全阈值直接拒绝。 - 先检查 metadata size超过阈值直接返回提示。
- 二进制 inline base64 增加单独大小上限,超限只返回错误和文件信息。 - 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。
- 含 NUL 字节内容按二进制处理,避免全 0 文件被 UTF-8 路径误判为文本。 - 对文本读取也改成流式按行读取,而不是整文件读入。
- 增加大文件和大二进制文件测试。
### 已修复HTTP 私网防护只检查字面 host未做 DNS 解析校验 ### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置: 位置:
@ -195,14 +188,13 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
如果该工具暴露给非完全可信输入,存在 SSRF 风险。 如果该工具暴露给非完全可信输入,存在 SSRF 风险。
已采取修复 建议
- `http_request``web_fetch` 在发送请求前通过 DNS 解析 host并拒绝解析到 loopback、private、link-local、multicast、unspecified 的地址。 - 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。
- IPv6 unique-local 和 link-local 地址也纳入私网判定。 - 禁止或限制重定向,重定向后的每个 URL 重新校验。
- 禁用 reqwest 自动重定向,避免跳转到未校验的内网地址。 - 对 `http_request``web_fetch` 复用同一套 URL 安全策略。
- 增加端口解析和 IPv6 私网判断测试。
### 已修复:后台任务和主循环缺少监督与优雅关闭 ### 中优先级:后台任务和主循环缺少监督与优雅关闭
位置: 位置:
@ -220,14 +212,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。 - 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。 - bus channel 关闭时更像崩溃,而不是可恢复状态。
已采取修复 建议
- `MessageBus::consume_inbound/consume_outbound/consume_control` 不再在 channel 关闭时 `expect()` panic改为返回 `Option<T>` - 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。
- Gateway message processor 在 inbound/control bus 关闭时记录 warning 并退出 loop。 - 用 `CancellationToken` 贯穿 Gateway 子任务。
- OutboundDispatcher 在 outbound bus 关闭时记录 warning 并退出 loop。 - `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
- 这不是完整 runtime supervisor但已消除 bus 关闭导致的 panic 崩溃路径,为后续集中 JoinHandle 管理留出接口。
### 已修复Cron 计算函数没有按入参 `from` 计算 cron 下一次时间 ### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置: 位置:
@ -241,13 +232,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。 单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
已采取修复 建议
- cron 分支改`cron_schedule.after(&from_dt).next()` - 使`cron_schedule.after(&from_dt).next()` 或等价 API
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为计算起点。 - timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
- 增加 UTC 和 Asia/Shanghai 固定时间输入测试。 - 增加固定时间输入的单元测试,避免受系统时间影响
### 已修复:存在未接入或半接入代码,增加维护噪音 ### 中低优先级:存在未接入或半接入代码,增加维护噪音
位置: 位置:
@ -263,11 +254,10 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。 维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
已采取修复 建议
- `src/tools/pty.rs` 已接入 `tools/mod.rs`,导出 `PtyManager`/`PtyTool` - 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
- `create_default_tools()` 默认注册共享 `PtyManager``PtyTool` - 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。
- 修复 PTY 原本因未编译暴露不出的借用问题。
## 架构评价 ## 架构评价

View File

@ -1,496 +0,0 @@
# PicoBot 跨渠道交互式消息规划
规划日期2026-06-16
## 背景
飞书交互式卡片可以让用户直接在消息卡片上点击按钮、提交选择或触发回调。这个能力很适合用于工具调用审批、快捷回复、任务确认、表单收集等 agent 交互。
参考项目调研结果:
- `reference/zeroclaw` 已实现飞书/Lark 工具审批卡片:发送 Card JSON 2.0 按钮卡片,收到 `card.action.trigger` 后解析 `approval_id``decision`,唤醒等待中的 approval future并 PATCH 原卡片为已处理状态。
- `reference/nanobot` 主要使用飞书 CardKit 做 agent 输出展示和流式更新,适合参考消息渲染体验,但没有完整的按钮回调驱动 agent 流程。
- `reference/openlark` 是 SDK/API 封装,支持发送 interactive card 和 CardKit API不包含完整 agent channel 编排。
PicoBot 当前飞书渠道已经会把普通 markdown 回复发送成 interactive card但还缺少“用户在卡片上操作 -> 统一交互事件 -> Session/Agent/Tool 流程继续”的抽象。
## 目标
1. 支持飞书交互式卡片按钮回调。
2. 设计成跨渠道能力,后续 Slack、Telegram、Discord、CLI chat 等渠道可以复用同一套交互语义。
3. 支持渠道降级:不支持按钮的渠道也能用纯文本命令完成同样操作。
4. 保持 PicoBot 现有边界Channel 只做收发和渠道适配SessionManager 管会话AgentLoop 执行 LLM 和工具。
5. 为工具调用审批、快捷回复和未来表单交互预留扩展点。
非目标:
- 本阶段不立即实现完整功能。
- 不把飞书卡片细节泄漏到 AgentLoop 或工具层。
- 不要求所有渠道同时支持原生交互组件。
## 核心原则
交互语义和渠道渲染分离。
Agent、工具或 Session 层只表达“我要一个 approval/quick reply/form interaction”。具体是飞书卡片按钮、Slack Block Kit、Telegram inline keyboard还是 CLI 里显示编号选项,由 Channel 根据能力渲染。
回调也要统一。
飞书的 `card.action.trigger`、Telegram 的 `callback_query`、Slack 的 interaction payload 都应归一化成 PicoBot 内部的 `InteractionEvent`,再交给统一的处理器。
## 数据模型
建议新增一个 `interaction` 模块,定义渠道无关的数据结构。
```rust
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum InteractionKind {
QuickReply,
Approval,
FormSubmit,
Command,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum InteractionStyle {
Default,
Primary,
Danger,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InteractionAction {
pub id: String,
pub label: String,
pub value: String,
pub style: InteractionStyle,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InteractionPayload {
pub interaction_id: String,
pub kind: InteractionKind,
pub title: Option<String>,
pub body: String,
pub actions: Vec<InteractionAction>,
pub expires_at: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct InteractionEvent {
pub channel: String,
pub chat_id: String,
pub sender_id: String,
pub interaction_id: String,
pub action_id: String,
pub action_value: String,
pub timestamp: i64,
pub metadata: std::collections::HashMap<String, String>,
}
```
`InteractionPayload` 用于 outbound 渲染,`InteractionEvent` 用于 inbound 回调。
## OutboundMessage 扩展
短期兼容方案:
- 继续使用 `OutboundMessage.metadata` 携带交互描述。
- 例如:
- `interaction.kind = "approval"`
- `interaction.id = "<uuid>"`
- `interaction.actions = "<json>"`
长期推荐方案:
```rust
pub struct OutboundMessage {
pub channel: String,
pub chat_id: String,
pub content: String,
pub reply_to: Option<String>,
pub media: Vec<MediaItem>,
pub metadata: HashMap<String, String>,
pub interaction: Option<InteractionPayload>,
}
```
推荐长期方案。它能避免把结构化交互塞进字符串 metadata也让每个 Channel 的 `send()` 更清晰。
## Channel 能力声明
`Channel` 增加可选能力声明:
```rust
#[derive(Debug, Clone, Default)]
pub struct ChannelCapabilities {
pub interactive_buttons: bool,
pub forms: bool,
pub message_update: bool,
pub markdown_cards: bool,
}
pub trait Channel {
fn capabilities(&self) -> ChannelCapabilities {
ChannelCapabilities::default()
}
}
```
渠道能力示例:
| 渠道 | 原生按钮 | 表单 | 更新原消息 | 降级策略 |
|------|----------|------|------------|----------|
| Feishu | 是interactive card | 可后续支持 | 是PATCH message/card | 文本命令 |
| Slack | 是Block Kit | 是 | 是 | 文本命令 |
| Telegram | 是inline keyboard | 有限 | 可编辑消息 | 文本命令 |
| Discord | 是components | 有限 | 可编辑消息 | 文本命令 |
| CLI chat | 否 | 否 | 局部可模拟 | 编号/命令输入 |
| Webhook/Email/SMS | 否 | 否 | 通常否 | 纯文本命令或链接 |
## 渲染策略
每个 channel 实现一个渠道内的渲染函数:
```rust
async fn send_interaction(
&self,
chat_id: &str,
payload: &InteractionPayload,
) -> Result<(), ChannelError>;
```
也可以先不改 trait`send()` 内部判断 `msg.interaction`
飞书渲染:
- 使用 Card JSON 2.0。
- `schema = "2.0"`
- body 用 markdown 展示 `payload.body`
- actions 渲染为 button。
- 每个按钮的 callback value 写入:
```json
{
"interaction_id": "...",
"action_id": "...",
"action_value": "approve"
}
```
需要兼容飞书 Card 2.0 回调路径:
- `/action/value`
- `/action/behaviors/0/value`
CLI 降级渲染:
```text
需要确认:
Tool: bash
Args: cargo test --lib
可选操作:
1. Approve
2. Deny
3. Always approve
回复:
/_interaction <interaction_id> approve
/_interaction <interaction_id> deny
/_interaction <interaction_id> always
```
纯文本渠道都可以复用这个 fallback renderer。
## Inbound 回调归一化
飞书 WebSocket 当前在 `src/channels/feishu.rs` 里处理 `im.message.receive_v1`。需要新增对 `card.action.trigger` 的识别:
1. ACK 仍要尽快发送,飞书要求 3 秒内响应。
2. 如果 event type 是 `card.action.trigger`,不要走普通消息解析。
3. 从 event payload 中解析 `interaction_id``action_id``action_value`
4. 构造 `InteractionEvent` 发布给统一处理器。
5. 对未知、过期或重复 interaction 返回成功但记录日志,不应导致渠道重连或报错。
如果短期不新增 interaction bus可以把回调转成特殊 `InboundMessage`
```text
content = "/_interaction <interaction_id> <action_value>"
metadata["event.kind"] = "interaction"
metadata["interaction.id"] = "<interaction_id>"
metadata["interaction.action_id"] = "<action_id>"
metadata["interaction.action_value"] = "<action_value>"
```
但必须由 SessionManager 或 InteractionManager 先拦截,不能把 `/_interaction` 当普通用户文本直接送进 LLM。
长期推荐新增 bus 通道:
```rust
pub enum InboundEvent {
Message(InboundMessage),
Interaction(InteractionEvent),
}
```
或者在 `MessageBus` 上增加 `interaction_tx`
## InteractionManager
建议新增 `InteractionManager`,集中管理 pending 交互状态,而不是让每个 Channel 各自维护。
职责:
- 生成 `interaction_id`
- 保存 pending interaction。
- 处理超时和过期。
- 接收 `InteractionEvent` 并解析成业务结果。
- 对重复点击、未知 interaction、过期 interaction 做幂等处理。
- 必要时通知 channel 更新原消息。
内部状态示例:
```rust
pub struct PendingInteraction {
pub id: String,
pub kind: InteractionKind,
pub channel: String,
pub chat_id: String,
pub sender_id: Option<String>,
pub session_id: Option<String>,
pub created_at: i64,
pub expires_at: Option<i64>,
pub status: InteractionStatus,
pub responder: InteractionResponder,
pub message_ref: Option<InteractionMessageRef>,
}
pub struct InteractionMessageRef {
pub channel: String,
pub chat_id: String,
pub message_id: String,
pub metadata: HashMap<String, String>,
}
```
`InteractionResponder` 可以先支持 oneshot
```rust
pub enum InteractionResponder {
Approval(tokio::sync::oneshot::Sender<ApprovalDecision>),
InboundMessage,
}
```
后续如果需要持久化长期交互oneshot 不够,需要落库。
## 工具审批流程
工具审批是第一批最适合落地的交互类型。
推荐流程:
1. AgentLoop 准备执行需要审批的工具。
2. 调用 `InteractionManager::request_approval(...)`
3. InteractionManager 创建 `InteractionPayload`,通过 outbound 发送到原 channel/chat。
4. AgentLoop 等待 oneshot带 timeout。
5. 用户在飞书卡片上点击 Approve/Deny/Always。
6. FeishuChannel 收到 `card.action.trigger`,发布 `InteractionEvent`
7. InteractionManager resolve pending approval。
8. AgentLoop 收到结果,继续执行或拒绝工具。
9. 如果 channel 支持更新消息InteractionManager 或 Channel 把原卡片更新成 resolved 状态。
审批 action 建议:
```rust
approve -> ApprovalDecision::Approve
deny -> ApprovalDecision::Deny
always -> ApprovalDecision::AlwaysApprove
```
`DenyWithEdit` 可后续支持,适合 ACP/Web/CLI 这类能输入文本的渠道。
## 快捷回复流程
快捷回复不是阻塞工具执行,而是把用户点击转成新的用户输入。
示例:
```json
{
"kind": "QuickReply",
"body": "你想继续哪个操作?",
"actions": [
{ "label": "继续分析", "value": "继续分析" },
{ "label": "生成报告", "value": "生成报告" }
]
}
```
用户点击后:
- `InteractionEvent.action_value` 转成一条普通 `InboundMessage.content`
- `sender_id``chat_id` 保留原用户和会话。
- metadata 标记来源为 interaction供审计或 UI 使用。
## 消息更新
支持原消息更新的渠道应在交互完成后更新 UI避免重复点击。
飞书:
- 发送卡片后保存 `data.message_id`
- resolve 后 PATCH `/im/v1/messages/{message_id}`
- 卡片 schema 发送和更新都使用 Card JSON 2.0,参考项目指出跨版本 PATCH 可能返回成功但客户端不重渲染。
不支持更新的渠道:
- 发送一条新消息提示“已批准/已拒绝”。
- 或仅在后台幂等拒绝重复点击。
## 安全和权限
交互回调必须校验:
- `interaction_id` 是否存在。
- 是否已过期。
- 是否已处理。
- 点击用户是否允许处理该 interaction。
- 当前 channel/chat 是否匹配。
对于工具审批,默认建议只有触发该 agent turn 的用户或允许列表用户可以审批。群聊里要特别注意 `sender_id`,不能只看 `chat_id`
日志中避免记录原始飞书回调敏感字段:
- callback token
- operator open_id/union_id/user_id/tenant_key
- open_chat_id/open_message_id
可以记录脱敏后的 payload shape用于排查飞书回调字段变化。
## 持久化策略
第一阶段可以只做内存 pending map
- 适合短时工具审批。
- 进程重启后旧按钮点击会变成 unknown/expired。
- 实现简单。
后续如果要支持长期任务或跨重启交互,需要持久化:
- `interactions` 表保存 id、kind、channel、chat_id、sender_id、status、payload、created_at、expires_at。
- `interaction_actions` 可选,或直接 JSON 存在 payload 中。
- resolve 时事务更新 status防止重复点击竞态。
## 与现有架构的关系
现有数据流:
```text
Channel -> MessageBus -> SessionManager -> AgentLoop -> tools -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
```
加入 interaction 后建议:
```text
Outbound:
AgentLoop/Tool approval -> InteractionManager -> MessageBus outbound -> OutboundDispatcher -> Channel renderer
Inbound:
Channel callback -> InteractionEvent -> InteractionManager -> pending waiter / synthetic InboundMessage
```
Channel 仍然只做渠道协议适配:
- 飞书负责 Card JSON 和 `card.action.trigger`
- Slack 负责 Block Kit 和 signing secret。
- Telegram 负责 callback query。
- CLI 负责文本命令 fallback。
InteractionManager 负责语义:
- 这是 approval 还是 quick reply。
- 是否过期。
- 是否有权限。
- 应该唤醒哪个等待者。
SessionManager/AgentLoop 不需要知道飞书卡片格式。
## 分阶段实施计划
### 阶段 1模型和 fallback
- 新增 `src/interaction/` 模块。
- 定义 `InteractionPayload``InteractionAction``InteractionEvent`
- 增加 fallback text renderer。
- 为 `OutboundMessage` 增加 `interaction: Option<InteractionPayload>`,或短期使用 metadata。
- 增加单元测试覆盖序列化和 fallback 文本。
### 阶段 2Feishu card action
- Feishu outbound 支持把 `InteractionPayload` 渲染为 Card JSON 2.0。
- Feishu inbound 在 WebSocket frame 中识别 `card.action.trigger`
- 解析 `/action/value``/action/behaviors/0/value`
- 发布统一 `InteractionEvent`
- 保存 `message_id`,支持完成后 PATCH resolved card。
- 增加 fixtures 测试真实/模拟的 `card.action.trigger` payload。
### 阶段 3InteractionManager 和审批
- 新增内存版 `InteractionManager`
- 支持 request/resolve/timeout。
- 接入工具执行前审批点。
- 支持 Approve/Deny/AlwaysApprove。
- 未知、过期、重复点击保持幂等。
- 增加 agent/tool 审批单元测试。
### 阶段 4其他渠道兼容
- CLI chat 支持 `/_interaction <id> <value>` fallback。
- 其他不支持原生按钮的渠道使用纯文本 fallback。
- 后续按需实现 Slack/Telegram/Discord 原生按钮。
### 阶段 5持久化和高级交互
- 需要时落库 pending interaction。
- 支持 quick reply 生成 synthetic inbound message。
- 支持表单提交。
- 支持 `DenyWithEdit`
- 支持长期任务交互和重启恢复。
## 测试计划
单元测试:
- `InteractionPayload` 序列化。
- fallback text renderer 输出。
- Feishu Card JSON 包含正确 callback value。
- Feishu 回调同时支持 `/action/value``/action/behaviors/0/value`
- unknown/expired interaction 不报错。
- 重复点击只 resolve 一次。
集成测试:
- 模拟 Feishu `card.action.trigger`,验证 pending approval 被唤醒。
- 模拟超时,验证默认 deny。
- CLI fallback 输入 `/_interaction`,验证能 resolve。
- 不支持按钮的 channel 能收到可读 fallback 文本。
手工验证:
- 飞书群聊点击 Approve/Deny/Always。
- 私聊点击。
- 点击后原卡片更新为 resolved。
- 重复点击不会重复执行工具。
- 非触发用户点击时按权限策略处理。
## 开放问题
1. 工具审批应该由 AgentLoop 直接调用 InteractionManager还是通过 SessionManager 代理?
2. 群聊中是否只允许原始触发者审批,还是允许配置中的所有 allowed user 审批?
3. `AlwaysApprove` 的作用域是本次会话、本 dialog、本 chat还是全局工具策略
4. 是否需要第一阶段就修改 `OutboundMessage` 结构,还是先用 metadata 降低改动面?
5. 飞书 CardKit 流式输出是否要与交互卡片统一,还是继续保持普通回复卡片和交互卡片两套路径?

View File

@ -28,10 +28,6 @@ fn build_content_blocks(
) -> Vec<ContentBlock> { ) -> Vec<ContentBlock> {
let mut blocks = Vec::new(); let mut blocks = Vec::new();
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
if !media_refs.is_empty() { if !media_refs.is_empty() {
for mr in media_refs { for mr in media_refs {
if input_types.contains(&mr.media_type) { if input_types.contains(&mr.media_type) {
@ -63,6 +59,8 @@ fn build_content_blocks(
))); )));
} }
} }
} else if !text.is_empty() {
blocks.push(ContentBlock::text(text));
} }
if blocks.is_empty() { if blocks.is_empty() {
@ -860,23 +858,6 @@ mod tests {
"calculator" "calculator"
); );
} }
#[test]
fn test_build_content_blocks_keeps_text_with_media() {
let registry = MediaHandlerRegistry::new();
let blocks = build_content_blocks(
"先看这段文字",
&[MediaRef {
path: "missing.png".to_string(),
media_type: "image".to_string(),
}],
&[],
&registry,
);
assert!(matches!(blocks.first(), Some(ContentBlock::Text { text }) if text == "先看这段文字"));
assert!(matches!(blocks.get(1), Some(ContentBlock::Text { text }) if text.contains("用户发来了一个文件")));
}
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -24,10 +24,7 @@ impl OutboundDispatcher {
tracing::info!("OutboundDispatcher started"); tracing::info!("OutboundDispatcher started");
loop { loop {
let Some(msg) = self.bus.consume_outbound().await else { let msg = self.bus.consume_outbound().await;
tracing::warn!("OutboundDispatcher stopping because outbound bus closed");
break;
};
let channel_name = msg.channel.clone(); let channel_name = msg.channel.clone();
let channel = self.channel_manager.get_channel(&channel_name).await; let channel = self.channel_manager.get_channel(&channel_name).await;

View File

@ -51,11 +51,17 @@ impl MessageBus {
} }
/// Consume an inbound message (Agent -> Bus) /// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> Option<InboundMessage> { pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self.inbound_rx.lock().await.recv().await?; let msg = self
.inbound_rx
.lock()
.await
.recv()
.await
.expect("bus inbound closed");
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message"); tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
Some(msg) msg
} }
/// Publish an outbound message (Agent -> Bus) /// Publish an outbound message (Agent -> Bus)
@ -69,8 +75,13 @@ impl MessageBus {
} }
/// Consume an outbound message (Dispatcher -> Bus) /// Consume an outbound message (Dispatcher -> Bus)
pub async fn consume_outbound(&self) -> Option<OutboundMessage> { pub async fn consume_outbound(&self) -> OutboundMessage {
self.outbound_rx.lock().await.recv().await self.outbound_rx
.lock()
.await
.recv()
.await
.expect("bus outbound closed")
} }
/// Publish a control message (Channel -> Bus for session management) /// Publish a control message (Channel -> Bus for session management)
@ -83,8 +94,13 @@ impl MessageBus {
} }
/// Consume a control message (ControlProcessor -> Bus) /// Consume a control message (ControlProcessor -> Bus)
pub async fn consume_control(&self) -> Option<ControlMessage> { pub async fn consume_control(&self) -> ControlMessage {
self.control_rx.lock().await.recv().await self.control_rx
.lock()
.await
.recv()
.await
.expect("bus control closed")
} }
} }

View File

@ -165,7 +165,7 @@ struct ParsedMessage {
open_id: String, open_id: String,
chat_id: String, chat_id: String,
content: String, content: String,
media: Vec<MediaItem>, media: Option<MediaItem>,
/// ID of the message this message is replying to (if any). /// ID of the message this message is replying to (if any).
/// Used to fetch quoted message content for display. /// Used to fetch quoted message content for display.
parent_id: Option<String>, parent_id: Option<String>,
@ -1007,7 +1007,7 @@ impl FeishuChannel {
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
for m in &media { if let Some(ref m) = media {
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully"); tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
} }
@ -1027,7 +1027,7 @@ impl FeishuChannel {
msg_type: &str, msg_type: &str,
content: &str, content: &str,
message_id: &str, message_id: &str,
) -> Result<(String, Vec<MediaItem>), ChannelError> { ) -> Result<(String, Option<MediaItem>), ChannelError> {
let (text, media) = match msg_type { let (text, media) = match msg_type {
"text" => { "text" => {
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) { let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
@ -1039,40 +1039,20 @@ impl FeishuChannel {
} else { } else {
content.to_string() content.to_string()
}; };
(text, Vec::new()) (text, None)
}
"post" => {
let text = parse_post_content(content);
let mut media = Vec::new();
for image_key in collect_post_image_keys(content) {
let content_json = serde_json::json!({ "image_key": image_key });
match self
.download_media("image", &content_json, message_id)
.await
{
Ok((_text, Some(item))) => media.push(item),
Ok((_text, None)) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to download image from Feishu post message");
}
}
}
(text, media)
} }
"post" => (parse_post_content(content), None),
"image" | "audio" | "file" | "media" => { "image" | "audio" | "file" | "media" => {
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) { if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
match self match self
.download_media(msg_type, &content_json, message_id) .download_media(msg_type, &content_json, message_id)
.await .await
{ {
Ok((text, Some(media))) => (text, vec![media]), Ok((text, media)) => (text, media),
Ok((text, None)) => (text, Vec::new()), Err(_) => (format!("[{}: content unavailable]", msg_type), None),
Err(_) => (format!("[{}: content unavailable]", msg_type), Vec::new()),
} }
} else { } else {
(format!("[{}: content unavailable]", msg_type), Vec::new()) (format!("[{}: content unavailable]", msg_type), None)
} }
} }
"share_chat" => { "share_chat" => {
@ -1082,9 +1062,9 @@ impl FeishuChannel {
.get("chat_id") .get("chat_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("unknown"); .unwrap_or("unknown");
(format!("[shared chat: {}]", chat_id), Vec::new()) (format!("[shared chat: {}]", chat_id), None)
} else { } else {
("[shared chat]".to_string(), Vec::new()) ("[shared chat]".to_string(), None)
} }
} }
"share_user" => { "share_user" => {
@ -1094,44 +1074,42 @@ impl FeishuChannel {
.get("user_id") .get("user_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("unknown"); .unwrap_or("unknown");
(format!("[shared user: {}]", user_id), Vec::new()) (format!("[shared user: {}]", user_id), None)
} else { } else {
("[shared user]".to_string(), Vec::new()) ("[shared user]".to_string(), None)
} }
} }
"interactive" => { "interactive" => {
// Interactive card messages - extract text content // Interactive card messages - extract text content
match extract_interactive_content(content) { match extract_interactive_content(content) {
Ok((text, Some(media))) => (text, vec![media]), Ok((text, media)) => (text, media),
Ok((text, None)) => (text, Vec::new()),
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "Failed to extract interactive content"); tracing::warn!(error = %e, "Failed to extract interactive content");
(content.to_string(), Vec::new()) (content.to_string(), None)
} }
} }
} }
"list" => { "list" => {
// List/bullet messages // List/bullet messages
match parse_list_content(content) { match parse_list_content(content) {
Ok((text, Some(media))) => (text, vec![media]), Ok((text, media)) => (text, media),
Ok((text, None)) => (text, Vec::new()), Err(_) => (content.to_string(), None),
Err(_) => (content.to_string(), Vec::new()),
} }
} }
"merge_forward" => ("[merged forward messages]".to_string(), Vec::new()), "merge_forward" => ("[merged forward messages]".to_string(), None),
"share_calendar_event" => { "share_calendar_event" => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) { if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
let event_key = parsed let event_key = parsed
.get("event_key") .get("event_key")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("unknown"); .unwrap_or("unknown");
(format!("[shared calendar event: {}]", event_key), Vec::new()) (format!("[shared calendar event: {}]", event_key), None)
} else { } else {
("[shared calendar event]".to_string(), Vec::new()) ("[shared calendar event]".to_string(), None)
} }
} }
"system" => ("[system message]".to_string(), Vec::new()), "system" => ("[system message]".to_string(), None),
_ => (content.to_string(), Vec::new()), _ => (content.to_string(), None),
}; };
// Strip @_user_N placeholders from group chat @mentions // Strip @_user_N placeholders from group chat @mentions
@ -1257,15 +1235,16 @@ impl FeishuChannel {
let channel = self.clone(); let channel = self.clone();
let bus = bus.clone(); let bus = bus.clone();
tokio::spawn(async move { tokio::spawn(async move {
let media_count = if parsed.media.is_some() { 1 } else { 0 };
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %parsed.media.len(), "Publishing message to bus"); tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
let msg = crate::bus::InboundMessage { let msg = crate::bus::InboundMessage {
channel: "feishu".to_string(), channel: "feishu".to_string(),
sender_id: parsed.open_id.clone(), sender_id: parsed.open_id.clone(),
chat_id: parsed.chat_id.clone(), chat_id: parsed.chat_id.clone(),
content: parsed.content.clone(), content: parsed.content.clone(),
timestamp: crate::bus::message::current_timestamp(), timestamp: crate::bus::message::current_timestamp(),
media: parsed.media.clone(), media: parsed.media.map(|m| vec![m]).unwrap_or_default(),
metadata: std::collections::HashMap::new(), metadata: std::collections::HashMap::new(),
forwarded_metadata, forwarded_metadata,
}; };
@ -1354,52 +1333,6 @@ impl FeishuChannel {
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn collect_post_image_keys_finds_nested_images() {
let content = serde_json::json!({
"zh_cn": {
"title": "",
"content": [[
{"tag": "img", "image_key": "img_v3_001"},
{"tag": "text", "text": "这是哪里?"},
{"tag": "img", "image_key": "img_v3_002"},
{"tag": "img", "image_key": "img_v3_001"}
]]
}
})
.to_string();
assert_eq!(
collect_post_image_keys(&content),
vec!["img_v3_001".to_string(), "img_v3_002".to_string()]
);
}
#[test]
fn parse_post_content_preserves_image_positions() {
let content = serde_json::json!({
"zh_cn": {
"title": "",
"content": [[
{"tag": "text", "text": "这是一张图:"},
{"tag": "img", "image_key": "img_v3_001"},
{"tag": "text", "text": "看完继续说"}
]]
}
})
.to_string();
assert_eq!(
parse_post_content(&content),
"这是一张图:[image]看完继续说"
);
}
}
fn parse_post_content(content: &str) -> String { fn parse_post_content(content: &str) -> String {
/// Extract text from a single post element (text, link, at-mention). /// Extract text from a single post element (text, link, at-mention).
fn extract_element(el: &serde_json::Value, out: &mut Vec<String>) { fn extract_element(el: &serde_json::Value, out: &mut Vec<String>) {
@ -1426,9 +1359,6 @@ fn parse_post_content(content: &str) -> String {
.unwrap_or("user"); .unwrap_or("user");
out.push(format!("@{}", name)); out.push(format!("@{}", name));
} }
"img" => {
out.push("[image]".to_string());
}
"code_block" => { "code_block" => {
let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or(""); let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or("");
let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or(""); let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or("");
@ -1519,38 +1449,6 @@ fn parse_post_content(content: &str) -> String {
content.to_string() content.to_string()
} }
fn collect_post_image_keys(content: &str) -> Vec<String> {
fn visit(value: &serde_json::Value, keys: &mut Vec<String>) {
match value {
serde_json::Value::Object(map) => {
if let Some(image_key) = map.get("image_key").and_then(|v| v.as_str())
&& !keys.iter().any(|k| k == image_key)
{
keys.push(image_key.to_string());
}
for child in map.values() {
visit(child, keys);
}
}
serde_json::Value::Array(items) => {
for item in items {
visit(item, keys);
}
}
_ => {}
}
}
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) else {
return Vec::new();
};
let mut keys = Vec::new();
visit(&parsed, &mut keys);
keys
}
/// Extract text content from interactive card messages /// Extract text content from interactive card messages
fn extract_interactive_content(content: &str) -> Result<(String, Option<MediaItem>), ChannelError> { fn extract_interactive_content(content: &str) -> Result<(String, Option<MediaItem>), ChannelError> {
let parsed = match serde_json::from_str::<serde_json::Value>(content) { let parsed = match serde_json::from_str::<serde_json::Value>(content) {

View File

@ -205,10 +205,6 @@ impl GatewayState {
tokio::select! { tokio::select! {
// Inbound: AI message flow // Inbound: AI message flow
inbound = bus.consume_inbound() => { inbound = bus.consume_inbound() => {
let Some(inbound) = inbound else {
tracing::warn!("Message processor stopping because inbound bus closed");
break;
};
match session_manager.handle_message( match session_manager.handle_message(
&inbound.channel, &inbound.channel,
&inbound.sender_id, &inbound.sender_id,
@ -256,10 +252,6 @@ impl GatewayState {
// Control: session management operations // Control: session management operations
msg = bus.consume_control() => { msg = bus.consume_control() => {
let Some(msg) = msg else {
tracing::warn!("Message processor stopping because control bus closed");
break;
};
Self::handle_control_message(&session_manager, msg).await; Self::handle_control_message(&session_manager, msg).await;
} }
} }

View File

@ -3,7 +3,7 @@ use clap::{CommandFactory, Parser};
#[derive(Parser)] #[derive(Parser)]
#[command(name = "picobot")] #[command(name = "picobot")]
#[command(about = "A CLI chatbot", long_about = None)] #[command(about = "A CLI chatbot", long_about = None)]
#[command(version = "1.1.1")] #[command(version = "1.1.0")]
enum Command { enum Command {
/// Connect to gateway /// Connect to gateway
Chat { Chat {

View File

@ -150,20 +150,13 @@ struct OpenAIChoice {
message: OpenAIMessage, message: OpenAIMessage,
} }
fn null_or_missing_tool_calls<'de, D>(deserializer: D) -> Result<Vec<OpenAIToolCall>, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Option::<Vec<OpenAIToolCall>>::deserialize(deserializer)?.unwrap_or_default())
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct OpenAIMessage { struct OpenAIMessage {
#[serde(default)] #[serde(default)]
content: Option<String>, content: Option<String>,
#[serde(default)] #[serde(default)]
reasoning_content: Option<String>, reasoning_content: Option<String>,
#[serde(default, deserialize_with = "null_or_missing_tool_calls")] #[serde(default)]
tool_calls: Vec<OpenAIToolCall>, tool_calls: Vec<OpenAIToolCall>,
} }
@ -425,42 +418,4 @@ mod tests {
"{\"expression\":\"1+1\"}" "{\"expression\":\"1+1\"}"
); );
} }
#[test]
fn test_decode_response_accepts_null_tool_calls() {
let text = r#"{
"id": "d21abaa6552741949e2aba76bde59359",
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "你好!",
"role": "assistant",
"tool_calls": null,
"reasoning_content": "The user sent a greeting."
}
}],
"created": 1781622889,
"model": "mimo-v2.5",
"object": "chat.completion",
"usage": {
"completion_tokens": 65,
"prompt_tokens": 11741,
"total_tokens": 11806,
"completion_tokens_details": {"reasoning_tokens": 40},
"prompt_tokens_details": {}
}
}"#;
let response: OpenAIResponse = serde_json::from_str(text).unwrap();
let message = &response.choices[0].message;
assert_eq!(message.content.as_deref(), Some("你好!"));
assert_eq!(
message.reasoning_content.as_deref(),
Some("The user sent a greeting.")
);
assert!(message.tool_calls.is_empty());
assert_eq!(response.usage.total_tokens, 11806);
}
} }

View File

@ -30,11 +30,11 @@ pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
let next_utc = if let Some(tz_str) = tz { let next_utc = if let Some(tz_str) = tz {
let tz: chrono_tz::Tz = tz_str.parse().ok()?; let tz: chrono_tz::Tz = tz_str.parse().ok()?;
let from_local = from_dt.with_timezone(&tz); let _from_local = from_dt.with_timezone(&tz);
let next_local = cron_schedule.after(&from_local).next()?; let next_local = cron_schedule.upcoming(tz).next()?;
next_local.with_timezone(&Utc) next_local.with_timezone(&Utc)
} else { } else {
cron_schedule.after(&from_dt).next()? cron_schedule.upcoming(Utc).next()?
}; };
Some(next_utc.timestamp_millis()) Some(next_utc.timestamp_millis())
@ -311,37 +311,4 @@ mod tests {
let next_ms = next.unwrap(); let next_ms = next.unwrap();
assert!(next_ms > now); assert!(next_ms > now);
} }
#[test]
fn test_next_run_cron_uses_from_argument() {
let expr = "0 * * * * *".to_string();
let schedule = Schedule::Cron { expr, tz: None };
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:34:20Z")
.unwrap()
.timestamp_millis();
let next = next_run_for_schedule(&schedule, from).unwrap();
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:35:00Z")
.unwrap()
.timestamp_millis();
assert_eq!(next, expected);
}
#[test]
fn test_next_run_cron_timezone_uses_from_argument() {
let expr = "0 0 9 * * *".to_string();
let schedule = Schedule::Cron {
expr,
tz: Some("Asia/Shanghai".to_string()),
};
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T00:30:00Z")
.unwrap()
.timestamp_millis();
let next = next_run_for_schedule(&schedule, from).unwrap();
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T01:00:00Z")
.unwrap()
.timestamp_millis();
assert_eq!(next, expected);
}
} }

View File

@ -8,13 +8,6 @@ use crate::mcp::get_mcp_status;
use crate::storage::{Storage, StorageError}; use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc; use std::sync::Arc as StdArc;
type MessagePersistSnapshot = (
StdArc<Storage>,
String,
crate::storage::message::MessageMeta,
crate::storage::session::SessionMeta,
);
tokio::task_local! { tokio::task_local! {
static CURRENT_SOURCE_SESSION: Option<String>; static CURRENT_SOURCE_SESSION: Option<String>;
} }
@ -89,14 +82,6 @@ pub struct Session {
current_cancel: Option<oneshot::Sender<()>>, current_cancel: Option<oneshot::Sender<()>>,
/// Monotonic counter to detect stale workers /// Monotonic counter to detect stale workers
worker_generation: u64, worker_generation: u64,
/// Monotonic counter for in-memory session mutations.
///
/// Slow work such as memory recall, compression, and title generation runs
/// outside the session lock. Workers capture this version before starting
/// that work and verify it before committing results, so stale snapshots do
/// not overwrite a session that was changed by a command such as /clear or
/// /delete while the slow work was in flight.
state_version: u64,
} }
/// A task to be processed by the per-session agent worker /// A task to be processed by the per-session agent worker
@ -161,7 +146,6 @@ impl Session {
agent_tx: None, agent_tx: None,
current_cancel: None, current_cancel: None,
worker_generation: 0, worker_generation: 0,
state_version: 0,
}) })
} }
@ -338,7 +322,6 @@ impl Session {
agent_tx: None, agent_tx: None,
current_cancel: None, current_cancel: None,
worker_generation: 0, worker_generation: 0,
state_version: 0,
}) })
} }
@ -354,15 +337,6 @@ impl Session {
message: ChatMessage, message: ChatMessage,
persist: bool, persist: bool,
) -> Result<(), StorageError> { ) -> Result<(), StorageError> {
let snapshot = self.add_message_in_memory(message, persist);
persist_added_message(snapshot).await
}
fn add_message_in_memory(
&mut self,
message: ChatMessage,
persist: bool,
) -> Option<MessagePersistSnapshot> {
let is_user = message.role == "user"; let is_user = message.role == "user";
let now = chrono::Utc::now().timestamp_millis(); let now = chrono::Utc::now().timestamp_millis();
@ -370,37 +344,36 @@ impl Session {
let seq = self.seq_counter; let seq = self.seq_counter;
self.seq_counter += 1; self.seq_counter += 1;
let persist_snapshot = if persist { // Persist to Storage
self.storage.clone().map(|storage| { if persist && let Some(ref storage) = self.storage {
let msg_meta = crate::storage::message::MessageMeta { let msg_meta = crate::storage::message::MessageMeta {
id: message.id.clone(), id: message.id.clone(),
session_id: self.id.to_string(), session_id: self.id.to_string(),
seq, seq,
role: message.role.clone(), role: message.role.clone(),
content: message.content.clone(), content: message.content.clone(),
reasoning_content: message.reasoning_content.clone(), reasoning_content: message.reasoning_content.clone(),
media_refs: if message.media_refs.is_empty() { media_refs: if message.media_refs.is_empty() {
None None
} else { } else {
Some(serde_json::to_string(&message.media_refs).unwrap_or_default()) Some(serde_json::to_string(&message.media_refs).unwrap_or_default())
}, },
tool_call_id: message.tool_call_id.clone(), tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(), tool_name: message.tool_name.clone(),
tool_calls: message tool_calls: message
.tool_calls .tool_calls
.as_ref() .as_ref()
.and_then(|tc| serde_json::to_string(tc).ok()), .and_then(|tc| serde_json::to_string(tc).ok()),
source: message source: message
.source .source
.as_ref() .as_ref()
.map(|s| serde_json::to_string(s).unwrap_or_default()), .map(|s| serde_json::to_string(s).unwrap_or_default()),
created_at: now, created_at: now,
}; };
(storage, self.id.to_string(), msg_meta) storage
}) .append_message_with_retry(&self.id.to_string(), &msg_meta)
} else { .await?;
None }
};
// Update in-memory state // Update in-memory state
self.messages.push(message); self.messages.push(message);
@ -409,30 +382,16 @@ impl Session {
self.message_count += 1; self.message_count += 1;
} }
self.last_active_at = now; self.last_active_at = now;
self.state_version = self.state_version.wrapping_add(1);
persist_snapshot.map(|(storage, session_id, msg_meta)| { // Sync message_count to Storage
let session_meta = crate::storage::session::SessionMeta { if persist {
id: session_id.clone(), tracing::debug!(session_id = %self.id, last_active_at = %now, message_count = %self.message_count, "Persisting session meta after add_message");
channel: self.id.channel.clone(), if let Err(e) = self.persist_session_meta().await {
chat_id: self.id.chat_id.clone(), tracing::warn!("failed to persist session meta: {}", e);
dialog_id: self.id.dialog_id.clone(), }
title: self.title.clone(), }
created_at: self.created_at,
last_active_at: self.last_active_at, Ok(())
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
archived_at: self.archived_at,
deleted_at: None,
last_consolidated_at: self.last_consolidated_at,
last_compressed_message_at: self.last_compressed_message_at,
};
(storage, session_id, msg_meta, session_meta)
})
} }
/// 获取消息历史 /// 获取消息历史
@ -447,7 +406,6 @@ impl Session {
self.seq_counter = 1; self.seq_counter = 1;
self.total_message_count = 0; self.total_message_count = 0;
self.message_count = 0; self.message_count = 0;
self.state_version = self.state_version.wrapping_add(1);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared"); tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
} }
@ -459,7 +417,6 @@ impl Session {
self.seq_counter = 1; self.seq_counter = 1;
self.total_message_count = 0; self.total_message_count = 0;
self.message_count = 0; self.message_count = 0;
self.state_version = self.state_version.wrapping_add(1);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory"); tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
} }
@ -487,49 +444,43 @@ impl Session {
/// 将 session 元数据写回 Storage /// 将 session 元数据写回 Storage
pub async fn persist_session_meta(&self) -> Result<(), StorageError> { pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
if let Some((storage, meta)) = self.session_meta_snapshot() { if let Some(ref storage) = self.storage {
let meta = crate::storage::session::SessionMeta {
id: self.id.to_string(),
channel: self.id.channel.clone(),
chat_id: self.id.chat_id.clone(),
dialog_id: self.id.dialog_id.clone(),
title: self.title.clone(),
created_at: self.created_at,
last_active_at: self.last_active_at,
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
archived_at: self.archived_at,
deleted_at: None,
last_consolidated_at: self.last_consolidated_at,
last_compressed_message_at: self.last_compressed_message_at,
};
storage.upsert_session(&meta).await?; storage.upsert_session(&meta).await?;
} }
Ok(()) Ok(())
} }
fn session_meta_snapshot(
&self,
) -> Option<(StdArc<Storage>, crate::storage::session::SessionMeta)> {
let storage = self.storage.clone()?;
let meta = crate::storage::session::SessionMeta {
id: self.id.to_string(),
channel: self.id.channel.clone(),
chat_id: self.id.chat_id.clone(),
dialog_id: self.id.dialog_id.clone(),
title: self.title.clone(),
created_at: self.created_at,
last_active_at: self.last_active_at,
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
archived_at: self.archived_at,
deleted_at: None,
last_consolidated_at: self.last_consolidated_at,
last_compressed_message_at: self.last_compressed_message_at,
};
Some((storage, meta))
}
/// 检查是否需要自动生成 title5 条用户消息后) /// 检查是否需要自动生成 title5 条用户消息后)
pub fn should_generate_title(&self) -> bool { pub fn should_generate_title(&self) -> bool {
self.title == "新对话" && self.message_count >= 5 self.title == "新对话" && self.message_count >= 5
} }
fn title_prompt_snapshot(&self) -> Option<String> { /// 生成标题(调用 LLM
pub async fn generate_title(&mut self) -> Result<(), AgentError> {
if !self.should_generate_title() { if !self.should_generate_title() {
return None; return Ok(());
} }
Some(format!( let prompt = format!(
r#"给定以下对话历史生成一个简短的会话标题5-15 个中文字符),概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。 r#"给定以下对话历史生成一个简短的会话标题5-15 个中文字符),概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
@ -541,41 +492,38 @@ impl Session {
.map(|m| format!("[{}]: {}", m.role, m.content)) .map(|m| format!("[{}]: {}", m.role, m.content))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n") .join("\n")
)) );
}
fn apply_generated_title(&mut self, title: String) -> bool { let title = self.call_llm_for_title(&prompt).await?;
if title.is_empty() || !self.should_generate_title() {
return false; if !title.is_empty() {
self.title = title.clone();
if let Err(e) = self.persist_session_meta().await {
tracing::warn!("failed to persist title: {}", e);
}
} }
self.title = title; Ok(())
self.state_version = self.state_version.wrapping_add(1);
true
} }
fn fresh_context_compressor(&self) -> ContextCompressor { /// 调用 LLM 生成标题
let compressor_config = ContextCompressionConfig { async fn call_llm_for_title(&self, prompt: &str) -> Result<String, AgentError> {
protect_first_n: 2, use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
..Default::default()
let request = ChatCompletionRequest {
messages: vec![Message::user(prompt.to_string())],
temperature: Some(0.3),
max_tokens: Some(20),
tools: None,
}; };
let mut compressor = ContextCompressor::with_config(
self.provider.clone(),
self.provider_config.token_limit,
compressor_config,
self.memory_manager.clone(),
);
compressor.set_session_id(Some(self.id.to_string()));
compressor
}
fn replace_history_in_memory(&mut self, messages: Vec<ChatMessage>) { let response: ChatCompletionResponse = self
self.messages = messages; .provider
self.seq_counter = self.messages.len() as i64 + 1; .chat(request)
self.total_message_count = self.messages.len() as i64; .await
self.message_count = self.messages.iter().filter(|m| m.role == "user").count() as i64; .map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
self.last_active_at = chrono::Utc::now().timestamp_millis();
self.state_version = self.state_version.wrapping_add(1); Ok(response.content.trim().to_string())
} }
/// 获取 provider_config 引用 /// 获取 provider_config 引用
@ -1127,39 +1075,24 @@ impl SessionManager {
"compact" => { "compact" => {
if let Some(sid) = current_session_id { if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?; let session = self.get_or_create_session(sid).await?;
let (original_count, history, mut compressor, base_version) = { let mut session_guard = session.lock().await;
let session_guard = session.lock().await; let original_count = session_guard.get_history().len();
( let history = session_guard.get_history().to_vec();
session_guard.get_history().len(), let result = session_guard.compressor.compress_if_needed(history).await?;
session_guard.get_history().to_vec(),
session_guard.fresh_context_compressor(),
session_guard.state_version,
)
};
let result = compressor.compress_if_needed(history).await?;
let compressed_count = result.history.len(); let compressed_count = result.history.len();
let meta_snapshot = { if result.created_timelines {
let mut session_guard = session.lock().await; session_guard.last_compressed_message_at =
if session_guard.state_version != base_version { Some(chrono::Utc::now().timestamp_millis());
return Ok(( if let Err(e) = session_guard.persist_session_meta().await {
None, tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
"Context changed while compacting; please run /compact again."
.to_string(),
));
} }
if result.created_timelines { }
session_guard.last_compressed_message_at = session_guard.clear_history();
Some(chrono::Utc::now().timestamp_millis()); for msg in result.history {
} session_guard
session_guard.replace_history_in_memory(result.history); .add_message(msg, false)
session_guard.session_meta_snapshot() .await
}; .map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
if let Some((storage, meta)) = meta_snapshot
&& let Err(e) = storage.upsert_session(&meta).await
{
tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
} }
Ok(( Ok((
None, None,
@ -1371,22 +1304,16 @@ impl SessionManager {
let sid = current_session_id let sid = current_session_id
.ok_or_else(|| AgentError::Other("no active session".to_string()))?; .ok_or_else(|| AgentError::Other("no active session".to_string()))?;
let session = self.get_or_create_session(sid).await?; let session = self.get_or_create_session(sid).await?;
let msgs = { let mut guard = session.lock().await;
let mut guard = session.lock().await; let mut msgs: Vec<String> = Vec::new();
let mut msgs: Vec<String> = Vec::new(); if guard.current_cancel.take().is_some() {
if guard.current_cancel.take().is_some() { msgs.push("当前任务已发送停止信号。".to_string());
msgs.push("当前任务已发送停止信号。".to_string()); }
} if guard.agent_tx.take().is_some() {
if guard.agent_tx.take().is_some() { msgs.push("消息队列已清空。".to_string());
msgs.push("消息队列已清空。".to_string()); }
} guard.worker_generation = guard.worker_generation.wrapping_add(1);
guard.worker_generation = guard.worker_generation.wrapping_add(1);
guard.state_version = guard.state_version.wrapping_add(1);
msgs
};
// Cancel all running background sub-agent tasks for this session // Cancel all running background sub-agent tasks for this session
// after releasing the session lock.
self.sub_agent_manager self.sub_agent_manager
.cancel_by_session(&sid.to_string()) .cancel_by_session(&sid.to_string())
.await; .await;
@ -1743,7 +1670,7 @@ impl SessionManager {
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
let unified_id = self.resolve_dialog_id(channel, chat_id).await?; let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
let session = self.get_or_create_session(&unified_id).await?; let session = self.get_or_create_session(&unified_id).await?;
let persist_snapshot = { {
let mut guard = session.lock().await; let mut guard = session.lock().await;
let source = MessageSource { let source = MessageSource {
kind: SourceKind::SystemNotification, kind: SourceKind::SystemNotification,
@ -1754,11 +1681,11 @@ impl SessionManager {
task_id: task_id.map(|s| s.to_string()), task_id: task_id.map(|s| s.to_string()),
}; };
let msg = ChatMessage::assistant_with_source(content, source); let msg = ChatMessage::assistant_with_source(content, source);
guard.add_message_in_memory(msg, true) guard
}; .add_message(msg, true)
persist_added_message(persist_snapshot) .await
.await .map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?; }
let outbound = OutboundMessage { let outbound = OutboundMessage {
channel: channel.to_string(), channel: channel.to_string(),
@ -1878,63 +1805,6 @@ impl SessionManager {
} }
} }
async fn maybe_generate_title_outside_lock(session: Arc<Mutex<Session>>) -> Result<(), AgentError> {
use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
let (provider, prompt) = {
let guard = session.lock().await;
let Some(prompt) = guard.title_prompt_snapshot() else {
return Ok(());
};
(guard.provider.clone(), prompt)
};
let request = ChatCompletionRequest {
messages: vec![Message::user(prompt)],
temperature: Some(0.3),
max_tokens: Some(20),
tools: None,
};
let response: ChatCompletionResponse = provider
.chat(request)
.await
.map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
let title = response.content.trim().to_string();
let meta_snapshot = {
let mut guard = session.lock().await;
if guard.apply_generated_title(title) {
guard.session_meta_snapshot()
} else {
None
}
};
if let Some((storage, meta)) = meta_snapshot {
storage
.upsert_session(&meta)
.await
.map_err(|e| AgentError::Other(format!("failed to persist title: {}", e)))?;
}
Ok(())
}
async fn persist_added_message(
snapshot: Option<MessagePersistSnapshot>,
) -> Result<(), StorageError> {
let Some((storage, session_id, msg_meta, session_meta)) = snapshot else {
return Ok(());
};
storage
.append_message_with_retry(&session_id, &msg_meta)
.await?;
storage.upsert_session(&session_meta).await?;
Ok(())
}
fn spawn_agent_worker( fn spawn_agent_worker(
mut task_rx: mpsc::UnboundedReceiver<AgentTask>, mut task_rx: mpsc::UnboundedReceiver<AgentTask>,
session: Arc<Mutex<Session>>, session: Arc<Mutex<Session>>,
@ -1975,12 +1845,8 @@ fn spawn_agent_worker(
}); });
} }
// Phase 1: capture a stable session snapshot under lock. // Phase 1: prepare data under session lock
// Memory recall and compression happen outside this block so let (agent, history_out, system_prompt_out, cancel_rx) = {
// /stop and other commands are not blocked behind slow I/O or
// LLM-backed compaction.
let skills_prompt = skills_loader.build_skills_prompt();
let (agent, history_raw, mut compressor, base_version, cancel_rx) = {
let mut guard = session.lock().await; let mut guard = session.lock().await;
if guard.worker_generation != worker_gen { if guard.worker_generation != worker_gen {
@ -1991,9 +1857,7 @@ fn spawn_agent_worker(
task.media.iter().map(|m| m.to_media_ref()).collect(); task.media.iter().map(|m| m.to_media_ref()).collect();
let user_message = let user_message =
guard.create_user_message(&task.content, media_refs); guard.create_user_message(&task.content, media_refs);
let user_persist = guard.add_message_in_memory(user_message, true); if let Err(e) = guard.add_message(user_message, true).await {
drop(guard);
if let Err(e) = persist_added_message(user_persist).await {
tracing::error!(error = %e, "Failed to persist user message"); tracing::error!(error = %e, "Failed to persist user message");
let err_outbound = OutboundMessage { let err_outbound = OutboundMessage {
channel: task_chan.clone(), channel: task_chan.clone(),
@ -2007,12 +1871,61 @@ fn spawn_agent_worker(
let _ = bus.publish_outbound(err_outbound).await; let _ = bus.publish_outbound(err_outbound).await;
return; return;
} }
let mut guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
let history_raw = guard.get_history().to_vec(); let history_raw = guard.get_history().to_vec();
let skills_prompt = skills_loader.build_skills_prompt();
let memory_context = match memory_manager
.recall(
&task.content,
5,
Some(crate::memory::MemoryCategory::Knowledge),
None,
)
.await
{
Ok(entries) if !entries.is_empty() => Some(
entries
.iter()
.map(|e| format!("- {}: {}", e.key, e.content))
.collect::<Vec<_>>()
.join("\n"),
),
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
};
let system_prompt = guard
.build_system_prompt(&skills_prompt, memory_context.as_deref());
let result = guard
.compressor
.compress_if_needed(history_raw)
.await
.map(|r| {
if r.created_timelines {
guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
}
r.history
})
.unwrap_or_else(|e| {
tracing::warn!(
error = %e,
"Context compression failed in worker"
);
guard.get_history().to_vec()
});
let mut history = result;
history.insert(0, ChatMessage::system(system_prompt.clone()));
let now = chrono::Utc::now().timestamp_millis();
guard.last_consolidated_at = Some(now);
let _ = guard.persist_session_meta().await;
let agent = match guard.create_agent_with_notify(notify_tx) { let agent = match guard.create_agent_with_notify(notify_tx) {
Ok(a) => a, Ok(a) => a,
@ -2039,87 +1952,9 @@ fn spawn_agent_worker(
} }
guard.current_cancel = Some(cancel_tx); guard.current_cancel = Some(cancel_tx);
( (agent, history, system_prompt, cancel_rx)
agent,
history_raw,
guard.fresh_context_compressor(),
guard.state_version,
cancel_rx,
)
}; // lock released }; // lock released
let memory_context = match memory_manager
.recall(
&task.content,
5,
Some(crate::memory::MemoryCategory::Knowledge),
None,
)
.await
{
Ok(entries) if !entries.is_empty() => Some(
entries
.iter()
.map(|e| format!("- {}: {}", e.key, e.content))
.collect::<Vec<_>>()
.join("\n"),
),
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
};
let system_prompt_out = {
let guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
guard.build_system_prompt(&skills_prompt, memory_context.as_deref())
};
let compression_result = compressor.compress_if_needed(history_raw).await;
let mut history_out = match compression_result {
Ok(result) => {
let meta_snapshot = {
let mut guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
if guard.state_version != base_version {
tracing::warn!(
session_id = %guard.id,
"Session changed while preparing agent history; dropping stale task"
);
return;
}
if result.created_timelines {
guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
}
guard.last_consolidated_at =
Some(chrono::Utc::now().timestamp_millis());
guard.session_meta_snapshot()
};
if let Some((storage, meta)) = meta_snapshot
&& let Err(e) = storage.upsert_session(&meta).await
{
tracing::warn!(error = %e, "Failed to persist session meta after compression");
}
result.history
}
Err(e) => {
tracing::warn!(error = %e, "Context compression failed in worker");
let guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
guard.get_history().to_vec()
}
};
history_out.insert(0, ChatMessage::system(system_prompt_out.clone()));
// Phase 2 + 3: LLM call with cancellation // Phase 2 + 3: LLM call with cancellation
let session2 = session.clone(); let session2 = session.clone();
let bus2 = bus.clone(); let bus2 = bus.clone();
@ -2140,8 +1975,8 @@ fn spawn_agent_worker(
Err(AgentError::LlmError(ref msg)) Err(AgentError::LlmError(ref msg))
if is_context_overflow_error(msg) => if is_context_overflow_error(msg) =>
{ {
let (raw, mut retry_compressor, retry_base_version, new_window) = { let retry_history = {
let guard = session2.lock().await; let mut guard = session2.lock().await;
let new_window = let new_window =
crate::agent::ContextCompressor::parse_context_limit_from_error(msg) crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
.unwrap_or(guard.compressor_threshold()); .unwrap_or(guard.compressor_threshold());
@ -2150,56 +1985,31 @@ fn spawn_agent_worker(
error = %msg, error = %msg,
"Context overflow in worker — retrying" "Context overflow in worker — retrying"
); );
(
guard.get_history().to_vec(),
guard.fresh_context_compressor(),
guard.state_version,
new_window,
)
};
retry_compressor.set_context_window(new_window);
let retry_result =
match retry_compressor.compress_if_needed(raw).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Retry compression failed");
let err_outbound = OutboundMessage {
channel: chan2,
chat_id: cid2,
content: "Context overflow handling failed."
.to_string(),
reply_to: None,
media: vec![],
metadata: HashMap::new(),
};
let _ = bus2.publish_outbound(err_outbound).await;
return;
}
};
let meta_snapshot = {
let mut guard = session2.lock().await;
if guard.state_version != retry_base_version {
tracing::warn!(
session_id = %guard.id,
"Session changed while retry-compressing after context overflow"
);
return;
}
guard.compressor.set_context_window(new_window); guard.compressor.set_context_window(new_window);
let raw = guard.get_history().to_vec();
let retry_result =
match guard.compressor.compress_if_needed(raw).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Retry compression failed");
let err_outbound = OutboundMessage {
channel: chan2,
chat_id: cid2,
content: "Context overflow handling failed."
.to_string(),
reply_to: None,
media: vec![],
metadata: HashMap::new(),
};
let _ = bus2.publish_outbound(err_outbound).await;
return;
}
};
if retry_result.created_timelines { if retry_result.created_timelines {
guard.last_compressed_message_at = guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis()); Some(chrono::Utc::now().timestamp_millis());
let _ = guard.persist_session_meta().await;
} }
guard.session_meta_snapshot()
};
if let Some((storage, meta)) = meta_snapshot
&& let Err(e) = storage.upsert_session(&meta).await
{
tracing::warn!(error = %e, "Failed to persist session meta after retry compression");
}
let retry_history = {
let mut retry = retry_result.history; let mut retry = retry_result.history;
retry.insert( retry.insert(
0, 0,
@ -2245,24 +2055,20 @@ fn spawn_agent_worker(
let response = { let response = {
let mut guard = session2.lock().await; let mut guard = session2.lock().await;
let mut persist_snapshots = Vec::new();
for msg in result.emitted_messages { for msg in result.emitted_messages {
persist_snapshots.push(guard.add_message_in_memory(msg, true)); guard.add_message(msg, true).await.inspect_err(|e| {
tracing::error!(error = %e, "Failed to persist message")
}).ok();
}
if guard.should_generate_title()
&& let Err(e) = guard.generate_title().await
{
tracing::warn!("failed to generate title: {}", e);
} }
let sent_count = guard.messages.len(); let sent_count = guard.messages.len();
guard.compressor.set_last_api_info(sent_count, result.total_tokens); guard.compressor.set_last_api_info(sent_count, result.total_tokens);
(result.final_response.content, persist_snapshots) result.final_response.content
}; };
let (response, persist_snapshots) = response;
for snapshot in persist_snapshots {
if let Err(e) = persist_added_message(snapshot).await {
tracing::error!(error = %e, "Failed to persist message");
}
}
if let Err(e) = maybe_generate_title_outside_lock(session2.clone()).await {
tracing::warn!("failed to generate title: {}", e);
}
let outbound = OutboundMessage { let outbound = OutboundMessage {
channel: chan2, channel: chan2,
@ -2352,34 +2158,25 @@ impl SessionManager {
unified_id: &UnifiedSessionId, unified_id: &UnifiedSessionId,
) -> Result<(), AgentError> { ) -> Result<(), AgentError> {
let session = self.get_or_create_session(unified_id).await?; let session = self.get_or_create_session(unified_id).await?;
let (storage, session_id, meta_snapshot) = { let mut session_guard = session.lock().await;
let mut session_guard = session.lock().await; // Clear in-memory
// Clear in-memory session_guard.messages.clear();
session_guard.messages.clear(); session_guard.seq_counter = 1;
session_guard.seq_counter = 1; session_guard.total_message_count = 0;
session_guard.total_message_count = 0; session_guard.message_count = 0;
session_guard.message_count = 0; session_guard.last_consolidated_at = None;
session_guard.last_consolidated_at = None; session_guard.last_compressed_message_at = None;
session_guard.last_compressed_message_at = None; // Clear Storage
session_guard.state_version = session_guard.state_version.wrapping_add(1); if let Some(ref storage) = session_guard.storage {
(
session_guard.storage.clone(),
session_guard.id.to_string(),
session_guard.session_meta_snapshot(),
)
};
// Clear Storage outside the session lock.
if let Some(storage) = storage {
storage storage
.clear_messages(&session_id) .clear_messages(&session_guard.id.to_string())
.await .await
.map_err(|e| AgentError::Other(format!("failed to clear messages: {}", e)))?; .map_err(|e| AgentError::Other(format!("failed to clear messages: {}", e)))?;
} }
if let Some((storage, meta)) = meta_snapshot { session_guard
storage.upsert_session(&meta).await.map_err(|e| { .persist_session_meta()
AgentError::Other(format!("failed to persist cleared session: {}", e)) .await
})?; .map_err(|e| AgentError::Other(format!("failed to persist cleared session: {}", e)))?;
}
Ok(()) Ok(())
} }
} }
@ -2438,14 +2235,14 @@ impl OutboundMessenger for SessionManager {
}; };
// Write source-tagged assistant message to target session history // Write source-tagged assistant message to target session history
let persist_snapshot = { {
let mut guard = session.lock().await; let mut guard = session.lock().await;
let msg = ChatMessage::assistant_with_source(marked_content.clone(), source); let msg = ChatMessage::assistant_with_source(marked_content.clone(), source);
guard.add_message_in_memory(msg, true) guard
}; .add_message(msg, true)
persist_added_message(persist_snapshot) .await
.await .map_err(|e| e.to_string())?;
.map_err(|e| e.to_string())?; }
// Restore active dialog if source and target share channel:chat_id but differ in dialog_id // Restore active dialog if source and target share channel:chat_id but differ in dialog_id
if let Some(ref origin_id) = origin_id { if let Some(ref origin_id) = origin_id {

View File

@ -4,6 +4,7 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use tokio::io::AsyncReadExt;
use tokio::process::Command; use tokio::process::Command;
use tokio::time::timeout; use tokio::time::timeout;
@ -146,55 +147,71 @@ impl Tool for BashTool {
.map(Path::new) .map(Path::new)
.unwrap_or_else(|| Path::new(".")); .unwrap_or_else(|| Path::new("."));
match self.run_command(command, cwd, timeout_secs).await { let result = timeout(
Ok(output) => Ok(ToolResult { Duration::from_secs(timeout_secs),
self.run_command(command, cwd),
)
.await;
match result {
Ok(Ok(output)) => Ok(ToolResult {
success: true, success: true,
output, output,
error: None, error: None,
}), }),
Err(e) => Ok(ToolResult { Ok(Err(e)) => Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(e), error: Some(e),
}), }),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
}),
} }
} }
} }
impl BashTool { impl BashTool {
async fn run_command( async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
&self,
command: &str,
cwd: &Path,
timeout_secs: u64,
) -> Result<String, String> {
let mut cmd = Command::new("bash"); let mut cmd = Command::new("bash");
cmd.args(["-c", command]) cmd.args(["-c", command])
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.current_dir(cwd) .current_dir(cwd);
.kill_on_drop(true);
let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let process_output = let mut stdout = Vec::new();
match timeout(Duration::from_secs(timeout_secs), child.wait_with_output()).await { let mut stderr = Vec::new();
Ok(Ok(output)) => output,
Ok(Err(e)) => return Err(format!("Failed to wait: {}", e)), if let Some(ref mut out) = child.stdout {
Err(_) => { out.read_to_end(&mut stdout)
return Err(format!("Command timed out after {} seconds", timeout_secs)); .await
} .map_err(|e| format!("Failed to read stdout: {}", e))?;
}; }
if let Some(ref mut err) = child.stderr {
err.read_to_end(&mut stderr)
.await
.map_err(|e| format!("Failed to read stderr: {}", e))?;
}
let status = child
.wait()
.await
.map_err(|e| format!("Failed to wait: {}", e))?;
let mut output = String::new(); let mut output = String::new();
if !process_output.stdout.is_empty() { if !stdout.is_empty() {
let stdout_str = String::from_utf8_lossy(&process_output.stdout); let stdout_str = String::from_utf8_lossy(&stdout);
output.push_str(&stdout_str); output.push_str(&stdout_str);
} }
if !process_output.stderr.is_empty() { if !stderr.is_empty() {
let stderr_str = String::from_utf8_lossy(&process_output.stderr); let stderr_str = String::from_utf8_lossy(&stderr);
if !stderr_str.trim().is_empty() { if !stderr_str.trim().is_empty() {
if !output.is_empty() { if !output.is_empty() {
output.push('\n'); output.push('\n');
@ -204,10 +221,7 @@ impl BashTool {
} }
} }
output.push_str(&format!( output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
"\nExit code: {}",
process_output.status.code().unwrap_or(-1)
));
Ok(self.truncate_output(&output)) Ok(self.truncate_output(&output))
} }
@ -295,19 +309,4 @@ mod tests {
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("timed out")); assert!(result.error.unwrap().contains("timed out"));
} }
#[tokio::test]
async fn test_large_stderr_does_not_deadlock() {
let tool = BashTool::new().with_timeout(5);
let result = tool
.execute(json!({
"command": "for i in $(seq 1 2000); do echo noisy-error-line >&2; done; echo done"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("done"));
assert!(result.output.contains("STDERR"));
}
} }

View File

@ -6,8 +6,6 @@ use crate::tools::path_utils;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
const MAX_CHARS: usize = 128_000; const MAX_CHARS: usize = 128_000;
const MAX_FILE_BYTES: u64 = 5 * 1024 * 1024;
const MAX_BINARY_BYTES: usize = 512 * 1024;
const DEFAULT_LIMIT: usize = 2000; const DEFAULT_LIMIT: usize = 2000;
pub struct FileReadTool { pub struct FileReadTool {
@ -120,29 +118,6 @@ impl Tool for FileReadTool {
}); });
} }
let metadata = match std::fs::metadata(&resolved) {
Ok(m) => m,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to inspect file: {}", e)),
});
}
};
if metadata.len() > MAX_FILE_BYTES {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"File too large to read safely: {} bytes (max {} bytes). Use a narrower tool or inspect a smaller excerpt.",
metadata.len(),
MAX_FILE_BYTES
)),
});
}
// Read raw bytes and try multiple encodings // Read raw bytes and try multiple encodings
let bytes = match std::fs::read(&resolved) { let bytes = match std::fs::read(&resolved) {
Ok(b) => b, Ok(b) => b,
@ -234,21 +209,6 @@ impl Tool for FileReadTool {
None => { None => {
// Truly binary file — base64 encode // Truly binary file — base64 encode
use base64::{Engine, engine::general_purpose::STANDARD}; use base64::{Engine, engine::general_purpose::STANDARD};
if bytes.len() > MAX_BINARY_BYTES {
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
.to_string();
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Binary file too large to inline: {}, {} bytes (max {} bytes).",
mime,
bytes.len(),
MAX_BINARY_BYTES
)),
});
}
let encoded = STANDARD.encode(&bytes); let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved) let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream() .first_or_octet_stream()
@ -269,10 +229,6 @@ impl Tool for FileReadTool {
} }
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) { fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
if bytes.contains(&0) {
return (None, None);
}
// Try UTF-8 first // Try UTF-8 first
if let Ok(text) = std::str::from_utf8(bytes) { if let Ok(text) = std::str::from_utf8(bytes) {
return (Some(text.to_string()), None); return (Some(text.to_string()), None);
@ -381,37 +337,4 @@ mod tests {
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file")); assert!(result.error.unwrap().contains("Not a file"));
} }
#[tokio::test]
async fn test_rejects_large_file_before_reading() {
let mut file = NamedTempFile::new().unwrap();
file.as_file_mut()
.set_len(MAX_FILE_BYTES + 1)
.expect("set large file length");
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("too large"));
}
#[tokio::test]
async fn test_rejects_large_binary_inline() {
let mut file = NamedTempFile::new().unwrap();
let bytes = vec![0_u8; MAX_BINARY_BYTES + 1];
file.write_all(&bytes).unwrap();
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Binary file too large"));
}
} }

View File

@ -3,7 +3,6 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use serde_json::json; use serde_json::json;
use tokio::net::lookup_host;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
@ -57,34 +56,6 @@ impl HttpRequestTool {
Ok(url.to_string()) Ok(url.to_string())
} }
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
if self.allow_private_hosts {
return Ok(());
}
let host = extract_host(url)?;
if host.parse::<std::net::IpAddr>().is_ok() {
return Ok(());
}
let port = extract_port(url)?;
let addrs = lookup_host((host.as_str(), port))
.await
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
for addr in addrs {
let ip = addr.ip();
if is_private_ip(&ip) {
return Err(format!(
"Blocked host '{}' because DNS resolved to local/private IP {}",
host, ip
));
}
}
Ok(())
}
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> { fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
match method.to_uppercase().as_str() { match method.to_uppercase().as_str() {
"GET" => Ok(reqwest::Method::GET), "GET" => Ok(reqwest::Method::GET),
@ -209,34 +180,6 @@ fn extract_host(url: &str) -> Result<String, String> {
Ok(host) Ok(host)
} }
fn extract_port(url: &str) -> Result<u16, String> {
let scheme = if url.starts_with("https://") {
"https"
} else if url.starts_with("http://") {
"http"
} else {
return Err("Only http:// and https:// URLs are allowed".to_string());
};
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if let Some((_, port)) = authority.rsplit_once(':') {
port.parse::<u16>()
.map_err(|_| format!("Invalid URL port: {}", port))
} else if scheme == "https" {
Ok(443)
} else {
Ok(80)
}
}
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
if allowed_domains.iter().any(|domain| domain == "*") { if allowed_domains.iter().any(|domain| domain == "*") {
return true; return true;
@ -283,13 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast() || v4.is_broadcast()
|| v4.is_multicast() || v4.is_multicast()
} }
std::net::IpAddr::V6(v6) => { std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
}
} }
} }
@ -357,14 +294,6 @@ impl Tool for HttpRequestTool {
} }
}; };
if let Err(e) = self.validate_resolved_host(&url).await {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
let method = match self.validate_method(method_str) { let method = match self.validate_method(method_str) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
@ -380,7 +309,6 @@ impl Tool for HttpRequestTool {
let client = match reqwest::Client::builder() let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs)) .timeout(Duration::from_secs(self.timeout_secs))
.redirect(reqwest::redirect::Policy::none())
.build() .build()
{ {
Ok(c) => c, Ok(c) => c,
@ -508,19 +436,4 @@ mod tests {
async fn test_blocks_local_tld() { async fn test_blocks_local_tld() {
assert!(is_private_host("service.local")); assert!(is_private_host("service.local"));
} }
#[test]
fn test_extract_port_defaults_and_explicit() {
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
}
#[test]
fn test_private_ipv6_ranges_are_blocked() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
}
} }

View File

@ -13,7 +13,6 @@ pub mod get_skill;
pub mod http_request; pub mod http_request;
pub mod memory; pub mod memory;
pub mod path_utils; pub mod path_utils;
pub mod pty;
pub mod registry; pub mod registry;
pub mod schema; pub mod schema;
pub mod send_message; pub mod send_message;
@ -33,7 +32,6 @@ pub use file_write::FileWriteTool;
pub use get_skill::GetSkillTool; pub use get_skill::GetSkillTool;
pub use http_request::HttpRequestTool; pub use http_request::HttpRequestTool;
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool}; pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
pub use pty::{PtyManager, PtyTool};
pub use registry::ToolRegistry; pub use registry::ToolRegistry;
pub use send_message::SendMessageTool; pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use traits::{OutboundMessenger, Tool, ToolResult};
@ -62,7 +60,6 @@ pub fn create_default_tools(
registry.register(FileSearchTool::new()); registry.register(FileSearchTool::new());
registry.register(ContentSearchTool::new()); registry.register(ContentSearchTool::new());
registry.register(BashTool::new()); registry.register(BashTool::new());
registry.register(PtyTool::new(Arc::new(PtyManager::new())));
registry.register(HttpRequestTool::new( registry.register(HttpRequestTool::new(
vec!["*".to_string()], vec!["*".to_string()],
1_000_000, 1_000_000,

View File

@ -160,14 +160,11 @@ impl PtyManager {
}; };
for session in sessions { for session in sessions {
let mut guard = session.lock().unwrap(); let mut guard = session.lock().unwrap();
let child_handle = guard.child.clone(); let mut child_guard = guard.child.lock().unwrap();
{ if let Some(ref mut child) = *child_guard {
let mut child_guard = child_handle.lock().unwrap(); let _ = child.kill();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
}
*child_guard = None;
} }
*child_guard = None;
guard.status = SessionStatus::Killed; guard.status = SessionStatus::Killed;
} }
} }
@ -277,7 +274,7 @@ impl PtyManager {
let session = sessions let session = sessions
.get(session_id) .get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?; .ok_or_else(|| format!("Session not found: {}", session_id))?;
let guard = session.lock().unwrap(); let mut guard = session.lock().unwrap();
if guard.status != SessionStatus::Running { if guard.status != SessionStatus::Running {
return Err("Session is not running".to_string()); return Err("Session is not running".to_string());
} }
@ -299,7 +296,12 @@ impl PtyManager {
Ok(format!("OK, wrote {} bytes", byte_count)) Ok(format!("OK, wrote {} bytes", byte_count))
} }
fn read(&self, session_id: &str, offset: usize, limit: usize) -> Result<String, String> { fn read(
&self,
session_id: &str,
offset: usize,
limit: usize,
) -> Result<String, String> {
let sessions = self.sessions.lock().unwrap(); let sessions = self.sessions.lock().unwrap();
let session = sessions let session = sessions
.get(session_id) .get(session_id)
@ -350,16 +352,14 @@ impl PtyManager {
.ok_or_else(|| format!("Session not found: {}", session_id))?; .ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap(); let mut guard = session.lock().unwrap();
let child_handle = guard.child.clone(); let mut child_guard = guard.child.lock().unwrap();
{ if let Some(ref mut child) = *child_guard {
let mut child_guard = child_handle.lock().unwrap(); let _ = child.kill();
if let Some(ref mut child) = *child_guard { let _ = child.wait();
let _ = child.kill();
let _ = child.wait();
}
*child_guard = None;
} }
*child_guard = None;
guard.status = SessionStatus::Killed; guard.status = SessionStatus::Killed;
drop(child_guard);
drop(guard); drop(guard);
sessions.remove(session_id); sessions.remove(session_id);
@ -545,8 +545,14 @@ impl Tool for PtyTool {
}); });
} }
}; };
let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(0) as usize; let offset = args
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(500) as usize; .get("offset")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(500) as usize;
match self.pty_manager.read(session_id, offset, limit) { match self.pty_manager.read(session_id, offset, limit) {
Ok(output) => Ok(ToolResult { Ok(output) => Ok(ToolResult {
success: true, success: true,

View File

@ -3,7 +3,6 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use serde_json::json; use serde_json::json;
use tokio::net::lookup_host;
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
@ -61,34 +60,9 @@ impl WebFetchTool {
} }
} }
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
let host = extract_host(url)?;
if host.parse::<std::net::IpAddr>().is_ok() {
return Ok(());
}
let port = extract_port(url)?;
let addrs = lookup_host((host.as_str(), port))
.await
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
for addr in addrs {
let ip = addr.ip();
if is_private_ip(&ip) {
return Err(format!(
"Blocked host '{}' because DNS resolved to local/private IP {}",
host, ip
));
}
}
Ok(())
}
async fn fetch_content(&self, url: &str) -> Result<String, String> { async fn fetch_content(&self, url: &str) -> Result<String, String> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs)) .timeout(Duration::from_secs(self.timeout_secs))
.redirect(reqwest::redirect::Policy::none())
.build() .build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?; .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
@ -260,35 +234,6 @@ fn extract_host(url: &str) -> Result<String, String> {
Ok(host) Ok(host)
} }
fn extract_port(url: &str) -> Result<u16, String> {
let scheme = if url.starts_with("https://") {
"https"
} else if url.starts_with("http://") {
"http"
} else {
return Err("Only http:// and https:// URLs are allowed".to_string());
};
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if let Some((_, port)) = authority.rsplit_once(':') {
port.parse::<u16>()
.map_err(|_| format!("Invalid URL port: {}", port))
} else if scheme == "https" {
Ok(443)
} else {
Ok(80)
}
}
fn is_private_host(host: &str) -> bool { fn is_private_host(host: &str) -> bool {
if host == "localhost" || host.ends_with(".localhost") { if host == "localhost" || host.ends_with(".localhost") {
return true; return true;
@ -303,32 +248,19 @@ fn is_private_host(host: &str) -> bool {
} }
if let Ok(ip) = host.parse::<std::net::IpAddr>() { if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return is_private_ip(&ip); return match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
};
} }
false false
} }
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
}
}
}
#[async_trait] #[async_trait]
impl Tool for WebFetchTool { impl Tool for WebFetchTool {
fn name(&self) -> &str { fn name(&self) -> &str {
@ -379,14 +311,6 @@ impl Tool for WebFetchTool {
} }
}; };
if let Err(e) = self.validate_resolved_host(&url).await {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
match self.fetch_content(&url).await { match self.fetch_content(&url).await {
Ok(content) => Ok(ToolResult { Ok(content) => Ok(ToolResult {
success: true, success: true,
@ -433,21 +357,6 @@ mod tests {
assert!(result.unwrap_err().contains("whitespace")); assert!(result.unwrap_err().contains("whitespace"));
} }
#[test]
fn test_extract_port_defaults_and_explicit() {
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
}
#[test]
fn test_private_ipv6_ranges_are_blocked() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
}
#[tokio::test] #[tokio::test]
async fn test_extract_text_simple() { async fn test_extract_text_simple() {
let tool = test_tool(); let tool = test_tool();