Compare commits
3 Commits
48c8a51d9a
...
fe2bc3dfd3
| Author | SHA1 | Date | |
|---|---|---|---|
| fe2bc3dfd3 | |||
| e707774175 | |||
| ad7fa70a02 |
@ -3,7 +3,9 @@
|
|||||||
.gitignore
|
.gitignore
|
||||||
|
|
||||||
# Build artifacts
|
# Build artifacts
|
||||||
target/
|
target/*
|
||||||
|
!target/release/
|
||||||
|
target/release/*
|
||||||
!target/release/picobot
|
!target/release/picobot
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "picobot"
|
name = "picobot"
|
||||||
version = "1.1.0"
|
version = "1.1.2"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|||||||
@ -55,11 +55,8 @@ RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
|
|||||||
&& npm cache clean --force \
|
&& npm cache clean --force \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install himalaya (CLI email client) from local file
|
# Install himalaya (CLI email client) from the official pre-built binary release
|
||||||
COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
|
RUN curl -sSL https://raw.githubusercontent.com/pimalaya/himalaya/master/install.sh | sh
|
||||||
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
|
|
||||||
&& chmod +x /usr/local/bin/himalaya \
|
|
||||||
&& rm -f /tmp/himalaya.tgz
|
|
||||||
|
|
||||||
# Install fd (alternative to find)
|
# Install fd (alternative to find)
|
||||||
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \
|
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \
|
||||||
|
|||||||
@ -18,7 +18,9 @@ PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只
|
|||||||
|
|
||||||
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
||||||
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
||||||
- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
|
- 已修复:Session 主处理路径不再在持有 session mutex 时执行 memory recall、上下文压缩、标题 LLM 生成、消息持久化、`/stop` sub-agent 取消或清历史存储操作;慢操作改为锁外执行并用 `state_version`/`worker_generation` 防止陈旧结果覆盖当前会话。
|
||||||
|
- 已修复:Bash 超时清理、文件读取大文件限制、HTTP DNS 私网校验、Bus 关闭退出、Cron `from` 语义和 PTY 工具接入等中等级问题已完成清扫。
|
||||||
|
- 待处理:工具文件边界仍是后续质量风险。
|
||||||
|
|
||||||
## 主要发现
|
## 主要发现
|
||||||
|
|
||||||
@ -108,7 +110,7 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
||||||
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
||||||
|
|
||||||
### 中高优先级:Session 锁内执行过多异步操作
|
### 已修复:Session 锁内执行过多异步操作
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -125,13 +127,17 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
||||||
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 锁内只做内存状态快照和必要的状态标记。
|
- 为 `Session` 增加 `state_version`,慢操作提交前检查会话是否已被 `/stop`、清历史或其它内存变更替换。
|
||||||
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
|
- `/compact` 改为锁内取 history 快照,锁外压缩,锁内提交压缩结果,锁外持久化 meta。
|
||||||
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
|
- agent worker Phase 1 改为锁内只创建用户消息、agent、cancel handle 和 history 快照;memory recall 与 context compression 都在锁外执行。
|
||||||
|
- context overflow retry 的二次压缩移到锁外。
|
||||||
|
- 标题生成改为锁内取 prompt/provider 快照,锁外调用 LLM,锁内应用标题,锁外持久化。
|
||||||
|
- `add_message` 拆出内存更新和持久化快照,主消息路径在释放 session 锁后写入 SQLite。
|
||||||
|
- `/stop` 和清历史不再持有 session 锁等待 sub-agent 取消或 Storage 操作。
|
||||||
|
|
||||||
### 中优先级:Bash 超时不会显式终止子进程
|
### 已修复:Bash 超时不会显式终止子进程
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -146,14 +152,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
|
|
||||||
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。
|
- Bash 一次性命令改用 `wait_with_output()`,避免 stdout/stderr 顺序读取造成 pipe 阻塞。
|
||||||
- 超时分支显式 kill child 并 wait。
|
- 子进程启用 `kill_on_drop(true)`,超时后丢弃等待 future 时会清理 child。
|
||||||
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
|
- 新增大 stderr 输出测试,覆盖不会因为 stderr pipe 填满而卡住。
|
||||||
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
|
- 持久/交互式进程通过已接入的 PTY 工具承载。
|
||||||
|
|
||||||
### 中优先级:文件读取对大二进制文件没有输出上限
|
### 已修复:文件读取对大二进制文件没有输出上限
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -168,13 +174,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
|
|
||||||
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 先检查 metadata size,超过阈值直接返回提示。
|
- `file_read` 在读取前检查 metadata size,超过安全阈值直接拒绝。
|
||||||
- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。
|
- 二进制 inline base64 增加单独大小上限,超限只返回错误和文件信息。
|
||||||
- 对文本读取也改成流式按行读取,而不是整文件读入。
|
- 含 NUL 字节内容按二进制处理,避免全 0 文件被 UTF-8 路径误判为文本。
|
||||||
|
- 增加大文件和大二进制文件测试。
|
||||||
|
|
||||||
### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
### 已修复:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -188,13 +195,14 @@ Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“
|
|||||||
|
|
||||||
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。
|
- `http_request` 和 `web_fetch` 在发送请求前通过 DNS 解析 host,并拒绝解析到 loopback、private、link-local、multicast、unspecified 的地址。
|
||||||
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
|
- IPv6 unique-local 和 link-local 地址也纳入私网判定。
|
||||||
- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。
|
- 禁用 reqwest 自动重定向,避免跳转到未校验的内网地址。
|
||||||
|
- 增加端口解析和 IPv6 私网判断测试。
|
||||||
|
|
||||||
### 中优先级:后台任务和主循环缺少监督与优雅关闭
|
### 已修复:后台任务和主循环缺少监督与优雅关闭
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -212,13 +220,14 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
|
|||||||
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
||||||
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。
|
- `MessageBus::consume_inbound/consume_outbound/consume_control` 不再在 channel 关闭时 `expect()` panic,改为返回 `Option<T>`。
|
||||||
- 用 `CancellationToken` 贯穿 Gateway 子任务。
|
- Gateway message processor 在 inbound/control bus 关闭时记录 warning 并退出 loop。
|
||||||
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
|
- OutboundDispatcher 在 outbound bus 关闭时记录 warning 并退出 loop。
|
||||||
|
- 这不是完整 runtime supervisor,但已消除 bus 关闭导致的 panic 崩溃路径,为后续集中 JoinHandle 管理留出接口。
|
||||||
|
|
||||||
### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
### 已修复:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -232,13 +241,13 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
|
|||||||
|
|
||||||
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
|
- cron 分支改用 `cron_schedule.after(&from_dt).next()`。
|
||||||
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
|
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为计算起点。
|
||||||
- 增加固定时间输入的单元测试,避免受系统时间影响。
|
- 增加 UTC 和 Asia/Shanghai 固定时间输入测试。
|
||||||
|
|
||||||
### 中低优先级:存在未接入或半接入代码,增加维护噪音
|
### 已修复:存在未接入或半接入代码,增加维护噪音
|
||||||
|
|
||||||
位置:
|
位置:
|
||||||
|
|
||||||
@ -254,10 +263,11 @@ Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHan
|
|||||||
|
|
||||||
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
||||||
|
|
||||||
建议:
|
已采取修复:
|
||||||
|
|
||||||
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
|
- `src/tools/pty.rs` 已接入 `tools/mod.rs`,导出 `PtyManager`/`PtyTool`。
|
||||||
- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。
|
- `create_default_tools()` 默认注册共享 `PtyManager` 的 `PtyTool`。
|
||||||
|
- 修复 PTY 原本因未编译暴露不出的借用问题。
|
||||||
|
|
||||||
## 架构评价
|
## 架构评价
|
||||||
|
|
||||||
|
|||||||
496
docs/INTERACTIVE_CHANNEL_DESIGN.md
Normal file
496
docs/INTERACTIVE_CHANNEL_DESIGN.md
Normal file
@ -0,0 +1,496 @@
|
|||||||
|
# PicoBot 跨渠道交互式消息规划
|
||||||
|
|
||||||
|
规划日期:2026-06-16
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
飞书交互式卡片可以让用户直接在消息卡片上点击按钮、提交选择或触发回调。这个能力很适合用于工具调用审批、快捷回复、任务确认、表单收集等 agent 交互。
|
||||||
|
|
||||||
|
参考项目调研结果:
|
||||||
|
|
||||||
|
- `reference/zeroclaw` 已实现飞书/Lark 工具审批卡片:发送 Card JSON 2.0 按钮卡片,收到 `card.action.trigger` 后解析 `approval_id` 和 `decision`,唤醒等待中的 approval future,并 PATCH 原卡片为已处理状态。
|
||||||
|
- `reference/nanobot` 主要使用飞书 CardKit 做 agent 输出展示和流式更新,适合参考消息渲染体验,但没有完整的按钮回调驱动 agent 流程。
|
||||||
|
- `reference/openlark` 是 SDK/API 封装,支持发送 interactive card 和 CardKit API,不包含完整 agent channel 编排。
|
||||||
|
|
||||||
|
PicoBot 当前飞书渠道已经会把普通 markdown 回复发送成 interactive card,但还缺少“用户在卡片上操作 -> 统一交互事件 -> Session/Agent/Tool 流程继续”的抽象。
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
1. 支持飞书交互式卡片按钮回调。
|
||||||
|
2. 设计成跨渠道能力,后续 Slack、Telegram、Discord、CLI chat 等渠道可以复用同一套交互语义。
|
||||||
|
3. 支持渠道降级:不支持按钮的渠道也能用纯文本命令完成同样操作。
|
||||||
|
4. 保持 PicoBot 现有边界:Channel 只做收发和渠道适配,SessionManager 管会话,AgentLoop 执行 LLM 和工具。
|
||||||
|
5. 为工具调用审批、快捷回复和未来表单交互预留扩展点。
|
||||||
|
|
||||||
|
非目标:
|
||||||
|
|
||||||
|
- 本阶段不立即实现完整功能。
|
||||||
|
- 不把飞书卡片细节泄漏到 AgentLoop 或工具层。
|
||||||
|
- 不要求所有渠道同时支持原生交互组件。
|
||||||
|
|
||||||
|
## 核心原则
|
||||||
|
|
||||||
|
交互语义和渠道渲染分离。
|
||||||
|
|
||||||
|
Agent、工具或 Session 层只表达“我要一个 approval/quick reply/form interaction”。具体是飞书卡片按钮、Slack Block Kit、Telegram inline keyboard,还是 CLI 里显示编号选项,由 Channel 根据能力渲染。
|
||||||
|
|
||||||
|
回调也要统一。
|
||||||
|
|
||||||
|
飞书的 `card.action.trigger`、Telegram 的 `callback_query`、Slack 的 interaction payload 都应归一化成 PicoBot 内部的 `InteractionEvent`,再交给统一的处理器。
|
||||||
|
|
||||||
|
## 数据模型
|
||||||
|
|
||||||
|
建议新增一个 `interaction` 模块,定义渠道无关的数据结构。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub enum InteractionKind {
|
||||||
|
QuickReply,
|
||||||
|
Approval,
|
||||||
|
FormSubmit,
|
||||||
|
Command,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub enum InteractionStyle {
|
||||||
|
Default,
|
||||||
|
Primary,
|
||||||
|
Danger,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct InteractionAction {
|
||||||
|
pub id: String,
|
||||||
|
pub label: String,
|
||||||
|
pub value: String,
|
||||||
|
pub style: InteractionStyle,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct InteractionPayload {
|
||||||
|
pub interaction_id: String,
|
||||||
|
pub kind: InteractionKind,
|
||||||
|
pub title: Option<String>,
|
||||||
|
pub body: String,
|
||||||
|
pub actions: Vec<InteractionAction>,
|
||||||
|
pub expires_at: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct InteractionEvent {
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub sender_id: String,
|
||||||
|
pub interaction_id: String,
|
||||||
|
pub action_id: String,
|
||||||
|
pub action_value: String,
|
||||||
|
pub timestamp: i64,
|
||||||
|
pub metadata: std::collections::HashMap<String, String>,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`InteractionPayload` 用于 outbound 渲染,`InteractionEvent` 用于 inbound 回调。
|
||||||
|
|
||||||
|
## OutboundMessage 扩展
|
||||||
|
|
||||||
|
短期兼容方案:
|
||||||
|
|
||||||
|
- 继续使用 `OutboundMessage.metadata` 携带交互描述。
|
||||||
|
- 例如:
|
||||||
|
- `interaction.kind = "approval"`
|
||||||
|
- `interaction.id = "<uuid>"`
|
||||||
|
- `interaction.actions = "<json>"`
|
||||||
|
|
||||||
|
长期推荐方案:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct OutboundMessage {
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub content: String,
|
||||||
|
pub reply_to: Option<String>,
|
||||||
|
pub media: Vec<MediaItem>,
|
||||||
|
pub metadata: HashMap<String, String>,
|
||||||
|
pub interaction: Option<InteractionPayload>,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
推荐长期方案。它能避免把结构化交互塞进字符串 metadata,也让每个 Channel 的 `send()` 更清晰。
|
||||||
|
|
||||||
|
## Channel 能力声明
|
||||||
|
|
||||||
|
给 `Channel` 增加可选能力声明:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct ChannelCapabilities {
|
||||||
|
pub interactive_buttons: bool,
|
||||||
|
pub forms: bool,
|
||||||
|
pub message_update: bool,
|
||||||
|
pub markdown_cards: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Channel {
|
||||||
|
fn capabilities(&self) -> ChannelCapabilities {
|
||||||
|
ChannelCapabilities::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
渠道能力示例:
|
||||||
|
|
||||||
|
| 渠道 | 原生按钮 | 表单 | 更新原消息 | 降级策略 |
|
||||||
|
|------|----------|------|------------|----------|
|
||||||
|
| Feishu | 是,interactive card | 可后续支持 | 是,PATCH message/card | 文本命令 |
|
||||||
|
| Slack | 是,Block Kit | 是 | 是 | 文本命令 |
|
||||||
|
| Telegram | 是,inline keyboard | 有限 | 可编辑消息 | 文本命令 |
|
||||||
|
| Discord | 是,components | 有限 | 可编辑消息 | 文本命令 |
|
||||||
|
| CLI chat | 否 | 否 | 局部可模拟 | 编号/命令输入 |
|
||||||
|
| Webhook/Email/SMS | 否 | 否 | 通常否 | 纯文本命令或链接 |
|
||||||
|
|
||||||
|
## 渲染策略
|
||||||
|
|
||||||
|
每个 channel 实现一个渠道内的渲染函数:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
async fn send_interaction(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
payload: &InteractionPayload,
|
||||||
|
) -> Result<(), ChannelError>;
|
||||||
|
```
|
||||||
|
|
||||||
|
也可以先不改 trait,在 `send()` 内部判断 `msg.interaction`。
|
||||||
|
|
||||||
|
飞书渲染:
|
||||||
|
|
||||||
|
- 使用 Card JSON 2.0。
|
||||||
|
- `schema = "2.0"`。
|
||||||
|
- body 用 markdown 展示 `payload.body`。
|
||||||
|
- actions 渲染为 button。
|
||||||
|
- 每个按钮的 callback value 写入:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"interaction_id": "...",
|
||||||
|
"action_id": "...",
|
||||||
|
"action_value": "approve"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
需要兼容飞书 Card 2.0 回调路径:
|
||||||
|
|
||||||
|
- `/action/value`
|
||||||
|
- `/action/behaviors/0/value`
|
||||||
|
|
||||||
|
CLI 降级渲染:
|
||||||
|
|
||||||
|
```text
|
||||||
|
需要确认:
|
||||||
|
|
||||||
|
Tool: bash
|
||||||
|
Args: cargo test --lib
|
||||||
|
|
||||||
|
可选操作:
|
||||||
|
1. Approve
|
||||||
|
2. Deny
|
||||||
|
3. Always approve
|
||||||
|
|
||||||
|
回复:
|
||||||
|
/_interaction <interaction_id> approve
|
||||||
|
/_interaction <interaction_id> deny
|
||||||
|
/_interaction <interaction_id> always
|
||||||
|
```
|
||||||
|
|
||||||
|
纯文本渠道都可以复用这个 fallback renderer。
|
||||||
|
|
||||||
|
## Inbound 回调归一化
|
||||||
|
|
||||||
|
飞书 WebSocket 当前在 `src/channels/feishu.rs` 里处理 `im.message.receive_v1`。需要新增对 `card.action.trigger` 的识别:
|
||||||
|
|
||||||
|
1. ACK 仍要尽快发送,飞书要求 3 秒内响应。
|
||||||
|
2. 如果 event type 是 `card.action.trigger`,不要走普通消息解析。
|
||||||
|
3. 从 event payload 中解析 `interaction_id`、`action_id`、`action_value`。
|
||||||
|
4. 构造 `InteractionEvent` 发布给统一处理器。
|
||||||
|
5. 对未知、过期或重复 interaction 返回成功但记录日志,不应导致渠道重连或报错。
|
||||||
|
|
||||||
|
如果短期不新增 interaction bus,可以把回调转成特殊 `InboundMessage`:
|
||||||
|
|
||||||
|
```text
|
||||||
|
content = "/_interaction <interaction_id> <action_value>"
|
||||||
|
metadata["event.kind"] = "interaction"
|
||||||
|
metadata["interaction.id"] = "<interaction_id>"
|
||||||
|
metadata["interaction.action_id"] = "<action_id>"
|
||||||
|
metadata["interaction.action_value"] = "<action_value>"
|
||||||
|
```
|
||||||
|
|
||||||
|
但必须由 SessionManager 或 InteractionManager 先拦截,不能把 `/_interaction` 当普通用户文本直接送进 LLM。
|
||||||
|
|
||||||
|
长期推荐新增 bus 通道:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum InboundEvent {
|
||||||
|
Message(InboundMessage),
|
||||||
|
Interaction(InteractionEvent),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
或者在 `MessageBus` 上增加 `interaction_tx`。
|
||||||
|
|
||||||
|
## InteractionManager
|
||||||
|
|
||||||
|
建议新增 `InteractionManager`,集中管理 pending 交互状态,而不是让每个 Channel 各自维护。
|
||||||
|
|
||||||
|
职责:
|
||||||
|
|
||||||
|
- 生成 `interaction_id`。
|
||||||
|
- 保存 pending interaction。
|
||||||
|
- 处理超时和过期。
|
||||||
|
- 接收 `InteractionEvent` 并解析成业务结果。
|
||||||
|
- 对重复点击、未知 interaction、过期 interaction 做幂等处理。
|
||||||
|
- 必要时通知 channel 更新原消息。
|
||||||
|
|
||||||
|
内部状态示例:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct PendingInteraction {
|
||||||
|
pub id: String,
|
||||||
|
pub kind: InteractionKind,
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub sender_id: Option<String>,
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
pub created_at: i64,
|
||||||
|
pub expires_at: Option<i64>,
|
||||||
|
pub status: InteractionStatus,
|
||||||
|
pub responder: InteractionResponder,
|
||||||
|
pub message_ref: Option<InteractionMessageRef>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct InteractionMessageRef {
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub message_id: String,
|
||||||
|
pub metadata: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`InteractionResponder` 可以先支持 oneshot:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum InteractionResponder {
|
||||||
|
Approval(tokio::sync::oneshot::Sender<ApprovalDecision>),
|
||||||
|
InboundMessage,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
后续如果需要持久化长期交互,oneshot 不够,需要落库。
|
||||||
|
|
||||||
|
## 工具审批流程
|
||||||
|
|
||||||
|
工具审批是第一批最适合落地的交互类型。
|
||||||
|
|
||||||
|
推荐流程:
|
||||||
|
|
||||||
|
1. AgentLoop 准备执行需要审批的工具。
|
||||||
|
2. 调用 `InteractionManager::request_approval(...)`。
|
||||||
|
3. InteractionManager 创建 `InteractionPayload`,通过 outbound 发送到原 channel/chat。
|
||||||
|
4. AgentLoop 等待 oneshot,带 timeout。
|
||||||
|
5. 用户在飞书卡片上点击 Approve/Deny/Always。
|
||||||
|
6. FeishuChannel 收到 `card.action.trigger`,发布 `InteractionEvent`。
|
||||||
|
7. InteractionManager resolve pending approval。
|
||||||
|
8. AgentLoop 收到结果,继续执行或拒绝工具。
|
||||||
|
9. 如果 channel 支持更新消息,InteractionManager 或 Channel 把原卡片更新成 resolved 状态。
|
||||||
|
|
||||||
|
审批 action 建议:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
approve -> ApprovalDecision::Approve
|
||||||
|
deny -> ApprovalDecision::Deny
|
||||||
|
always -> ApprovalDecision::AlwaysApprove
|
||||||
|
```
|
||||||
|
|
||||||
|
`DenyWithEdit` 可后续支持,适合 ACP/Web/CLI 这类能输入文本的渠道。
|
||||||
|
|
||||||
|
## 快捷回复流程
|
||||||
|
|
||||||
|
快捷回复不是阻塞工具执行,而是把用户点击转成新的用户输入。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"kind": "QuickReply",
|
||||||
|
"body": "你想继续哪个操作?",
|
||||||
|
"actions": [
|
||||||
|
{ "label": "继续分析", "value": "继续分析" },
|
||||||
|
{ "label": "生成报告", "value": "生成报告" }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
用户点击后:
|
||||||
|
|
||||||
|
- `InteractionEvent.action_value` 转成一条普通 `InboundMessage.content`。
|
||||||
|
- `sender_id` 和 `chat_id` 保留原用户和会话。
|
||||||
|
- metadata 标记来源为 interaction,供审计或 UI 使用。
|
||||||
|
|
||||||
|
## 消息更新
|
||||||
|
|
||||||
|
支持原消息更新的渠道应在交互完成后更新 UI,避免重复点击。
|
||||||
|
|
||||||
|
飞书:
|
||||||
|
|
||||||
|
- 发送卡片后保存 `data.message_id`。
|
||||||
|
- resolve 后 PATCH `/im/v1/messages/{message_id}`。
|
||||||
|
- 卡片 schema 发送和更新都使用 Card JSON 2.0,参考项目指出跨版本 PATCH 可能返回成功但客户端不重渲染。
|
||||||
|
|
||||||
|
不支持更新的渠道:
|
||||||
|
|
||||||
|
- 发送一条新消息提示“已批准/已拒绝”。
|
||||||
|
- 或仅在后台幂等拒绝重复点击。
|
||||||
|
|
||||||
|
## 安全和权限
|
||||||
|
|
||||||
|
交互回调必须校验:
|
||||||
|
|
||||||
|
- `interaction_id` 是否存在。
|
||||||
|
- 是否已过期。
|
||||||
|
- 是否已处理。
|
||||||
|
- 点击用户是否允许处理该 interaction。
|
||||||
|
- 当前 channel/chat 是否匹配。
|
||||||
|
|
||||||
|
对于工具审批,默认建议只有触发该 agent turn 的用户或允许列表用户可以审批。群聊里要特别注意 `sender_id`,不能只看 `chat_id`。
|
||||||
|
|
||||||
|
日志中避免记录原始飞书回调敏感字段:
|
||||||
|
|
||||||
|
- callback token
|
||||||
|
- operator open_id/union_id/user_id/tenant_key
|
||||||
|
- open_chat_id/open_message_id
|
||||||
|
|
||||||
|
可以记录脱敏后的 payload shape,用于排查飞书回调字段变化。
|
||||||
|
|
||||||
|
## 持久化策略
|
||||||
|
|
||||||
|
第一阶段可以只做内存 pending map:
|
||||||
|
|
||||||
|
- 适合短时工具审批。
|
||||||
|
- 进程重启后旧按钮点击会变成 unknown/expired。
|
||||||
|
- 实现简单。
|
||||||
|
|
||||||
|
后续如果要支持长期任务或跨重启交互,需要持久化:
|
||||||
|
|
||||||
|
- `interactions` 表保存 id、kind、channel、chat_id、sender_id、status、payload、created_at、expires_at。
|
||||||
|
- `interaction_actions` 可选,或直接 JSON 存在 payload 中。
|
||||||
|
- resolve 时事务更新 status,防止重复点击竞态。
|
||||||
|
|
||||||
|
## 与现有架构的关系
|
||||||
|
|
||||||
|
现有数据流:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Channel -> MessageBus -> SessionManager -> AgentLoop -> tools -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
|
||||||
|
```
|
||||||
|
|
||||||
|
加入 interaction 后建议:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Outbound:
|
||||||
|
AgentLoop/Tool approval -> InteractionManager -> MessageBus outbound -> OutboundDispatcher -> Channel renderer
|
||||||
|
|
||||||
|
Inbound:
|
||||||
|
Channel callback -> InteractionEvent -> InteractionManager -> pending waiter / synthetic InboundMessage
|
||||||
|
```
|
||||||
|
|
||||||
|
Channel 仍然只做渠道协议适配:
|
||||||
|
|
||||||
|
- 飞书负责 Card JSON 和 `card.action.trigger`。
|
||||||
|
- Slack 负责 Block Kit 和 signing secret。
|
||||||
|
- Telegram 负责 callback query。
|
||||||
|
- CLI 负责文本命令 fallback。
|
||||||
|
|
||||||
|
InteractionManager 负责语义:
|
||||||
|
|
||||||
|
- 这是 approval 还是 quick reply。
|
||||||
|
- 是否过期。
|
||||||
|
- 是否有权限。
|
||||||
|
- 应该唤醒哪个等待者。
|
||||||
|
|
||||||
|
SessionManager/AgentLoop 不需要知道飞书卡片格式。
|
||||||
|
|
||||||
|
## 分阶段实施计划
|
||||||
|
|
||||||
|
### 阶段 1:模型和 fallback
|
||||||
|
|
||||||
|
- 新增 `src/interaction/` 模块。
|
||||||
|
- 定义 `InteractionPayload`、`InteractionAction`、`InteractionEvent`。
|
||||||
|
- 增加 fallback text renderer。
|
||||||
|
- 为 `OutboundMessage` 增加 `interaction: Option<InteractionPayload>`,或短期使用 metadata。
|
||||||
|
- 增加单元测试覆盖序列化和 fallback 文本。
|
||||||
|
|
||||||
|
### 阶段 2:Feishu card action
|
||||||
|
|
||||||
|
- Feishu outbound 支持把 `InteractionPayload` 渲染为 Card JSON 2.0。
|
||||||
|
- Feishu inbound 在 WebSocket frame 中识别 `card.action.trigger`。
|
||||||
|
- 解析 `/action/value` 和 `/action/behaviors/0/value`。
|
||||||
|
- 发布统一 `InteractionEvent`。
|
||||||
|
- 保存 `message_id`,支持完成后 PATCH resolved card。
|
||||||
|
- 增加 fixtures 测试真实/模拟的 `card.action.trigger` payload。
|
||||||
|
|
||||||
|
### 阶段 3:InteractionManager 和审批
|
||||||
|
|
||||||
|
- 新增内存版 `InteractionManager`。
|
||||||
|
- 支持 request/resolve/timeout。
|
||||||
|
- 接入工具执行前审批点。
|
||||||
|
- 支持 Approve/Deny/AlwaysApprove。
|
||||||
|
- 未知、过期、重复点击保持幂等。
|
||||||
|
- 增加 agent/tool 审批单元测试。
|
||||||
|
|
||||||
|
### 阶段 4:其他渠道兼容
|
||||||
|
|
||||||
|
- CLI chat 支持 `/_interaction <id> <value>` fallback。
|
||||||
|
- 其他不支持原生按钮的渠道使用纯文本 fallback。
|
||||||
|
- 后续按需实现 Slack/Telegram/Discord 原生按钮。
|
||||||
|
|
||||||
|
### 阶段 5:持久化和高级交互
|
||||||
|
|
||||||
|
- 需要时落库 pending interaction。
|
||||||
|
- 支持 quick reply 生成 synthetic inbound message。
|
||||||
|
- 支持表单提交。
|
||||||
|
- 支持 `DenyWithEdit`。
|
||||||
|
- 支持长期任务交互和重启恢复。
|
||||||
|
|
||||||
|
## 测试计划
|
||||||
|
|
||||||
|
单元测试:
|
||||||
|
|
||||||
|
- `InteractionPayload` 序列化。
|
||||||
|
- fallback text renderer 输出。
|
||||||
|
- Feishu Card JSON 包含正确 callback value。
|
||||||
|
- Feishu 回调同时支持 `/action/value` 和 `/action/behaviors/0/value`。
|
||||||
|
- unknown/expired interaction 不报错。
|
||||||
|
- 重复点击只 resolve 一次。
|
||||||
|
|
||||||
|
集成测试:
|
||||||
|
|
||||||
|
- 模拟 Feishu `card.action.trigger`,验证 pending approval 被唤醒。
|
||||||
|
- 模拟超时,验证默认 deny。
|
||||||
|
- CLI fallback 输入 `/_interaction`,验证能 resolve。
|
||||||
|
- 不支持按钮的 channel 能收到可读 fallback 文本。
|
||||||
|
|
||||||
|
手工验证:
|
||||||
|
|
||||||
|
- 飞书群聊点击 Approve/Deny/Always。
|
||||||
|
- 私聊点击。
|
||||||
|
- 点击后原卡片更新为 resolved。
|
||||||
|
- 重复点击不会重复执行工具。
|
||||||
|
- 非触发用户点击时按权限策略处理。
|
||||||
|
|
||||||
|
## 开放问题
|
||||||
|
|
||||||
|
1. 工具审批应该由 AgentLoop 直接调用 InteractionManager,还是通过 SessionManager 代理?
|
||||||
|
2. 群聊中是否只允许原始触发者审批,还是允许配置中的所有 allowed user 审批?
|
||||||
|
3. `AlwaysApprove` 的作用域是本次会话、本 dialog、本 chat,还是全局工具策略?
|
||||||
|
4. 是否需要第一阶段就修改 `OutboundMessage` 结构,还是先用 metadata 降低改动面?
|
||||||
|
5. 飞书 CardKit 流式输出是否要与交互卡片统一,还是继续保持普通回复卡片和交互卡片两套路径?
|
||||||
|
|
||||||
@ -28,6 +28,10 @@ fn build_content_blocks(
|
|||||||
) -> Vec<ContentBlock> {
|
) -> Vec<ContentBlock> {
|
||||||
let mut blocks = Vec::new();
|
let mut blocks = Vec::new();
|
||||||
|
|
||||||
|
if !text.is_empty() {
|
||||||
|
blocks.push(ContentBlock::text(text));
|
||||||
|
}
|
||||||
|
|
||||||
if !media_refs.is_empty() {
|
if !media_refs.is_empty() {
|
||||||
for mr in media_refs {
|
for mr in media_refs {
|
||||||
if input_types.contains(&mr.media_type) {
|
if input_types.contains(&mr.media_type) {
|
||||||
@ -59,8 +63,6 @@ fn build_content_blocks(
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if !text.is_empty() {
|
|
||||||
blocks.push(ContentBlock::text(text));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if blocks.is_empty() {
|
if blocks.is_empty() {
|
||||||
@ -858,6 +860,23 @@ mod tests {
|
|||||||
"calculator"
|
"calculator"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_content_blocks_keeps_text_with_media() {
|
||||||
|
let registry = MediaHandlerRegistry::new();
|
||||||
|
let blocks = build_content_blocks(
|
||||||
|
"先看这段文字",
|
||||||
|
&[MediaRef {
|
||||||
|
path: "missing.png".to_string(),
|
||||||
|
media_type: "image".to_string(),
|
||||||
|
}],
|
||||||
|
&[],
|
||||||
|
®istry,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(blocks.first(), Some(ContentBlock::Text { text }) if text == "先看这段文字"));
|
||||||
|
assert!(matches!(blocks.get(1), Some(ContentBlock::Text { text }) if text.contains("用户发来了一个文件")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
@ -24,7 +24,10 @@ impl OutboundDispatcher {
|
|||||||
tracing::info!("OutboundDispatcher started");
|
tracing::info!("OutboundDispatcher started");
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let msg = self.bus.consume_outbound().await;
|
let Some(msg) = self.bus.consume_outbound().await else {
|
||||||
|
tracing::warn!("OutboundDispatcher stopping because outbound bus closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
let channel_name = msg.channel.clone();
|
let channel_name = msg.channel.clone();
|
||||||
let channel = self.channel_manager.get_channel(&channel_name).await;
|
let channel = self.channel_manager.get_channel(&channel_name).await;
|
||||||
|
|||||||
@ -51,17 +51,11 @@ impl MessageBus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Consume an inbound message (Agent -> Bus)
|
/// Consume an inbound message (Agent -> Bus)
|
||||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
pub async fn consume_inbound(&self) -> Option<InboundMessage> {
|
||||||
let msg = self
|
let msg = self.inbound_rx.lock().await.recv().await?;
|
||||||
.inbound_rx
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.expect("bus inbound closed");
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
|
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
|
||||||
msg
|
Some(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Publish an outbound message (Agent -> Bus)
|
/// Publish an outbound message (Agent -> Bus)
|
||||||
@ -75,13 +69,8 @@ impl MessageBus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Consume an outbound message (Dispatcher -> Bus)
|
/// Consume an outbound message (Dispatcher -> Bus)
|
||||||
pub async fn consume_outbound(&self) -> OutboundMessage {
|
pub async fn consume_outbound(&self) -> Option<OutboundMessage> {
|
||||||
self.outbound_rx
|
self.outbound_rx.lock().await.recv().await
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.expect("bus outbound closed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Publish a control message (Channel -> Bus for session management)
|
/// Publish a control message (Channel -> Bus for session management)
|
||||||
@ -94,13 +83,8 @@ impl MessageBus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Consume a control message (ControlProcessor -> Bus)
|
/// Consume a control message (ControlProcessor -> Bus)
|
||||||
pub async fn consume_control(&self) -> ControlMessage {
|
pub async fn consume_control(&self) -> Option<ControlMessage> {
|
||||||
self.control_rx
|
self.control_rx.lock().await.recv().await
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.expect("bus control closed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -165,7 +165,7 @@ struct ParsedMessage {
|
|||||||
open_id: String,
|
open_id: String,
|
||||||
chat_id: String,
|
chat_id: String,
|
||||||
content: String,
|
content: String,
|
||||||
media: Option<MediaItem>,
|
media: Vec<MediaItem>,
|
||||||
/// ID of the message this message is replying to (if any).
|
/// ID of the message this message is replying to (if any).
|
||||||
/// Used to fetch quoted message content for display.
|
/// Used to fetch quoted message content for display.
|
||||||
parent_id: Option<String>,
|
parent_id: Option<String>,
|
||||||
@ -1007,7 +1007,7 @@ impl FeishuChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
if let Some(ref m) = media {
|
for m in &media {
|
||||||
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
|
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1027,7 +1027,7 @@ impl FeishuChannel {
|
|||||||
msg_type: &str,
|
msg_type: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
message_id: &str,
|
message_id: &str,
|
||||||
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
) -> Result<(String, Vec<MediaItem>), ChannelError> {
|
||||||
let (text, media) = match msg_type {
|
let (text, media) = match msg_type {
|
||||||
"text" => {
|
"text" => {
|
||||||
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
@ -1039,20 +1039,40 @@ impl FeishuChannel {
|
|||||||
} else {
|
} else {
|
||||||
content.to_string()
|
content.to_string()
|
||||||
};
|
};
|
||||||
(text, None)
|
(text, Vec::new())
|
||||||
|
}
|
||||||
|
"post" => {
|
||||||
|
let text = parse_post_content(content);
|
||||||
|
let mut media = Vec::new();
|
||||||
|
|
||||||
|
for image_key in collect_post_image_keys(content) {
|
||||||
|
let content_json = serde_json::json!({ "image_key": image_key });
|
||||||
|
match self
|
||||||
|
.download_media("image", &content_json, message_id)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((_text, Some(item))) => media.push(item),
|
||||||
|
Ok((_text, None)) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Failed to download image from Feishu post message");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(text, media)
|
||||||
}
|
}
|
||||||
"post" => (parse_post_content(content), None),
|
|
||||||
"image" | "audio" | "file" | "media" => {
|
"image" | "audio" | "file" | "media" => {
|
||||||
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
|
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
match self
|
match self
|
||||||
.download_media(msg_type, &content_json, message_id)
|
.download_media(msg_type, &content_json, message_id)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok((text, media)) => (text, media),
|
Ok((text, Some(media))) => (text, vec![media]),
|
||||||
Err(_) => (format!("[{}: content unavailable]", msg_type), None),
|
Ok((text, None)) => (text, Vec::new()),
|
||||||
|
Err(_) => (format!("[{}: content unavailable]", msg_type), Vec::new()),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(format!("[{}: content unavailable]", msg_type), None)
|
(format!("[{}: content unavailable]", msg_type), Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"share_chat" => {
|
"share_chat" => {
|
||||||
@ -1062,9 +1082,9 @@ impl FeishuChannel {
|
|||||||
.get("chat_id")
|
.get("chat_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("unknown");
|
.unwrap_or("unknown");
|
||||||
(format!("[shared chat: {}]", chat_id), None)
|
(format!("[shared chat: {}]", chat_id), Vec::new())
|
||||||
} else {
|
} else {
|
||||||
("[shared chat]".to_string(), None)
|
("[shared chat]".to_string(), Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"share_user" => {
|
"share_user" => {
|
||||||
@ -1074,42 +1094,44 @@ impl FeishuChannel {
|
|||||||
.get("user_id")
|
.get("user_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("unknown");
|
.unwrap_or("unknown");
|
||||||
(format!("[shared user: {}]", user_id), None)
|
(format!("[shared user: {}]", user_id), Vec::new())
|
||||||
} else {
|
} else {
|
||||||
("[shared user]".to_string(), None)
|
("[shared user]".to_string(), Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"interactive" => {
|
"interactive" => {
|
||||||
// Interactive card messages - extract text content
|
// Interactive card messages - extract text content
|
||||||
match extract_interactive_content(content) {
|
match extract_interactive_content(content) {
|
||||||
Ok((text, media)) => (text, media),
|
Ok((text, Some(media))) => (text, vec![media]),
|
||||||
|
Ok((text, None)) => (text, Vec::new()),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to extract interactive content");
|
tracing::warn!(error = %e, "Failed to extract interactive content");
|
||||||
(content.to_string(), None)
|
(content.to_string(), Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"list" => {
|
"list" => {
|
||||||
// List/bullet messages
|
// List/bullet messages
|
||||||
match parse_list_content(content) {
|
match parse_list_content(content) {
|
||||||
Ok((text, media)) => (text, media),
|
Ok((text, Some(media))) => (text, vec![media]),
|
||||||
Err(_) => (content.to_string(), None),
|
Ok((text, None)) => (text, Vec::new()),
|
||||||
|
Err(_) => (content.to_string(), Vec::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"merge_forward" => ("[merged forward messages]".to_string(), None),
|
"merge_forward" => ("[merged forward messages]".to_string(), Vec::new()),
|
||||||
"share_calendar_event" => {
|
"share_calendar_event" => {
|
||||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
let event_key = parsed
|
let event_key = parsed
|
||||||
.get("event_key")
|
.get("event_key")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("unknown");
|
.unwrap_or("unknown");
|
||||||
(format!("[shared calendar event: {}]", event_key), None)
|
(format!("[shared calendar event: {}]", event_key), Vec::new())
|
||||||
} else {
|
} else {
|
||||||
("[shared calendar event]".to_string(), None)
|
("[shared calendar event]".to_string(), Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"system" => ("[system message]".to_string(), None),
|
"system" => ("[system message]".to_string(), Vec::new()),
|
||||||
_ => (content.to_string(), None),
|
_ => (content.to_string(), Vec::new()),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Strip @_user_N placeholders from group chat @mentions
|
// Strip @_user_N placeholders from group chat @mentions
|
||||||
@ -1235,16 +1257,15 @@ impl FeishuChannel {
|
|||||||
let channel = self.clone();
|
let channel = self.clone();
|
||||||
let bus = bus.clone();
|
let bus = bus.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let media_count = if parsed.media.is_some() { 1 } else { 0 };
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
|
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %parsed.media.len(), "Publishing message to bus");
|
||||||
let msg = crate::bus::InboundMessage {
|
let msg = crate::bus::InboundMessage {
|
||||||
channel: "feishu".to_string(),
|
channel: "feishu".to_string(),
|
||||||
sender_id: parsed.open_id.clone(),
|
sender_id: parsed.open_id.clone(),
|
||||||
chat_id: parsed.chat_id.clone(),
|
chat_id: parsed.chat_id.clone(),
|
||||||
content: parsed.content.clone(),
|
content: parsed.content.clone(),
|
||||||
timestamp: crate::bus::message::current_timestamp(),
|
timestamp: crate::bus::message::current_timestamp(),
|
||||||
media: parsed.media.map(|m| vec![m]).unwrap_or_default(),
|
media: parsed.media.clone(),
|
||||||
metadata: std::collections::HashMap::new(),
|
metadata: std::collections::HashMap::new(),
|
||||||
forwarded_metadata,
|
forwarded_metadata,
|
||||||
};
|
};
|
||||||
@ -1333,6 +1354,52 @@ impl FeishuChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn collect_post_image_keys_finds_nested_images() {
|
||||||
|
let content = serde_json::json!({
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "",
|
||||||
|
"content": [[
|
||||||
|
{"tag": "img", "image_key": "img_v3_001"},
|
||||||
|
{"tag": "text", "text": "这是哪里?"},
|
||||||
|
{"tag": "img", "image_key": "img_v3_002"},
|
||||||
|
{"tag": "img", "image_key": "img_v3_001"}
|
||||||
|
]]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
collect_post_image_keys(&content),
|
||||||
|
vec!["img_v3_001".to_string(), "img_v3_002".to_string()]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_post_content_preserves_image_positions() {
|
||||||
|
let content = serde_json::json!({
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "",
|
||||||
|
"content": [[
|
||||||
|
{"tag": "text", "text": "这是一张图:"},
|
||||||
|
{"tag": "img", "image_key": "img_v3_001"},
|
||||||
|
{"tag": "text", "text": "看完继续说"}
|
||||||
|
]]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
parse_post_content(&content),
|
||||||
|
"这是一张图:[image]看完继续说"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_post_content(content: &str) -> String {
|
fn parse_post_content(content: &str) -> String {
|
||||||
/// Extract text from a single post element (text, link, at-mention).
|
/// Extract text from a single post element (text, link, at-mention).
|
||||||
fn extract_element(el: &serde_json::Value, out: &mut Vec<String>) {
|
fn extract_element(el: &serde_json::Value, out: &mut Vec<String>) {
|
||||||
@ -1359,6 +1426,9 @@ fn parse_post_content(content: &str) -> String {
|
|||||||
.unwrap_or("user");
|
.unwrap_or("user");
|
||||||
out.push(format!("@{}", name));
|
out.push(format!("@{}", name));
|
||||||
}
|
}
|
||||||
|
"img" => {
|
||||||
|
out.push("[image]".to_string());
|
||||||
|
}
|
||||||
"code_block" => {
|
"code_block" => {
|
||||||
let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or("");
|
let lang = el.get("language").and_then(|l| l.as_str()).unwrap_or("");
|
||||||
let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
let code_text = el.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||||
@ -1449,6 +1519,38 @@ fn parse_post_content(content: &str) -> String {
|
|||||||
content.to_string()
|
content.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn collect_post_image_keys(content: &str) -> Vec<String> {
|
||||||
|
fn visit(value: &serde_json::Value, keys: &mut Vec<String>) {
|
||||||
|
match value {
|
||||||
|
serde_json::Value::Object(map) => {
|
||||||
|
if let Some(image_key) = map.get("image_key").and_then(|v| v.as_str())
|
||||||
|
&& !keys.iter().any(|k| k == image_key)
|
||||||
|
{
|
||||||
|
keys.push(image_key.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
for child in map.values() {
|
||||||
|
visit(child, keys);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
serde_json::Value::Array(items) => {
|
||||||
|
for item in items {
|
||||||
|
visit(item, keys);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) else {
|
||||||
|
return Vec::new();
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut keys = Vec::new();
|
||||||
|
visit(&parsed, &mut keys);
|
||||||
|
keys
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract text content from interactive card messages
|
/// Extract text content from interactive card messages
|
||||||
fn extract_interactive_content(content: &str) -> Result<(String, Option<MediaItem>), ChannelError> {
|
fn extract_interactive_content(content: &str) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||||
let parsed = match serde_json::from_str::<serde_json::Value>(content) {
|
let parsed = match serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
|||||||
@ -205,6 +205,10 @@ impl GatewayState {
|
|||||||
tokio::select! {
|
tokio::select! {
|
||||||
// Inbound: AI message flow
|
// Inbound: AI message flow
|
||||||
inbound = bus.consume_inbound() => {
|
inbound = bus.consume_inbound() => {
|
||||||
|
let Some(inbound) = inbound else {
|
||||||
|
tracing::warn!("Message processor stopping because inbound bus closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
match session_manager.handle_message(
|
match session_manager.handle_message(
|
||||||
&inbound.channel,
|
&inbound.channel,
|
||||||
&inbound.sender_id,
|
&inbound.sender_id,
|
||||||
@ -252,6 +256,10 @@ impl GatewayState {
|
|||||||
|
|
||||||
// Control: session management operations
|
// Control: session management operations
|
||||||
msg = bus.consume_control() => {
|
msg = bus.consume_control() => {
|
||||||
|
let Some(msg) = msg else {
|
||||||
|
tracing::warn!("Message processor stopping because control bus closed");
|
||||||
|
break;
|
||||||
|
};
|
||||||
Self::handle_control_message(&session_manager, msg).await;
|
Self::handle_control_message(&session_manager, msg).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,7 +3,7 @@ use clap::{CommandFactory, Parser};
|
|||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "picobot")]
|
#[command(name = "picobot")]
|
||||||
#[command(about = "A CLI chatbot", long_about = None)]
|
#[command(about = "A CLI chatbot", long_about = None)]
|
||||||
#[command(version = "1.1.0")]
|
#[command(version = "1.1.1")]
|
||||||
enum Command {
|
enum Command {
|
||||||
/// Connect to gateway
|
/// Connect to gateway
|
||||||
Chat {
|
Chat {
|
||||||
|
|||||||
@ -150,13 +150,20 @@ struct OpenAIChoice {
|
|||||||
message: OpenAIMessage,
|
message: OpenAIMessage,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn null_or_missing_tool_calls<'de, D>(deserializer: D) -> Result<Vec<OpenAIToolCall>, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
Ok(Option::<Vec<OpenAIToolCall>>::deserialize(deserializer)?.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OpenAIMessage {
|
struct OpenAIMessage {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
reasoning_content: Option<String>,
|
reasoning_content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default, deserialize_with = "null_or_missing_tool_calls")]
|
||||||
tool_calls: Vec<OpenAIToolCall>,
|
tool_calls: Vec<OpenAIToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -418,4 +425,42 @@ mod tests {
|
|||||||
"{\"expression\":\"1+1\"}"
|
"{\"expression\":\"1+1\"}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_response_accepts_null_tool_calls() {
|
||||||
|
let text = r#"{
|
||||||
|
"id": "d21abaa6552741949e2aba76bde59359",
|
||||||
|
"choices": [{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "你好!",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null,
|
||||||
|
"reasoning_content": "The user sent a greeting."
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"created": 1781622889,
|
||||||
|
"model": "mimo-v2.5",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 65,
|
||||||
|
"prompt_tokens": 11741,
|
||||||
|
"total_tokens": 11806,
|
||||||
|
"completion_tokens_details": {"reasoning_tokens": 40},
|
||||||
|
"prompt_tokens_details": {}
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let response: OpenAIResponse = serde_json::from_str(text).unwrap();
|
||||||
|
let message = &response.choices[0].message;
|
||||||
|
|
||||||
|
assert_eq!(message.content.as_deref(), Some("你好!"));
|
||||||
|
assert_eq!(
|
||||||
|
message.reasoning_content.as_deref(),
|
||||||
|
Some("The user sent a greeting.")
|
||||||
|
);
|
||||||
|
assert!(message.tool_calls.is_empty());
|
||||||
|
assert_eq!(response.usage.total_tokens, 11806);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,11 +30,11 @@ pub fn next_run_for_schedule(schedule: &Schedule, from: i64) -> Option<i64> {
|
|||||||
|
|
||||||
let next_utc = if let Some(tz_str) = tz {
|
let next_utc = if let Some(tz_str) = tz {
|
||||||
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
|
let tz: chrono_tz::Tz = tz_str.parse().ok()?;
|
||||||
let _from_local = from_dt.with_timezone(&tz);
|
let from_local = from_dt.with_timezone(&tz);
|
||||||
let next_local = cron_schedule.upcoming(tz).next()?;
|
let next_local = cron_schedule.after(&from_local).next()?;
|
||||||
next_local.with_timezone(&Utc)
|
next_local.with_timezone(&Utc)
|
||||||
} else {
|
} else {
|
||||||
cron_schedule.upcoming(Utc).next()?
|
cron_schedule.after(&from_dt).next()?
|
||||||
};
|
};
|
||||||
|
|
||||||
Some(next_utc.timestamp_millis())
|
Some(next_utc.timestamp_millis())
|
||||||
@ -311,4 +311,37 @@ mod tests {
|
|||||||
let next_ms = next.unwrap();
|
let next_ms = next.unwrap();
|
||||||
assert!(next_ms > now);
|
assert!(next_ms > now);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_cron_uses_from_argument() {
|
||||||
|
let expr = "0 * * * * *".to_string();
|
||||||
|
let schedule = Schedule::Cron { expr, tz: None };
|
||||||
|
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:34:20Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
|
||||||
|
let next = next_run_for_schedule(&schedule, from).unwrap();
|
||||||
|
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T12:35:00Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
assert_eq!(next, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_next_run_cron_timezone_uses_from_argument() {
|
||||||
|
let expr = "0 0 9 * * *".to_string();
|
||||||
|
let schedule = Schedule::Cron {
|
||||||
|
expr,
|
||||||
|
tz: Some("Asia/Shanghai".to_string()),
|
||||||
|
};
|
||||||
|
let from = chrono::DateTime::parse_from_rfc3339("2026-06-16T00:30:00Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
|
||||||
|
let next = next_run_for_schedule(&schedule, from).unwrap();
|
||||||
|
let expected = chrono::DateTime::parse_from_rfc3339("2026-06-16T01:00:00Z")
|
||||||
|
.unwrap()
|
||||||
|
.timestamp_millis();
|
||||||
|
assert_eq!(next, expected);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,13 @@ use crate::mcp::get_mcp_status;
|
|||||||
use crate::storage::{Storage, StorageError};
|
use crate::storage::{Storage, StorageError};
|
||||||
use std::sync::Arc as StdArc;
|
use std::sync::Arc as StdArc;
|
||||||
|
|
||||||
|
type MessagePersistSnapshot = (
|
||||||
|
StdArc<Storage>,
|
||||||
|
String,
|
||||||
|
crate::storage::message::MessageMeta,
|
||||||
|
crate::storage::session::SessionMeta,
|
||||||
|
);
|
||||||
|
|
||||||
tokio::task_local! {
|
tokio::task_local! {
|
||||||
static CURRENT_SOURCE_SESSION: Option<String>;
|
static CURRENT_SOURCE_SESSION: Option<String>;
|
||||||
}
|
}
|
||||||
@ -82,6 +89,14 @@ pub struct Session {
|
|||||||
current_cancel: Option<oneshot::Sender<()>>,
|
current_cancel: Option<oneshot::Sender<()>>,
|
||||||
/// Monotonic counter to detect stale workers
|
/// Monotonic counter to detect stale workers
|
||||||
worker_generation: u64,
|
worker_generation: u64,
|
||||||
|
/// Monotonic counter for in-memory session mutations.
|
||||||
|
///
|
||||||
|
/// Slow work such as memory recall, compression, and title generation runs
|
||||||
|
/// outside the session lock. Workers capture this version before starting
|
||||||
|
/// that work and verify it before committing results, so stale snapshots do
|
||||||
|
/// not overwrite a session that was changed by a command such as /clear or
|
||||||
|
/// /delete while the slow work was in flight.
|
||||||
|
state_version: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A task to be processed by the per-session agent worker
|
/// A task to be processed by the per-session agent worker
|
||||||
@ -146,6 +161,7 @@ impl Session {
|
|||||||
agent_tx: None,
|
agent_tx: None,
|
||||||
current_cancel: None,
|
current_cancel: None,
|
||||||
worker_generation: 0,
|
worker_generation: 0,
|
||||||
|
state_version: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,6 +338,7 @@ impl Session {
|
|||||||
agent_tx: None,
|
agent_tx: None,
|
||||||
current_cancel: None,
|
current_cancel: None,
|
||||||
worker_generation: 0,
|
worker_generation: 0,
|
||||||
|
state_version: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,6 +354,15 @@ impl Session {
|
|||||||
message: ChatMessage,
|
message: ChatMessage,
|
||||||
persist: bool,
|
persist: bool,
|
||||||
) -> Result<(), StorageError> {
|
) -> Result<(), StorageError> {
|
||||||
|
let snapshot = self.add_message_in_memory(message, persist);
|
||||||
|
persist_added_message(snapshot).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_message_in_memory(
|
||||||
|
&mut self,
|
||||||
|
message: ChatMessage,
|
||||||
|
persist: bool,
|
||||||
|
) -> Option<MessagePersistSnapshot> {
|
||||||
let is_user = message.role == "user";
|
let is_user = message.role == "user";
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
|
||||||
@ -344,36 +370,37 @@ impl Session {
|
|||||||
let seq = self.seq_counter;
|
let seq = self.seq_counter;
|
||||||
self.seq_counter += 1;
|
self.seq_counter += 1;
|
||||||
|
|
||||||
// Persist to Storage
|
let persist_snapshot = if persist {
|
||||||
if persist && let Some(ref storage) = self.storage {
|
self.storage.clone().map(|storage| {
|
||||||
let msg_meta = crate::storage::message::MessageMeta {
|
let msg_meta = crate::storage::message::MessageMeta {
|
||||||
id: message.id.clone(),
|
id: message.id.clone(),
|
||||||
session_id: self.id.to_string(),
|
session_id: self.id.to_string(),
|
||||||
seq,
|
seq,
|
||||||
role: message.role.clone(),
|
role: message.role.clone(),
|
||||||
content: message.content.clone(),
|
content: message.content.clone(),
|
||||||
reasoning_content: message.reasoning_content.clone(),
|
reasoning_content: message.reasoning_content.clone(),
|
||||||
media_refs: if message.media_refs.is_empty() {
|
media_refs: if message.media_refs.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(serde_json::to_string(&message.media_refs).unwrap_or_default())
|
Some(serde_json::to_string(&message.media_refs).unwrap_or_default())
|
||||||
},
|
},
|
||||||
tool_call_id: message.tool_call_id.clone(),
|
tool_call_id: message.tool_call_id.clone(),
|
||||||
tool_name: message.tool_name.clone(),
|
tool_name: message.tool_name.clone(),
|
||||||
tool_calls: message
|
tool_calls: message
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|tc| serde_json::to_string(tc).ok()),
|
.and_then(|tc| serde_json::to_string(tc).ok()),
|
||||||
source: message
|
source: message
|
||||||
.source
|
.source
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|s| serde_json::to_string(s).unwrap_or_default()),
|
.map(|s| serde_json::to_string(s).unwrap_or_default()),
|
||||||
created_at: now,
|
created_at: now,
|
||||||
};
|
};
|
||||||
storage
|
(storage, self.id.to_string(), msg_meta)
|
||||||
.append_message_with_retry(&self.id.to_string(), &msg_meta)
|
})
|
||||||
.await?;
|
} else {
|
||||||
}
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// Update in-memory state
|
// Update in-memory state
|
||||||
self.messages.push(message);
|
self.messages.push(message);
|
||||||
@ -382,16 +409,30 @@ impl Session {
|
|||||||
self.message_count += 1;
|
self.message_count += 1;
|
||||||
}
|
}
|
||||||
self.last_active_at = now;
|
self.last_active_at = now;
|
||||||
|
self.state_version = self.state_version.wrapping_add(1);
|
||||||
|
|
||||||
// Sync message_count to Storage
|
persist_snapshot.map(|(storage, session_id, msg_meta)| {
|
||||||
if persist {
|
let session_meta = crate::storage::session::SessionMeta {
|
||||||
tracing::debug!(session_id = %self.id, last_active_at = %now, message_count = %self.message_count, "Persisting session meta after add_message");
|
id: session_id.clone(),
|
||||||
if let Err(e) = self.persist_session_meta().await {
|
channel: self.id.channel.clone(),
|
||||||
tracing::warn!("failed to persist session meta: {}", e);
|
chat_id: self.id.chat_id.clone(),
|
||||||
}
|
dialog_id: self.id.dialog_id.clone(),
|
||||||
}
|
title: self.title.clone(),
|
||||||
|
created_at: self.created_at,
|
||||||
Ok(())
|
last_active_at: self.last_active_at,
|
||||||
|
message_count: self.message_count,
|
||||||
|
routing_info: if self.routing_info.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.routing_info.clone())
|
||||||
|
},
|
||||||
|
archived_at: self.archived_at,
|
||||||
|
deleted_at: None,
|
||||||
|
last_consolidated_at: self.last_consolidated_at,
|
||||||
|
last_compressed_message_at: self.last_compressed_message_at,
|
||||||
|
};
|
||||||
|
(storage, session_id, msg_meta, session_meta)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取消息历史
|
/// 获取消息历史
|
||||||
@ -406,6 +447,7 @@ impl Session {
|
|||||||
self.seq_counter = 1;
|
self.seq_counter = 1;
|
||||||
self.total_message_count = 0;
|
self.total_message_count = 0;
|
||||||
self.message_count = 0;
|
self.message_count = 0;
|
||||||
|
self.state_version = self.state_version.wrapping_add(1);
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
|
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
|
||||||
}
|
}
|
||||||
@ -417,6 +459,7 @@ impl Session {
|
|||||||
self.seq_counter = 1;
|
self.seq_counter = 1;
|
||||||
self.total_message_count = 0;
|
self.total_message_count = 0;
|
||||||
self.message_count = 0;
|
self.message_count = 0;
|
||||||
|
self.state_version = self.state_version.wrapping_add(1);
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
|
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
|
||||||
}
|
}
|
||||||
@ -444,43 +487,49 @@ impl Session {
|
|||||||
|
|
||||||
/// 将 session 元数据写回 Storage
|
/// 将 session 元数据写回 Storage
|
||||||
pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
|
pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some((storage, meta)) = self.session_meta_snapshot() {
|
||||||
let meta = crate::storage::session::SessionMeta {
|
|
||||||
id: self.id.to_string(),
|
|
||||||
channel: self.id.channel.clone(),
|
|
||||||
chat_id: self.id.chat_id.clone(),
|
|
||||||
dialog_id: self.id.dialog_id.clone(),
|
|
||||||
title: self.title.clone(),
|
|
||||||
created_at: self.created_at,
|
|
||||||
last_active_at: self.last_active_at,
|
|
||||||
message_count: self.message_count,
|
|
||||||
routing_info: if self.routing_info.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(self.routing_info.clone())
|
|
||||||
},
|
|
||||||
archived_at: self.archived_at,
|
|
||||||
deleted_at: None,
|
|
||||||
last_consolidated_at: self.last_consolidated_at,
|
|
||||||
last_compressed_message_at: self.last_compressed_message_at,
|
|
||||||
};
|
|
||||||
storage.upsert_session(&meta).await?;
|
storage.upsert_session(&meta).await?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn session_meta_snapshot(
|
||||||
|
&self,
|
||||||
|
) -> Option<(StdArc<Storage>, crate::storage::session::SessionMeta)> {
|
||||||
|
let storage = self.storage.clone()?;
|
||||||
|
let meta = crate::storage::session::SessionMeta {
|
||||||
|
id: self.id.to_string(),
|
||||||
|
channel: self.id.channel.clone(),
|
||||||
|
chat_id: self.id.chat_id.clone(),
|
||||||
|
dialog_id: self.id.dialog_id.clone(),
|
||||||
|
title: self.title.clone(),
|
||||||
|
created_at: self.created_at,
|
||||||
|
last_active_at: self.last_active_at,
|
||||||
|
message_count: self.message_count,
|
||||||
|
routing_info: if self.routing_info.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.routing_info.clone())
|
||||||
|
},
|
||||||
|
archived_at: self.archived_at,
|
||||||
|
deleted_at: None,
|
||||||
|
last_consolidated_at: self.last_consolidated_at,
|
||||||
|
last_compressed_message_at: self.last_compressed_message_at,
|
||||||
|
};
|
||||||
|
Some((storage, meta))
|
||||||
|
}
|
||||||
|
|
||||||
/// 检查是否需要自动生成 title(5 条用户消息后)
|
/// 检查是否需要自动生成 title(5 条用户消息后)
|
||||||
pub fn should_generate_title(&self) -> bool {
|
pub fn should_generate_title(&self) -> bool {
|
||||||
self.title == "新对话" && self.message_count >= 5
|
self.title == "新对话" && self.message_count >= 5
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 生成标题(调用 LLM)
|
fn title_prompt_snapshot(&self) -> Option<String> {
|
||||||
pub async fn generate_title(&mut self) -> Result<(), AgentError> {
|
|
||||||
if !self.should_generate_title() {
|
if !self.should_generate_title() {
|
||||||
return Ok(());
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt = format!(
|
Some(format!(
|
||||||
r#"给定以下对话历史,生成一个简短的会话标题(5-15 个中文字符),概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
|
r#"给定以下对话历史,生成一个简短的会话标题(5-15 个中文字符),概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
|
||||||
|
|
||||||
历史:
|
历史:
|
||||||
@ -492,38 +541,41 @@ impl Session {
|
|||||||
.map(|m| format!("[{}]: {}", m.role, m.content))
|
.map(|m| format!("[{}]: {}", m.role, m.content))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("\n")
|
.join("\n")
|
||||||
);
|
))
|
||||||
|
|
||||||
let title = self.call_llm_for_title(&prompt).await?;
|
|
||||||
|
|
||||||
if !title.is_empty() {
|
|
||||||
self.title = title.clone();
|
|
||||||
if let Err(e) = self.persist_session_meta().await {
|
|
||||||
tracing::warn!("failed to persist title: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 调用 LLM 生成标题
|
fn apply_generated_title(&mut self, title: String) -> bool {
|
||||||
async fn call_llm_for_title(&self, prompt: &str) -> Result<String, AgentError> {
|
if title.is_empty() || !self.should_generate_title() {
|
||||||
use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
self.title = title;
|
||||||
messages: vec![Message::user(prompt.to_string())],
|
self.state_version = self.state_version.wrapping_add(1);
|
||||||
temperature: Some(0.3),
|
true
|
||||||
max_tokens: Some(20),
|
}
|
||||||
tools: None,
|
|
||||||
|
fn fresh_context_compressor(&self) -> ContextCompressor {
|
||||||
|
let compressor_config = ContextCompressionConfig {
|
||||||
|
protect_first_n: 2,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
let mut compressor = ContextCompressor::with_config(
|
||||||
|
self.provider.clone(),
|
||||||
|
self.provider_config.token_limit,
|
||||||
|
compressor_config,
|
||||||
|
self.memory_manager.clone(),
|
||||||
|
);
|
||||||
|
compressor.set_session_id(Some(self.id.to_string()));
|
||||||
|
compressor
|
||||||
|
}
|
||||||
|
|
||||||
let response: ChatCompletionResponse = self
|
fn replace_history_in_memory(&mut self, messages: Vec<ChatMessage>) {
|
||||||
.provider
|
self.messages = messages;
|
||||||
.chat(request)
|
self.seq_counter = self.messages.len() as i64 + 1;
|
||||||
.await
|
self.total_message_count = self.messages.len() as i64;
|
||||||
.map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
|
self.message_count = self.messages.iter().filter(|m| m.role == "user").count() as i64;
|
||||||
|
self.last_active_at = chrono::Utc::now().timestamp_millis();
|
||||||
Ok(response.content.trim().to_string())
|
self.state_version = self.state_version.wrapping_add(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取 provider_config 引用
|
/// 获取 provider_config 引用
|
||||||
@ -1075,24 +1127,39 @@ impl SessionManager {
|
|||||||
"compact" => {
|
"compact" => {
|
||||||
if let Some(sid) = current_session_id {
|
if let Some(sid) = current_session_id {
|
||||||
let session = self.get_or_create_session(sid).await?;
|
let session = self.get_or_create_session(sid).await?;
|
||||||
let mut session_guard = session.lock().await;
|
let (original_count, history, mut compressor, base_version) = {
|
||||||
let original_count = session_guard.get_history().len();
|
let session_guard = session.lock().await;
|
||||||
let history = session_guard.get_history().to_vec();
|
(
|
||||||
let result = session_guard.compressor.compress_if_needed(history).await?;
|
session_guard.get_history().len(),
|
||||||
|
session_guard.get_history().to_vec(),
|
||||||
|
session_guard.fresh_context_compressor(),
|
||||||
|
session_guard.state_version,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = compressor.compress_if_needed(history).await?;
|
||||||
let compressed_count = result.history.len();
|
let compressed_count = result.history.len();
|
||||||
if result.created_timelines {
|
let meta_snapshot = {
|
||||||
session_guard.last_compressed_message_at =
|
let mut session_guard = session.lock().await;
|
||||||
Some(chrono::Utc::now().timestamp_millis());
|
if session_guard.state_version != base_version {
|
||||||
if let Err(e) = session_guard.persist_session_meta().await {
|
return Ok((
|
||||||
tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
|
None,
|
||||||
|
"Context changed while compacting; please run /compact again."
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
if result.created_timelines {
|
||||||
session_guard.clear_history();
|
session_guard.last_compressed_message_at =
|
||||||
for msg in result.history {
|
Some(chrono::Utc::now().timestamp_millis());
|
||||||
session_guard
|
}
|
||||||
.add_message(msg, false)
|
session_guard.replace_history_in_memory(result.history);
|
||||||
.await
|
session_guard.session_meta_snapshot()
|
||||||
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
};
|
||||||
|
|
||||||
|
if let Some((storage, meta)) = meta_snapshot
|
||||||
|
&& let Err(e) = storage.upsert_session(&meta).await
|
||||||
|
{
|
||||||
|
tracing::warn!(error = %e, "Failed to persist compression marker after /compact");
|
||||||
}
|
}
|
||||||
Ok((
|
Ok((
|
||||||
None,
|
None,
|
||||||
@ -1304,16 +1371,22 @@ impl SessionManager {
|
|||||||
let sid = current_session_id
|
let sid = current_session_id
|
||||||
.ok_or_else(|| AgentError::Other("no active session".to_string()))?;
|
.ok_or_else(|| AgentError::Other("no active session".to_string()))?;
|
||||||
let session = self.get_or_create_session(sid).await?;
|
let session = self.get_or_create_session(sid).await?;
|
||||||
let mut guard = session.lock().await;
|
let msgs = {
|
||||||
let mut msgs: Vec<String> = Vec::new();
|
let mut guard = session.lock().await;
|
||||||
if guard.current_cancel.take().is_some() {
|
let mut msgs: Vec<String> = Vec::new();
|
||||||
msgs.push("当前任务已发送停止信号。".to_string());
|
if guard.current_cancel.take().is_some() {
|
||||||
}
|
msgs.push("当前任务已发送停止信号。".to_string());
|
||||||
if guard.agent_tx.take().is_some() {
|
}
|
||||||
msgs.push("消息队列已清空。".to_string());
|
if guard.agent_tx.take().is_some() {
|
||||||
}
|
msgs.push("消息队列已清空。".to_string());
|
||||||
guard.worker_generation = guard.worker_generation.wrapping_add(1);
|
}
|
||||||
|
guard.worker_generation = guard.worker_generation.wrapping_add(1);
|
||||||
|
guard.state_version = guard.state_version.wrapping_add(1);
|
||||||
|
msgs
|
||||||
|
};
|
||||||
|
|
||||||
// Cancel all running background sub-agent tasks for this session
|
// Cancel all running background sub-agent tasks for this session
|
||||||
|
// after releasing the session lock.
|
||||||
self.sub_agent_manager
|
self.sub_agent_manager
|
||||||
.cancel_by_session(&sid.to_string())
|
.cancel_by_session(&sid.to_string())
|
||||||
.await;
|
.await;
|
||||||
@ -1670,7 +1743,7 @@ impl SessionManager {
|
|||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
|
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
|
||||||
let session = self.get_or_create_session(&unified_id).await?;
|
let session = self.get_or_create_session(&unified_id).await?;
|
||||||
{
|
let persist_snapshot = {
|
||||||
let mut guard = session.lock().await;
|
let mut guard = session.lock().await;
|
||||||
let source = MessageSource {
|
let source = MessageSource {
|
||||||
kind: SourceKind::SystemNotification,
|
kind: SourceKind::SystemNotification,
|
||||||
@ -1681,11 +1754,11 @@ impl SessionManager {
|
|||||||
task_id: task_id.map(|s| s.to_string()),
|
task_id: task_id.map(|s| s.to_string()),
|
||||||
};
|
};
|
||||||
let msg = ChatMessage::assistant_with_source(content, source);
|
let msg = ChatMessage::assistant_with_source(content, source);
|
||||||
guard
|
guard.add_message_in_memory(msg, true)
|
||||||
.add_message(msg, true)
|
};
|
||||||
.await
|
persist_added_message(persist_snapshot)
|
||||||
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
.await
|
||||||
}
|
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||||
|
|
||||||
let outbound = OutboundMessage {
|
let outbound = OutboundMessage {
|
||||||
channel: channel.to_string(),
|
channel: channel.to_string(),
|
||||||
@ -1805,6 +1878,63 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn maybe_generate_title_outside_lock(session: Arc<Mutex<Session>>) -> Result<(), AgentError> {
|
||||||
|
use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
|
||||||
|
|
||||||
|
let (provider, prompt) = {
|
||||||
|
let guard = session.lock().await;
|
||||||
|
let Some(prompt) = guard.title_prompt_snapshot() else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
(guard.provider.clone(), prompt)
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message::user(prompt)],
|
||||||
|
temperature: Some(0.3),
|
||||||
|
max_tokens: Some(20),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response: ChatCompletionResponse = provider
|
||||||
|
.chat(request)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
|
||||||
|
let title = response.content.trim().to_string();
|
||||||
|
|
||||||
|
let meta_snapshot = {
|
||||||
|
let mut guard = session.lock().await;
|
||||||
|
if guard.apply_generated_title(title) {
|
||||||
|
guard.session_meta_snapshot()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some((storage, meta)) = meta_snapshot {
|
||||||
|
storage
|
||||||
|
.upsert_session(&meta)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AgentError::Other(format!("failed to persist title: {}", e)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn persist_added_message(
|
||||||
|
snapshot: Option<MessagePersistSnapshot>,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let Some((storage, session_id, msg_meta, session_meta)) = snapshot else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
storage
|
||||||
|
.append_message_with_retry(&session_id, &msg_meta)
|
||||||
|
.await?;
|
||||||
|
storage.upsert_session(&session_meta).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn spawn_agent_worker(
|
fn spawn_agent_worker(
|
||||||
mut task_rx: mpsc::UnboundedReceiver<AgentTask>,
|
mut task_rx: mpsc::UnboundedReceiver<AgentTask>,
|
||||||
session: Arc<Mutex<Session>>,
|
session: Arc<Mutex<Session>>,
|
||||||
@ -1845,8 +1975,12 @@ fn spawn_agent_worker(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 1: prepare data under session lock
|
// Phase 1: capture a stable session snapshot under lock.
|
||||||
let (agent, history_out, system_prompt_out, cancel_rx) = {
|
// Memory recall and compression happen outside this block so
|
||||||
|
// /stop and other commands are not blocked behind slow I/O or
|
||||||
|
// LLM-backed compaction.
|
||||||
|
let skills_prompt = skills_loader.build_skills_prompt();
|
||||||
|
let (agent, history_raw, mut compressor, base_version, cancel_rx) = {
|
||||||
let mut guard = session.lock().await;
|
let mut guard = session.lock().await;
|
||||||
|
|
||||||
if guard.worker_generation != worker_gen {
|
if guard.worker_generation != worker_gen {
|
||||||
@ -1857,7 +1991,9 @@ fn spawn_agent_worker(
|
|||||||
task.media.iter().map(|m| m.to_media_ref()).collect();
|
task.media.iter().map(|m| m.to_media_ref()).collect();
|
||||||
let user_message =
|
let user_message =
|
||||||
guard.create_user_message(&task.content, media_refs);
|
guard.create_user_message(&task.content, media_refs);
|
||||||
if let Err(e) = guard.add_message(user_message, true).await {
|
let user_persist = guard.add_message_in_memory(user_message, true);
|
||||||
|
drop(guard);
|
||||||
|
if let Err(e) = persist_added_message(user_persist).await {
|
||||||
tracing::error!(error = %e, "Failed to persist user message");
|
tracing::error!(error = %e, "Failed to persist user message");
|
||||||
let err_outbound = OutboundMessage {
|
let err_outbound = OutboundMessage {
|
||||||
channel: task_chan.clone(),
|
channel: task_chan.clone(),
|
||||||
@ -1871,61 +2007,12 @@ fn spawn_agent_worker(
|
|||||||
let _ = bus.publish_outbound(err_outbound).await;
|
let _ = bus.publish_outbound(err_outbound).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
let mut guard = session.lock().await;
|
||||||
|
if guard.worker_generation != worker_gen {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let history_raw = guard.get_history().to_vec();
|
let history_raw = guard.get_history().to_vec();
|
||||||
let skills_prompt = skills_loader.build_skills_prompt();
|
|
||||||
|
|
||||||
let memory_context = match memory_manager
|
|
||||||
.recall(
|
|
||||||
&task.content,
|
|
||||||
5,
|
|
||||||
Some(crate::memory::MemoryCategory::Knowledge),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(entries) if !entries.is_empty() => Some(
|
|
||||||
entries
|
|
||||||
.iter()
|
|
||||||
.map(|e| format!("- {}: {}", e.key, e.content))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n"),
|
|
||||||
),
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(error = %e, "Failed to fetch memory context");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let system_prompt = guard
|
|
||||||
.build_system_prompt(&skills_prompt, memory_context.as_deref());
|
|
||||||
|
|
||||||
let result = guard
|
|
||||||
.compressor
|
|
||||||
.compress_if_needed(history_raw)
|
|
||||||
.await
|
|
||||||
.map(|r| {
|
|
||||||
if r.created_timelines {
|
|
||||||
guard.last_compressed_message_at =
|
|
||||||
Some(chrono::Utc::now().timestamp_millis());
|
|
||||||
}
|
|
||||||
r.history
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|e| {
|
|
||||||
tracing::warn!(
|
|
||||||
error = %e,
|
|
||||||
"Context compression failed in worker"
|
|
||||||
);
|
|
||||||
guard.get_history().to_vec()
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut history = result;
|
|
||||||
history.insert(0, ChatMessage::system(system_prompt.clone()));
|
|
||||||
|
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
|
||||||
guard.last_consolidated_at = Some(now);
|
|
||||||
let _ = guard.persist_session_meta().await;
|
|
||||||
|
|
||||||
let agent = match guard.create_agent_with_notify(notify_tx) {
|
let agent = match guard.create_agent_with_notify(notify_tx) {
|
||||||
Ok(a) => a,
|
Ok(a) => a,
|
||||||
@ -1952,9 +2039,87 @@ fn spawn_agent_worker(
|
|||||||
}
|
}
|
||||||
guard.current_cancel = Some(cancel_tx);
|
guard.current_cancel = Some(cancel_tx);
|
||||||
|
|
||||||
(agent, history, system_prompt, cancel_rx)
|
(
|
||||||
|
agent,
|
||||||
|
history_raw,
|
||||||
|
guard.fresh_context_compressor(),
|
||||||
|
guard.state_version,
|
||||||
|
cancel_rx,
|
||||||
|
)
|
||||||
}; // lock released
|
}; // lock released
|
||||||
|
|
||||||
|
let memory_context = match memory_manager
|
||||||
|
.recall(
|
||||||
|
&task.content,
|
||||||
|
5,
|
||||||
|
Some(crate::memory::MemoryCategory::Knowledge),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(entries) if !entries.is_empty() => Some(
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.map(|e| format!("- {}: {}", e.key, e.content))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n"),
|
||||||
|
),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Failed to fetch memory context");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let system_prompt_out = {
|
||||||
|
let guard = session.lock().await;
|
||||||
|
if guard.worker_generation != worker_gen {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
guard.build_system_prompt(&skills_prompt, memory_context.as_deref())
|
||||||
|
};
|
||||||
|
|
||||||
|
let compression_result = compressor.compress_if_needed(history_raw).await;
|
||||||
|
let mut history_out = match compression_result {
|
||||||
|
Ok(result) => {
|
||||||
|
let meta_snapshot = {
|
||||||
|
let mut guard = session.lock().await;
|
||||||
|
if guard.worker_generation != worker_gen {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if guard.state_version != base_version {
|
||||||
|
tracing::warn!(
|
||||||
|
session_id = %guard.id,
|
||||||
|
"Session changed while preparing agent history; dropping stale task"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if result.created_timelines {
|
||||||
|
guard.last_compressed_message_at =
|
||||||
|
Some(chrono::Utc::now().timestamp_millis());
|
||||||
|
}
|
||||||
|
guard.last_consolidated_at =
|
||||||
|
Some(chrono::Utc::now().timestamp_millis());
|
||||||
|
guard.session_meta_snapshot()
|
||||||
|
};
|
||||||
|
if let Some((storage, meta)) = meta_snapshot
|
||||||
|
&& let Err(e) = storage.upsert_session(&meta).await
|
||||||
|
{
|
||||||
|
tracing::warn!(error = %e, "Failed to persist session meta after compression");
|
||||||
|
}
|
||||||
|
result.history
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Context compression failed in worker");
|
||||||
|
let guard = session.lock().await;
|
||||||
|
if guard.worker_generation != worker_gen {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
guard.get_history().to_vec()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
history_out.insert(0, ChatMessage::system(system_prompt_out.clone()));
|
||||||
|
|
||||||
// Phase 2 + 3: LLM call with cancellation
|
// Phase 2 + 3: LLM call with cancellation
|
||||||
let session2 = session.clone();
|
let session2 = session.clone();
|
||||||
let bus2 = bus.clone();
|
let bus2 = bus.clone();
|
||||||
@ -1975,8 +2140,8 @@ fn spawn_agent_worker(
|
|||||||
Err(AgentError::LlmError(ref msg))
|
Err(AgentError::LlmError(ref msg))
|
||||||
if is_context_overflow_error(msg) =>
|
if is_context_overflow_error(msg) =>
|
||||||
{
|
{
|
||||||
let retry_history = {
|
let (raw, mut retry_compressor, retry_base_version, new_window) = {
|
||||||
let mut guard = session2.lock().await;
|
let guard = session2.lock().await;
|
||||||
let new_window =
|
let new_window =
|
||||||
crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
|
crate::agent::ContextCompressor::parse_context_limit_from_error(msg)
|
||||||
.unwrap_or(guard.compressor_threshold());
|
.unwrap_or(guard.compressor_threshold());
|
||||||
@ -1985,31 +2150,56 @@ fn spawn_agent_worker(
|
|||||||
error = %msg,
|
error = %msg,
|
||||||
"Context overflow in worker — retrying"
|
"Context overflow in worker — retrying"
|
||||||
);
|
);
|
||||||
|
(
|
||||||
|
guard.get_history().to_vec(),
|
||||||
|
guard.fresh_context_compressor(),
|
||||||
|
guard.state_version,
|
||||||
|
new_window,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
retry_compressor.set_context_window(new_window);
|
||||||
|
let retry_result =
|
||||||
|
match retry_compressor.compress_if_needed(raw).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "Retry compression failed");
|
||||||
|
let err_outbound = OutboundMessage {
|
||||||
|
channel: chan2,
|
||||||
|
chat_id: cid2,
|
||||||
|
content: "Context overflow handling failed."
|
||||||
|
.to_string(),
|
||||||
|
reply_to: None,
|
||||||
|
media: vec![],
|
||||||
|
metadata: HashMap::new(),
|
||||||
|
};
|
||||||
|
let _ = bus2.publish_outbound(err_outbound).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let meta_snapshot = {
|
||||||
|
let mut guard = session2.lock().await;
|
||||||
|
if guard.state_version != retry_base_version {
|
||||||
|
tracing::warn!(
|
||||||
|
session_id = %guard.id,
|
||||||
|
"Session changed while retry-compressing after context overflow"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
guard.compressor.set_context_window(new_window);
|
guard.compressor.set_context_window(new_window);
|
||||||
let raw = guard.get_history().to_vec();
|
|
||||||
let retry_result =
|
|
||||||
match guard.compressor.compress_if_needed(raw).await {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(error = %e, "Retry compression failed");
|
|
||||||
let err_outbound = OutboundMessage {
|
|
||||||
channel: chan2,
|
|
||||||
chat_id: cid2,
|
|
||||||
content: "Context overflow handling failed."
|
|
||||||
.to_string(),
|
|
||||||
reply_to: None,
|
|
||||||
media: vec![],
|
|
||||||
metadata: HashMap::new(),
|
|
||||||
};
|
|
||||||
let _ = bus2.publish_outbound(err_outbound).await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if retry_result.created_timelines {
|
if retry_result.created_timelines {
|
||||||
guard.last_compressed_message_at =
|
guard.last_compressed_message_at =
|
||||||
Some(chrono::Utc::now().timestamp_millis());
|
Some(chrono::Utc::now().timestamp_millis());
|
||||||
let _ = guard.persist_session_meta().await;
|
|
||||||
}
|
}
|
||||||
|
guard.session_meta_snapshot()
|
||||||
|
};
|
||||||
|
if let Some((storage, meta)) = meta_snapshot
|
||||||
|
&& let Err(e) = storage.upsert_session(&meta).await
|
||||||
|
{
|
||||||
|
tracing::warn!(error = %e, "Failed to persist session meta after retry compression");
|
||||||
|
}
|
||||||
|
|
||||||
|
let retry_history = {
|
||||||
let mut retry = retry_result.history;
|
let mut retry = retry_result.history;
|
||||||
retry.insert(
|
retry.insert(
|
||||||
0,
|
0,
|
||||||
@ -2055,20 +2245,24 @@ fn spawn_agent_worker(
|
|||||||
|
|
||||||
let response = {
|
let response = {
|
||||||
let mut guard = session2.lock().await;
|
let mut guard = session2.lock().await;
|
||||||
|
let mut persist_snapshots = Vec::new();
|
||||||
for msg in result.emitted_messages {
|
for msg in result.emitted_messages {
|
||||||
guard.add_message(msg, true).await.inspect_err(|e| {
|
persist_snapshots.push(guard.add_message_in_memory(msg, true));
|
||||||
tracing::error!(error = %e, "Failed to persist message")
|
|
||||||
}).ok();
|
|
||||||
}
|
|
||||||
if guard.should_generate_title()
|
|
||||||
&& let Err(e) = guard.generate_title().await
|
|
||||||
{
|
|
||||||
tracing::warn!("failed to generate title: {}", e);
|
|
||||||
}
|
}
|
||||||
let sent_count = guard.messages.len();
|
let sent_count = guard.messages.len();
|
||||||
guard.compressor.set_last_api_info(sent_count, result.total_tokens);
|
guard.compressor.set_last_api_info(sent_count, result.total_tokens);
|
||||||
result.final_response.content
|
(result.final_response.content, persist_snapshots)
|
||||||
};
|
};
|
||||||
|
let (response, persist_snapshots) = response;
|
||||||
|
for snapshot in persist_snapshots {
|
||||||
|
if let Err(e) = persist_added_message(snapshot).await {
|
||||||
|
tracing::error!(error = %e, "Failed to persist message");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = maybe_generate_title_outside_lock(session2.clone()).await {
|
||||||
|
tracing::warn!("failed to generate title: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
let outbound = OutboundMessage {
|
let outbound = OutboundMessage {
|
||||||
channel: chan2,
|
channel: chan2,
|
||||||
@ -2158,25 +2352,34 @@ impl SessionManager {
|
|||||||
unified_id: &UnifiedSessionId,
|
unified_id: &UnifiedSessionId,
|
||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
let session = self.get_or_create_session(unified_id).await?;
|
let session = self.get_or_create_session(unified_id).await?;
|
||||||
let mut session_guard = session.lock().await;
|
let (storage, session_id, meta_snapshot) = {
|
||||||
// Clear in-memory
|
let mut session_guard = session.lock().await;
|
||||||
session_guard.messages.clear();
|
// Clear in-memory
|
||||||
session_guard.seq_counter = 1;
|
session_guard.messages.clear();
|
||||||
session_guard.total_message_count = 0;
|
session_guard.seq_counter = 1;
|
||||||
session_guard.message_count = 0;
|
session_guard.total_message_count = 0;
|
||||||
session_guard.last_consolidated_at = None;
|
session_guard.message_count = 0;
|
||||||
session_guard.last_compressed_message_at = None;
|
session_guard.last_consolidated_at = None;
|
||||||
// Clear Storage
|
session_guard.last_compressed_message_at = None;
|
||||||
if let Some(ref storage) = session_guard.storage {
|
session_guard.state_version = session_guard.state_version.wrapping_add(1);
|
||||||
|
(
|
||||||
|
session_guard.storage.clone(),
|
||||||
|
session_guard.id.to_string(),
|
||||||
|
session_guard.session_meta_snapshot(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
// Clear Storage outside the session lock.
|
||||||
|
if let Some(storage) = storage {
|
||||||
storage
|
storage
|
||||||
.clear_messages(&session_guard.id.to_string())
|
.clear_messages(&session_id)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AgentError::Other(format!("failed to clear messages: {}", e)))?;
|
.map_err(|e| AgentError::Other(format!("failed to clear messages: {}", e)))?;
|
||||||
}
|
}
|
||||||
session_guard
|
if let Some((storage, meta)) = meta_snapshot {
|
||||||
.persist_session_meta()
|
storage.upsert_session(&meta).await.map_err(|e| {
|
||||||
.await
|
AgentError::Other(format!("failed to persist cleared session: {}", e))
|
||||||
.map_err(|e| AgentError::Other(format!("failed to persist cleared session: {}", e)))?;
|
})?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2235,14 +2438,14 @@ impl OutboundMessenger for SessionManager {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Write source-tagged assistant message to target session history
|
// Write source-tagged assistant message to target session history
|
||||||
{
|
let persist_snapshot = {
|
||||||
let mut guard = session.lock().await;
|
let mut guard = session.lock().await;
|
||||||
let msg = ChatMessage::assistant_with_source(marked_content.clone(), source);
|
let msg = ChatMessage::assistant_with_source(marked_content.clone(), source);
|
||||||
guard
|
guard.add_message_in_memory(msg, true)
|
||||||
.add_message(msg, true)
|
};
|
||||||
.await
|
persist_added_message(persist_snapshot)
|
||||||
.map_err(|e| e.to_string())?;
|
.await
|
||||||
}
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Restore active dialog if source and target share channel:chat_id but differ in dialog_id
|
// Restore active dialog if source and target share channel:chat_id but differ in dialog_id
|
||||||
if let Some(ref origin_id) = origin_id {
|
if let Some(ref origin_id) = origin_id {
|
||||||
|
|||||||
@ -4,7 +4,6 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::io::AsyncReadExt;
|
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
@ -147,71 +146,55 @@ impl Tool for BashTool {
|
|||||||
.map(Path::new)
|
.map(Path::new)
|
||||||
.unwrap_or_else(|| Path::new("."));
|
.unwrap_or_else(|| Path::new("."));
|
||||||
|
|
||||||
let result = timeout(
|
match self.run_command(command, cwd, timeout_secs).await {
|
||||||
Duration::from_secs(timeout_secs),
|
Ok(output) => Ok(ToolResult {
|
||||||
self.run_command(command, cwd),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(Ok(output)) => Ok(ToolResult {
|
|
||||||
success: true,
|
success: true,
|
||||||
output,
|
output,
|
||||||
error: None,
|
error: None,
|
||||||
}),
|
}),
|
||||||
Ok(Err(e)) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(e),
|
error: Some(e),
|
||||||
}),
|
}),
|
||||||
Err(_) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BashTool {
|
impl BashTool {
|
||||||
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
|
async fn run_command(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
cwd: &Path,
|
||||||
|
timeout_secs: u64,
|
||||||
|
) -> Result<String, String> {
|
||||||
let mut cmd = Command::new("bash");
|
let mut cmd = Command::new("bash");
|
||||||
cmd.args(["-c", command])
|
cmd.args(["-c", command])
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.current_dir(cwd);
|
.current_dir(cwd)
|
||||||
|
.kill_on_drop(true);
|
||||||
|
|
||||||
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||||
|
|
||||||
let mut stdout = Vec::new();
|
let process_output =
|
||||||
let mut stderr = Vec::new();
|
match timeout(Duration::from_secs(timeout_secs), child.wait_with_output()).await {
|
||||||
|
Ok(Ok(output)) => output,
|
||||||
if let Some(ref mut out) = child.stdout {
|
Ok(Err(e)) => return Err(format!("Failed to wait: {}", e)),
|
||||||
out.read_to_end(&mut stdout)
|
Err(_) => {
|
||||||
.await
|
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
||||||
.map_err(|e| format!("Failed to read stdout: {}", e))?;
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
if let Some(ref mut err) = child.stderr {
|
|
||||||
err.read_to_end(&mut stderr)
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to read stderr: {}", e))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let status = child
|
|
||||||
.wait()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to wait: {}", e))?;
|
|
||||||
|
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
|
|
||||||
if !stdout.is_empty() {
|
if !process_output.stdout.is_empty() {
|
||||||
let stdout_str = String::from_utf8_lossy(&stdout);
|
let stdout_str = String::from_utf8_lossy(&process_output.stdout);
|
||||||
output.push_str(&stdout_str);
|
output.push_str(&stdout_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stderr.is_empty() {
|
if !process_output.stderr.is_empty() {
|
||||||
let stderr_str = String::from_utf8_lossy(&stderr);
|
let stderr_str = String::from_utf8_lossy(&process_output.stderr);
|
||||||
if !stderr_str.trim().is_empty() {
|
if !stderr_str.trim().is_empty() {
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
output.push('\n');
|
output.push('\n');
|
||||||
@ -221,7 +204,10 @@ impl BashTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
|
output.push_str(&format!(
|
||||||
|
"\nExit code: {}",
|
||||||
|
process_output.status.code().unwrap_or(-1)
|
||||||
|
));
|
||||||
|
|
||||||
Ok(self.truncate_output(&output))
|
Ok(self.truncate_output(&output))
|
||||||
}
|
}
|
||||||
@ -309,4 +295,19 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("timed out"));
|
assert!(result.error.unwrap().contains("timed out"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_large_stderr_does_not_deadlock() {
|
||||||
|
let tool = BashTool::new().with_timeout(5);
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"command": "for i in $(seq 1 2000); do echo noisy-error-line >&2; done; echo done"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("done"));
|
||||||
|
assert!(result.output.contains("STDERR"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,8 @@ use crate::tools::path_utils;
|
|||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
const MAX_CHARS: usize = 128_000;
|
const MAX_CHARS: usize = 128_000;
|
||||||
|
const MAX_FILE_BYTES: u64 = 5 * 1024 * 1024;
|
||||||
|
const MAX_BINARY_BYTES: usize = 512 * 1024;
|
||||||
const DEFAULT_LIMIT: usize = 2000;
|
const DEFAULT_LIMIT: usize = 2000;
|
||||||
|
|
||||||
pub struct FileReadTool {
|
pub struct FileReadTool {
|
||||||
@ -118,6 +120,29 @@ impl Tool for FileReadTool {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let metadata = match std::fs::metadata(&resolved) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to inspect file: {}", e)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if metadata.len() > MAX_FILE_BYTES {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"File too large to read safely: {} bytes (max {} bytes). Use a narrower tool or inspect a smaller excerpt.",
|
||||||
|
metadata.len(),
|
||||||
|
MAX_FILE_BYTES
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Read raw bytes and try multiple encodings
|
// Read raw bytes and try multiple encodings
|
||||||
let bytes = match std::fs::read(&resolved) {
|
let bytes = match std::fs::read(&resolved) {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
@ -209,6 +234,21 @@ impl Tool for FileReadTool {
|
|||||||
None => {
|
None => {
|
||||||
// Truly binary file — base64 encode
|
// Truly binary file — base64 encode
|
||||||
use base64::{Engine, engine::general_purpose::STANDARD};
|
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||||
|
if bytes.len() > MAX_BINARY_BYTES {
|
||||||
|
let mime = mime_guess::from_path(&resolved)
|
||||||
|
.first_or_octet_stream()
|
||||||
|
.to_string();
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Binary file too large to inline: {}, {} bytes (max {} bytes).",
|
||||||
|
mime,
|
||||||
|
bytes.len(),
|
||||||
|
MAX_BINARY_BYTES
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
let encoded = STANDARD.encode(&bytes);
|
let encoded = STANDARD.encode(&bytes);
|
||||||
let mime = mime_guess::from_path(&resolved)
|
let mime = mime_guess::from_path(&resolved)
|
||||||
.first_or_octet_stream()
|
.first_or_octet_stream()
|
||||||
@ -229,6 +269,10 @@ impl Tool for FileReadTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
||||||
|
if bytes.contains(&0) {
|
||||||
|
return (None, None);
|
||||||
|
}
|
||||||
|
|
||||||
// Try UTF-8 first
|
// Try UTF-8 first
|
||||||
if let Ok(text) = std::str::from_utf8(bytes) {
|
if let Ok(text) = std::str::from_utf8(bytes) {
|
||||||
return (Some(text.to_string()), None);
|
return (Some(text.to_string()), None);
|
||||||
@ -337,4 +381,37 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Not a file"));
|
assert!(result.error.unwrap().contains("Not a file"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rejects_large_file_before_reading() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
file.as_file_mut()
|
||||||
|
.set_len(MAX_FILE_BYTES + 1)
|
||||||
|
.expect("set large file length");
|
||||||
|
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": file.path().to_str().unwrap() }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("too large"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rejects_large_binary_inline() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
let bytes = vec![0_u8; MAX_BINARY_BYTES + 1];
|
||||||
|
file.write_all(&bytes).unwrap();
|
||||||
|
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": file.path().to_str().unwrap() }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("Binary file too large"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use std::time::Duration;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::header::HeaderMap;
|
use reqwest::header::HeaderMap;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tokio::net::lookup_host;
|
||||||
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
@ -56,6 +57,34 @@ impl HttpRequestTool {
|
|||||||
Ok(url.to_string())
|
Ok(url.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
|
||||||
|
if self.allow_private_hosts {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
if host.parse::<std::net::IpAddr>().is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let port = extract_port(url)?;
|
||||||
|
let addrs = lookup_host((host.as_str(), port))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
|
||||||
|
|
||||||
|
for addr in addrs {
|
||||||
|
let ip = addr.ip();
|
||||||
|
if is_private_ip(&ip) {
|
||||||
|
return Err(format!(
|
||||||
|
"Blocked host '{}' because DNS resolved to local/private IP {}",
|
||||||
|
host, ip
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
|
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
|
||||||
match method.to_uppercase().as_str() {
|
match method.to_uppercase().as_str() {
|
||||||
"GET" => Ok(reqwest::Method::GET),
|
"GET" => Ok(reqwest::Method::GET),
|
||||||
@ -180,6 +209,34 @@ fn extract_host(url: &str) -> Result<String, String> {
|
|||||||
Ok(host)
|
Ok(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn extract_port(url: &str) -> Result<u16, String> {
|
||||||
|
let scheme = if url.starts_with("https://") {
|
||||||
|
"https"
|
||||||
|
} else if url.starts_with("http://") {
|
||||||
|
"http"
|
||||||
|
} else {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if let Some((_, port)) = authority.rsplit_once(':') {
|
||||||
|
port.parse::<u16>()
|
||||||
|
.map_err(|_| format!("Invalid URL port: {}", port))
|
||||||
|
} else if scheme == "https" {
|
||||||
|
Ok(443)
|
||||||
|
} else {
|
||||||
|
Ok(80)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||||
if allowed_domains.iter().any(|domain| domain == "*") {
|
if allowed_domains.iter().any(|domain| domain == "*") {
|
||||||
return true;
|
return true;
|
||||||
@ -226,7 +283,13 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
|||||||
|| v4.is_broadcast()
|
|| v4.is_broadcast()
|
||||||
|| v4.is_multicast()
|
|| v4.is_multicast()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback()
|
||||||
|
|| v6.is_unspecified()
|
||||||
|
|| v6.is_multicast()
|
||||||
|
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|
||||||
|
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +357,14 @@ impl Tool for HttpRequestTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Err(e) = self.validate_resolved_host(&url).await {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let method = match self.validate_method(method_str) {
|
let method = match self.validate_method(method_str) {
|
||||||
Ok(m) => m,
|
Ok(m) => m,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -309,6 +380,7 @@ impl Tool for HttpRequestTool {
|
|||||||
|
|
||||||
let client = match reqwest::Client::builder()
|
let client = match reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(self.timeout_secs))
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
.build()
|
.build()
|
||||||
{
|
{
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
@ -436,4 +508,19 @@ mod tests {
|
|||||||
async fn test_blocks_local_tld() {
|
async fn test_blocks_local_tld() {
|
||||||
assert!(is_private_host("service.local"));
|
assert!(is_private_host("service.local"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_port_defaults_and_explicit() {
|
||||||
|
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
|
||||||
|
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
|
||||||
|
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_private_ipv6_ranges_are_blocked() {
|
||||||
|
assert!(is_private_ip(&"::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
|
||||||
|
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,6 +13,7 @@ pub mod get_skill;
|
|||||||
pub mod http_request;
|
pub mod http_request;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod path_utils;
|
pub mod path_utils;
|
||||||
|
pub mod pty;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod send_message;
|
pub mod send_message;
|
||||||
@ -32,6 +33,7 @@ pub use file_write::FileWriteTool;
|
|||||||
pub use get_skill::GetSkillTool;
|
pub use get_skill::GetSkillTool;
|
||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
||||||
|
pub use pty::{PtyManager, PtyTool};
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use send_message::SendMessageTool;
|
pub use send_message::SendMessageTool;
|
||||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||||
@ -60,6 +62,7 @@ pub fn create_default_tools(
|
|||||||
registry.register(FileSearchTool::new());
|
registry.register(FileSearchTool::new());
|
||||||
registry.register(ContentSearchTool::new());
|
registry.register(ContentSearchTool::new());
|
||||||
registry.register(BashTool::new());
|
registry.register(BashTool::new());
|
||||||
|
registry.register(PtyTool::new(Arc::new(PtyManager::new())));
|
||||||
registry.register(HttpRequestTool::new(
|
registry.register(HttpRequestTool::new(
|
||||||
vec!["*".to_string()],
|
vec!["*".to_string()],
|
||||||
1_000_000,
|
1_000_000,
|
||||||
|
|||||||
@ -160,11 +160,14 @@ impl PtyManager {
|
|||||||
};
|
};
|
||||||
for session in sessions {
|
for session in sessions {
|
||||||
let mut guard = session.lock().unwrap();
|
let mut guard = session.lock().unwrap();
|
||||||
let mut child_guard = guard.child.lock().unwrap();
|
let child_handle = guard.child.clone();
|
||||||
if let Some(ref mut child) = *child_guard {
|
{
|
||||||
let _ = child.kill();
|
let mut child_guard = child_handle.lock().unwrap();
|
||||||
|
if let Some(ref mut child) = *child_guard {
|
||||||
|
let _ = child.kill();
|
||||||
|
}
|
||||||
|
*child_guard = None;
|
||||||
}
|
}
|
||||||
*child_guard = None;
|
|
||||||
guard.status = SessionStatus::Killed;
|
guard.status = SessionStatus::Killed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,7 +277,7 @@ impl PtyManager {
|
|||||||
let session = sessions
|
let session = sessions
|
||||||
.get(session_id)
|
.get(session_id)
|
||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
let mut guard = session.lock().unwrap();
|
let guard = session.lock().unwrap();
|
||||||
if guard.status != SessionStatus::Running {
|
if guard.status != SessionStatus::Running {
|
||||||
return Err("Session is not running".to_string());
|
return Err("Session is not running".to_string());
|
||||||
}
|
}
|
||||||
@ -296,12 +299,7 @@ impl PtyManager {
|
|||||||
Ok(format!("OK, wrote {} bytes", byte_count))
|
Ok(format!("OK, wrote {} bytes", byte_count))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read(
|
fn read(&self, session_id: &str, offset: usize, limit: usize) -> Result<String, String> {
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
offset: usize,
|
|
||||||
limit: usize,
|
|
||||||
) -> Result<String, String> {
|
|
||||||
let sessions = self.sessions.lock().unwrap();
|
let sessions = self.sessions.lock().unwrap();
|
||||||
let session = sessions
|
let session = sessions
|
||||||
.get(session_id)
|
.get(session_id)
|
||||||
@ -352,14 +350,16 @@ impl PtyManager {
|
|||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||||
let mut guard = session.lock().unwrap();
|
let mut guard = session.lock().unwrap();
|
||||||
|
|
||||||
let mut child_guard = guard.child.lock().unwrap();
|
let child_handle = guard.child.clone();
|
||||||
if let Some(ref mut child) = *child_guard {
|
{
|
||||||
let _ = child.kill();
|
let mut child_guard = child_handle.lock().unwrap();
|
||||||
let _ = child.wait();
|
if let Some(ref mut child) = *child_guard {
|
||||||
|
let _ = child.kill();
|
||||||
|
let _ = child.wait();
|
||||||
|
}
|
||||||
|
*child_guard = None;
|
||||||
}
|
}
|
||||||
*child_guard = None;
|
|
||||||
guard.status = SessionStatus::Killed;
|
guard.status = SessionStatus::Killed;
|
||||||
drop(child_guard);
|
|
||||||
drop(guard);
|
drop(guard);
|
||||||
sessions.remove(session_id);
|
sessions.remove(session_id);
|
||||||
|
|
||||||
@ -545,14 +545,8 @@ impl Tool for PtyTool {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let offset = args
|
let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
|
||||||
.get("offset")
|
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(500) as usize;
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(500) as usize;
|
|
||||||
match self.pty_manager.read(session_id, offset, limit) {
|
match self.pty_manager.read(session_id, offset, limit) {
|
||||||
Ok(output) => Ok(ToolResult {
|
Ok(output) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use std::time::Duration;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::header::HeaderMap;
|
use reqwest::header::HeaderMap;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tokio::net::lookup_host;
|
||||||
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
@ -60,9 +61,34 @@ impl WebFetchTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn validate_resolved_host(&self, url: &str) -> Result<(), String> {
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
if host.parse::<std::net::IpAddr>().is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let port = extract_port(url)?;
|
||||||
|
let addrs = lookup_host((host.as_str(), port))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to resolve host '{}': {}", host, e))?;
|
||||||
|
|
||||||
|
for addr in addrs {
|
||||||
|
let ip = addr.ip();
|
||||||
|
if is_private_ip(&ip) {
|
||||||
|
return Err(format!(
|
||||||
|
"Blocked host '{}' because DNS resolved to local/private IP {}",
|
||||||
|
host, ip
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(self.timeout_secs))
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
@ -234,6 +260,35 @@ fn extract_host(url: &str) -> Result<String, String> {
|
|||||||
Ok(host)
|
Ok(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn extract_port(url: &str) -> Result<u16, String> {
|
||||||
|
let scheme = if url.starts_with("https://") {
|
||||||
|
"https"
|
||||||
|
} else if url.starts_with("http://") {
|
||||||
|
"http"
|
||||||
|
} else {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if let Some((_, port)) = authority.rsplit_once(':') {
|
||||||
|
port.parse::<u16>()
|
||||||
|
.map_err(|_| format!("Invalid URL port: {}", port))
|
||||||
|
} else if scheme == "https" {
|
||||||
|
Ok(443)
|
||||||
|
} else {
|
||||||
|
Ok(80)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn is_private_host(host: &str) -> bool {
|
fn is_private_host(host: &str) -> bool {
|
||||||
if host == "localhost" || host.ends_with(".localhost") {
|
if host == "localhost" || host.ends_with(".localhost") {
|
||||||
return true;
|
return true;
|
||||||
@ -248,19 +303,32 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
return match ip {
|
return is_private_ip(&ip);
|
||||||
std::net::IpAddr::V4(v4) => {
|
|
||||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
|
||||||
}
|
|
||||||
std::net::IpAddr::V6(v6) => {
|
|
||||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||||
|
match ip {
|
||||||
|
std::net::IpAddr::V4(v4) => {
|
||||||
|
v4.is_loopback()
|
||||||
|
|| v4.is_private()
|
||||||
|
|| v4.is_link_local()
|
||||||
|
|| v4.is_unspecified()
|
||||||
|
|| v4.is_broadcast()
|
||||||
|
|| v4.is_multicast()
|
||||||
|
}
|
||||||
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback()
|
||||||
|
|| v6.is_unspecified()
|
||||||
|
|| v6.is_multicast()
|
||||||
|
|| ((v6.segments()[0] & 0xfe00) == 0xfc00)
|
||||||
|
|| ((v6.segments()[0] & 0xffc0) == 0xfe80)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for WebFetchTool {
|
impl Tool for WebFetchTool {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
@ -311,6 +379,14 @@ impl Tool for WebFetchTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Err(e) = self.validate_resolved_host(&url).await {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
match self.fetch_content(&url).await {
|
match self.fetch_content(&url).await {
|
||||||
Ok(content) => Ok(ToolResult {
|
Ok(content) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
@ -357,6 +433,21 @@ mod tests {
|
|||||||
assert!(result.unwrap_err().contains("whitespace"));
|
assert!(result.unwrap_err().contains("whitespace"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_port_defaults_and_explicit() {
|
||||||
|
assert_eq!(extract_port("https://example.com/path").unwrap(), 443);
|
||||||
|
assert_eq!(extract_port("http://example.com/path").unwrap(), 80);
|
||||||
|
assert_eq!(extract_port("https://example.com:8443/path").unwrap(), 8443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_private_ipv6_ranges_are_blocked() {
|
||||||
|
assert!(is_private_ip(&"::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fc00::1".parse().unwrap()));
|
||||||
|
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
|
||||||
|
assert!(!is_private_ip(&"2606:4700:4700::1111".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_extract_text_simple() {
|
async fn test_extract_text_simple() {
|
||||||
let tool = test_tool();
|
let tool = test_tool();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user