Compare commits

..

3 Commits

20 changed files with 1559 additions and 404 deletions

View File

@ -3,7 +3,9 @@
.gitignore
# Build artifacts
target/
target/*
!target/release/
target/release/*
!target/release/picobot
# IDE

View File

@ -1,6 +1,6 @@
[package]
name = "picobot"
version = "1.1.0"
version = "1.1.2"
edition = "2024"
[dependencies]

View File

@ -55,11 +55,8 @@ RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
&& npm cache clean --force \
&& rm -rf /var/lib/apt/lists/*
# Install himalaya (CLI email client) from local file
COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
&& chmod +x /usr/local/bin/himalaya \
&& rm -f /tmp/himalaya.tgz
# Install himalaya (CLI email client) from the official pre-built binary release
RUN curl -sSL https://raw.githubusercontent.com/pimalaya/himalaya/master/install.sh | sh
# Install fd (alternative to find)
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \

View File

@ -18,7 +18,9 @@ PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
- 已修复Session 主处理路径不再在持有 session mutex 时执行 memory recall、上下文压缩、标题 LLM 生成、消息持久化、`/stop` sub-agent 取消或清历史存储操作;慢操作改为锁外执行并用 `state_version`/`worker_generation` 防止陈旧结果覆盖当前会话。
- 已修复Bash 超时清理、文件读取大文件限制、HTTP DNS 私网校验、Bus 关闭退出、Cron `from` 语义和 PTY 工具接入等中等级问题已完成清扫。
- 待处理:工具文件边界仍是后续质量风险。
## 主要发现
@ -108,7 +110,7 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 中高优先级Session 锁内执行过多异步操作
### 已修复Session 锁内执行过多异步操作
位置:
@ -125,13 +127,17 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
建议
已采取修复
- 锁内只做内存状态快照和必要的状态标记。
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
- 为 `Session` 增加 `state_version`,慢操作提交前检查会话是否已被 `/stop`、清历史或其它内存变更替换。
- `/compact` 改为锁内取 history 快照,锁外压缩,锁内提交压缩结果,锁外持久化 meta。
- agent worker Phase 1 改为锁内只创建用户消息、agent、cancel handle 和 history 快照memory recall 与 context compression 都在锁外执行。
- context overflow retry 的二次压缩移到锁外。
- 标题生成改为锁内取 prompt/provider 快照,锁外调用 LLM锁内应用标题锁外持久化。
- `add_message` 拆出内存更新和持久化快照,主消息路径在释放 session 锁后写入 SQLite。
- `/stop` 和清历史不再持有 session 锁等待 sub-agent 取消或 Storage 操作。
### 中优先级Bash 超时不会显式终止子进程
### 已修复Bash 超时不会显式终止子进程
位置:
@ -146,14 +152,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
建议
已采取修复
- 使用 `tokio::process::Child``kill_on_drop(true)`
- 超时分支显式 kill child 并 wait
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义
- Bash 一次性命令改用 `wait_with_output()`,避免 stdout/stderr 顺序读取造成 pipe 阻塞
- 子进程启用 `kill_on_drop(true)`,超时后丢弃等待 future 时会清理 child
- 新增大 stderr 输出测试,覆盖不会因为 stderr pipe 填满而卡住
- 持久/交互式进程通过已接入的 PTY 工具承载
### 中优先级:文件读取对大二进制文件没有输出上限
### 已修复:文件读取对大二进制文件没有输出上限
位置:
@ -168,13 +174,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
建议
已采取修复
- 先检查 metadata size超过阈值直接返回提示。
- 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。
- 对文本读取也改成流式按行读取,而不是整文件读入。
- `file_read` 在读取前检查 metadata size超过安全阈值直接拒绝。
- 二进制 inline base64 增加单独大小上限,超限只返回错误和文件信息。
- 含 NUL 字节内容按二进制处理,避免全 0 文件被 UTF-8 路径误判为文本。
- 增加大文件和大二进制文件测试。
### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验
### 已修复HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置:
@ -188,13 +195,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
建议
已采取修复
- 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
- 对 `http_request``web_fetch` 复用同一套 URL 安全策略。
- `http_request``web_fetch` 在发送请求前通过 DNS 解析 host并拒绝解析到 loopback、private、link-local、multicast、unspecified 的地址。
- IPv6 unique-local 和 link-local 地址也纳入私网判定。
- 禁用 reqwest 自动重定向,避免跳转到未校验的内网地址。
- 增加端口解析和 IPv6 私网判断测试。
### 中优先级:后台任务和主循环缺少监督与优雅关闭
### 已修复:后台任务和主循环缺少监督与优雅关闭
位置:
@ -212,13 +220,14 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。
建议
已采取修复
- 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。
- 用 `CancellationToken` 贯穿 Gateway 子任务。
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
- `MessageBus::consume_inbound/consume_outbound/consume_control` 不再在 channel 关闭时 `expect()` panic改为返回 `Option<T>`
- Gateway message processor 在 inbound/control bus 关闭时记录 warning 并退出 loop。
- OutboundDispatcher 在 outbound bus 关闭时记录 warning 并退出 loop。
- 这不是完整 runtime supervisor但已消除 bus 关闭导致的 panic 崩溃路径,为后续集中 JoinHandle 管理留出接口。
### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
### 已修复Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置:
@ -232,13 +241,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
建议
已采取修复
- 使`cron_schedule.after(&from_dt).next()` 或等价 API
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
- 增加固定时间输入的单元测试,避免受系统时间影响
- cron 分支改`cron_schedule.after(&from_dt).next()`
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为计算起点。
- 增加 UTC 和 Asia/Shanghai 固定时间输入测试。
### 中低优先级:存在未接入或半接入代码,增加维护噪音
### 已修复:存在未接入或半接入代码,增加维护噪音
位置:
@ -254,10 +263,11 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
建议
已采取修复
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
- 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。
- `src/tools/pty.rs` 已接入 `tools/mod.rs`,导出 `PtyManager`/`PtyTool`
- `create_default_tools()` 默认注册共享 `PtyManager``PtyTool`
- 修复 PTY 原本因未编译暴露不出的借用问题。
## 架构评价

View File

@ -0,0 +1,496 @@
# PicoBot 跨渠道交互式消息规划
规划日期2026-06-16
## 背景
飞书交互式卡片可以让用户直接在消息卡片上点击按钮、提交选择或触发回调。这个能力很适合用于工具调用审批、快捷回复、任务确认、表单收集等 agent 交互。
参考项目调研结果:
- `reference/zeroclaw` 已实现飞书/Lark 工具审批卡片:发送 Card JSON 2.0 按钮卡片,收到 `card.action.trigger` 后解析 `approval_id``decision`,唤醒等待中的 approval future并 PATCH 原卡片为已处理状态。
- `reference/nanobot` 主要使用飞书 CardKit 做 agent 输出展示和流式更新,适合参考消息渲染体验,但没有完整的按钮回调驱动 agent 流程。
- `reference/openlark` 是 SDK/API 封装,支持发送 interactive card 和 CardKit API不包含完整 agent channel 编排。
PicoBot 当前飞书渠道已经会把普通 markdown 回复发送成 interactive card但还缺少“用户在卡片上操作 -> 统一交互事件 -> Session/Agent/Tool 流程继续”的抽象。
## 目标
1. 支持飞书交互式卡片按钮回调。
2. 设计成跨渠道能力,后续 Slack、Telegram、Discord、CLI chat 等渠道可以复用同一套交互语义。
3. 支持渠道降级:不支持按钮的渠道也能用纯文本命令完成同样操作。
4. 保持 PicoBot 现有边界Channel 只做收发和渠道适配SessionManager 管会话AgentLoop 执行 LLM 和工具。
5. 为工具调用审批、快捷回复和未来表单交互预留扩展点。
非目标:
- 本阶段不立即实现完整功能。
- 不把飞书卡片细节泄漏到 AgentLoop 或工具层。
- 不要求所有渠道同时支持原生交互组件。
## 核心原则
交互语义和渠道渲染分离。
Agent、工具或 Session 层只表达“我要一个 approval/quick reply/form interaction”。具体是飞书卡片按钮、Slack Block Kit、Telegram inline keyboard还是 CLI 里显示编号选项,由 Channel 根据能力渲染。
回调也要统一。
飞书的 `card.action.trigger`、Telegram 的 `callback_query`、Slack 的 interaction payload 都应归一化成 PicoBot 内部的 `InteractionEvent`,再交给统一的处理器。
## 数据模型
建议新增一个 `interaction` 模块,定义渠道无关的数据结构。
```rust
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum InteractionKind {
QuickReply,
Approval,
FormSubmit,
Command,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum InteractionStyle {
Default,
Primary,
Danger,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InteractionAction {
pub id: String,
pub label: String,
pub value: String,
pub style: InteractionStyle,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InteractionPayload {
pub interaction_id: String,
pub kind: InteractionKind,
pub title: Option<String>,
pub body: String,
pub actions: Vec<InteractionAction>,
pub expires_at: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct InteractionEvent {
pub channel: String,
pub chat_id: String,
pub sender_id: String,
pub interaction_id: String,
pub action_id: String,
pub action_value: String,
pub timestamp: i64,
pub metadata: std::collections::HashMap<String, String>,
}
```
`InteractionPayload` 用于 outbound 渲染,`InteractionEvent` 用于 inbound 回调。
## OutboundMessage 扩展
短期兼容方案:
- 继续使用 `OutboundMessage.metadata` 携带交互描述。
- 例如:
- `interaction.kind = "approval"`
- `interaction.id = "<uuid>"`
- `interaction.actions = "<json>"`
长期推荐方案:
```rust
pub struct OutboundMessage {
pub channel: String,
pub chat_id: String,
pub content: String,
pub reply_to: Option<String>,
pub media: Vec<MediaItem>,
pub metadata: HashMap<String, String>,
pub interaction: Option<InteractionPayload>,
}
```
推荐长期方案。它能避免把结构化交互塞进字符串 metadata也让每个 Channel 的 `send()` 更清晰。
## Channel 能力声明
`Channel` 增加可选能力声明:
```rust
#[derive(Debug, Clone, Default)]
pub struct ChannelCapabilities {
pub interactive_buttons: bool,
pub forms: bool,
pub message_update: bool,
pub markdown_cards: bool,
}
pub trait Channel {
fn capabilities(&self) -> ChannelCapabilities {
ChannelCapabilities::default()
}
}
```
渠道能力示例:
| 渠道 | 原生按钮 | 表单 | 更新原消息 | 降级策略 |
|------|----------|------|------------|----------|
| Feishu | 是interactive card | 可后续支持 | 是PATCH message/card | 文本命令 |
| Slack | 是Block Kit | 是 | 是 | 文本命令 |
| Telegram | 是inline keyboard | 有限 | 可编辑消息 | 文本命令 |
| Discord | 是components | 有限 | 可编辑消息 | 文本命令 |
| CLI chat | 否 | 否 | 局部可模拟 | 编号/命令输入 |
| Webhook/Email/SMS | 否 | 否 | 通常否 | 纯文本命令或链接 |
## 渲染策略
每个 channel 实现一个渠道内的渲染函数:
```rust
async fn send_interaction(
&self,
chat_id: &str,
payload: &InteractionPayload,
) -> Result<(), ChannelError>;
```
也可以先不改 trait`send()` 内部判断 `msg.interaction`
飞书渲染:
- 使用 Card JSON 2.0。
- `schema = "2.0"`
- body 用 markdown 展示 `payload.body`
- actions 渲染为 button。
- 每个按钮的 callback value 写入:
```json
{
"interaction_id": "...",
"action_id": "...",
"action_value": "approve"
}
```
需要兼容飞书 Card 2.0 回调路径:
- `/action/value`
- `/action/behaviors/0/value`
CLI 降级渲染:
```text
需要确认:
Tool: bash
Args: cargo test --lib
可选操作:
1. Approve
2. Deny
3. Always approve
回复:
/_interaction <interaction_id> approve
/_interaction <interaction_id> deny
/_interaction <interaction_id> always
```
纯文本渠道都可以复用这个 fallback renderer。
## Inbound 回调归一化
飞书 WebSocket 当前在 `src/channels/feishu.rs` 里处理 `im.message.receive_v1`。需要新增对 `card.action.trigger` 的识别:
1. ACK 仍要尽快发送,飞书要求 3 秒内响应。
2. 如果 event type 是 `card.action.trigger`,不要走普通消息解析。
3. 从 event payload 中解析 `interaction_id``action_id``action_value`
4. 构造 `InteractionEvent` 发布给统一处理器。
5. 对未知、过期或重复 interaction 返回成功但记录日志,不应导致渠道重连或报错。
如果短期不新增 interaction bus可以把回调转成特殊 `InboundMessage`
```text
content = "/_interaction <interaction_id> <action_value>"
metadata["event.kind"] = "interaction"
metadata["interaction.id"] = "<interaction_id>"
metadata["interaction.action_id"] = "<action_id>"
metadata["interaction.action_value"] = "<action_value>"
```
但必须由 SessionManager 或 InteractionManager 先拦截,不能把 `/_interaction` 当普通用户文本直接送进 LLM。
长期推荐新增 bus 通道:
```rust
pub enum InboundEvent {
Message(InboundMessage),
Interaction(InteractionEvent),
}
```
或者在 `MessageBus` 上增加 `interaction_tx`
## InteractionManager
建议新增 `InteractionManager`,集中管理 pending 交互状态,而不是让每个 Channel 各自维护。
职责:
- 生成 `interaction_id`
- 保存 pending interaction。
- 处理超时和过期。
- 接收 `InteractionEvent` 并解析成业务结果。
- 对重复点击、未知 interaction、过期 interaction 做幂等处理。
- 必要时通知 channel 更新原消息。
内部状态示例:
```rust
pub struct PendingInteraction {
pub id: String,
pub kind: InteractionKind,
pub channel: String,
pub chat_id: String,
pub sender_id: Option<String>,
pub session_id: Option<String>,
pub created_at: i64,
pub expires_at: Option<i64>,
pub status: InteractionStatus,
pub responder: InteractionResponder,
pub message_ref: Option<InteractionMessageRef>,
}
pub struct InteractionMessageRef {
pub channel: String,
pub chat_id: String,
pub message_id: String,
pub metadata: HashMap<String, String>,
}
```
`InteractionResponder` 可以先支持 oneshot
```rust
pub enum InteractionResponder {
Approval(tokio::sync::oneshot::Sender<ApprovalDecision>),
InboundMessage,
}
```
后续如果需要持久化长期交互oneshot 不够,需要落库。
## 工具审批流程
工具审批是第一批最适合落地的交互类型。
推荐流程:
1. AgentLoop 准备执行需要审批的工具。
2. 调用 `InteractionManager::request_approval(...)`
3. InteractionManager 创建 `InteractionPayload`,通过 outbound 发送到原 channel/chat。
4. AgentLoop 等待 oneshot带 timeout。
5. 用户在飞书卡片上点击 Approve/Deny/Always。
6. FeishuChannel 收到 `card.action.trigger`,发布 `InteractionEvent`
7. InteractionManager resolve pending approval。
8. AgentLoop 收到结果,继续执行或拒绝工具。
9. 如果 channel 支持更新消息InteractionManager 或 Channel 把原卡片更新成 resolved 状态。
审批 action 建议:
```rust
approve -> ApprovalDecision::Approve
deny -> ApprovalDecision::Deny
always -> ApprovalDecision::AlwaysApprove
```
`DenyWithEdit` 可后续支持,适合 ACP/Web/CLI 这类能输入文本的渠道。
## 快捷回复流程
快捷回复不是阻塞工具执行,而是把用户点击转成新的用户输入。
示例:
```json
{
"kind": "QuickReply",
"body": "你想继续哪个操作?",
"actions": [
{ "label": "继续分析", "value": "继续分析" },
{ "label": "生成报告", "value": "生成报告" }
]
}
```
用户点击后:
- `InteractionEvent.action_value` 转成一条普通 `InboundMessage.content`
- `sender_id``chat_id` 保留原用户和会话。
- metadata 标记来源为 interaction供审计或 UI 使用。
## 消息更新
支持原消息更新的渠道应在交互完成后更新 UI避免重复点击。
飞书:
- 发送卡片后保存 `data.message_id`
- resolve 后 PATCH `/im/v1/messages/{message_id}`
- 卡片 schema 发送和更新都使用 Card JSON 2.0,参考项目指出跨版本 PATCH 可能返回成功但客户端不重渲染。
不支持更新的渠道:
- 发送一条新消息提示“已批准/已拒绝”。
- 或仅在后台幂等拒绝重复点击。
## 安全和权限
交互回调必须校验:
- `interaction_id` 是否存在。
- 是否已过期。
- 是否已处理。
- 点击用户是否允许处理该 interaction。
- 当前 channel/chat 是否匹配。
对于工具审批,默认建议只有触发该 agent turn 的用户或允许列表用户可以审批。群聊里要特别注意 `sender_id`,不能只看 `chat_id`
日志中避免记录原始飞书回调敏感字段:
- callback token
- operator open_id/union_id/user_id/tenant_key
- open_chat_id/open_message_id
可以记录脱敏后的 payload shape用于排查飞书回调字段变化。
## 持久化策略
第一阶段可以只做内存 pending map
- 适合短时工具审批。
- 进程重启后旧按钮点击会变成 unknown/expired。
- 实现简单。
后续如果要支持长期任务或跨重启交互,需要持久化:
- `interactions` 表保存 id、kind、channel、chat_id、sender_id、status、payload、created_at、expires_at。
- `interaction_actions` 可选,或直接 JSON 存在 payload 中。
- resolve 时事务更新 status防止重复点击竞态。
## 与现有架构的关系
现有数据流:
```text
Channel -> MessageBus -> SessionManager -> AgentLoop -> tools -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
```
加入 interaction 后建议:
```text
Outbound:
AgentLoop/Tool approval -> InteractionManager -> MessageBus outbound -> OutboundDispatcher -> Channel renderer
Inbound:
Channel callback -> InteractionEvent -> InteractionManager -> pending waiter / synthetic InboundMessage
```
Channel 仍然只做渠道协议适配:
- 飞书负责 Card JSON 和 `card.action.trigger`
- Slack 负责 Block Kit 和 signing secret。
- Telegram 负责 callback query。
- CLI 负责文本命令 fallback。
InteractionManager 负责语义:
- 这是 approval 还是 quick reply。
- 是否过期。
- 是否有权限。
- 应该唤醒哪个等待者。
SessionManager/AgentLoop 不需要知道飞书卡片格式。
## 分阶段实施计划
### 阶段 1模型和 fallback
- 新增 `src/interaction/` 模块。
- 定义 `InteractionPayload``InteractionAction``InteractionEvent`
- 增加 fallback text renderer。
- 为 `OutboundMessage` 增加 `interaction: Option<InteractionPayload>`,或短期使用 metadata。
- 增加单元测试覆盖序列化和 fallback 文本。
### 阶段 2Feishu card action
- Feishu outbound 支持把 `InteractionPayload` 渲染为 Card JSON 2.0。
- Feishu inbound 在 WebSocket frame 中识别 `card.action.trigger`
- 解析 `/action/value``/action/behaviors/0/value`
- 发布统一 `InteractionEvent`
- 保存 `message_id`,支持完成后 PATCH resolved card。
- 增加 fixtures 测试真实/模拟的 `card.action.trigger` payload。
### 阶段 3InteractionManager 和审批
- 新增内存版 `InteractionManager`
- 支持 request/resolve/timeout。
- 接入工具执行前审批点。
- 支持 Approve/Deny/AlwaysApprove。
- 未知、过期、重复点击保持幂等。
- 增加 agent/tool 审批单元测试。
### 阶段 4其他渠道兼容
- CLI chat 支持 `/_interaction <id> <value>` fallback。
- 其他不支持原生按钮的渠道使用纯文本 fallback。
- 后续按需实现 Slack/Telegram/Discord 原生按钮。
### 阶段 5持久化和高级交互
- 需要时落库 pending interaction。
- 支持 quick reply 生成 synthetic inbound message。
- 支持表单提交。
- 支持 `DenyWithEdit`
- 支持长期任务交互和重启恢复。
## 测试计划
单元测试:
- `InteractionPayload` 序列化。
- fallback text renderer 输出。
- Feishu Card JSON 包含正确 callback value。
- Feishu 回调同时支持 `/action/value``/action/behaviors/0/value`
- unknown/expired interaction 不报错。
- 重复点击只 resolve 一次。
集成测试:
- 模拟 Feishu `card.action.trigger`,验证 pending approval 被唤醒。
- 模拟超时,验证默认 deny。
- CLI fallback 输入 `/_interaction`,验证能 resolve。
- 不支持按钮的 channel 能收到可读 fallback 文本。
手工验证:
- 飞书群聊点击 Approve/Deny/Always。
- 私聊点击。
- 点击后原卡片更新为 resolved。
- 重复点击不会重复执行工具。
- 非触发用户点击时按权限策略处理。
## 开放问题
1. 工具审批应该由 AgentLoop 直接调用 InteractionManager还是通过 SessionManager 代理?
2. 群聊中是否只允许原始触发者审批,还是允许配置中的所有 allowed user 审批?
3. `AlwaysApprove` 的作用域是本次会话、本 dialog、本 chat还是全局工具策略
4. 是否需要第一阶段就修改 `OutboundMessage` 结构,还是先用 metadata 降低改动面?
5. 飞书 CardKit 流式输出是否要与交互卡片统一,还是继续保持普通回复卡片和交互卡片两套路径?

View File

@ -28,6 +28,10 @@ fn build_content_blocks(
) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
if !media_refs.is_empty() {
for mr in media_refs {
if input_types.contains(&mr.media_type) {
@ -59,8 +63,6 @@ fn build_content_blocks(
)));
}
}
} else if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
if blocks.is_empty() {
@ -858,6 +860,23 @@ mod tests {
"calculator"
);
}
#[test]
fn test_build_content_blocks_keeps_text_with_media() {
let registry = MediaHandlerRegistry::new();
let blocks = build_content_blocks(
"先看这段文字",
&[MediaRef {
path: "missing.png".to_string(),
media_type: "image".to_string(),
}],
&[],
&registry,
);
assert!(matches!(blocks.first(), Some(ContentBlock::Text { text }) if text == "先看这段文字"));
assert!(matches!(blocks.get(1), Some(ContentBlock::Text { text }) if text.contains("用户发来了一个文件")));
}
}
#[derive(Debug)]

View File

@ -24,7 +24,10 @@ impl OutboundDispatcher {
tracing::info!("OutboundDispatcher started");
loop {
let msg = self.bus.consume_outbound().await;
let Some(msg) = self.bus.consume_outbound().await else {
tracing::warn!("OutboundDispatcher stopping because outbound bus closed");
break;
};
let channel_name = msg.channel.clone();
let channel = self.channel_manager.get_channel(&channel_name).await;

View File

@ -51,17 +51,11 @@ impl MessageBus {
}
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self
.inbound_rx
.lock()
.await
.recv()
.await
.expect("bus inbound closed");
pub async fn consume_inbound(&self) -> Option<InboundMessage> {
let msg = self.inbound_rx.lock().await.recv().await?;
#[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
msg
Some(msg)
}
/// Publish an outbound message (Agent -> Bus)
@ -75,13 +69,8 @@ impl MessageBus {
}
/// Consume an outbound message (Dispatcher -> Bus)
pub async fn consume_outbound(&self) -> OutboundMessage {
self.outbound_rx
.lock()
.await
.recv()
.await
.expect("bus outbound closed")
pub async fn consume_outbound(&self) -> Option<OutboundMessage> {
self.outbound_rx.lock().await.recv().await
}
/// Publish a control message (Channel -> Bus for session management)
@ -94,13 +83,8 @@ impl MessageBus {
}
/// Consume a control message (ControlProcessor -> Bus)
pub async fn consume_control(&self) -> ControlMessage {
self.control_rx
.lock()
.await
.recv()
.await
.expect("bus control closed")
pub async fn consume_control(&self) -> Option<ControlMessage> {
self.control_rx.lock().await.recv().await
}
}

View File

@ -165,7 +165,7 @@ struct ParsedMessage {
open_id: String,
chat_id: String,
content: String,
media: Option<MediaItem>,
media: Vec<MediaItem>,
/// ID of the message this message is replying to (if any).
/// Used to fetch quoted message content for display.
parent_id: Option<String>,
@ -1007,7 +1007,7 @@ impl FeishuChannel {
}
#[cfg(debug_assertions)]
if let Some(ref m) = media {
for m in &media {
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
}
@ -1027,7 +1027,7 @@ impl FeishuChannel {
msg_type: &str,
content: &str,
message_id: &str,
) -> Result<(String, Option<MediaItem>), ChannelError> {
) -> Result<(String, Vec<MediaItem>), ChannelError> {
let (text, media) = match msg_type {
"text" => {
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
@ -1039,20 +1039,40 @@ impl FeishuChannel {
} else {
content.to_string()
};
(text, None)
(text, Vec::new())
}
"post" => {
let text = parse_post_content(content);
let mut media = Vec::new();
for image_key in collect_post_image_keys(content) {
let content_json = serde_json::json!({ "image_key": image_key });
match self
.download_media("image", &content_json, message_id)
.await
{
Ok((_text, Some(item))) => media.push(item),
Ok((_text, None)) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to download image from Feishu post message");
}
}
}
(text, media)
}
"post" => (parse_post_content(content), None),
"image" | "audio" | "file" | "media" => {
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
match self
.download_media(msg_type, &content_json, message_id)
.await
{
Ok((text, media)) => (text, media),
Err(_) => (format!("[{}: content unavailable]", msg_type), None),
Ok((text, Some(media))) => (text, vec![media]),
Ok((text, None)) => (text, Vec::new()),
Err(_) => (format!("[{}: content unavailable]", msg_type), Vec::new()),
}
} else {
(format!("[{}: content unavailable]", msg_type), None)
(format!("[{}: content unavailable]", msg_type), Vec::new())
}
}
"share_chat" => {
@ -1062,9 +1082,9 @@ impl FeishuChannel {
.get("chat_id")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
(format!("[shared chat: {}]", chat_id), None)
(format!("[shared chat: {}]", chat_id), Vec::new())
} else {
("[shared chat]".to_string(), None)
("[shared chat]".to_string(), Vec::new())
}
}
"share_user" => {
@ -1074,42 +1094,44 @@ impl FeishuChannel {
.get("user_id")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
(format!("[shared user: {}]", user_id), None)
(format!("[shared user: {}]", user_id), Vec::new())
} else {
("[shared user]".to_string(), None)
("[shared user]".to_string(), Vec::new())
}
}
"interactive" => {
// Interactive card messages - extract text content
match extract_interactive_content(content) {
Ok((text, media)) => (text, media),
Ok((text, Some(media))) => (text, vec![media]),
Ok((text, None)) => (text, Vec::new()),
Err(e) => {
tracing::warn!(error = %e, "Failed to extract interactive content");
(content.to_string(), None)
(content.to_string(), Vec::new())
}
}
}
"list" => {
// List/bullet messages
match parse_list_content(content) {
Ok((text, media)) => (text, media),
Err(_) => (content.to_string(), None),
Ok((text, Some(media))) => (text, vec![media]),
Ok((text, None)) => (text, Vec::new()),
Err(_) => (content.to_string(), Vec::new()),
}
}
"merge_forward" => ("[merged forward messages]".to_string(), None),
"merge_forward" => ("[merged forward messages]".to_string(), Vec::new()),
"share_calendar_event" => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
let event_key = parsed
.get("event_key")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
(format!("[shared calendar event: {}]", event_key), None)
(format!("[shared calendar event: {}]", event_key), Vec::new())
} else {
("[shared calendar event]".to_string(), None)
("[shared calendar event]".to_string(), Vec::new())
}
}
"system" => ("[system message]".to_string(), None),
_ => (content.to_string(), None),
"system" => ("[system message]".to_string(), Vec::new()),
_ => (content.to_string(), Vec::new()),
};
// Strip @_user_N placeholders from group chat @mentions
@ -1235,16 +1257,15 @@ impl FeishuChannel {
let channel = self.clone();
let bus = bus.clone();
tokio::spawn(async move {
let media_count = if parsed.media.is_some() { 1 } else { 0 };
#[cfg(debug_assertions)]
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %parsed.media.len(), "Publishing message to bus");
let msg = crate::bus::InboundMessage {
channel: "feishu".to_string(),
sender_id: parsed.open_id.clone(),
chat_id: parsed.chat_id.clone(),
content: parsed.content.clone(),
timestamp: crate::bus::message::current_timestamp(),
media: parsed.media.map(|m| vec![m]).unwrap_or_default(),
media: parsed.media.clone(),
metadata: std::collections::HashMap::new(),
forwarded_metadata,
};
@ -1333,6 +1354,52 @@ impl FeishuChannel {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn collect_post_image_keys_finds_nested_images() {
let content = serde_json::json!({
"zh_cn": {
"title": "",
"content": [[
{"tag": "img", "image_key": "img_v3_001"},
{"tag": "text", "text": "这是哪里?"},
{"tag": "img", "image_key": "img_v3_002"},
{"tag": "img", "image_key": "img_v3_001"}
]]
}
})
.to_string();
assert_eq!(
collect_post_image_keys(&content),
vec!["img_v3_001".to_string(), "img_v3_002".to_string()]
);
}
#[test]
fn parse_post_content_preserves_image_positions() {
let content = serde_json::json!({
"zh_cn": {
"title": "",
"content": [[
{"tag": "text", "text": "这是一张图:"},
{"tag": "img", "image_key": "img_v3_001"},
{"tag": "text", "text": "看完继续说"}
]]
}
})
.to_string();
assert_eq!(
parse_post_content(&content),
"这是一张图:[image]看完继续说"
);
}
}
fn parse_post_content(content: &str) -> String {
/// Extract text from a single post element (text, link, at-mention).
fn extract_element(el: &serde_json::Value, out: &mut Vec<String>) {
@ -1359,6 +1426,9 @@ fn parse_post_content(content: &str) -> String {
.unwrap_or("user");
out.push(format!("@{}", name));
}
"img" => {
out.push("[image]".to_string());
}
"code_block" => {
let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or("");
let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or("");
@ -1449,6 +1519,38 @@ fn parse_post_content(content: &str) -> String {
content.to_string()
}
fn collect_post_image_keys(content: &str) -> Vec<String> {
fn visit(value: &serde_json::Value, keys: &mut Vec<String>) {
match value {
serde_json::Value::Object(map) => {
if let Some(image_key) = map.get("image_key").and_then(|v| v.as_str())
&& !keys.iter().any(|k| k == image_key)
{
keys.push(image_key.to_string());
}
for child in map.values() {
visit(child, keys);
}
}
serde_json::Value::Array(items) => {
for item in items {
visit(item, keys);
}
}
_ => {}
}
}
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) else {
return Vec::new();
};
let mut keys = Vec::new();
visit(&parsed, &mut keys);
keys
}
/// Extract text content from interactive card messages
fn extract_interactive_content(content: &str) -> Result<(String, Option<MediaItem>), ChannelError> {
let parsed = match serde_json::from_str::<serde_json::Value>(content) {

View File

@ -205,6 +205,10 @@ impl GatewayState {
tokio::select! {
// Inbound: AI message flow
inbound = bus.consume_inbound() => {
let Some(inbound) = inbound else {
tracing::warn!("Message processor stopping because inbound bus closed");
break;
};
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
@ -252,6 +256,10 @@ impl GatewayState {
// Control: session management operations
msg = bus.consume_control() => {
let Some(msg) = msg else {
tracing::warn!("Message processor stopping because control bus closed");
break;
};
Self::handle_control_message(&session_manager, msg).await;
}
}

View File

@ -3,7 +3,7 @@ use clap::{CommandFactory, Parser};
#[derive(Parser)]
#[command(name = "picobot")]
#[command(about = "A CLI chatbot", long_about = None)]
#[command(version = "1.1.0")]
#[command(version = "1.1.1")]
enum Command {
/// Connect to gateway
Chat {

View File

@ -150,13 +150,20 @@ struct OpenAIChoice {
message: OpenAIMessage,
}
fn null_or_missing_tool_calls<'de, D>(deserializer: D) -> Result<Vec<OpenAIToolCall>, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Option::<Vec<OpenAIToolCall>>::deserialize(deserializer)?.unwrap_or_default())
}
#[derive(Deserialize)]
struct OpenAIMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
#[serde(default)]
#[serde(default, deserialize_with = "null_or_missing_tool_calls")]
tool_calls: Vec<OpenAIToolCall>,
}
@ -418,4 +425,42 @@ mod tests {
"{\"expression\":\"1+1\"}"
);
}
#[test]
fn test_decode_response_accepts_null_tool_calls() {
let text = r#"{
"id": "d21abaa6552741949e2aba76bde59359",
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "你好!",
"role": "assistant",
"tool_calls": null,
"reasoning_content": "The user sent a greeting."
}
}],
"created": 1781622889,
"model": "mimo-v2.5",
"object": "chat.completion",
"usage": {
"completion_tokens": 65,
"prompt_tokens": 11741,
"total_tokens": 11806,
"completion_tokens_details": {"reasoning_tokens": 40},
"prompt_tokens_details": {}
}
}"#;
let response: OpenAIResponse = serde_json::from_str(text).unwrap();
let message = &response.choices[0].message;
assert_eq!(message.content.as_deref(), Some("你好!"));
assert_eq!(
message.reasoning_content.as_deref(),
Some("The user sent a greeting.")
);
assert!(message.tool_calls.is_empty());
assert_eq!(response.usage.total_tokens, 11806);
}
}

View File

@ -30,11 +30,11 @@ pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
let next_utc = if let Some(tz_str) = tz {
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
let _from_local = from_dt.with_timezone(&tz);
let next_local = cron_schedule.upcoming(tz).next()?;
let from_local = from_dt.with_timezone(&tz);
let next_local = cron_schedule.after(&from_local).next()?;
next_local.with_timezone(&Utc)
} else {
cron_schedule.upcoming(Utc).next()?
cron_schedule.after(&from_dt).next()?
};
Some(next_utc.timestamp_millis())
@ -311,4 +311,37 @@ mod tests {
let next_ms = next.unwrap();
assert!(next_ms > now);
}
#[test]
fn test_next_run_cron_uses_from_argument() {
let expr = "0 * * * * *".to_string();
let schedule = Schedule::Cron { expr, tz: None };
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:34:20Z")
.unwrap()
.timestamp_millis();
let next = next_run_for_schedule(&schedule, from).unwrap();
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:35:00Z")
.unwrap()
.timestamp_millis();
assert_eq!(next, expected);
}
#[test]
fn test_next_run_cron_timezone_uses_from_argument() {
let expr = "0 0 9 * * *".to_string();
let schedule = Schedule::Cron {
expr,
tz: Some("Asia/Shanghai".to_string()),
};
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T00:30:00Z")
.unwrap()
.timestamp_millis();
let next = next_run_for_schedule(&schedule, from).unwrap();
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T01:00:00Z")
.unwrap()
.timestamp_millis();
assert_eq!(next, expected);
}
}

View File

@ -8,6 +8,13 @@ use crate::mcp::get_mcp_status;
use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc;
type MessagePersistSnapshot = (
StdArc<Storage>,
String,
crate::storage::message::MessageMeta,
crate::storage::session::SessionMeta,
);
tokio::task_local! {
static CURRENT_SOURCE_SESSION: Option<String>;
}
@ -82,6 +89,14 @@ pub struct Session {
current_cancel: Option<oneshot::Sender<()>>,
/// Monotonic counter to detect stale workers
worker_generation: u64,
/// Monotonic counter for in-memory session mutations.
///
/// Slow work such as memory recall, compression, and title generation runs
/// outside the session lock. Workers capture this version before starting
/// that work and verify it before committing results, so stale snapshots do
/// not overwrite a session that was changed by a command such as /clear or
/// /delete while the slow work was in flight.
state_version: u64,
}
/// A task to be processed by the per-session agent worker
@ -146,6 +161,7 @@ impl Session {
agent_tx: None,
current_cancel: None,
worker_generation: 0,
state_version: 0,
})
}
@ -322,6 +338,7 @@ impl Session {
agent_tx: None,
current_cancel: None,
worker_generation: 0,
state_version: 0,
})
}
@ -337,6 +354,15 @@ impl Session {
message: ChatMessage,
persist: bool,
) -> Result<(), StorageError> {
let snapshot = self.add_message_in_memory(message, persist);
persist_added_message(snapshot).await
}
fn add_message_in_memory(
&mut self,
message: ChatMessage,
persist: bool,
) -> Option<MessagePersistSnapshot> {
let is_user = message.role == "user";
let now = chrono::Utc::now().timestamp_millis();
@ -344,36 +370,37 @@ impl Session {
let seq = self.seq_counter;
self.seq_counter += 1;
// Persist to Storage
if persist && let Some(ref storage) = self.storage {
let msg_meta = crate::storage::message::MessageMeta {
id: message.id.clone(),
session_id: self.id.to_string(),
seq,
role: message.role.clone(),
content: message.content.clone(),
reasoning_content: message.reasoning_content.clone(),
media_refs: if message.media_refs.is_empty() {
None
} else {
Some(serde_json::to_string(&message.media_refs).unwrap_or_default())
},
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_calls: message
.tool_calls
.as_ref()
.and_then(|tc| serde_json::to_string(tc).ok()),
source: message
.source
.as_ref()
.map(|s| serde_json::to_string(s).unwrap_or_default()),
created_at: now,
};
storage
.append_message_with_retry(&self.id.to_string(), &msg_meta)
.await?;
}
let persist_snapshot = if persist {
self.storage.clone().map(|storage| {
let msg_meta = crate::storage::message::MessageMeta {
id: message.id.clone(),
session_id: self.id.to_string(),
seq,
role: message.role.clone(),
content: message.content.clone(),
reasoning_content: message.reasoning_content.clone(),
media_refs: if message.media_refs.is_empty() {
None
} else {
Some(serde_json::to_string(&message.media_refs).unwrap_or_default())
},
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_calls: message
.tool_calls
.as_ref()
.and_then(|tc| serde_json::to_string(tc).ok()),
source: message
.source
.as_ref()
.map(|s| serde_json::to_string(s).unwrap_or_default()),
created_at: now,
};
(storage, self.id.to_string(), msg_meta)
})
} else {
None
};
// Update in-memory state
self.messages.push(message);
@ -382,16 +409,30 @@ impl Session {
self.message_count += 1;
}
self.last_active_at = now;
self.state_version = self.state_version.wrapping_add(1);
// Sync message_count to Storage
if persist {
tracing::debug!(session_id = %self.id, last_active_at = %now, message_count = %self.message_count, "Persisting session meta after add_message");
if let Err(e) = self.persist_session_meta().await {
tracing::warn!("failed to persist session meta: {}", e);
}
}
Ok(())
persist_snapshot.map(|(storage, session_id, msg_meta)| {
let session_meta = crate::storage::session::SessionMeta {
id: session_id.clone(),
channel: self.id.channel.clone(),
chat_id: self.id.chat_id.clone(),
dialog_id: self.id.dialog_id.clone(),
title: self.title.clone(),
created_at: self.created_at,
last_active_at: self.last_active_at,
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
archived_at: self.archived_at,
deleted_at: None,
last_consolidated_at: self.last_consolidated_at,
last_compressed_message_at: self.last_compressed_message_at,
};
(storage, session_id, msg_meta, session_meta)
})
}
/// 获取消息历史
@ -406,6 +447,7 @@ impl Session {
self.seq_counter = 1;
self.total_message_count = 0;
self.message_count = 0;
self.state_version = self.state_version.wrapping_add(1);
#[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
}
@ -417,6 +459,7 @@ impl Session {
self.seq_counter = 1;
self.total_message_count = 0;
self.message_count = 0;
self.state_version = self.state_version.wrapping_add(1);
#[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
}
@ -444,43 +487,49 @@ impl Session {
/// 将 session 元数据写回 Storage
pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
if let Some(ref storage) = self.storage {
let meta = crate::storage::session::SessionMeta {
id: self.id.to_string(),
channel: self.id.channel.clone(),
chat_id: self.id.chat_id.clone(),
dialog_id: self.id.dialog_id.clone(),
title: self.title.clone(),
created_at: self.created_at,
last_active_at: self.last_active_at,
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
archived_at: self.archived_at,
deleted_at: None,
last_consolidated_at: self.last_consolidated_at,
last_compressed_message_at: self.last_compressed_message_at,
};
if let Some((storage, meta)) = self.session_meta_snapshot() {
storage.upsert_session(&meta).await?;
}
Ok(())
}
fn session_meta_snapshot(
&self,
) -> Option<(StdArc<Storage>, crate::storage::session::SessionMeta)> {
let storage = self.storage.clone()?;
let meta = crate::storage::session::SessionMeta {
id: self.id.to_string(),
channel: self.id.channel.clone(),
chat_id: self.id.chat_id.clone(),
dialog_id: self.id.dialog_id.clone(),
title: self.title.clone(),
created_at: self.created_at,
last_active_at: self.last_active_at,
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
archived_at: self.archived_at,
deleted_at: None,
last_consolidated_at: self.last_consolidated_at,
last_compressed_message_at: self.last_compressed_message_at,
};
Some((storage, meta))
}
/// 检查是否需要自动生成 title5 条用户消息后)
pub fn should_generate_title(&self) -> bool {
self.title == "新对话" && self.message_count >= 5
}
/// 生成标题(调用 LLM
pub async fn generate_title(&mut self) -> Result<(), AgentError> {
fn title_prompt_snapshot(&self) -> Option<String> {
if !self.should_generate_title() {
return Ok(());
return None;
}
let prompt = format!(
Some(format!(
r#"给定以下对话历史生成一个简短的会话标题5-15 个中文字符),概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
@ -492,38 +541,41 @@ impl Session {
.map(|m| format!("[{}]: {}", m.role, m.content))
.collect::<Vec<_>>()
.join("\n")
);
let title = self.call_llm_for_title(&prompt).await?;
if !title.is_empty() {
self.title = title.clone();
if let Err(e) = self.persist_session_meta().await {
tracing::warn!("failed to persist title: {}", e);
}
}
Ok(())
))
}
/// 调用 LLM 生成标题
async fn call_llm_for_title(&self, prompt: &str) -> Result<String, AgentError> {
use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
fn apply_generated_title(&mut self, title: String) -> bool {
if title.is_empty() || !self.should_generate_title() {
return false;
}
let request = ChatCompletionRequest {
messages: vec![Message::user(prompt.to_string())],
temperature: Some(0.3),
max_tokens: Some(20),
tools: None,
self.title = title;
self.state_version = self.state_version.wrapping_add(1);
true
}
fn fresh_context_compressor(&self) -> ContextCompressor {
let compressor_config = ContextCompressionConfig {
protect_first_n: 2,
..Default::default()
};
let mut compressor = ContextCompressor::with_config(
self.provider.clone(),
self.provider_config.token_limit,
compressor_config,
self.memory_manager.clone(),
);
compressor.set_session_id(Some(self.id.to_string()));
compressor
}
let response: ChatCompletionResponse = self
.provider
.chat(request)
.await
.map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
Ok(response.content.trim().to_string())
fn replace_history_in_memory(&mut self, messages: Vec<ChatMessage>) {
self.messages = messages;
self.seq_counter = self.messages.len() as i64 + 1;
self.total_message_count = self.messages.len() as i64;
self.message_count = self.messages.iter().filter(|m| m.role == "user").count() as i64;
self.last_active_at = chrono::Utc::now().timestamp_millis();
self.state_version = self.state_version.wrapping_add(1);
}
/// 获取 provider_config 引用
@ -1075,24 +1127,39 @@ impl SessionManager {
"compact" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let mut session_guard = session.lock().await;
let original_count = session_guard.get_history().len();
let history = session_guard.get_history().to_vec();
let result = session_guard.compressor.compress_if_needed(history).await?;
let (original_count, history, mut compressor, base_version) = {
let session_guard = session.lock().await;
(
session_guard.get_history().len(),
session_guard.get_history().to_vec(),
session_guard.fresh_context_compressor(),
session_guard.state_version,
)
};
let result = compressor.compress_if_needed(history).await?;
let compressed_count = result.history.len();
if result.created_timelines {
session_guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
let meta_snapshot = {
let mut session_guard = session.lock().await;
if session_guard.state_version != base_version {
return Ok((
None,
"Context changed while compacting; please run /compact again."
.to_string(),
));
}
}
session_guard.clear_history();
for msg in result.history {
session_guard
.add_message(msg, false)
.await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
if result.created_timelines {
session_guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
}
session_guard.replace_history_in_memory(result.history);
session_guard.session_meta_snapshot()
};
if let Some((storage, meta)) = meta_snapshot
&& let Err(e) = storage.upsert_session(&meta).await
{
tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
}
Ok((
None,
@ -1304,16 +1371,22 @@ impl SessionManager {
let sid = current_session_id
.ok_or_else(|| AgentError::Other("no active session".to_string()))?;
let session = self.get_or_create_session(sid).await?;
let mut guard = session.lock().await;
let mut msgs: Vec<String> = Vec::new();
if guard.current_cancel.take().is_some() {
msgs.push("当前任务已发送停止信号。".to_string());
}
if guard.agent_tx.take().is_some() {
msgs.push("消息队列已清空。".to_string());
}
guard.worker_generation = guard.worker_generation.wrapping_add(1);
let msgs = {
let mut guard = session.lock().await;
let mut msgs: Vec<String> = Vec::new();
if guard.current_cancel.take().is_some() {
msgs.push("当前任务已发送停止信号。".to_string());
}
if guard.agent_tx.take().is_some() {
msgs.push("消息队列已清空。".to_string());
}
guard.worker_generation = guard.worker_generation.wrapping_add(1);
guard.state_version = guard.state_version.wrapping_add(1);
msgs
};
// Cancel all running background sub-agent tasks for this session
// after releasing the session lock.
self.sub_agent_manager
.cancel_by_session(&sid.to_string())
.await;
@ -1670,7 +1743,7 @@ impl SessionManager {
) -> Result<(), AgentError> {
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
let session = self.get_or_create_session(&unified_id).await?;
{
let persist_snapshot = {
let mut guard = session.lock().await;
let source = MessageSource {
kind: SourceKind::SystemNotification,
@ -1681,11 +1754,11 @@ impl SessionManager {
task_id: task_id.map(|s| s.to_string()),
};
let msg = ChatMessage::assistant_with_source(content, source);
guard
.add_message(msg, true)
.await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
}
guard.add_message_in_memory(msg, true)
};
persist_added_message(persist_snapshot)
.await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
let outbound = OutboundMessage {
channel: channel.to_string(),
@ -1805,6 +1878,63 @@ impl SessionManager {
}
}
async fn maybe_generate_title_outside_lock(session: Arc<Mutex<Session>>) -> Result<(), AgentError> {
use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
let (provider, prompt) = {
let guard = session.lock().await;
let Some(prompt) = guard.title_prompt_snapshot() else {
return Ok(());
};
(guard.provider.clone(), prompt)
};
let request = ChatCompletionRequest {
messages: vec![Message::user(prompt)],
temperature: Some(0.3),
max_tokens: Some(20),
tools: None,
};
let response: ChatCompletionResponse = provider
.chat(request)
.await
.map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
let title = response.content.trim().to_string();
let meta_snapshot = {
let mut guard = session.lock().await;
if guard.apply_generated_title(title) {
guard.session_meta_snapshot()
} else {
None
}
};
if let Some((storage, meta)) = meta_snapshot {
storage
.upsert_session(&meta)
.await
.map_err(|e| AgentError::Other(format!("failed to persist title: {}", e)))?;
}
Ok(())
}
async fn persist_added_message(
snapshot: Option<MessagePersistSnapshot>,
) -> Result<(), StorageError> {
let Some((storage, session_id, msg_meta, session_meta)) = snapshot else {
return Ok(());
};
storage
.append_message_with_retry(&session_id, &msg_meta)
.await?;
storage.upsert_session(&session_meta).await?;
Ok(())
}
fn spawn_agent_worker(
mut task_rx: mpsc::UnboundedReceiver<AgentTask>,
session: Arc<Mutex<Session>>,
@ -1845,8 +1975,12 @@ fn spawn_agent_worker(
});
}
// Phase 1: prepare data under session lock
let (agent, history_out, system_prompt_out, cancel_rx) = {
// Phase 1: capture a stable session snapshot under lock.
// Memory recall and compression happen outside this block so
// /stop and other commands are not blocked behind slow I/O or
// LLM-backed compaction.
let skills_prompt = skills_loader.build_skills_prompt();
let (agent, history_raw, mut compressor, base_version, cancel_rx) = {
let mut guard = session.lock().await;
if guard.worker_generation != worker_gen {
@ -1857,7 +1991,9 @@ fn spawn_agent_worker(
task.media.iter().map(|m| m.to_media_ref()).collect();
let user_message =
guard.create_user_message(&task.content, media_refs);
if let Err(e) = guard.add_message(user_message, true).await {
let user_persist = guard.add_message_in_memory(user_message, true);
drop(guard);
if let Err(e) = persist_added_message(user_persist).await {
tracing::error!(error = %e, "Failed to persist user message");
let err_outbound = OutboundMessage {
channel: task_chan.clone(),
@ -1871,61 +2007,12 @@ fn spawn_agent_worker(
let _ = bus.publish_outbound(err_outbound).await;
return;
}
let mut guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
let history_raw = guard.get_history().to_vec();
let skills_prompt = skills_loader.build_skills_prompt();
let memory_context = match memory_manager
.recall(
&task.content,
5,
Some(crate::memory::MemoryCategory::Knowledge),
None,
)
.await
{
Ok(entries) if !entries.is_empty() => Some(
entries
.iter()
.map(|e| format!("- {}: {}", e.key, e.content))
.collect::<Vec<_>>()
.join("\n"),
),
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
};
let system_prompt = guard
.build_system_prompt(&skills_prompt, memory_context.as_deref());
let result = guard
.compressor
.compress_if_needed(history_raw)
.await
.map(|r| {
if r.created_timelines {
guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
}
r.history
})
.unwrap_or_else(|e| {
tracing::warn!(
error = %e,
"Context compression failed in worker"
);
guard.get_history().to_vec()
});
let mut history = result;
history.insert(0, ChatMessage::system(system_prompt.clone()));
let now = chrono::Utc::now().timestamp_millis();
guard.last_consolidated_at = Some(now);
let _ = guard.persist_session_meta().await;
let agent = match guard.create_agent_with_notify(notify_tx) {
Ok(a) => a,
@ -1952,9 +2039,87 @@ fn spawn_agent_worker(
}
guard.current_cancel = Some(cancel_tx);
(agent, history, system_prompt, cancel_rx)
(
agent,
history_raw,
guard.fresh_context_compressor(),
guard.state_version,
cancel_rx,
)
}; // lock released
let memory_context = match memory_manager
.recall(
&task.content,
5,
Some(crate::memory::MemoryCategory::Knowledge),
None,
)
.await
{
Ok(entries) if !entries.is_empty() => Some(
entries
.iter()
.map(|e| format!("- {}: {}", e.key, e.content))
.collect::<Vec<_>>()
.join("\n"),
),
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch memory context");
None
}
_ => None,
};
let system_prompt_out = {
let guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
guard.build_system_prompt(&skills_prompt, memory_context.as_deref())
};
let compression_result = compressor.compress_if_needed(history_raw).await;
let mut history_out = match compression_result {
Ok(result) => {
let meta_snapshot = {
let mut guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
if guard.state_version != base_version {
tracing::warn!(
session_id = %guard.id,
"Session changed while preparing agent history; dropping stale task"
);
return;
}
if result.created_timelines {
guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
}
guard.last_consolidated_at =
Some(chrono::Utc::now().timestamp_millis());
guard.session_meta_snapshot()
};
if let Some((storage, meta)) = meta_snapshot
&& let Err(e) = storage.upsert_session(&meta).await
{
tracing::warn!(error = %e, "Failed to persist session meta after compression");
}
result.history
}
Err(e) => {
tracing::warn!(error = %e, "Context compression failed in worker");
let guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
guard.get_history().to_vec()
}
};
history_out.insert(0, ChatMessage::system(system_prompt_out.clone()));
// Phase 2 + 3: LLM call with cancellation
let session2 = session.clone();
let bus2 = bus.clone();
@ -1975,8 +2140,8 @@ fn spawn_agent_worker(
Err(AgentError::LlmError(ref msg))
if is_context_overflow_error(msg) =>
{
let retry_history = {
let mut guard = session2.lock().await;
let (raw, mut retry_compressor, retry_base_version, new_window) = {
let guard = session2.lock().await;
let new_window =
crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
.unwrap_or(guard.compressor_threshold());
@ -1985,31 +2150,56 @@ fn spawn_agent_worker(
error = %msg,
"Context overflow in worker — retrying"
);
(
guard.get_history().to_vec(),
guard.fresh_context_compressor(),
guard.state_version,
new_window,
)
};
retry_compressor.set_context_window(new_window);
let retry_result =
match retry_compressor.compress_if_needed(raw).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Retry compression failed");
let err_outbound = OutboundMessage {
channel: chan2,
chat_id: cid2,
content: "Context overflow handling failed."
.to_string(),
reply_to: None,
media: vec![],
metadata: HashMap::new(),
};
let _ = bus2.publish_outbound(err_outbound).await;
return;
}
};
let meta_snapshot = {
let mut guard = session2.lock().await;
if guard.state_version != retry_base_version {
tracing::warn!(
session_id = %guard.id,
"Session changed while retry-compressing after context overflow"
);
return;
}
guard.compressor.set_context_window(new_window);
let raw = guard.get_history().to_vec();
let retry_result =
match guard.compressor.compress_if_needed(raw).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Retry compression failed");
let err_outbound = OutboundMessage {
channel: chan2,
chat_id: cid2,
content: "Context overflow handling failed."
.to_string(),
reply_to: None,
media: vec![],
metadata: HashMap::new(),
};
let _ = bus2.publish_outbound(err_outbound).await;
return;
}
};
if retry_result.created_timelines {
guard.last_compressed_message_at =
Some(chrono::Utc::now().timestamp_millis());
let _ = guard.persist_session_meta().await;
}
guard.session_meta_snapshot()
};
if let Some((storage, meta)) = meta_snapshot
&& let Err(e) = storage.upsert_session(&meta).await
{
tracing::warn!(error = %e, "Failed to persist session meta after retry compression");
}
let retry_history = {
let mut retry = retry_result.history;
retry.insert(
0,
@ -2055,20 +2245,24 @@ fn spawn_agent_worker(
let response = {
let mut guard = session2.lock().await;
let mut persist_snapshots = Vec::new();
for msg in result.emitted_messages {
guard.add_message(msg, true).await.inspect_err(|e| {
tracing::error!(error = %e, "Failed to persist message")
}).ok();
}
if guard.should_generate_title()
&& let Err(e) = guard.generate_title().await
{
tracing::warn!("failed to generate title: {}", e);
persist_snapshots.push(guard.add_message_in_memory(msg, true));
}
let sent_count = guard.messages.len();
guard.compressor.set_last_api_info(sent_count, result.total_tokens);
result.final_response.content
(result.final_response.content, persist_snapshots)
};
let (response, persist_snapshots) = response;
for snapshot in persist_snapshots {
if let Err(e) = persist_added_message(snapshot).await {
tracing::error!(error = %e, "Failed to persist message");
}
}
if let Err(e) = maybe_generate_title_outside_lock(session2.clone()).await {
tracing::warn!("failed to generate title: {}", e);
}
let outbound = OutboundMessage {
channel: chan2,
@ -2158,25 +2352,34 @@ impl SessionManager {
unified_id: &UnifiedSessionId,
) -> Result<(), AgentError> {
let session = self.get_or_create_session(unified_id).await?;
let mut session_guard = session.lock().await;
// Clear in-memory
session_guard.messages.clear();
session_guard.seq_counter = 1;
session_guard.total_message_count = 0;
session_guard.message_count = 0;
session_guard.last_consolidated_at = None;
session_guard.last_compressed_message_at = None;
// Clear Storage
if let Some(ref storage) = session_guard.storage {
let (storage, session_id, meta_snapshot) = {
let mut session_guard = session.lock().await;
// Clear in-memory
session_guard.messages.clear();
session_guard.seq_counter = 1;
session_guard.total_message_count = 0;
session_guard.message_count = 0;
session_guard.last_consolidated_at = None;
session_guard.last_compressed_message_at = None;
session_guard.state_version = session_guard.state_version.wrapping_add(1);
(
session_guard.storage.clone(),
session_guard.id.to_string(),
session_guard.session_meta_snapshot(),
)
};
// Clear Storage outside the session lock.
if let Some(storage) = storage {
storage
.clear_messages(&session_guard.id.to_string())
.clear_messages(&session_id)
.await
.map_err(|e| AgentError::Other(format!("failed to clear messages: {}", e)))?;
}
session_guard
.persist_session_meta()
.await
.map_err(|e| AgentError::Other(format!("failed to persist cleared session: {}", e)))?;
if let Some((storage, meta)) = meta_snapshot {
storage.upsert_session(&meta).await.map_err(|e| {
AgentError::Other(format!("failed to persist cleared session: {}", e))
})?;
}
Ok(())
}
}
@ -2235,14 +2438,14 @@ impl OutboundMessenger for SessionManager {
};
// Write source-tagged assistant message to target session history
{
let persist_snapshot = {
let mut guard = session.lock().await;
let msg = ChatMessage::assistant_with_source(marked_content.clone(), source);
guard
.add_message(msg, true)
.await
.map_err(|e| e.to_string())?;
}
guard.add_message_in_memory(msg, true)
};
persist_added_message(persist_snapshot)
.await
.map_err(|e| e.to_string())?;
// Restore active dialog if source and target share channel:chat_id but differ in dialog_id
if let Some(ref origin_id) = origin_id {

View File

@ -4,7 +4,6 @@ use std::time::Duration;
use async_trait::async_trait;
use serde_json::json;
use tokio::io::AsyncReadExt;
use tokio::process::Command;
use tokio::time::timeout;
@ -147,71 +146,55 @@ impl Tool for BashTool {
.map(Path::new)
.unwrap_or_else(|| Path::new("."));
let result = timeout(
Duration::from_secs(timeout_secs),
self.run_command(command, cwd),
)
.await;
match result {
Ok(Ok(output)) => Ok(ToolResult {
match self.run_command(command, cwd, timeout_secs).await {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Ok(Err(e)) => Ok(ToolResult {
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
}),
}
}
}
impl BashTool {
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
async fn run_command(
&self,
command: &str,
cwd: &Path,
timeout_secs: u64,
) -> Result<String, String> {
let mut cmd = Command::new("bash");
cmd.args(["-c", command])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(cwd);
.current_dir(cwd)
.kill_on_drop(true);
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let mut stdout = Vec::new();
let mut stderr = Vec::new();
if let Some(ref mut out) = child.stdout {
out.read_to_end(&mut stdout)
.await
.map_err(|e| format!("Failed to read stdout: {}", e))?;
}
if let Some(ref mut err) = child.stderr {
err.read_to_end(&mut stderr)
.await
.map_err(|e| format!("Failed to read stderr: {}", e))?;
}
let status = child
.wait()
.await
.map_err(|e| format!("Failed to wait: {}", e))?;
let process_output =
match timeout(Duration::from_secs(timeout_secs), child.wait_with_output()).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => return Err(format!("Failed to wait: {}", e)),
Err(_) => {
return Err(format!("Command timed out after {} seconds", timeout_secs));
}
};
let mut output = String::new();
if !stdout.is_empty() {
let stdout_str = String::from_utf8_lossy(&stdout);
if !process_output.stdout.is_empty() {
let stdout_str = String::from_utf8_lossy(&process_output.stdout);
output.push_str(&stdout_str);
}
if !stderr.is_empty() {
let stderr_str = String::from_utf8_lossy(&stderr);
if !process_output.stderr.is_empty() {
let stderr_str = String::from_utf8_lossy(&process_output.stderr);
if !stderr_str.trim().is_empty() {
if !output.is_empty() {
output.push('\n');
@ -221,7 +204,10 @@ impl BashTool {
}
}
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
output.push_str(&format!(
"\nExit code: {}",
process_output.status.code().unwrap_or(-1)
));
Ok(self.truncate_output(&output))
}
@ -309,4 +295,19 @@ mod tests {
assert!(!result.success);
assert!(result.error.unwrap().contains("timed out"));
}
#[tokio::test]
async fn test_large_stderr_does_not_deadlock() {
let tool = BashTool::new().with_timeout(5);
let result = tool
.execute(json!({
"command": "for i in $(seq 1 2000); do echo noisy-error-line >&2; done; echo done"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("done"));
assert!(result.output.contains("STDERR"));
}
}

View File

@ -6,6 +6,8 @@ use crate::tools::path_utils;
use crate::tools::traits::{Tool, ToolResult};
const MAX_CHARS: usize = 128_000;
const MAX_FILE_BYTES: u64 = 5 * 1024 * 1024;
const MAX_BINARY_BYTES: usize = 512 * 1024;
const DEFAULT_LIMIT: usize = 2000;
pub struct FileReadTool {
@ -118,6 +120,29 @@ impl Tool for FileReadTool {
});
}
let metadata = match std::fs::metadata(&resolved) {
Ok(m) => m,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to inspect file: {}", e)),
});
}
};
if metadata.len() > MAX_FILE_BYTES {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"File too large to read safely: {} bytes (max {} bytes). Use a narrower tool or inspect a smaller excerpt.",
metadata.len(),
MAX_FILE_BYTES
)),
});
}
// Read raw bytes and try multiple encodings
let bytes = match std::fs::read(&resolved) {
Ok(b) => b,
@ -209,6 +234,21 @@ impl Tool for FileReadTool {
None => {
// Truly binary file — base64 encode
use base64::{Engine, engine::general_purpose::STANDARD};
if bytes.len() > MAX_BINARY_BYTES {
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
.to_string();
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Binary file too large to inline: {}, {} bytes (max {} bytes).",
mime,
bytes.len(),
MAX_BINARY_BYTES
)),
});
}
let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
@ -229,6 +269,10 @@ impl Tool for FileReadTool {
}
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
if bytes.contains(&0) {
return (None, None);
}
// Try UTF-8 first
if let Ok(text) = std::str::from_utf8(bytes) {
return (Some(text.to_string()), None);
@ -337,4 +381,37 @@ mod tests {
assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file"));
}
#[tokio::test]
async fn test_rejects_large_file_before_reading() {
let mut file = NamedTempFile::new().unwrap();
file.as_file_mut()
.set_len(MAX_FILE_BYTES + 1)
.expect("set large file length");
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("too large"));
}
#[tokio::test]
async fn test_rejects_large_binary_inline() {
let mut file = NamedTempFile::new().unwrap();
let bytes = vec![0_u8; MAX_BINARY_BYTES + 1];
file.write_all(&bytes).unwrap();
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Binary file too large"));
}
}

View File

@ -3,6 +3,7 @@ use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use tokio::net::lookup_host;
use crate::tools::traits::{Tool, ToolResult};
@ -56,6 +57,34 @@ impl HttpRequestTool {
Ok(url.to_string())
}
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
if self.allow_private_hosts {
return Ok(());
}
let host = extract_host(url)?;
if host.parse::<std::net::IpAddr>().is_ok() {
return Ok(());
}
let port = extract_port(url)?;
let addrs = lookup_host((host.as_str(), port))
.await
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
for addr in addrs {
let ip = addr.ip();
if is_private_ip(&ip) {
return Err(format!(
"Blocked host '{}' because DNS resolved to local/private IP {}",
host, ip
));
}
}
Ok(())
}
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
match method.to_uppercase().as_str() {
"GET" => Ok(reqwest::Method::GET),
@ -180,6 +209,34 @@ fn extract_host(url: &str) -> Result<String, String> {
Ok(host)
}
fn extract_port(url: &str) -> Result<u16, String> {
let scheme = if url.starts_with("https://") {
"https"
} else if url.starts_with("http://") {
"http"
} else {
return Err("Only http:// and https:// URLs are allowed".to_string());
};
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if let Some((_, port)) = authority.rsplit_once(':') {
port.parse::<u16>()
.map_err(|_| format!("Invalid URL port: {}", port))
} else if scheme == "https" {
Ok(443)
} else {
Ok(80)
}
}
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
if allowed_domains.iter().any(|domain| domain == "*") {
return true;
@ -226,7 +283,13 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
std::net::IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
}
}
}
@ -294,6 +357,14 @@ impl Tool for HttpRequestTool {
}
};
if let Err(e) = self.validate_resolved_host(&url).await {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
let method = match self.validate_method(method_str) {
Ok(m) => m,
Err(e) => {
@ -309,6 +380,7 @@ impl Tool for HttpRequestTool {
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.redirect(reqwest::redirect::Policy::none())
.build()
{
Ok(c) => c,
@ -436,4 +508,19 @@ mod tests {
async fn test_blocks_local_tld() {
assert!(is_private_host("service.local"));
}
#[test]
fn test_extract_port_defaults_and_explicit() {
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
}
#[test]
fn test_private_ipv6_ranges_are_blocked() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
}
}

View File

@ -13,6 +13,7 @@ pub mod get_skill;
pub mod http_request;
pub mod memory;
pub mod path_utils;
pub mod pty;
pub mod registry;
pub mod schema;
pub mod send_message;
@ -32,6 +33,7 @@ pub use file_write::FileWriteTool;
pub use get_skill::GetSkillTool;
pub use http_request::HttpRequestTool;
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
pub use pty::{PtyManager, PtyTool};
pub use registry::ToolRegistry;
pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult};
@ -60,6 +62,7 @@ pub fn create_default_tools(
registry.register(FileSearchTool::new());
registry.register(ContentSearchTool::new());
registry.register(BashTool::new());
registry.register(PtyTool::new(Arc::new(PtyManager::new())));
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,

View File

@ -160,11 +160,14 @@ impl PtyManager {
};
for session in sessions {
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let child_handle = guard.child.clone();
{
let mut child_guard = child_handle.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
}
*child_guard = None;
}
*child_guard = None;
guard.status = SessionStatus::Killed;
}
}
@ -274,7 +277,7 @@ impl PtyManager {
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
let guard = session.lock().unwrap();
if guard.status != SessionStatus::Running {
return Err("Session is not running".to_string());
}
@ -296,12 +299,7 @@ impl PtyManager {
Ok(format!("OK, wrote {} bytes", byte_count))
}
fn read(
&self,
session_id: &str,
offset: usize,
limit: usize,
) -> Result<String, String> {
fn read(&self, session_id: &str, offset: usize, limit: usize) -> Result<String, String> {
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
@ -352,14 +350,16 @@ impl PtyManager {
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let _ = child.wait();
let child_handle = guard.child.clone();
{
let mut child_guard = child_handle.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let _ = child.wait();
}
*child_guard = None;
}
*child_guard = None;
guard.status = SessionStatus::Killed;
drop(child_guard);
drop(guard);
sessions.remove(session_id);
@ -545,14 +545,8 @@ impl Tool for PtyTool {
});
}
};
let offset = args
.get("offset")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(500) as usize;
let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(500) as usize;
match self.pty_manager.read(session_id, offset, limit) {
Ok(output) => Ok(ToolResult {
success: true,

View File

@ -3,6 +3,7 @@ use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use tokio::net::lookup_host;
use crate::tools::traits::{Tool, ToolResult};
@ -60,9 +61,34 @@ impl WebFetchTool {
}
}
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
let host = extract_host(url)?;
if host.parse::<std::net::IpAddr>().is_ok() {
return Ok(());
}
let port = extract_port(url)?;
let addrs = lookup_host((host.as_str(), port))
.await
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
for addr in addrs {
let ip = addr.ip();
if is_private_ip(&ip) {
return Err(format!(
"Blocked host '{}' because DNS resolved to local/private IP {}",
host, ip
));
}
}
Ok(())
}
async fn fetch_content(&self, url: &str) -> Result<String, String> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
@ -234,6 +260,35 @@ fn extract_host(url: &str) -> Result<String, String> {
Ok(host)
}
fn extract_port(url: &str) -> Result<u16, String> {
let scheme = if url.starts_with("https://") {
"https"
} else if url.starts_with("http://") {
"http"
} else {
return Err("Only http:// and https:// URLs are allowed".to_string());
};
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if let Some((_, port)) = authority.rsplit_once(':') {
port.parse::<u16>()
.map_err(|_| format!("Invalid URL port: {}", port))
} else if scheme == "https" {
Ok(443)
} else {
Ok(80)
}
}
fn is_private_host(host: &str) -> bool {
if host == "localhost" || host.ends_with(".localhost") {
return true;
@ -248,19 +303,32 @@ fn is_private_host(host: &str) -> bool {
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
};
return is_private_ip(&ip);
}
false
}
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
}
}
}
#[async_trait]
impl Tool for WebFetchTool {
fn name(&self) -> &str {
@ -311,6 +379,14 @@ impl Tool for WebFetchTool {
}
};
if let Err(e) = self.validate_resolved_host(&url).await {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
match self.fetch_content(&url).await {
Ok(content) => Ok(ToolResult {
success: true,
@ -357,6 +433,21 @@ mod tests {
assert!(result.unwrap_err().contains("whitespace"));
}
#[test]
fn test_extract_port_defaults_and_explicit() {
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
}
#[test]
fn test_private_ipv6_ranges_are_blocked() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
}
#[tokio::test]
async fn test_extract_text_simple() {
let tool = test_tool();